#pragma once #include "./core/omp_config.h" #include "./utils/vector.h" #include "./utils/matrix.h" #include "./numerics/matmul.h" #include namespace neural_networks{ template struct Optimizer_Adam{ T learning_rate = T{1}; T current_learning_rate = learning_rate; T decay = T{0}; T epsilon = T{1e-7}; T beta_1 = T{0.9}; T beta_2 = T{0.999}; uint64_t iterations = 0; utils::Matrix weight_momentums_corrected; utils::Vector bias_momentums_corrected; utils::Matrix weight_cache_corrected; utils::Vector bias_cache_corrected; // Default Constructor Optimizer_Adam() = default; // Constructor explicit Optimizer_Adam(const T lr, const T lr_decay, const T epsilons, const T beta1, const T beta2): learning_rate(lr), current_learning_rate{lr}, decay(lr_decay), epsilon(epsilons), beta_1(beta1), beta_2(beta2) {} void pre_update_params(){ if(decay){ current_learning_rate = learning_rate * (T{1}/(T{1}+(decay*iterations))); //std::cout << current_learning_rate << std::endl; } } template void update_params(Layer& layer){ // if layer does not contain cache arrays, create them filled with zeros. if ((layer.weight_cache.rows() != layer.weights.rows()) || (layer.weight_cache.cols() != layer.weights.cols())){ layer.weight_momentums.resize(layer.weights.rows(), layer.weights.cols(), T{0}); layer.weight_cache.resize(layer.weights.rows(), layer.weights.cols(), T{0}); } if (layer.bias_cache.size() != layer.biases.size()){ layer.bias_momentums.resize(layer.biases.size(), T{0}); layer.bias_cache.resize(layer.biases.size(), T{0}); } // Update momentum with current gradients for (uint64_t i = 0; i < layer.weights.rows(); ++i){ for (uint64_t j = 0; j < layer.weights.cols(); ++j){ layer.weight_momentums(i,j) = (beta_1 * layer.weight_momentums(i,j)) + ((T{1} - beta_1) * layer.dweights(i,j)); } } for (uint64_t i = 0; i < layer.biases.size(); ++i){ layer.bias_momentums[i] = (beta_1 * layer.bias_momentums[i]) + ((T{1} - beta_1) * layer.dbiases[i]); } // Get corrected momentum // interation is 0 at first pass // and we need to start with 1 here weight_momentums_corrected.resize(layer.weights.rows(),layer.weights.cols()); // can be optimized out later for (uint64_t i = 0; i < layer.weights.rows(); ++i){ for (uint64_t j = 0; j < layer.weights.cols(); ++j){ weight_momentums_corrected(i,j) = layer.weight_momentums(i,j) / (T{1} - std::pow(beta_1, iterations+1)); } } bias_momentums_corrected.resize(layer.biases.size()); // can be optimized out later for (uint64_t i = 0; i < layer.biases.size(); ++i){ bias_momentums_corrected[i] = layer.bias_momentums[i] / (T{1} - std::pow(beta_1, iterations+1)); } // Update cache with squared current gradients for (uint64_t i = 0; i < layer.weights.rows(); ++i){ for (uint64_t j = 0; j < layer.weights.cols(); ++j){ layer.weight_cache(i,j) = (beta_2*layer.weight_cache(i,j)) + ((T{1}-beta_2) * (layer.dweights(i,j)*layer.dweights(i,j))); } } for (uint64_t i = 0; i < layer.biases.size(); ++i){ // can maybe be included when updating weights (saves time) layer.bias_cache[i] = (beta_2*layer.bias_cache[i]) + ((T{1}-beta_2) * (layer.dbiases[i]*layer.dbiases[i])); } // Get corrected cache // interation is 0 at first pass // and we need to start with 1 here weight_cache_corrected.resize(layer.weights.rows(),layer.weights.cols()); // can be optimized out later for (uint64_t i = 0; i < layer.weights.rows(); ++i){ for (uint64_t j = 0; j < layer.weights.cols(); ++j){ weight_cache_corrected(i,j) = layer.weight_cache(i,j) / (T{1} - std::pow(beta_2, iterations+1)); } } bias_cache_corrected.resize(layer.biases.size()); // can be optimized out later for (uint64_t i = 0; i < layer.biases.size(); ++i){ bias_cache_corrected[i] = layer.bias_cache[i] / (T{1} - std::pow(beta_2, iterations+1)); } // Vanilla SGD parameter update + normalization with squared rooted cache for (uint64_t i = 0; i < layer.weights.rows(); ++i){ for (uint64_t j = 0; j < layer.weights.cols(); ++j){ layer.weights(i,j) -= (current_learning_rate*weight_momentums_corrected(i,j)) / (std::sqrt(weight_cache_corrected(i,j)) + epsilon); } } for (uint64_t i = 0; i < layer.biases.size(); ++i){ layer.biases[i] -= (current_learning_rate*bias_momentums_corrected[i]) / (std::sqrt(bias_cache_corrected[i]) + epsilon); } } void post_update_params(){ iterations++; } }; } // end namespace neural_networks