// Copyright (C) 2012  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_
#define DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_

#include "find_max_factor_graph_potts_abstract.h"
#include "../matrix.h"
#include "min_cut.h"
#include "general_potts_problem.h"
#include "../algs.h"
#include "../graph_utils.h"
#include "../array2d.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    namespace impl
    {

        template <
            typename potts_problem,
            typename T = void
            >
        class flows_container
        {
            /*
                This object notionally represents a matrix of flow values.  It's
                overloaded to represent this matrix efficiently though.  In this case
                it represents the matrix using a sparse representation.
            */

            typedef typename potts_problem::value_type edge_type;
            std::vector<std::vector<edge_type> > flows;
        public:

            void setup(
                const potts_problem& p
            )
            {
                flows.resize(p.number_of_nodes());
                for (unsigned long i = 0; i < flows.size(); ++i)
                {
                    flows[i].resize(p.number_of_neighbors(i));
                }
            }

            edge_type& operator() (
                const long r,
                const long c
            ) { return flows[r][c]; }

            const edge_type& operator() (
                const long r,
                const long c
            ) const { return flows[r][c]; }
        };

// ----------------------------------------------------------------------------------------

        template <
            typename potts_problem
            >
        class flows_container<potts_problem, 
                              typename enable_if_c<potts_problem::max_number_of_neighbors!=0>::type>
        {
            /*
                This object notionally represents a matrix of flow values.  It's
                overloaded to represent this matrix efficiently though.  In this case
                it represents the matrix using a dense representation.

            */
            typedef typename potts_problem::value_type edge_type;
            const static unsigned long max_number_of_neighbors = potts_problem::max_number_of_neighbors;
            matrix<edge_type,0,max_number_of_neighbors> flows;
        public:

            void setup(
                const potts_problem& p
            )
            {
                flows.set_size(p.number_of_nodes(), max_number_of_neighbors);
            }

            edge_type& operator() (
                const long r,
                const long c
            ) { return flows(r,c); }

            const edge_type& operator() (
                const long r,
                const long c
            ) const { return flows(r,c); }
        };

// ----------------------------------------------------------------------------------------

        template <
            typename potts_problem 
            >
        class potts_flow_graph 
        {
        public:
            typedef typename potts_problem::value_type edge_type;
        private:
            /*!
                This is a utility class used by dlib::min_cut to convert a potts_problem 
                into the kind of flow graph expected by the min_cut object's main block
                of code.

                Within this object, we will use the convention that one past 
                potts_problem::number_of_nodes() is the source node and two past is 
                the sink node.
            !*/

            potts_problem& g;

            // flows(i,j) == the flow from node id i to it's jth neighbor
            flows_container<potts_problem> flows;
            // source_flows(i,0) == flow from source to node i, 
            // source_flows(i,1) == flow from node i to source
            matrix<edge_type,0,2> source_flows;

            // sink_flows(i,0) == flow from sink to node i, 
            // sink_flows(i,1) == flow from node i to sink
            matrix<edge_type,0,2> sink_flows;

            node_label source_label, sink_label;
        public:

            potts_flow_graph(
                potts_problem& g_
            ) : g(g_)
            {
                flows.setup(g);

                source_flows.set_size(g.number_of_nodes(), 2);
                sink_flows.set_size(g.number_of_nodes(), 2);
                source_flows = 0;
                sink_flows = 0;

                source_label = FREE_NODE;
                sink_label = FREE_NODE;

                // setup flows based on factor potentials
                for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
                {
                    const edge_type temp = g.factor_value(i);
                    if (temp < 0)
                        source_flows(i,0) = -temp;
                    else
                        sink_flows(i,1) = temp;

                    for (unsigned long j = 0; j < g.number_of_neighbors(i); ++j)
                    {
                        flows(i,j) = g.factor_value_disagreement(i, g.get_neighbor(i,j));
                    }
                }
            }

            class out_edge_iterator
            {
                friend class potts_flow_graph;
                unsigned long idx; // base node idx
                unsigned long cnt; // count over the neighbors of idx
            public:

                out_edge_iterator(
                ):idx(0),cnt(0){}

                out_edge_iterator(
                    unsigned long idx_,
                    unsigned long cnt_
                ):idx(idx_),cnt(cnt_)
                {}

                bool operator!= (
                    const out_edge_iterator& item
                ) const { return cnt != item.cnt; }

                out_edge_iterator& operator++(
                )
                {
                    ++cnt;
                    return *this;
                }
            };

            class in_edge_iterator
            {
                friend class potts_flow_graph;
                unsigned long idx; // base node idx
                unsigned long cnt; // count over the neighbors of idx
            public:

                in_edge_iterator(
                ):idx(0),cnt(0)  
                {}


                in_edge_iterator(
                    unsigned long idx_,
                    unsigned long cnt_
                ):idx(idx_),cnt(cnt_)
                {}

                bool operator!= (
                    const in_edge_iterator& item
                ) const { return cnt != item.cnt; }

                in_edge_iterator& operator++(
                )
                {
                    ++cnt;
                    return *this;
                }
            };

            unsigned long number_of_nodes (
            ) const { return g.number_of_nodes() + 2; }

            out_edge_iterator out_begin(
                const unsigned long& it
            ) const { return out_edge_iterator(it, 0); }

            in_edge_iterator in_begin(
                const unsigned long& it
            ) const { return in_edge_iterator(it, 0); }

            out_edge_iterator out_end(
                const unsigned long& it
            ) const 
            { 
                if (it >= g.number_of_nodes())
                    return out_edge_iterator(it, g.number_of_nodes()); 
                else
                    return out_edge_iterator(it, g.number_of_neighbors(it)+2); 
            }

            in_edge_iterator in_end(
                const unsigned long& it
            ) const 
            { 
                if (it >= g.number_of_nodes())
                    return in_edge_iterator(it, g.number_of_nodes()); 
                else
                    return in_edge_iterator(it, g.number_of_neighbors(it)+2); 
            }


            template <typename iterator_type>
            unsigned long node_id (
                const iterator_type& it
            ) const 
            { 
                // if this isn't an iterator over the source or sink nodes
                if (it.idx < g.number_of_nodes())
                {
                    const unsigned long num = g.number_of_neighbors(it.idx);
                    if (it.cnt < num)
                        return g.get_neighbor(it.idx, it.cnt); 
                    else if (it.cnt == num)
                        return g.number_of_nodes();
                    else
                        return g.number_of_nodes()+1;
                }
                else
                {
                    return it.cnt;
                }
            }


            edge_type get_flow (
                const unsigned long& it1,     
                const unsigned long& it2
            ) const
            {
                if (it1 >= g.number_of_nodes())
                {
                    // if it1 is the source
                    if (it1 == g.number_of_nodes())
                    {
                        return source_flows(it2,0);
                    }
                    else // if it1 is the sink
                    {
                        return sink_flows(it2,0);
                    }
                }
                else if (it2 >= g.number_of_nodes())
                {
                    // if it2 is the source
                    if (it2 == g.number_of_nodes())
                    {
                        return source_flows(it1,1);
                    }
                    else // if it2 is the sink
                    {
                        return sink_flows(it1,1);
                    }
                }
                else
                {
                    return flows(it1, g.get_neighbor_idx(it1, it2));
                }

            }

            edge_type get_flow (
                const out_edge_iterator& it
            ) const
            {
                if (it.idx < g.number_of_nodes())
                {
                    const unsigned long num = g.number_of_neighbors(it.idx);
                    if (it.cnt < num)
                        return flows(it.idx, it.cnt);
                    else if (it.cnt == num)
                        return source_flows(it.idx,1);
                    else
                        return sink_flows(it.idx,1);
                }
                else
                {
                    // if it.idx is the source
                    if (it.idx == g.number_of_nodes())
                    {
                        return source_flows(it.cnt,0);
                    }
                    else // if it.idx is the sink
                    {
                        return sink_flows(it.cnt,0);
                    }
                }
            }

            edge_type get_flow (
                const in_edge_iterator& it
            ) const
            {
                return get_flow(node_id(it), it.idx); 
            }

            void adjust_flow (
                const unsigned long& it1,     
                const unsigned long& it2,     
                const edge_type& value
            )
            {
                if (it1 >= g.number_of_nodes())
                {
                    // if it1 is the source
                    if (it1 == g.number_of_nodes())
                    {
                        source_flows(it2,0) += value;
                        source_flows(it2,1) -= value;
                    }
                    else // if it1 is the sink
                    {
                        sink_flows(it2,0) += value;
                        sink_flows(it2,1) -= value;
                    }
                }
                else if (it2 >= g.number_of_nodes())
                {
                    // if it2 is the source
                    if (it2 == g.number_of_nodes())
                    {
                        source_flows(it1,1) += value;
                        source_flows(it1,0) -= value;
                    }
                    else // if it2 is the sink
                    {
                        sink_flows(it1,1) += value;
                        sink_flows(it1,0) -= value;
                    }
                }
                else
                {
                    flows(it1, g.get_neighbor_idx(it1, it2)) += value;
                    flows(it2, g.get_neighbor_idx(it2, it1)) -= value;
                }

            }

            void set_label (
                const unsigned long& it,
                node_label value
            )
            {
                if (it < g.number_of_nodes())
                    g.set_label(it, value);
                else if (it == g.number_of_nodes())
                    source_label = value;
                else 
                    sink_label = value;
            }

            node_label get_label (
                const unsigned long& it
            ) const
            {
                if (it < g.number_of_nodes())
                    return g.get_label(it);
                if (it == g.number_of_nodes())
                    return source_label;
                else
                    return sink_label;
            }

        };

// ----------------------------------------------------------------------------------------

        template <
            typename label_image_type,
            typename image_potts_model
            >
        class potts_grid_problem 
        {
            label_image_type& label_img;
            long nc;
            long num_nodes;
            unsigned char* labels;
            const image_potts_model& model;

        public:
            const static unsigned long max_number_of_neighbors = 4;

            potts_grid_problem (
                label_image_type& label_img_,
                const image_potts_model& image_potts_model_
            ) : 
                label_img(label_img_),
                model(image_potts_model_)
            {
                num_nodes = model.nr()*model.nc();
                nc = model.nc();
                labels = &label_img[0][0];
            }

            unsigned long number_of_nodes (
            ) const { return num_nodes; }

            unsigned long number_of_neighbors (
                unsigned long 
            ) const 
            { 
                return 4;
            }

            unsigned long get_neighbor_idx (
                long node_id1,
                long node_id2
            ) const
            {
                long diff = node_id2-node_id1;
                if (diff > nc)
                    diff -= (long)number_of_nodes();
                else if (diff < -nc)
                    diff += (long)number_of_nodes();

                if (diff == 1) 
                    return 0;
                else if (diff == -1)
                    return 1;
                else if (diff == nc)
                    return 2;
                else
                    return 3;
            }

            unsigned long get_neighbor (
                long node_id,
                long idx
            ) const
            {
                switch(idx)
                {
                    case 0: 
                        {
                            long temp = node_id+1;
                            if (temp < (long)number_of_nodes())
                                return temp;
                            else
                                return temp - (long)number_of_nodes();
                        }
                    case 1: 
                        {
                            long temp = node_id-1;
                            if (node_id >= 1)
                                return temp;
                            else
                                return temp + (long)number_of_nodes();
                        }
                    case 2: 
                        {
                            long temp = node_id+nc;
                            if (temp < (long)number_of_nodes())
                                return temp;
                            else
                                return temp - (long)number_of_nodes();
                        }
                    case 3: 
                        {
                            long temp = node_id-nc;
                            if (node_id >= nc)
                                return temp;
                            else
                                return temp + (long)number_of_nodes();
                        }
                }
                return 0;
            }

            void set_label (
                const unsigned long& idx,
                node_label value
            )
            {
                *(labels+idx) = value;
            }

            node_label get_label (
                const unsigned long& idx
            ) const
            {
                return *(labels+idx);
            }

            typedef typename image_potts_model::value_type value_type;

            value_type factor_value (unsigned long idx) const
            {
                return model.factor_value(idx);
            }

            value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const
            {
                return model.factor_value_disagreement(idx1,idx2);
            }

        };

    }

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

    template <
        typename potts_model
        >
    typename potts_model::value_type potts_model_score (
        const potts_model& prob
    )
    {
#ifdef ENABLE_ASSERTS
        for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
        {
            for (unsigned long jj = 0; jj < prob.number_of_neighbors(i); ++jj)
            {
                unsigned long j = prob.get_neighbor(i,jj);
                DLIB_ASSERT(prob.factor_value_disagreement(i,j) >= 0,
                    "\t value_type potts_model_score(prob)"
                    << "\n\t Invalid inputs were given to this function." 
                    << "\n\t i: " << i 
                    << "\n\t j: " << j 
                    << "\n\t prob.factor_value_disagreement(i,j): " << prob.factor_value_disagreement(i,j)
                    );
                DLIB_ASSERT(prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i),
                    "\t value_type potts_model_score(prob)"
                    << "\n\t Invalid inputs were given to this function." 
                    << "\n\t i: " << i 
                    << "\n\t j: " << j 
                    << "\n\t prob.factor_value_disagreement(i,j): " << prob.factor_value_disagreement(i,j)
                    << "\n\t prob.factor_value_disagreement(j,i): " << prob.factor_value_disagreement(j,i)
                    );
            }
        }
#endif 

        typename potts_model::value_type score = 0;
        for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
        {
            const bool label = (prob.get_label(i)!=0);
            if (label)
                score += prob.factor_value(i);
        }

        for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
        {
            for (unsigned long n = 0; n < prob.number_of_neighbors(i); ++n)
            {
                const unsigned long idx2 = prob.get_neighbor(i,n);
                const bool label_i = (prob.get_label(i)!=0);
                const bool label_idx2 = (prob.get_label(idx2)!=0);
                if (label_i != label_idx2 && i < idx2)
                    score -= prob.factor_value_disagreement(i, idx2);
            }
        }

        return score;
    }

// ----------------------------------------------------------------------------------------

    template <
        typename graph_type 
        >
    typename graph_type::edge_type potts_model_score (
        const graph_type& g,
        const std::vector<node_label>& labels
    )
    {
        DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
                    "\t edge_type potts_model_score(g,labels)"
                    << "\n\t Invalid inputs were given to this function." 
                    );
        typedef typename graph_type::edge_type edge_type;
        typedef typename graph_type::type type;

        // The edges and node's have to use the same type to represent factor weights!
        COMPILE_TIME_ASSERT((is_same_type<edge_type, type>::value == true));

#ifdef ENABLE_ASSERTS
        for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
        {
            for (unsigned long jj = 0; jj < g.node(i).number_of_neighbors(); ++jj)
            {
                unsigned long j = g.node(i).neighbor(jj).index();
                DLIB_ASSERT(edge(g,i,j) >= 0,
                    "\t edge_type potts_model_score(g,labels)"
                    << "\n\t Invalid inputs were given to this function." 
                    << "\n\t i: " << i 
                    << "\n\t j: " << j 
                    << "\n\t edge(g,i,j): " << edge(g,i,j)
                    );
            }
        }
#endif 

        typename graph_type::edge_type score = 0;
        for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
        {
            const bool label = (labels[i]!=0);
            if (label)
                score += g.node(i).data;
        }

        for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
        {
            for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
            {
                const unsigned long idx2 = g.node(i).neighbor(n).index();
                const bool label_i = (labels[i]!=0);
                const bool label_idx2 = (labels[idx2]!=0);
                if (label_i != label_idx2 && i < idx2)
                    score -= g.node(i).edge(n);
            }
        }

        return score;
    }

// ----------------------------------------------------------------------------------------

    template <
        typename potts_grid_problem,
        typename mem_manager
        >
    typename potts_grid_problem::value_type potts_model_score (
        const potts_grid_problem& prob,
        const array2d<node_label,mem_manager>& labels
    )
    {
        DLIB_ASSERT(prob.nr() == labels.nr() && prob.nc() == labels.nc(),
            "\t value_type potts_model_score(prob,labels)"
            << "\n\t Invalid inputs were given to this function." 
            << "\n\t prob.nr(): " << labels.nr()
            << "\n\t prob.nc(): " << labels.nc()
            );
        typedef array2d<node_label,mem_manager> image_type;
        // This const_cast is ok because the model object won't actually modify labels
        dlib::impl::potts_grid_problem<image_type,potts_grid_problem> model(const_cast<image_type&>(labels),prob);
        return potts_model_score(model);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename potts_model
        >
    void find_max_factor_graph_potts (
        potts_model& prob
    )
    {
#ifdef ENABLE_ASSERTS
        for (unsigned long node_i = 0; node_i < prob.number_of_nodes(); ++node_i)
        {
            for (unsigned long jj = 0; jj < prob.number_of_neighbors(node_i); ++jj)
            {
                unsigned long node_j = prob.get_neighbor(node_i,jj);
                DLIB_ASSERT(prob.get_neighbor_idx(node_j,node_i) < prob.number_of_neighbors(node_j),
                    "\t void find_max_factor_graph_potts(prob)"
                    << "\n\t The supplied potts problem defines an invalid graph." 
                    << "\n\t node_i: " << node_i 
                    << "\n\t node_j: " << node_j 
                    << "\n\t prob.get_neighbor_idx(node_j,node_i): " << prob.get_neighbor_idx(node_j,node_i)
                    << "\n\t prob.number_of_neighbors(node_j):     " << prob.number_of_neighbors(node_j)
                            );

                DLIB_ASSERT(prob.get_neighbor_idx(node_i,prob.get_neighbor(node_i,jj)) == jj,
                    "\t void find_max_factor_graph_potts(prob)"
                    << "\n\t The get_neighbor_idx() and get_neighbor() functions must be inverses of each other." 
                    << "\n\t node_i: " << node_i 
                    << "\n\t jj:     " << jj
                    << "\n\t prob.get_neighbor(node_i,jj): " << prob.get_neighbor(node_i,jj)
                    << "\n\t prob.get_neighbor_idx(node_i,prob.get_neighbor(node_i,jj)): " << prob.get_neighbor_idx(node_i,node_j)
                            );

                DLIB_ASSERT(prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i))==node_i,
                    "\t void find_max_factor_graph_potts(prob)"
                    << "\n\t The get_neighbor_idx() and get_neighbor() functions must be inverses of each other." 
                    << "\n\t node_i: " << node_i 
                    << "\n\t node_j: " << node_j 
                    << "\n\t prob.get_neighbor_idx(node_j,node_i): " << prob.get_neighbor_idx(node_j,node_i)
                    << "\n\t prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i)): " << prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i))
                            );

                DLIB_ASSERT(prob.factor_value_disagreement(node_i,node_j) >= 0,
                    "\t void find_max_factor_graph_potts(prob)"
                    << "\n\t Invalid inputs were given to this function." 
                    << "\n\t node_i: " << node_i 
                    << "\n\t node_j: " << node_j 
                    << "\n\t prob.factor_value_disagreement(node_i,node_j): " << prob.factor_value_disagreement(node_i,node_j)
                    );
                DLIB_ASSERT(prob.factor_value_disagreement(node_i,node_j) == prob.factor_value_disagreement(node_j,node_i),
                    "\t void find_max_factor_graph_potts(prob)"
                    << "\n\t Invalid inputs were given to this function." 
                    << "\n\t node_i: " << node_i 
                    << "\n\t node_j: " << node_j 
                    << "\n\t prob.factor_value_disagreement(node_i,node_j): " << prob.factor_value_disagreement(node_i,node_j)
                    << "\n\t prob.factor_value_disagreement(node_j,node_i): " << prob.factor_value_disagreement(node_j,node_i)
                    );
            }
        }
#endif 
        COMPILE_TIME_ASSERT(is_signed_type<typename potts_model::value_type>::value);
        min_cut mc;
        dlib::impl::potts_flow_graph<potts_model> pfg(prob);
        mc(pfg, prob.number_of_nodes(), prob.number_of_nodes()+1);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename graph_type 
        >
    void find_max_factor_graph_potts (
        const graph_type& g,
        std::vector<node_label>& labels
    )
    {
        DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
                    "\t void find_max_factor_graph_potts(g,labels)"
                    << "\n\t Invalid inputs were given to this function." 
                    );
        typedef typename graph_type::edge_type edge_type;
        typedef typename graph_type::type type;

        // The edges and node's have to use the same type to represent factor weights!
        COMPILE_TIME_ASSERT((is_same_type<edge_type, type>::value == true));
        COMPILE_TIME_ASSERT(is_signed_type<edge_type>::value);

#ifdef ENABLE_ASSERTS
        for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
        {
            for (unsigned long jj = 0; jj < g.node(i).number_of_neighbors(); ++jj)
            {
                unsigned long j = g.node(i).neighbor(jj).index();
                DLIB_ASSERT(edge(g,i,j) >= 0,
                    "\t void find_max_factor_graph_potts(g,labels)"
                    << "\n\t Invalid inputs were given to this function." 
                    << "\n\t i: " << i 
                    << "\n\t j: " << j 
                    << "\n\t edge(g,i,j): " << edge(g,i,j)
                    );
            }
        }
#endif 

        dlib::impl::general_potts_problem<graph_type> gg(g, labels);
        find_max_factor_graph_potts(gg);

    }

// ----------------------------------------------------------------------------------------

    template <
        typename potts_grid_problem,
        typename mem_manager
        >
    void find_max_factor_graph_potts (
        const potts_grid_problem& prob,
        array2d<node_label,mem_manager>& labels
    )
    {
        typedef array2d<node_label,mem_manager> image_type;
        labels.set_size(prob.nr(), prob.nc());
        dlib::impl::potts_grid_problem<image_type,potts_grid_problem> model(labels,prob);
        find_max_factor_graph_potts(model);
    }

// ---------------------------------------------------------------------------------------- 

    namespace impl
    {
        template <
            typename pixel_type1,
            typename pixel_type2,
            typename model_type
            >
        struct potts_grid_image_pair_model
        {
            const pixel_type1* data1;
            const pixel_type2* data2;
            const model_type& model;
            const long nr_;
            const long nc_;
            template <typename image_type1, typename image_type2>
            potts_grid_image_pair_model(
                const model_type& model_,
                const image_type1& img1,
                const image_type2& img2
            ) :
                model(model_),
                nr_(img1.nr()),
                nc_(img1.nc())
            {
                data1 = &img1[0][0];
                data2 = &img2[0][0];
            }

            typedef typename model_type::value_type value_type;

            long nr() const { return nr_; }
            long nc() const { return nc_; }

            value_type factor_value (
                unsigned long idx
            ) const 
            {
                return model.factor_value(*(data1 + idx), *(data2 + idx));
            }

            value_type factor_value_disagreement (
                unsigned long idx1,
                unsigned long idx2
            ) const 
            {
                return model.factor_value_disagreement(*(data1 + idx1), *(data1 + idx2));
            }
        };

    // ---------------------------------------------------------------------------------------- 

        template <
            typename image_type,
            typename model_type
            >
        struct potts_grid_image_single_model
        {
            const typename image_type::type* data1;
            const model_type& model;
            const long nr_;
            const long nc_;
            potts_grid_image_single_model(
                const model_type& model_,
                const image_type& img1
            ) :
                model(model_),
                nr_(img1.nr()),
                nc_(img1.nc())
            {
                data1 = &img1[0][0];
            }

            typedef typename model_type::value_type value_type;

            long nr() const { return nr_; }
            long nc() const { return nc_; }

            value_type factor_value (
                unsigned long idx
            ) const 
            {
                return model.factor_value(*(data1 + idx));
            }

            value_type factor_value_disagreement (
                unsigned long idx1,
                unsigned long idx2
            ) const 
            {
                return model.factor_value_disagreement(*(data1 + idx1), *(data1 + idx2));
            }
        };

    }

// ---------------------------------------------------------------------------------------- 

    template <
        typename pair_image_model,
        typename pixel_type1,
        typename pixel_type2,
        typename mem_manager
        >
    impl::potts_grid_image_pair_model<pixel_type1, pixel_type2, pair_image_model> make_potts_grid_problem (
        const pair_image_model& model,
        const array2d<pixel_type1,mem_manager>& img1,
        const array2d<pixel_type2,mem_manager>& img2
    )
    {
        DLIB_ASSERT(get_rect(img1) == get_rect(img2),
            "\t potts_grid_problem make_potts_grid_problem()"
            << "\n\t Invalid inputs were given to this function." 
            << "\n\t get_rect(img1): " << get_rect(img1)
            << "\n\t get_rect(img2): " << get_rect(img2)
            );
        typedef impl::potts_grid_image_pair_model<pixel_type1, pixel_type2, pair_image_model> potts_type;
        return potts_type(model,img1,img2);
    }

// ---------------------------------------------------------------------------------------- 

    template <
        typename single_image_model,
        typename pixel_type,
        typename mem_manager
        >
    impl::potts_grid_image_single_model<array2d<pixel_type,mem_manager>, single_image_model> make_potts_grid_problem (
        const single_image_model& model,
        const array2d<pixel_type,mem_manager>& img
    )
    {
        typedef impl::potts_grid_image_single_model<array2d<pixel_type,mem_manager>, single_image_model> potts_type;
        return potts_type(model,img);
    }

// ---------------------------------------------------------------------------------------- 

}

#endif // DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_