// Copyright (C) 2009  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.

#include "tester.h"
#include <dlib/svm.h>
#include <dlib/rand.h>
#include <dlib/string.h>
#include <vector>
#include <sstream>
#include <ctime>

namespace  
{
    using namespace test;
    using namespace dlib;
    using namespace std;
    dlib::logger dlog("test.discriminant_pca");

    using dlib::equal;

    class discriminant_pca_tester : public tester
    {
        /*!
            WHAT THIS OBJECT REPRESENTS
                This object represents a unit test.  When it is constructed
                it adds itself into the testing framework.
        !*/
    public:
        discriminant_pca_tester (
        ) :
            tester (
                "test_discriminant_pca",       // the command line argument name for this test
                "Run tests on the discriminant_pca object.", // the command line argument description
                0                     // the number of command line arguments for this test
            )
        {
            thetime = 1407805946;// time(0);
        }

        time_t thetime;
        dlib::rand rnd;

        template <typename dpca_type>
        void test1()
        {

            dpca_type dpca, dpca2, dpca3;

            DLIB_TEST(dpca.in_vector_size() == 0);
            DLIB_TEST(dpca.between_class_weight() == 1);
            DLIB_TEST(dpca.within_class_weight() == 1);

            // generate a bunch of 4 dimensional vectors and compute the normal PCA transformation matrix
            // and just make sure it is a unitary matrix as it should be.
            for (int i = 0; i < 5000; ++i)
            {
                dpca.add_to_total_variance(randm(4,1,rnd));
                DLIB_TEST(dpca.in_vector_size() == 4);
            }


            matrix<double> mat = dpca.dpca_matrix(1);

            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));

            mat = dpca.dpca_matrix(0.9);
            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(mat.nr())));

            matrix<double> eig;
            dpca.dpca_matrix(mat, eig, 1);
            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
            // check that all eigen values are grater than 0
            DLIB_TEST(min(eig > 0) == 1);
            DLIB_TEST(eig.size() == mat.nr());
            DLIB_TEST(is_col_vector(eig));
            // check that the eigenvalues are sorted
            double last = eig(0);
            for (long i = 1; i < eig.size(); ++i)
            {
                DLIB_TEST(last >= eig(i));
            }

            {
                matrix<double> mat = dpca.dpca_matrix_of_size(4);
                DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
            }
            {
                matrix<double> mat = dpca.dpca_matrix_of_size(3);
                DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(3)));
            }


            dpca.set_within_class_weight(5);
            dpca.set_between_class_weight(6);

            DLIB_TEST(dpca.in_vector_size() == 4);
            DLIB_TEST(dpca.within_class_weight() == 5);
            DLIB_TEST(dpca.between_class_weight() == 6);


            ostringstream sout;
            serialize(dpca, sout);
            istringstream sin(sout.str());
            deserialize(dpca2, sin);

            // now make sure the serialization worked
            DLIB_TEST(dpca.in_vector_size() == 4);
            DLIB_TEST(dpca.within_class_weight() == 5);
            DLIB_TEST(dpca.between_class_weight() == 6);
            DLIB_TEST(dpca2.in_vector_size() == 4);
            DLIB_TEST(dpca2.within_class_weight() == 5);
            DLIB_TEST(dpca2.between_class_weight() == 6);
            DLIB_TEST(equal(dpca.dpca_matrix(), dpca2.dpca_matrix(), 1e-10));
            DLIB_TEST(equal(mat, dpca2.dpca_matrix(1), 1e-10));
            DLIB_TEST(equal(dpca.dpca_matrix(1), mat, 1e-10));

            // now test swap
            dpca2.swap(dpca3);
            DLIB_TEST(dpca2.in_vector_size() == 0);
            DLIB_TEST(dpca2.between_class_weight() == 1);
            DLIB_TEST(dpca2.within_class_weight() == 1);

            DLIB_TEST(dpca3.in_vector_size() == 4);
            DLIB_TEST(dpca3.within_class_weight() == 5);
            DLIB_TEST(dpca3.between_class_weight() == 6);
            DLIB_TEST(equal(mat, dpca3.dpca_matrix(1), 1e-10));
            DLIB_TEST((dpca3 + dpca3).in_vector_size() == 4);
            DLIB_TEST((dpca3 + dpca3).within_class_weight() == 5);
            DLIB_TEST((dpca3 + dpca3).between_class_weight() == 6);

            dpca.clear();

            DLIB_TEST(dpca.in_vector_size() == 0);
            DLIB_TEST(dpca.between_class_weight() == 1);
            DLIB_TEST(dpca.within_class_weight() == 1);
        }

        template <typename dpca_type>
        void test2()
        {
            dpca_type dpca, dpca2, dpca3;

            typename dpca_type::column_matrix samp1(4), samp2(4);

            for (int i = 0; i < 5000; ++i)
            {
                dpca.add_to_total_variance(randm(4,1,rnd));
                DLIB_TEST(dpca.in_vector_size() == 4);

                // do this to subtract out the variance along the 3rd axis 
                samp1 = 0,0,0,0;
                samp2 = 0,0,1,0;
                dpca.add_to_within_class_variance(samp1, samp2);
            }

            matrix<double> mat;

            dpca.set_within_class_weight(0);
            mat = dpca.dpca_matrix(1);
            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
            DLIB_TEST(dpca.dpca_matrix(1).nr() == 4);
            dpca.set_within_class_weight(1000);
            DLIB_TEST(dpca.dpca_matrix(1).nr() == 3);

            // the 3rd column of the transformation matrix should be all zero since
            // we killed all the variation long the 3rd axis
            DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),2))) < 1e-5);

            mat = dpca.dpca_matrix(1);
            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(3)));


        }

        template <typename dpca_type>
        void test3()
        {
            dpca_type dpca, dpca2, dpca3;

            typename dpca_type::column_matrix samp1(4), samp2(4);

            for (int i = 0; i < 5000; ++i)
            {
                dpca.add_to_total_variance(randm(4,1,rnd));
                DLIB_TEST(dpca.in_vector_size() == 4);

                // do this to subtract out the variance along the 3rd axis 
                samp1 = 0,0,0,0;
                samp2 = 0,0,1,0;
                dpca.add_to_within_class_variance(samp1, samp2);

                // do this to subtract out the variance along the 1st axis 
                samp1 = 0,0,0,0;
                samp2 = 1,0,0,0;
                dpca.add_to_within_class_variance(samp1, samp2);
            }

            matrix<double> mat;

            dpca.set_within_class_weight(0);
            mat = dpca.dpca_matrix(1);
            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
            DLIB_TEST(dpca.dpca_matrix(1).nr() == 4);
            dpca.set_within_class_weight(10000);
            DLIB_TEST(dpca.dpca_matrix(1).nr() == 2);

            // the 1st and 3rd columns of the transformation matrix should be all zero since
            // we killed all the variation long the 1st and 3rd axes
            DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),2))) < 1e-5);
            DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),0))) < 1e-5);

            mat = dpca.dpca_matrix(1);
            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(2)));


        }

        template <typename dpca_type>
        void test4()
        {
            dpca_type dpca, dpca2, dpca3;

            dpca_type add_dpca1, add_dpca2, add_dpca3, add_dpca4, sum_dpca;

            typename dpca_type::column_matrix samp1(4), samp2(4), samp;

            for (int i = 0; i < 5000; ++i)
            {
                samp = randm(4,1,rnd);
                dpca.add_to_total_variance(samp);
                add_dpca4.add_to_total_variance(samp);
                DLIB_TEST(dpca.in_vector_size() == 4);

                // do this to subtract out the variance along the 3rd axis 
                samp1 = 0,0,0,0;
                samp2 = 0,0,1,0;
                dpca.add_to_within_class_variance(samp1, samp2);
                add_dpca1.add_to_within_class_variance(samp1, samp2);

                // do this to subtract out the variance along the 1st axis 
                samp1 = 0,0,0,0;
                samp2 = 1,0,0,0;
                dpca.add_to_within_class_variance(samp1, samp2);
                add_dpca2.add_to_within_class_variance(samp1, samp2);

                // do this to add the variance along the 3rd axis back in
                samp1 = 0,0,0,0;
                samp2 = 0,0,1,0;
                dpca.add_to_between_class_variance(samp1, samp2);
                add_dpca3.add_to_between_class_variance(samp1, samp2);
            }

            matrix<double> mat, mat2;

            sum_dpca += dpca_type() + dpca_type() + add_dpca1 + dpca_type() + add_dpca2 + add_dpca3 + add_dpca4;
            dpca.set_within_class_weight(0);
            dpca.set_between_class_weight(0);
            sum_dpca.set_within_class_weight(0);
            sum_dpca.set_between_class_weight(0);
            mat = dpca.dpca_matrix(1);
            DLIB_TEST(equal(mat, sum_dpca.dpca_matrix(1), 1e-10));
            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
            DLIB_TEST(dpca.dpca_matrix(1).nr() == 4);
            dpca.set_within_class_weight(10000);
            sum_dpca.set_within_class_weight(10000);
            DLIB_TEST(dpca.dpca_matrix(1).nr() == 2);

            // the 1st and 3rd columns of the transformation matrix should be all zero since
            // we killed all the variation long the 1st and 3rd axes
            DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),2))) < 1e-4);
            DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),0))) < 1e-4);

            mat = dpca.dpca_matrix(1);
            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(2)));
            DLIB_TEST_MSG(equal(mat, mat2=sum_dpca.dpca_matrix(1), 1e-9), max(abs(mat - mat2)));


            // now add the variance back in using the between class weight
            dpca.set_within_class_weight(0);
            dpca.set_between_class_weight(1);
            mat = dpca.dpca_matrix(1);
            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
            DLIB_TEST(dpca.dpca_matrix(1).nr() == 4);
            dpca.set_within_class_weight (10000);
            dpca.set_between_class_weight(100000);
            sum_dpca.set_within_class_weight (10000);
            sum_dpca.set_between_class_weight(100000);
            DLIB_TEST(dpca.dpca_matrix(1).nr() == 3);

            // the first column should be all zeros
            DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),0))) < 1e-5);

            mat = dpca.dpca_matrix(1);
            DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(3)));
            DLIB_TEST(equal(mat, sum_dpca.dpca_matrix(1)));


        }

        template <typename dpca_type>
        void test5()
        {
            dpca_type dpca, dpca2;
            typename dpca_type::column_matrix samp1(4), samp2(4);

            samp1 = 0,0,0,0;
            samp2 = 0,0,1,0;

            for (int i = 0; i < 5000; ++i)
            {
                dpca.add_to_between_class_variance(samp1, samp2);
                dpca2.add_to_total_variance(samp1);
                dpca2.add_to_total_variance(samp2);
            }

            matrix<double> mat, eig;
            dpca.dpca_matrix(mat, eig, 1);

            // make sure the eigenvalues come out the way they should for this simple data set
            DLIB_TEST(eig.size() == 1);
            DLIB_TEST_MSG(abs(eig(0) - 1) < 1e-10, abs(eig(0) - 1));

            dpca2.dpca_matrix(mat, eig, 1);

            // make sure the eigenvalues come out the way they should for this simple data set
            DLIB_TEST(eig.size() == 1);
            DLIB_TEST(abs(eig(0) - 0.25) < 1e-10);

        }

        void perform_test (
        )
        {
            ++thetime;
            typedef matrix<double,0,1> sample_type;
            typedef discriminant_pca<sample_type> dpca_type;

            dlog << LINFO << "time seed: " << thetime;
            rnd.set_seed(cast_to_string(thetime));

            test5<dpca_type>();

            for (int i = 0; i < 10; ++i)
            {
                print_spinner();
                test1<dpca_type>();
                print_spinner();
                test2<dpca_type>();
                print_spinner();
                test3<dpca_type>();
                print_spinner();
                test4<dpca_type>();
            }
        }
    };

    // Create an instance of this object.  Doing this causes this test
    // to be automatically inserted into the testing framework whenever this cpp file
    // is linked into the project.  Note that since we are inside an unnamed-namespace 
    // we won't get any linker errors about the symbol a being defined multiple times. 
    discriminant_pca_tester a;

}