// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_
#define DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_

#include <memory>
#include <iostream>
#include <vector>

#include "structural_svm_distributed_abstract.h"
#include "structural_svm_problem.h"
#include "../bridge.h"
#include "../misc_api.h"
#include "../statistics.h"
#include "../threads.h"
#include "../pipe.h"
#include "../type_safe_union.h"


namespace dlib
{

// ----------------------------------------------------------------------------------------

    namespace impl
    {

        template <typename matrix_type>
        struct oracle_response
        {
            typedef typename matrix_type::type scalar_type;

            matrix_type subgradient;
            scalar_type loss;
            long num;

            friend void swap (oracle_response& a, oracle_response& b)
            {
                a.subgradient.swap(b.subgradient);
                std::swap(a.loss, b.loss);
                std::swap(a.num, b.num);
            }

            friend void serialize (const oracle_response& item, std::ostream& out)
            {
                serialize(item.subgradient, out);
                dlib::serialize(item.loss, out);
                dlib::serialize(item.num, out);
            }

            friend void deserialize (oracle_response& item, std::istream& in)
            {
                deserialize(item.subgradient, in);
                dlib::deserialize(item.loss, in);
                dlib::deserialize(item.num, in);
            }
        };

    // ----------------------------------------------------------------------------------------

        template <typename matrix_type>
        struct oracle_request
        {
            typedef typename matrix_type::type scalar_type;

            matrix_type current_solution;
            scalar_type saved_current_risk_gap;
            bool skip_cache;
            bool converged;

            friend void swap (oracle_request& a, oracle_request& b)
            {
                a.current_solution.swap(b.current_solution);
                std::swap(a.saved_current_risk_gap, b.saved_current_risk_gap);
                std::swap(a.skip_cache, b.skip_cache);
                std::swap(a.converged, b.converged);
            }

            friend void serialize (const oracle_request& item, std::ostream& out)
            {
                serialize(item.current_solution, out);
                dlib::serialize(item.saved_current_risk_gap, out);
                dlib::serialize(item.skip_cache, out);
                dlib::serialize(item.converged, out);
            }

            friend void deserialize (oracle_request& item, std::istream& in)
            {
                deserialize(item.current_solution, in);
                dlib::deserialize(item.saved_current_risk_gap, in);
                dlib::deserialize(item.skip_cache, in);
                dlib::deserialize(item.converged, in);
            }
        };

    }

// ----------------------------------------------------------------------------------------

    class svm_struct_processing_node : noncopyable
    {
    public:

        template <
            typename T,
            typename U 
            >
        svm_struct_processing_node (
            const structural_svm_problem<T,U>& problem,
            unsigned short port,
            unsigned short num_threads
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(port != 0 && problem.get_num_samples() != 0 &&
                        problem.get_num_dimensions() != 0,
                "\t svm_struct_processing_node()"
                << "\n\t Invalid arguments were given to this function"
                << "\n\t port: " << port 
                << "\n\t problem.get_num_samples():    " << problem.get_num_samples() 
                << "\n\t problem.get_num_dimensions(): " << problem.get_num_dimensions() 
                << "\n\t this: " << this
                );

            the_problem.reset(new node_type<T,U>(problem, port, num_threads));
        }

    private:

        struct base
        {
            virtual ~base(){}
        };

        template <
            typename matrix_type,
            typename feature_vector_type 
            >
        class node_type : public base, threaded_object
        {
        public:
            typedef typename matrix_type::type scalar_type;

            node_type(
                const structural_svm_problem<matrix_type,feature_vector_type>& prob,
                unsigned short port,
                unsigned long num_threads
            ) : in(3),out(3), problem(prob), tp(num_threads)
            {
                b.reconfigure(listen_on_port(port), receive(in), transmit(out));

                start();
            }

            ~node_type()
            {
                in.disable();
                out.disable();
                wait();
            }

        private:

            void thread()
            {
                using namespace impl;
                tsu_in msg; 
                tsu_out temp;

                timestamper ts;
                running_stats<double> with_buffer_time;
                running_stats<double> without_buffer_time;
                unsigned long num_iterations_executed = 0;

                while (in.dequeue(msg))
                {
                    // initialize the cache and compute psi_true.
                    if (cache.size() == 0)
                    {
                        cache.resize(problem.get_num_samples());
                        for (unsigned long i = 0; i < cache.size(); ++i)
                            cache[i].init(&problem,i);

                        psi_true.set_size(problem.get_num_dimensions(),1);
                        psi_true = 0;

                        const unsigned long num = problem.get_num_samples();
                        feature_vector_type ftemp;
                        for (unsigned long i = 0; i < num; ++i)
                        {
                            cache[i].get_truth_joint_feature_vector_cached(ftemp);

                            subtract_from(psi_true, ftemp);
                        }
                    }


                    if (msg.template contains<bridge_status>() && 
                        msg.template get<bridge_status>().is_connected)
                    {
                        temp = problem.get_num_dimensions();
                        out.enqueue(temp);

                    }
                    else if (msg.template contains<oracle_request<matrix_type> >())
                    {
                        ++num_iterations_executed;

                        const oracle_request<matrix_type>& req = msg.template get<oracle_request<matrix_type> >();

                        oracle_response<matrix_type>& data = temp.template get<oracle_response<matrix_type> >();

                        data.subgradient = psi_true;
                        data.loss = 0;

                        data.num = problem.get_num_samples();

                        const uint64 start_time = ts.get_timestamp();

                        // pick fastest buffering strategy
                        bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean();

                        // every 50 iterations we should try to flip the buffering scheme to see if
                        // doing it the other way might be better.  
                        if ((num_iterations_executed%50) == 0)
                        {
                            buffer_subgradients_locally = !buffer_subgradients_locally;
                        }

                        binder b(*this, req, data, buffer_subgradients_locally);
                        parallel_for_blocked(tp, 0, data.num, b, &binder::call_oracle);

                        const uint64 stop_time = ts.get_timestamp();
                        if (buffer_subgradients_locally)
                            with_buffer_time.add(stop_time-start_time);
                        else
                            without_buffer_time.add(stop_time-start_time);

                        out.enqueue(temp);
                    }
                }
            }

            struct binder
            {
                binder (
                    const node_type& self_,
                    const impl::oracle_request<matrix_type>& req_,
                    impl::oracle_response<matrix_type>& data_,
                    bool buffer_subgradients_locally_
                ) : self(self_), req(req_), data(data_),
                    buffer_subgradients_locally(buffer_subgradients_locally_) {}

                void call_oracle (
                    long begin,
                    long end
                ) 
                {
                    // If we are only going to call the separation oracle once then don't
                    // run the slightly more complex for loop version of this code.  Or if
                    // we just don't want to run the complex buffering one.  The code later
                    // on decides if we should do the buffering based on how long it takes
                    // to execute.  We do this because, when the subgradient is really high
                    // dimensional it can take a lot of time to add them together.  So we
                    // might want to avoid doing that.
                    if (end-begin <= 1 || !buffer_subgradients_locally)
                    {
                        scalar_type loss;
                        feature_vector_type ftemp;
                        for (long i = begin; i < end; ++i)
                        {
                            self.cache[i].separation_oracle_cached(req.converged, 
                                                                   req.skip_cache, 
                                                                   req.saved_current_risk_gap,
                                                                   req.current_solution,
                                                                   loss,
                                                                   ftemp);

                            auto_mutex lock(self.accum_mutex);
                            data.loss += loss;
                            add_to(data.subgradient, ftemp);
                        }
                    }
                    else
                    {
                        scalar_type loss = 0;
                        matrix_type faccum(data.subgradient.size(),1);
                        faccum = 0;

                        feature_vector_type ftemp;

                        for (long i = begin; i < end; ++i)
                        {
                            scalar_type loss_temp;
                            self.cache[i].separation_oracle_cached(req.converged,
                                                                   req.skip_cache, 
                                                                   req.saved_current_risk_gap,
                                                                   req.current_solution,
                                                                   loss_temp,
                                                                   ftemp);
                            loss += loss_temp;
                            add_to(faccum, ftemp);
                        }

                        auto_mutex lock(self.accum_mutex);
                        data.loss += loss;
                        add_to(data.subgradient, faccum);
                    }
                }

                const node_type& self;
                const impl::oracle_request<matrix_type>& req;
                impl::oracle_response<matrix_type>& data;
                bool buffer_subgradients_locally;
            };



            typedef type_safe_union<impl::oracle_request<matrix_type>, bridge_status> tsu_in;
            typedef type_safe_union<impl::oracle_response<matrix_type> , long> tsu_out;

            pipe<tsu_in> in;
            pipe<tsu_out> out;
            bridge b;

            mutable matrix_type psi_true;
            const structural_svm_problem<matrix_type,feature_vector_type>& problem;
            mutable std::vector<cache_element_structural_svm<structural_svm_problem<matrix_type,feature_vector_type> > > cache;

            mutable thread_pool tp;
            mutex accum_mutex;
        };


        std::unique_ptr<base> the_problem;
    };

// ----------------------------------------------------------------------------------------

    class svm_struct_controller_node : noncopyable
    {
    public:

        svm_struct_controller_node (
        ) :
            eps(0.001),
            max_iterations(10000),
            cache_based_eps(std::numeric_limits<double>::infinity()),
            verbose(false),
            C(1)
        {}

        double get_cache_based_epsilon (
        ) const
        {
            return cache_based_eps;
        }

        void set_cache_based_epsilon (
            double eps_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(eps_ > 0,
                "\t void svm_struct_controller_node::set_cache_based_epsilon()"
                << "\n\t eps_ must be greater than 0"
                << "\n\t eps_: " << eps_ 
                << "\n\t this: " << this
                );

            cache_based_eps = eps_;
        }

        void set_epsilon (
            double eps_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(eps_ > 0,
                "\t void svm_struct_controller_node::set_epsilon()"
                << "\n\t eps_ must be greater than 0"
                << "\n\t eps_: " << eps_ 
                << "\n\t this: " << this
                );

            eps = eps_;
        }

        double get_epsilon (
        ) const { return eps; }

        unsigned long get_max_iterations (
        ) const { return max_iterations; }

        void set_max_iterations (
            unsigned long max_iter
        ) 
        {
            max_iterations = max_iter;
        }

        void be_verbose (
        ) 
        {
            verbose = true;
        }

        void be_quiet(
        )
        {
            verbose = false;
        }

        void add_nuclear_norm_regularizer (
            long first_dimension,
            long rows,
            long cols,
            double regularization_strength
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(0 <= first_dimension  &&
                0 <= rows && 0 <= cols && 
                0 < regularization_strength,
                "\t void svm_struct_controller_node::add_nuclear_norm_regularizer()"
                << "\n\t Invalid arguments were given to this function."
                << "\n\t first_dimension:         " << first_dimension 
                << "\n\t rows:                    " << rows 
                << "\n\t cols:                    " << cols 
                << "\n\t regularization_strength: " << regularization_strength 
                << "\n\t this: " << this
                );

            impl::nuclear_norm_regularizer temp;
            temp.first_dimension = first_dimension;
            temp.nr = rows;
            temp.nc = cols;
            temp.regularization_strength = regularization_strength;
            nuclear_norm_regularizers.push_back(temp);
        }

        unsigned long num_nuclear_norm_regularizers (
        ) const { return nuclear_norm_regularizers.size(); }

        void clear_nuclear_norm_regularizers (
        ) { nuclear_norm_regularizers.clear(); }


        double get_c (
        ) const { return C; }

        void set_c (
            double C_
        ) 
        { 
            // make sure requires clause is not broken
            DLIB_ASSERT(C_ > 0,
                "\t void svm_struct_controller_node::set_c()"
                << "\n\t C_ must be greater than 0"
                << "\n\t C_:    " << C_ 
                << "\n\t this: " << this
                );

            C = C_; 
        }

        void add_processing_node (
            const network_address& addr
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(addr.port != 0,
                "\t void svm_struct_controller_node::add_processing_node()"
                << "\n\t Invalid inputs were given to this function"
                << "\n\t addr.host_address:   " << addr.host_address 
                << "\n\t addr.port: " << addr.port
                << "\n\t this: " << this
                );

            // check if this address is already registered
            for (unsigned long i = 0; i < nodes.size(); ++i)
            {
                if (nodes[i] == addr)
                {
                    return;
                }
            }
            
            nodes.push_back(addr);
        }

        void add_processing_node (
            const std::string& ip_or_hostname,
            unsigned short port
        )
        {
            add_processing_node(network_address(ip_or_hostname,port));
        }

        unsigned long get_num_processing_nodes (
        ) const
        {
            return nodes.size();
        }

        void remove_processing_nodes (
        ) 
        {
            nodes.clear();
        }

        template <typename matrix_type>
        double operator() (
            const oca& solver,
            matrix_type& w
        ) const
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(get_num_processing_nodes() != 0,
                        "\t double svm_struct_controller_node::operator()"
                        << "\n\t You must add some processing nodes before calling this function."
                        << "\n\t this: " << this
            );

            problem_type<matrix_type> problem(nodes);
            problem.set_cache_based_epsilon(cache_based_eps);
            problem.set_epsilon(eps);
            problem.set_max_iterations(max_iterations);
            if (verbose)
                problem.be_verbose();
            problem.set_c(C);
            for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i)
            {
                problem.add_nuclear_norm_regularizer(
                    nuclear_norm_regularizers[i].first_dimension,
                    nuclear_norm_regularizers[i].nr,
                    nuclear_norm_regularizers[i].nc,
                    nuclear_norm_regularizers[i].regularization_strength);
            }

            return solver(problem, w);
        }

        class invalid_problem : public error
        {
        public:
            invalid_problem(
                const std::string& a
            ): error(a) {}
        };


    private:

        template <typename matrix_type_>
        class problem_type : public structural_svm_problem<matrix_type_>
        {
        public:
            typedef typename matrix_type_::type scalar_type;
            typedef matrix_type_ matrix_type;

            problem_type (
                const std::vector<network_address>& nodes_
            ) :
                nodes(nodes_),
                in(3),
                num_dims(0)
            {

                // initialize all the transmit pipes
                out_pipes.resize(nodes.size());
                for (unsigned long i = 0; i < out_pipes.size(); ++i)
                {
                    out_pipes[i].reset(new pipe<tsu_out>(3));
                }

                // make bridges that connect to all our remote processing nodes
                bridges.resize(nodes.size());
                for (unsigned long i = 0; i< bridges.size(); ++i)
                {
                    bridges[i].reset(new bridge(connect_to(nodes[i]), 
                                                receive(in), transmit(*out_pipes[i])));
                }



                // The remote processing nodes are supposed to all send the problem dimensionality
                // upon connection. So get that and make sure everyone agrees on what it's supposed to be.
                tsu_in temp;
                unsigned long responses = 0;
                bool seen_dim = false;
                while (responses < nodes.size())
                {
                    in.dequeue(temp);
                    if (temp.template contains<long>())
                    {
                        ++responses;
                        // if this new dimension doesn't match what we have seen previously
                        if (seen_dim && num_dims != temp.template get<long>())
                        {
                            throw invalid_problem("remote hosts disagree on the number of dimensions!");
                        }
                        seen_dim = true;
                        num_dims = temp.template get<long>();
                    }
                }
            }

            // These functions are just here because the structural_svm_problem requires
            // them, but since we are overloading get_risk() they are never called so they
            // don't matter.
            virtual long get_num_samples () const {return 0;}
            virtual void get_truth_joint_feature_vector ( long , matrix_type&  ) const {}
            virtual void separation_oracle ( const long , const matrix_type& , scalar_type& , matrix_type& ) const {}

            virtual long get_num_dimensions (
            ) const
            {
                return num_dims;
            }

            virtual void get_risk (
                matrix_type& w,
                scalar_type& risk,
                matrix_type& subgradient
            ) const 
            {
                using namespace impl;
                subgradient.set_size(w.size(),1);
                subgradient = 0;

                // send out all the oracle requests
                tsu_out temp_out;
                for (unsigned long i = 0; i < out_pipes.size(); ++i)
                {
                    temp_out.template get<oracle_request<matrix_type> >().current_solution = w;
                    temp_out.template get<oracle_request<matrix_type> >().saved_current_risk_gap = this->saved_current_risk_gap;
                    temp_out.template get<oracle_request<matrix_type> >().skip_cache = this->skip_cache;
                    temp_out.template get<oracle_request<matrix_type> >().converged = this->converged;
                    out_pipes[i]->enqueue(temp_out);
                }

                // collect all the oracle responses  
                long num = 0;
                scalar_type total_loss = 0;
                tsu_in temp_in;
                unsigned long responses = 0;
                while (responses < out_pipes.size())
                {
                    in.dequeue(temp_in);
                    if (temp_in.template contains<oracle_response<matrix_type> >())
                    {
                        ++responses;
                        const oracle_response<matrix_type>& data = temp_in.template get<oracle_response<matrix_type> >();
                        subgradient += data.subgradient; 
                        total_loss += data.loss;
                        num += data.num;
                    }
                }

                subgradient /= num;
                total_loss /= num;
                risk = total_loss + dot(subgradient,w);

                if (this->nuclear_norm_regularizers.size() != 0)
                {
                    matrix_type grad; 
                    double obj;
                    this->compute_nuclear_norm_parts(w, grad, obj);
                    risk += obj;
                    subgradient += grad;
                }
            }

            std::vector<network_address> nodes;

            typedef type_safe_union<impl::oracle_request<matrix_type> > tsu_out;
            typedef type_safe_union<impl::oracle_response<matrix_type>, long> tsu_in;

            std::vector<std::shared_ptr<pipe<tsu_out> > > out_pipes;
            mutable pipe<tsu_in> in;
            std::vector<std::shared_ptr<bridge> > bridges;
            long num_dims;
        };

        std::vector<network_address> nodes;
        double eps;
        unsigned long max_iterations;
        double cache_based_eps;
        bool verbose;
        double C;
        std::vector<impl::nuclear_norm_regularizer> nuclear_norm_regularizers;
    };

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_