p.238 in NNFS
Sync public mirror / sync (push) Failing after 24s

This commit is contained in:
2025-12-23 14:47:40 +01:00
parent 22d6ea5fad
commit bd2edea8ef
56 changed files with 4446 additions and 147 deletions
@@ -27,41 +27,45 @@ namespace neural_networks{
utils::Matrix<Td> outputs;
utils::Matrix<Td> dinputs;
utils::Vector<Td> forward(const utils::Matrix<Td>& inputs, const utils::Matrix<Ti>& y_true){
Td forward(const utils::Matrix<Td>& inputs, const utils::Matrix<Ti>& y_true){
// Output layer's activation function
activation.forward(inputs);
// Set the output
outputs = activation.outputs;
// Calculate and return loss value
Td data_loss = loss.calculate(inputs, y_true);
Td data_loss = loss.calculate(outputs, y_true);
return data_loss;
}
void backward(const utils::Matrix<Td>& dvalues, const utils::Matrix<Ti>& y_true){
// Number of samples
const uint64_t samples = y_true.rows();
const uint64_t samples = dvalues.rows();
const uint64_t cols = dvalues.cols();
// If the labels are one-hot encoded,
// turn them into discrete values
dinputs = dvalues; // Copy
const uint64_t rows = dvalues.rows();
const uint64_t cols = dvalues.cols();
if ((dinputs.rows() != rows) || dinputs.cols() != cols){
dinputs.resize(rows, cols);
}
for (uint64_t i = 0; i < rows; ++i){
Td dot = Td{0};
for (uint64_t j = 0; j < cols; ++j){
dot += outputs(i,j) * dvalues(i,j);
if (y_true.cols() == 1){
for (uint64_t i = 0; i < samples; ++i){
uint64_t class_idx = static_cast<uint64_t>(y_true(i, 0));
dinputs(i, class_idx) -= Td{1};
}
for (uint64_t j = 0; j < cols; ++j){
dinputs(i,j) = outputs(i,j) * (dvalues(i,j) - dot);
} else{
// one-hot: dinputs = dvalues - y_true
for (uint64_t i = 0; i < samples; ++i){
for (uint64_t j = 0; j < cols; ++j){
dinputs(i,j) -= static_cast<Td>(y_true(i,j));
}
}
}
// divide by samples
for (uint64_t i = 0; i < samples; ++i){
for (uint64_t j = 0; j < cols; ++j){
dinputs(i,j) /= static_cast<Td>(samples);
}
}
}
};
@@ -54,6 +54,7 @@ namespace neural_networks{
void backward(const utils::Matrix<Td>& dvalues, const utils::Matrix<Ti>& y_true) override{
// Number of samples
const Td samples = static_cast<Td> (y_true.rows());
// Number of labels in every sample