#pragma once #include "core/omp_config.h" #include "utils/vector.h" #include "utils/matrix.h" #include "utils/matcast.h" #include "numerics/clip.h" #include "numerics/log.h" #include "numerics/sub.h" #include "Loss.h" namespace neural_networks{ template struct Loss_BinaryCrossentropy : Loss { utils::Matrix dinputs; utils::Matrix y_true; utils::Vector forward(const utils::Matrix& y_pred, const utils::Matrix& y_true) override{ this->y_true = utils::matcast(y_true); // Clip daa to prevent division by 0 // Clip both sides not to drag mean towards any value utils::Matrix y_pred_clipped = numerics::clip(y_pred, Td{1e-7}, Td{1.0} - Td{1e-7}); // Calculate sample-wise loss utils::Matrix sample_losses_temp = numerics::log(numerics::sub(Td{1}, y_pred_clipped)); sample_losses_temp = numerics::mul(sample_losses_temp, numerics::sub(Td{1}, this->y_true)); sample_losses_temp = numerics::add(sample_losses_temp, numerics::mul(this->y_true, numerics::log(y_pred_clipped))); sample_losses_temp = numerics::neg(sample_losses_temp); utils::Vector sample_losses = numerics::mean_rowwise(sample_losses_temp); // Return losses return sample_losses; } void backward(const utils::Matrix& dvalues, const utils::Matrix& y_true) override{ /*std::cout << "BCE backward y_true: " << y_true.rows() << " x " << y_true.cols() << std::endl;*/ // Number of samples const Td samples = static_cast (this->y_true.rows()); // Number of outputs in every sample const Td outputs = static_cast (dvalues.cols()); // Clip data to prevent division by 0 // Clip both sides to not drag mean towards any value utils::Matrix clipped_dvalues = numerics::clip(dvalues, Td{1e-7}, Td{1.0} - Td{1e-7}); // Calculate gradient dinputs = numerics::div(numerics::neg(numerics::sub(numerics::div(this->y_true, clipped_dvalues), numerics::div(numerics::sub(Td{1}, this->y_true), numerics::sub(Td{1}, clipped_dvalues)))), outputs); // Normalize gradients dinputs = numerics::div(dinputs, samples); /* std::cout << "BCE backward dinputs: " << dinputs.rows() << " x " << dinputs.cols() << std::endl;*/ } }; } // end namespace neural_networks