// Copyright (C) 2012  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_GRAPH_LaBELER_Hh_
#define DLIB_GRAPH_LaBELER_Hh_

#include "graph_labeler_abstract.h"
#include "../matrix.h"
#include "../string.h"
#include <vector>
#include "find_max_factor_graph_potts.h"
#include "../svm/sparse_vector.h"
#include "../graph.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename vector_type 
        >
    class graph_labeler 
    {

    public:

        typedef std::vector<bool> label_type;
        typedef label_type result_type;

        graph_labeler()
        {
        }

        graph_labeler(
            const vector_type& edge_weights_,
            const vector_type& node_weights_
        ) : 
            edge_weights(edge_weights_),
            node_weights(node_weights_)
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(edge_weights.size() == 0 || min(edge_weights) >= 0,
                    "\t graph_labeler::graph_labeler()"
                    << "\n\t Invalid inputs were given to this function."
                    << "\n\t min(edge_weights): " << min(edge_weights)
                    << "\n\t this:              " << this
                    );
        }

        const vector_type& get_edge_weights (
        ) const { return edge_weights; }

        const vector_type& get_node_weights (
        ) const { return node_weights; }

        template <typename graph_type>
        void operator() (
            const graph_type& sample,
            std::vector<bool>& labels 
        ) const
        {
            // make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
            DLIB_ASSERT(graph_contains_length_one_cycle(sample) == false,
                        "\t void graph_labeler::operator()"
                        << "\n\t Invalid inputs were given to this function."
                        << "\n\t get_edge_weights().size(): " << get_edge_weights().size()
                        << "\n\t get_node_weights().size(): " << get_node_weights().size()
                        << "\n\t graph_contains_length_one_cycle(sample): " << graph_contains_length_one_cycle(sample)
                        << "\n\t this:                      " << this
                    );
            for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
            {
                if (is_matrix<vector_type>::value &&
                    is_matrix<typename graph_type::type>::value)
                {
                    // check that dot() is legal.
                    DLIB_ASSERT((unsigned long)get_node_weights().size() == (unsigned long)sample.node(i).data.size(),
                                "\t void graph_labeler::operator()"
                                << "\n\t The size of the node weight vector must match the one in the node."
                                << "\n\t get_node_weights().size():  " << get_node_weights().size()
                                << "\n\t sample.node(i).data.size(): " << sample.node(i).data.size()
                                << "\n\t i: " << i 
                                << "\n\t this:              " << this
                            );
                }

                for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
                {
                    if (is_matrix<vector_type>::value &&
                        is_matrix<typename graph_type::edge_type>::value)
                    {
                        // check that dot() is legal.
                        DLIB_ASSERT((unsigned long)get_edge_weights().size() == (unsigned long)sample.node(i).edge(n).size(),
                                    "\t void graph_labeler::operator()"
                                    << "\n\t The size of the edge weight vector must match the one in graph's edge."
                                    << "\n\t get_edge_weights().size():  " << get_edge_weights().size()
                                    << "\n\t sample.node(i).edge(n).size(): " << sample.node(i).edge(n).size()
                                    << "\n\t i: " << i 
                                    << "\n\t this:              " << this
                        );
                    }

                    DLIB_ASSERT(sample.node(i).edge(n).size() == 0 || min(sample.node(i).edge(n)) >= 0,
                                "\t void graph_labeler::operator()"
                                << "\n\t No edge vectors are allowed to have negative elements."
                                << "\n\t min(sample.node(i).edge(n)): " << min(sample.node(i).edge(n))
                                << "\n\t i:    " << i 
                                << "\n\t n:    " << n 
                                << "\n\t this: " << this
                    );
                }
            }
#endif


            graph<double,double>::kernel_1a g; 
            copy_graph_structure(sample, g);
            for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
            {
                g.node(i).data = dot(node_weights, sample.node(i).data);

                for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
                {
                    const unsigned long j = g.node(i).neighbor(n).index();
                    // Don't compute an edge weight more than once. 
                    if (i < j)
                    {
                        g.node(i).edge(n) = dot(edge_weights, sample.node(i).edge(n));
                    }
                }

            }

            labels.clear();
            std::vector<node_label> temp;
            find_max_factor_graph_potts(g, temp);
            for (unsigned long i = 0; i < temp.size(); ++i)
            {
                if (temp[i] != 0)
                    labels.push_back(true);
                else
                    labels.push_back(false);
            }
        }

        template <typename graph_type>
        std::vector<bool> operator() (
            const graph_type& sample 
        ) const
        {
            std::vector<bool> temp;
            (*this)(sample, temp);
            return temp;
        }

    private:

        vector_type edge_weights;
        vector_type node_weights;
    };


// ----------------------------------------------------------------------------------------

    template <
        typename vector_type 
        >
    void serialize (
        const graph_labeler<vector_type>& item,
        std::ostream& out
    )
    {
        int version = 1;
        serialize(version, out);
        serialize(item.get_edge_weights(), out);
        serialize(item.get_node_weights(), out);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename vector_type 
        >
    void deserialize (
        graph_labeler<vector_type>& item,
        std::istream& in 
    )
    {
        int version = 0;
        deserialize(version, in);
        if (version != 1)
        {
            throw dlib::serialization_error("While deserializing graph_labeler, found unexpected version number of " + 
                                            cast_to_string(version) + ".");
        }

        vector_type edge_weights, node_weights;
        deserialize(edge_weights, in);
        deserialize(node_weights, in);

        item = graph_labeler<vector_type>(edge_weights, node_weights);
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_GRAPH_LaBELER_Hh_