// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_ONE_VS_ONE_DECISION_FUnCTION_Hh_
#define DLIB_ONE_VS_ONE_DECISION_FUnCTION_Hh_

#include "one_vs_one_decision_function_abstract.h"

#include "../serialize.h"
#include "../type_safe_union.h"
#include <iostream>
#include <sstream>
#include <set>
#include <map>
#include "../any.h"
#include "../unordered_pair.h"
#include "null_df.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename one_vs_one_trainer,
        typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df,
        typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df,
        typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df,
        typename DF10 = null_df
        >
    class one_vs_one_decision_function
    {
    public:

        typedef typename one_vs_one_trainer::label_type result_type;
        typedef typename one_vs_one_trainer::sample_type sample_type;
        typedef typename one_vs_one_trainer::scalar_type scalar_type;
        typedef typename one_vs_one_trainer::mem_manager_type mem_manager_type;

        typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;

        one_vs_one_decision_function() :num_classes(0) {}

        explicit one_vs_one_decision_function(
            const binary_function_table& dfs_
        ) : dfs(dfs_)
        {
#ifdef ENABLE_ASSERTS
            {
                const std::vector<unordered_pair<result_type> > missing_pairs = find_missing_pairs(dfs_);
                if (missing_pairs.size() != 0)
                {
                    std::ostringstream sout;
                    for (unsigned long i = 0; i < missing_pairs.size(); ++i)
                    {
                        sout << "\t      (" << missing_pairs[i].first << ", " << missing_pairs[i].second << ")\n";
                    }
                    DLIB_ASSERT(missing_pairs.size() == 0, 
                        "\t void one_vs_one_decision_function::one_vs_one_decision_function()"
                        << "\n\t The supplied set of binary decision functions is incomplete."
                        << "\n\t this: " << this
                        << "\n\t Classifiers are missing for the following label pairs: \n" << sout.str()
                                );
                }
            }
#endif

            // figure out how many labels are covered by this set of binary decision functions
            std::set<result_type> labels;
            for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
            {
                labels.insert(i->first.first);
                labels.insert(i->first.second);
            }
            num_classes = labels.size();
        }

        const binary_function_table& get_binary_decision_functions (
        ) const
        {
            return dfs;
        }

        const std::vector<result_type> get_labels (
        ) const
        {
            std::set<result_type> labels;
            for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
            {
                labels.insert(i->first.first);
                labels.insert(i->first.second);
            }
            return std::vector<result_type>(labels.begin(), labels.end());
        }


        template <
            typename df1, typename df2, typename df3, typename df4, typename df5,
            typename df6, typename df7, typename df8, typename df9, typename df10
            >
        one_vs_one_decision_function (
            const one_vs_one_decision_function<one_vs_one_trainer, 
                                               df1, df2, df3, df4, df5,
                                               df6, df7, df8, df9, df10>& item
        ) : dfs(item.get_binary_decision_functions()), num_classes(item.number_of_classes()) {}

        unsigned long number_of_classes (
        ) const
        {
            return num_classes;
        }

        result_type operator() (
            const sample_type& sample
        ) const
        {
            DLIB_ASSERT(number_of_classes() != 0, 
                "\t void one_vs_one_decision_function::operator()"
                << "\n\t You can't make predictions with an empty decision function."
                << "\n\t this: " << this
                );

            std::map<result_type,int> votes;

            // run all the classifiers over the sample
            for(typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
            {
                const scalar_type score = i->second(sample);

                if (score > 0)
                    votes[i->first.first] += 1;
                else
                    votes[i->first.second] += 1;
            }

            // now figure out who had the most votes
            result_type best_label = result_type();
            int best_votes = 0;
            for (typename std::map<result_type,int>::iterator i = votes.begin(); i != votes.end(); ++i)
            {
                if (i->second > best_votes)
                {
                    best_votes = i->second;
                    best_label = i->first;
                }
            }

            return best_label;
        }



    private:
        binary_function_table dfs;
        unsigned long num_classes;

    };

// ----------------------------------------------------------------------------------------

    template <
        typename T,
        typename DF1, typename DF2, typename DF3,
        typename DF4, typename DF5, typename DF6,
        typename DF7, typename DF8, typename DF9,
        typename DF10 
        >
    void serialize(
        const one_vs_one_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item, 
        std::ostream& out
    )
    {
        try
        {
            type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
            typedef typename T::label_type result_type;
            typedef typename T::sample_type sample_type;
            typedef typename T::scalar_type scalar_type;
            typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;

            const unsigned long version = 1;
            serialize(version, out);

            const unsigned long size = item.get_binary_decision_functions().size();
            serialize(size, out);

            for(typename binary_function_table::const_iterator i = item.get_binary_decision_functions().begin(); 
                i != item.get_binary_decision_functions().end(); ++i)
            {
                serialize(i->first, out);

                if      (i->second.template contains<DF1>()) temp.template get<DF1>() = any_cast<DF1>(i->second);
                else if (i->second.template contains<DF2>()) temp.template get<DF2>() = any_cast<DF2>(i->second);
                else if (i->second.template contains<DF3>()) temp.template get<DF3>() = any_cast<DF3>(i->second);
                else if (i->second.template contains<DF4>()) temp.template get<DF4>() = any_cast<DF4>(i->second);
                else if (i->second.template contains<DF5>()) temp.template get<DF5>() = any_cast<DF5>(i->second);
                else if (i->second.template contains<DF6>()) temp.template get<DF6>() = any_cast<DF6>(i->second);
                else if (i->second.template contains<DF7>()) temp.template get<DF7>() = any_cast<DF7>(i->second);
                else if (i->second.template contains<DF8>()) temp.template get<DF8>() = any_cast<DF8>(i->second);
                else if (i->second.template contains<DF9>()) temp.template get<DF9>() = any_cast<DF9>(i->second);
                else if (i->second.template contains<DF10>()) temp.template get<DF10>() = any_cast<DF10>(i->second);
                else throw serialization_error("Can't serialize one_vs_one_decision_function.  Not all decision functions defined.");

                serialize(temp,out);
            }
        }
        catch (serialization_error& e)
        {
            throw serialization_error(e.info + "\n   while serializing an object of type one_vs_one_decision_function");
        }

    }

// ----------------------------------------------------------------------------------------

    namespace impl
    {
        template <typename sample_type, typename scalar_type>
        struct copy_to_df_helper
        {
            copy_to_df_helper(any_decision_function<sample_type, scalar_type>& target_) : target(target_) {}

            any_decision_function<sample_type, scalar_type>& target;

            template <typename T>
            void operator() (
                const T& item
            ) const
            {
                target = item;
            }
        };
    }

    template <
        typename T,
        typename DF1, typename DF2, typename DF3,
        typename DF4, typename DF5, typename DF6,
        typename DF7, typename DF8, typename DF9,
        typename DF10 
        >
    void deserialize(
        one_vs_one_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item, 
        std::istream& in 
    )
    {
        try
        {
            type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
            typedef typename T::label_type result_type;
            typedef typename T::sample_type sample_type;
            typedef typename T::scalar_type scalar_type;
            typedef impl::copy_to_df_helper<sample_type, scalar_type> copy_to;

            unsigned long version;
            deserialize(version, in);

            if (version != 1)
                throw serialization_error("Can't deserialize one_vs_one_decision_function.  Wrong version.");

            unsigned long size;
            deserialize(size, in);

            typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;
            binary_function_table dfs;

            unordered_pair<result_type> p;
            for (unsigned long i = 0; i < size; ++i)
            {
                deserialize(p, in);
                deserialize(temp, in);
                if (temp.template contains<null_df>())
                    throw serialization_error("A sub decision function of unknown type was encountered.");

                temp.apply_to_contents(copy_to(dfs[p]));
            }

            item = one_vs_one_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>(dfs);
        }
        catch (serialization_error& e)
        {
            throw serialization_error(e.info + "\n   while deserializing an object of type one_vs_one_decision_function");
        }
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_ONE_VS_ONE_DECISION_FUnCTION_Hh_