#include "core/omp_config.h"

#include "utils/utils.h"
#include "numerics/numerics.h"
#include "decomp/decomp.h"

#include "modules/neural_networks/neural_networks.h"
#include "random/random.h"



//#include <iostream>
//#include <stdexcept>
//#include <chrono>





int main(int argc, char const *argv[])
{   

    uint64_t number_of_classes = 3;
    uint64_t number_of_samples = 150;
    uint64_t number_of_epochs = 500;

    utils::Mf X;
    utils::Mf X_test;
    utils::Matrix<int64_t> y;
    utils::Matrix<int64_t> y_test;
    float data_loss;
    float regularization_loss;
    float loss;
    float accuracy;

    utils::Vector<uint64_t> class_targets;
    utils::Vector<uint64_t> predections;


    // Create dataset
    neural_networks::create_spital_data<float, int64_t>(number_of_samples, number_of_classes, X, y);
    //neural_networks::create_vertical_data<float, int64_t>(number_of_samples, number_of_classes, X, y);

    // Create Dense layer with 2 input featues and 3 output values
    neural_networks::Dense_Layer<float> dense1(
                                            2, 16,  // input/output
                                            0.0f,  // weight L1
                                            5e-4f,  // weight L2
                                            0.0f,   // bias L1
                                            5e-4f    // bias L2
                                            );

    // Create ReLU activation (to be used with Dense layer)
    neural_networks::Activation_ReLU<float> activation1;
    neural_networks::Dropout_Layer<float> dropout1(0.1);



    // Create a second Dense layer with 16 inputs (as we take the vlaues from the last layer)
    // and 16 output values
    neural_networks::Dense_Layer<float> dense2(
                                            16, 16,  // input/output
                                            0.0f,  // weight L1
                                            5e-4f,  // weight L2
                                            0.0f,   // bias L1
                                            5e-4f    // bias L2
                                            );
    neural_networks::Activation_Softmax<float> activation2;


    // Create a second Dense layer with 3 inputs (as we take the vlaues from the last layer)
    // and 3 output values
    neural_networks::Dense_Layer<float> dense3(
                                            16, 16,  // input/output
                                            0.0f,  // weight L1
                                            5e-4f,  // weight L2
                                            0.0f,   // bias L1
                                            5e-4f    // bias L2
                                            );
    neural_networks::Activation_Sigmoid<float> activation3;

    neural_networks::Dense_Layer<float> dense4(
                                            16, number_of_classes,  // input/output
                                            0.0f,  // weight L1
                                            5e-4f,  // weight L2
                                            0.0f,   // bias L1
                                            5e-4f    // bias L2
                                            );

    // Create a Sfotmax classifier's combined loss and activation
    neural_networks::Activation_Softmax_Loss_CategoricalCrossentropy<float, int64_t> loss_activation;

    // Create optimizer
    //neural_networks::Optimizer_SGD<float> optimizer(1, 1e-3, 0.5);
    //neural_networks::Optimizer_Adagrad<float> optimizer(1, 1e-3, 1e-6);
    //neural_networks::Optimizer_RMSprop<float> optimizer(1, 1e-3, 1e-6, 0.9);
    neural_networks::Optimizer_Adam<float> optimizer(
                                                    0.05,      // Learning-rate
                                                    5e-5,   // Learning-rate decay
                                                    1e-6,   // epsilons
                                                    0.9,    // beta 1 
                                                    0.999   // beta 2
                                                    );
    


    // Train in loop
    for (uint64_t epoch = 0; epoch < number_of_epochs+1; ++epoch){

        // Perform a forward pass of our training data through this layer
        dense1.forward(X);
        activation1.forward(dense1.outputs);
        dropout1.forward(activation1.outputs);

        dense2.forward(dropout1.outputs);
        activation2.forward(dense2.outputs);

        dense3.forward(activation2.outputs);
        activation3.forward(dense3.outputs);

        dense4.forward(activation3.outputs);

        // Perform a foard pass through the activation/loss function
        // takes the output of the second dense layer here and returns loss
        data_loss = loss_activation.forward(dense4.outputs, y);

        // Calculate regularization penalty
        regularization_loss = loss_activation.loss.regularization_loss(dense1) + 
                            loss_activation.loss.regularization_loss(dense2) + 
                            loss_activation.loss.regularization_loss(dense3) + 
                            loss_activation.loss.regularization_loss(dense4);

        loss = data_loss + regularization_loss;

        // Calculate accuracy from output of activation2 and targets
        //predections = numerics::matargmax_row <int64_t, float>(loss_activation.outputs);
        predections = numerics::argmax_rowwise(loss_activation.outputs);

        if (y.cols() > 1){
            class_targets = numerics::argmax_rowwise(y);
        }else{
            class_targets = utils::veccast <uint64_t, int64_t> (y.get_col(0));
        }


        accuracy = numerics::mean( utils::veccast<float, uint64_t> (numerics::equal_elementwise_serial(predections, class_targets)));


        if (!(epoch%100)){
            std::cout << "epoch: " << epoch;
            std::cout << ", acc: " << accuracy;
            std::cout << ", loss: " << loss;
            std::cout << ", data_loss: " << data_loss;
            std::cout << ", regularization_loss: " << regularization_loss;
            std::cout << ", lr: " << optimizer.current_learning_rate;
            std::cout << std::endl;
        }

        // Backward pass
        loss_activation.backward(loss_activation.outputs, y);
        dense4.backward(loss_activation.dinputs);

        activation3.backward(dense4.dinputs);
        dense3.backward(activation3.dinputs);

        activation2.backward(dense3.dinputs);
        dense2.backward(activation2.dinputs);

        dropout1.backward(dense2.dinputs);
        activation1.backward(dropout1.dinputs);
        dense1.backward(activation1.dinputs);


        // Update weights and biases
        optimizer.pre_update_params();
        optimizer.update_params(dense1);
        optimizer.update_params(dense2);
        optimizer.update_params(dense3);
        optimizer.update_params(dense4);
        optimizer.post_update_params();

    }

    // Validate the model

    // Create dataset
    neural_networks::create_spital_data<float, int64_t>(100, number_of_classes, X_test, y_test);

    // Perform a forward pass of our training data through this layer
    dense1.forward(X_test);
    activation1.forward(dense1.outputs);
    //dropout1.forward(activation1.outputs);

    dense2.forward(activation1.outputs);
    activation2.forward(dense2.outputs);

    dense3.forward(activation2.outputs);
    activation3.forward(dense3.outputs);

    dense4.forward(activation3.outputs);

    // Perform a foard pass through the activation/loss function
    // takes the output of the second dense layer here and returns loss
    data_loss = loss_activation.forward(dense4.outputs, y_test);

    // Calculate regularization penalty
    regularization_loss = loss_activation.loss.regularization_loss(dense1) + 
                        loss_activation.loss.regularization_loss(dense2) + 
                        loss_activation.loss.regularization_loss(dense3) + 
                        loss_activation.loss.regularization_loss(dense4);

    loss = data_loss + regularization_loss;

    // Calculate accuracy from output of activation2 and targets
    predections = numerics::argmax_rowwise(loss_activation.outputs);




    if (y.cols() > 1){
        class_targets = numerics::argmax_rowwise(y_test);
    }else{
        class_targets = utils::veccast <uint64_t, int64_t> (y_test.get_col(0));
    }


    accuracy = numerics::mean( utils::veccast<float, uint64_t> (numerics::equal_elementwise_serial(predections, class_targets)));


    std::cout << "validation, acc: " << accuracy << ", loss: " << loss << std::endl;

    return 0;
}