// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.

#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <dlib/svm.h>
#include <dlib/matrix.h>

#include "tester.h"

namespace  
{
    using namespace test;
    using namespace dlib;
    using namespace std;

    logger dlog("test.kmeans");

    dlib::rand rnd;

    template <typename sample_type>
    void run_test(
        const std::vector<sample_type>& seed_centers
    )
    {
        print_spinner();


        sample_type samp;

        std::vector<sample_type> samples;


        for (unsigned long j = 0; j < seed_centers.size(); ++j)
        {
            for (int i = 0; i < 250; ++i)
            {
                samp = randm(seed_centers[0].size(),1,rnd) - 0.5;
                samples.push_back(samp + seed_centers[j]);
            }
        }

        randomize_samples(samples);

        {
            std::vector<sample_type> centers;
            pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel<sample_type>());

            find_clusters_using_kmeans(samples, centers);

            DLIB_TEST(centers.size() == seed_centers.size());

            std::vector<int> hits(centers.size(),0);
            for (unsigned long i = 0; i < samples.size(); ++i)
            {
                unsigned long best_idx = 0;
                double best_dist = 1e100;
                for (unsigned long j = 0; j < centers.size(); ++j)
                {
                    if (length(samples[i] - centers[j]) < best_dist)
                    {
                        best_dist = length(samples[i] - centers[j]);
                        best_idx = j;
                    }
                }
                hits[best_idx]++;
            }

            for (unsigned long i = 0; i < hits.size(); ++i)
            {
                DLIB_TEST(hits[i] == 250);
            }
        }
        {
            std::vector<sample_type> centers;
            pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel<sample_type>());

            find_clusters_using_angular_kmeans(samples, centers);

            DLIB_TEST(centers.size() == seed_centers.size());

            std::vector<int> hits(centers.size(),0);
            for (unsigned long i = 0; i < samples.size(); ++i)
            {
                unsigned long best_idx = 0;
                double best_dist = 1e100;
                for (unsigned long j = 0; j < centers.size(); ++j)
                {
                    if (length(samples[i] - centers[j]) < best_dist)
                    {
                        best_dist = length(samples[i] - centers[j]);
                        best_idx = j;
                    }
                }
                hits[best_idx]++;
            }

            for (unsigned long i = 0; i < hits.size(); ++i)
            {
                DLIB_TEST(hits[i] == 250);
            }
        }
    }


    class test_kmeans : public tester
    {
    public:
        test_kmeans (
        ) :
            tester ("test_kmeans",
                    "Runs tests on the find_clusters_using_kmeans() function.")
        {}

        void perform_test (
        )
        {
            {
                dlog << LINFO << "test dlib::vector<double,2>";
                typedef dlib::vector<double,2> sample_type;
                std::vector<sample_type> seed_centers;
                seed_centers.push_back(sample_type(10,10));
                seed_centers.push_back(sample_type(10,-10));
                seed_centers.push_back(sample_type(-10,10));
                seed_centers.push_back(sample_type(-10,-10));

                run_test(seed_centers);
            }
            {
                dlog << LINFO << "test dlib::vector<double,2>";
                typedef dlib::vector<float,2> sample_type;
                std::vector<sample_type> seed_centers;
                seed_centers.push_back(sample_type(10,10));
                seed_centers.push_back(sample_type(10,-10));
                seed_centers.push_back(sample_type(-10,10));
                seed_centers.push_back(sample_type(-10,-10));

                run_test(seed_centers);
            }
            {
                dlog << LINFO << "test dlib::matrix<double,3,1>";
                typedef dlib::matrix<double,3,1> sample_type;
                std::vector<sample_type> seed_centers;
                sample_type samp;
                samp = 10,10,0; seed_centers.push_back(samp);
                samp = -10,10,1; seed_centers.push_back(samp);
                samp = -10,-10,2; seed_centers.push_back(samp);

                run_test(seed_centers);
            }


        }
    } a;



}