// Copyright (C) 2013  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_
#define DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_

#include "find_k_nearest_neighbors_lsh_abstract.h"
#include "../threads.h"
#include "../lsh/hashes.h"
#include <vector>
#include <queue>
#include "sample_pair.h"
#include "edge_list_graphs.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    namespace impl
    {
        struct compare_sample_pair_with_distance 
        {
            inline bool operator() (const sample_pair& a, const sample_pair& b) const
            { 
                return a.distance() < b.distance();
            }
        };

        template <
            typename vector_type,
            typename hash_function_type
            >
        class hash_block
        {
        public:
            hash_block(
                const vector_type& samples_,
                const hash_function_type& hash_funct_,
                std::vector<typename hash_function_type::result_type>& hashes_
            ) : 
                samples(samples_),
                hash_funct(hash_funct_),
                hashes(hashes_)
            {}

            void operator() (long i) const
            {
                hashes[i] = hash_funct(samples[i]);
            }

            const vector_type& samples;
            const hash_function_type& hash_funct;
            std::vector<typename hash_function_type::result_type>& hashes;
        };

        template <
            typename vector_type,
            typename distance_function_type,
            typename hash_function_type,
            typename alloc
            >
        class scan_find_k_nearest_neighbors_lsh
        {
        public:
            scan_find_k_nearest_neighbors_lsh (
                const vector_type& samples_,
                const distance_function_type& dist_funct_,
                const hash_function_type& hash_funct_,
                const unsigned long k_,
                std::vector<sample_pair, alloc>& edges_,
                const unsigned long k_oversample_,
                const std::vector<typename hash_function_type::result_type>& hashes_
            ) :
                samples(samples_),
                dist_funct(dist_funct_),
                hash_funct(hash_funct_),
                k(k_),
                edges(edges_),
                k_oversample(k_oversample_),
                hashes(hashes_)
            {
                edges.clear();
                edges.reserve(samples.size()*k/2);
            }

            mutex m;
            const vector_type& samples;
            const distance_function_type& dist_funct;
            const hash_function_type& hash_funct;
            const unsigned long k;
            std::vector<sample_pair, alloc>& edges;
            const unsigned long k_oversample;
            const std::vector<typename hash_function_type::result_type>& hashes;

            void operator() (unsigned long i) const
            {
                const unsigned long k_hash = k*k_oversample;

                std::priority_queue<std::pair<unsigned long, unsigned long> > best_hashes;
                std::priority_queue<sample_pair, std::vector<sample_pair>, dlib::impl::compare_sample_pair_with_distance> best_samples;
                unsigned long worst_distance = std::numeric_limits<unsigned long>::max();
                // scan over the hashes and find the best matches for hashes[i]
                for (unsigned long j = 0; j < hashes.size(); ++j)
                {
                    if (i == j) 
                        continue;

                    const unsigned long dist = hash_funct.distance(hashes[i], hashes[j]);
                    if (dist < worst_distance || best_hashes.size() < k_hash)
                    {
                        if (best_hashes.size() >= k_hash)
                            best_hashes.pop();
                        best_hashes.push(std::make_pair(dist, j));
                        worst_distance = best_hashes.top().first;
                    }
                }

                // Now figure out which of the best_hashes are actually the k best matches
                // according to dist_funct()
                while (best_hashes.size() != 0)
                {
                    const unsigned long j = best_hashes.top().second;
                    best_hashes.pop();

                    const double dist = dist_funct(samples[i], samples[j]);
                    if (dist < std::numeric_limits<double>::infinity())
                    {
                        if (best_samples.size() >= k)
                            best_samples.pop();
                        best_samples.push(sample_pair(i,j,dist));
                    }
                }

                // Finally, now put the k best matches according to dist_funct() into edges
                auto_mutex lock(m);
                while (best_samples.size() != 0)
                {
                    edges.push_back(best_samples.top());
                    best_samples.pop();
                }
            }
        };

    }

// ----------------------------------------------------------------------------------------

    template <
        typename vector_type,
        typename hash_function_type
        >
    void hash_samples (
        const vector_type& samples,
        const hash_function_type& hash_funct,
        const unsigned long num_threads,
        std::vector<typename hash_function_type::result_type>& hashes
    )
    {
        hashes.resize(samples.size());

        typedef impl::hash_block<vector_type,hash_function_type> block_type;
        block_type temp(samples, hash_funct, hashes);
        parallel_for(num_threads, 0, samples.size(), temp);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename vector_type,
        typename distance_function_type,
        typename hash_function_type,
        typename alloc
        >
    void find_k_nearest_neighbors_lsh (
        const vector_type& samples,
        const distance_function_type& dist_funct,
        const hash_function_type& hash_funct,
        const unsigned long k,
        const unsigned long num_threads,
        std::vector<sample_pair, alloc>& edges,
        const unsigned long k_oversample = 20 
    )
    {
        // make sure requires clause is not broken
        DLIB_ASSERT(k > 0 && k_oversample > 0,
            "\t void find_k_nearest_neighbors_lsh()"
            << "\n\t Invalid inputs were given to this function."
            << "\n\t samples.size(): " << samples.size()
            << "\n\t k:              " << k 
            << "\n\t k_oversample:   " << k_oversample 
            );

        edges.clear();

        if (samples.size() <= 1)
        {
            return;
        }

        typedef typename hash_function_type::result_type hash_type;
        std::vector<hash_type> hashes;
        hash_samples(samples, hash_funct, num_threads, hashes);

        typedef impl::scan_find_k_nearest_neighbors_lsh<vector_type, distance_function_type,hash_function_type,alloc> scan_type;
        scan_type temp(samples, dist_funct, hash_funct, k, edges, k_oversample, hashes);
        parallel_for(num_threads, 0, hashes.size(), temp);

        remove_duplicate_edges(edges);
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_