// Copyright (C) 2007  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_MLp_KERNEL_1_
#define DLIB_MLp_KERNEL_1_

#include "../algs.h"
#include "../serialize.h"
#include "../matrix.h"
#include "../rand.h"
#include "mlp_kernel_abstract.h"
#include <ctime>
#include <sstream>

namespace dlib
{

    class mlp_kernel_1 : noncopyable
    {
        /*!
            INITIAL VALUE
                The network is initially initialized with random weights 

            CONVENTION
                - input_layer_nodes() == input_nodes
                - first_hidden_layer_nodes() == first_hidden_nodes
                - second_hidden_layer_nodes() == second_hidden_nodes
                - output_layer_nodes() == output_nodes
                - get_alpha == alpha
                - get_momentum() == momentum


                - if (second_hidden_nodes == 0) then
                    - for all i and j:
                        - w1(i,j) == the weight on the link from node i in the first hidden layer 
                          to input node j
                        - w3(i,j) == the weight on the link from node i in the output layer 
                          to first hidden layer node j
                    - for all i and j:
                        - w1m == the momentum terms for w1 from the previous update 
                        - w3m == the momentum terms for w3 from the previous update 
                - else
                    - for all i and j:
                        - w1(i,j) == the weight on the link from node i in the first hidden layer 
                          to input node j
                        - w2(i,j) == the weight on the link from node i in the second hidden layer 
                          to first hidden layer node j
                        - w3(i,j) == the weight on the link from node i in the output layer 
                          to second hidden layer node j
                    - for all i and j:
                        - w1m == the momentum terms for w1 from the previous update 
                        - w2m == the momentum terms for w2 from the previous update 
                        - w3m == the momentum terms for w3 from the previous update 
        !*/

    public:

        mlp_kernel_1 (
            long nodes_in_input_layer,
            long nodes_in_first_hidden_layer, 
            long nodes_in_second_hidden_layer = 0, 
            long nodes_in_output_layer = 1,
            double alpha_ = 0.1,
            double momentum_ = 0.8
        ) :
            input_nodes(nodes_in_input_layer),
            first_hidden_nodes(nodes_in_first_hidden_layer),
            second_hidden_nodes(nodes_in_second_hidden_layer),
            output_nodes(nodes_in_output_layer),
            alpha(alpha_),
            momentum(momentum_)
        {

            // seed the random number generator
            std::ostringstream sout;
            sout << time(0);
            rand_nums.set_seed(sout.str());

            w1.set_size(first_hidden_nodes+1, input_nodes+1);
            w1m.set_size(first_hidden_nodes+1, input_nodes+1);
            z.set_size(input_nodes+1,1);

            if (second_hidden_nodes != 0)
            {
                w2.set_size(second_hidden_nodes+1, first_hidden_nodes+1);
                w3.set_size(output_nodes, second_hidden_nodes+1);

                w2m.set_size(second_hidden_nodes+1, first_hidden_nodes+1);
                w3m.set_size(output_nodes, second_hidden_nodes+1);
            }
            else
            {
                w3.set_size(output_nodes, first_hidden_nodes+1);

                w3m.set_size(output_nodes, first_hidden_nodes+1);
            }

            reset();
        }

        virtual ~mlp_kernel_1 (
        ) {}

        void reset (
        ) 
        {
            // randomize the weights for the first layer
            for (long r = 0; r < w1.nr(); ++r)
                for (long c = 0; c < w1.nc(); ++c)
                    w1(r,c) = rand_nums.get_random_double();

            // randomize the weights for the second layer
            for (long r = 0; r < w2.nr(); ++r)
                for (long c = 0; c < w2.nc(); ++c)
                    w2(r,c) = rand_nums.get_random_double();

            // randomize the weights for the third layer
            for (long r = 0; r < w3.nr(); ++r)
                for (long c = 0; c < w3.nc(); ++c)
                    w3(r,c) = rand_nums.get_random_double();

            // zero all the momentum terms
            set_all_elements(w1m,0);
            set_all_elements(w2m,0);
            set_all_elements(w3m,0);
        }

        long input_layer_nodes (
        ) const { return input_nodes; }

        long first_hidden_layer_nodes (
        ) const { return first_hidden_nodes; }

        long second_hidden_layer_nodes (
        ) const { return second_hidden_nodes; }

        long output_layer_nodes (
        ) const { return output_nodes; }

        double get_alpha (
        ) const { return alpha; }

        double get_momentum (
        ) const { return momentum; }

        template <typename EXP>
        const matrix<double> operator() (
            const matrix_exp<EXP>& in 
        ) const
        {
            for (long i = 0; i < in.nr(); ++i)
                z(i) = in(i);
            // insert the bias 
            z(z.nr()-1) = -1;

            tmp1 = sigmoid(w1*z);
            // insert the bias 
            tmp1(tmp1.nr()-1) = -1;

            if (second_hidden_nodes == 0)
            {
                return sigmoid(w3*tmp1);
            }
            else
            {
                tmp2 = sigmoid(w2*tmp1);
                // insert the bias 
                tmp2(tmp2.nr()-1) = -1;

                return sigmoid(w3*tmp2);
            }
        }

        template <typename EXP1, typename EXP2>
        void train (
            const matrix_exp<EXP1>& example_in,
            const matrix_exp<EXP2>& example_out 
        )
        {
            for (long i = 0; i < example_in.nr(); ++i)
                z(i) = example_in(i);
            // insert the bias 
            z(z.nr()-1) = -1;

            tmp1 = sigmoid(w1*z);
            // insert the bias 
            tmp1(tmp1.nr()-1) = -1;


            if (second_hidden_nodes == 0)
            {
                o = sigmoid(w3*tmp1);

                // now compute the errors and propagate them backwards though the network
                e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o);
                e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w3)*e3 );

                // compute the new weight updates
                w3m = alpha * e3*trans(tmp1) + w3m*momentum;
                w1m = alpha * e1*trans(z)    + w1m*momentum;

                // now update the weights
                w1 += w1m;
                w3 += w3m;
            }
            else
            {
                tmp2 = sigmoid(w2*tmp1);
                // insert the bias 
                tmp2(tmp2.nr()-1) = -1;

                o = sigmoid(w3*tmp2);


                // now compute the errors and propagate them backwards though the network
                e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o);
                e2 = pointwise_multiply(tmp2, uniform_matrix<double>(second_hidden_nodes+1,1,1.0) - tmp2, trans(w3)*e3 );
                e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w2)*e2 );

                // compute the new weight updates
                w3m = alpha * e3*trans(tmp2) + w3m*momentum;
                w2m = alpha * e2*trans(tmp1) + w2m*momentum;
                w1m = alpha * e1*trans(z)    + w1m*momentum;

                // now update the weights
                w1 += w1m;
                w2 += w2m;
                w3 += w3m;
            }
        }

        template <typename EXP>
        void train (
            const matrix_exp<EXP>& example_in,
            double example_out
        )
        {
            matrix<double,1,1> e_out;
            e_out(0) = example_out;
            train(example_in,e_out);
        }

        double get_average_change (
        ) const
        {
            // sum up all the weight changes
            double delta = sum(abs(w1m)) + sum(abs(w2m)) + sum(abs(w3m));

            // divide by the number of weights
            delta /=  w1m.nr()*w1m.nc() + 
                w2m.nr()*w2m.nc() + 
                w3m.nr()*w3m.nc();

            return delta;
        }

        void swap (
            mlp_kernel_1& item
        )
        {
            exchange(input_nodes, item.input_nodes);
            exchange(first_hidden_nodes, item.first_hidden_nodes);
            exchange(second_hidden_nodes, item.second_hidden_nodes);
            exchange(output_nodes, item.output_nodes);
            exchange(alpha, item.alpha);
            exchange(momentum, item.momentum);

            w1.swap(item.w1);
            w2.swap(item.w2);
            w3.swap(item.w3);

            w1m.swap(item.w1m);
            w2m.swap(item.w2m);
            w3m.swap(item.w3m);

            // even swap the temporary matrices because this may ultimately result in 
            // fewer calls to new and delete.
            e1.swap(item.e1);
            e2.swap(item.e2);
            e3.swap(item.e3);
            z.swap(item.z);
            tmp1.swap(item.tmp1);
            tmp2.swap(item.tmp2);
            o.swap(item.o);
        }


        friend void serialize (
            const mlp_kernel_1& item, 
            std::ostream& out
        );

        friend void deserialize (
            mlp_kernel_1& item, 
            std::istream& in
        );

    private:

        long input_nodes;
        long first_hidden_nodes;
        long second_hidden_nodes;
        long output_nodes;
        double alpha;
        double momentum;

        matrix<double> w1;
        matrix<double> w2;
        matrix<double> w3;

        matrix<double> w1m;
        matrix<double> w2m;
        matrix<double> w3m;


        rand rand_nums;

        // temporary storage
        mutable matrix<double> e1, e2, e3;
        mutable matrix<double> z, tmp1, tmp2, o;
    };   

    inline void swap (
        mlp_kernel_1& a, 
        mlp_kernel_1& b 
    ) { a.swap(b); }   

// ----------------------------------------------------------------------------------------

    inline void serialize (
        const mlp_kernel_1& item, 
        std::ostream& out
    )   
    {
        try
        {
            serialize(item.input_nodes, out);
            serialize(item.first_hidden_nodes, out);
            serialize(item.second_hidden_nodes, out);
            serialize(item.output_nodes, out);
            serialize(item.alpha, out);
            serialize(item.momentum, out);

            serialize(item.w1, out);
            serialize(item.w2, out);
            serialize(item.w3, out);

            serialize(item.w1m, out);
            serialize(item.w2m, out);
            serialize(item.w3m, out);
        }
        catch (serialization_error& e)
        { 
            throw serialization_error(e.info + "\n   while serializing object of type mlp_kernel_1"); 
        }
    }

    inline void deserialize (
        mlp_kernel_1& item, 
        std::istream& in
    )   
    {
        try
        {
            deserialize(item.input_nodes, in);
            deserialize(item.first_hidden_nodes, in);
            deserialize(item.second_hidden_nodes, in);
            deserialize(item.output_nodes, in);
            deserialize(item.alpha, in);
            deserialize(item.momentum, in);

            deserialize(item.w1, in);
            deserialize(item.w2, in);
            deserialize(item.w3, in);

            deserialize(item.w1m, in);
            deserialize(item.w2m, in);
            deserialize(item.w3m, in);

            item.z.set_size(item.input_nodes+1,1);
        }
        catch (serialization_error& e)
        { 
            // give item a reasonable value since the deserialization failed
            mlp_kernel_1(1,1).swap(item);
            throw serialization_error(e.info + "\n   while deserializing object of type mlp_kernel_1"); 
        }
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_MLp_KERNEL_1_