// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_
#define DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_

#include "structural_object_detection_trainer_abstract.h"
#include "../algs.h"
#include "../optimization.h"
#include "structural_svm_object_detection_problem.h"
#include "../image_processing/object_detector.h"
#include "../image_processing/box_overlap_testing.h"
#include "../image_processing/full_object_detection.h"


namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename image_scanner_type,
        typename svm_struct_prob_type
        >
    void configure_nuclear_norm_regularizer (
        const image_scanner_type&,
        svm_struct_prob_type& 
    )
    { 
        // does nothing by default.  Specific scanner types overload this function to do
        // whatever is appropriate.
    }

// ----------------------------------------------------------------------------------------

    template <
        typename image_scanner_type
        >
    class structural_object_detection_trainer : noncopyable
    {

    public:
        typedef double scalar_type;
        typedef default_memory_manager mem_manager_type;
        typedef object_detector<image_scanner_type> trained_function_type;


        explicit structural_object_detection_trainer (
            const image_scanner_type& scanner_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(scanner_.get_num_detection_templates() > 0,
                "\t structural_object_detection_trainer::structural_object_detection_trainer(scanner_)"
                << "\n\t You can't have zero detection templates"
                << "\n\t this: " << this
                );

            C = 1;
            verbose = false;
            eps = 0.1;
            num_threads = 2;
            max_cache_size = 5;
            match_eps = 0.5;
            loss_per_missed_target = 1;
            loss_per_false_alarm = 1;

            scanner.copy_configuration(scanner_);

            auto_overlap_tester = true;
        }

        const image_scanner_type& get_scanner (
        ) const
        {
            return scanner;
        }

        bool auto_set_overlap_tester (
        ) const 
        { 
            return auto_overlap_tester; 
        }

        void set_overlap_tester (
            const test_box_overlap& tester
        )
        {
            overlap_tester = tester;
            auto_overlap_tester = false;
        }

        test_box_overlap get_overlap_tester (
        ) const
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(auto_set_overlap_tester() == false,
                "\t test_box_overlap structural_object_detection_trainer::get_overlap_tester()"
                << "\n\t You can't call this function if the overlap tester is generated dynamically."
                << "\n\t this: " << this
                );

            return overlap_tester;
        }

        void set_num_threads (
            unsigned long num
        )
        {
            num_threads = num;
        }

        unsigned long get_num_threads (
        ) const
        {
            return num_threads;
        }

        void set_epsilon (
            scalar_type eps_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(eps_ > 0,
                "\t void structural_object_detection_trainer::set_epsilon()"
                << "\n\t eps_ must be greater than 0"
                << "\n\t eps_: " << eps_ 
                << "\n\t this: " << this
                );

            eps = eps_;
        }

        scalar_type get_epsilon (
        ) const { return eps; }

        void set_max_runtime (
            const std::chrono::nanoseconds& max_runtime
        ) 
        {
            solver.set_max_runtime(max_runtime);
        }

        std::chrono::nanoseconds get_max_runtime (
        ) const
        {
            return solver.get_max_runtime();
        }

        void set_max_cache_size (
            unsigned long max_size
        )
        {
            max_cache_size = max_size;
        }

        unsigned long get_max_cache_size (
        ) const
        {
            return max_cache_size; 
        }

        void be_verbose (
        )
        {
            verbose = true;
        }

        void be_quiet (
        )
        {
            verbose = false;
        }

        void set_oca (
            const oca& item
        )
        {
            solver = item;
        }

        const oca get_oca (
        ) const
        {
            return solver;
        }

        void set_c (
            scalar_type C_ 
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(C_ > 0,
                "\t void structural_object_detection_trainer::set_c()"
                << "\n\t C_ must be greater than 0"
                << "\n\t C_:    " << C_ 
                << "\n\t this: " << this
                );

            C = C_;
        }

        scalar_type get_c (
        ) const
        {
            return C;
        }

        void set_match_eps (
            double eps
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(0 < eps && eps < 1, 
                "\t void structural_object_detection_trainer::set_match_eps(eps)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t eps:  " << eps 
                << "\n\t this: " << this
                );

            match_eps = eps;
        }

        double get_match_eps (
        ) const
        {
            return match_eps;
        }

        double get_loss_per_missed_target (
        ) const
        {
            return loss_per_missed_target;
        }

        void set_loss_per_missed_target (
            double loss
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(loss > 0, 
                "\t void structural_object_detection_trainer::set_loss_per_missed_target(loss)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t loss: " << loss
                << "\n\t this: " << this
                );

            loss_per_missed_target = loss;
        }

        double get_loss_per_false_alarm (
        ) const
        {
            return loss_per_false_alarm;
        }

        void set_loss_per_false_alarm (
            double loss
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(loss > 0, 
                "\t void structural_object_detection_trainer::set_loss_per_false_alarm(loss)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t loss: " << loss
                << "\n\t this: " << this
                );

            loss_per_false_alarm = loss;
        }

        template <
            typename image_array_type
            >
        const trained_function_type train (
            const image_array_type& images,
            const std::vector<std::vector<full_object_detection> >& truth_object_detections
        ) const
        {
            std::vector<std::vector<rectangle> > empty_ignore(images.size());
            return train_impl(images, truth_object_detections, empty_ignore, test_box_overlap());
        }

        template <
            typename image_array_type
            >
        const trained_function_type train (
            const image_array_type& images,
            const std::vector<std::vector<full_object_detection> >& truth_object_detections,
            const std::vector<std::vector<rectangle> >& ignore,
            const test_box_overlap& ignore_overlap_tester = test_box_overlap()
        ) const
        {
            return train_impl(images, truth_object_detections, ignore, ignore_overlap_tester);
        }

        template <
            typename image_array_type
            >
        const trained_function_type train (
            const image_array_type& images,
            const std::vector<std::vector<rectangle> >& truth_object_detections
        ) const
        {
            std::vector<std::vector<rectangle> > empty_ignore(images.size());
            return train(images, truth_object_detections, empty_ignore, test_box_overlap());
        }

        template <
            typename image_array_type
            >
        const trained_function_type train (
            const image_array_type& images,
            const std::vector<std::vector<rectangle> >& truth_object_detections,
            const std::vector<std::vector<rectangle> >& ignore,
            const test_box_overlap& ignore_overlap_tester = test_box_overlap()
        ) const
        {
            std::vector<std::vector<full_object_detection> > truth_dets(truth_object_detections.size());
            for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
            {
                for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
                {
                    truth_dets[i].push_back(full_object_detection(truth_object_detections[i][j]));
                }
            }

            return train_impl(images, truth_dets, ignore, ignore_overlap_tester);
        }

    private:

        template <
            typename image_array_type
            >
        const trained_function_type train_impl (
            const image_array_type& images,
            const std::vector<std::vector<full_object_detection> >& truth_object_detections,
            const std::vector<std::vector<rectangle> >& ignore,
            const test_box_overlap& ignore_overlap_tester
        ) const
        {
#ifdef ENABLE_ASSERTS
            // make sure requires clause is not broken
            DLIB_ASSERT(is_learning_problem(images,truth_object_detections) == true && images.size() == ignore.size(),
                "\t trained_function_type structural_object_detection_trainer::train()"
                << "\n\t invalid inputs were given to this function"
                << "\n\t images.size():      " << images.size()
                << "\n\t ignore.size():      " << ignore.size()
                << "\n\t truth_object_detections.size(): " << truth_object_detections.size()
                << "\n\t is_learning_problem(images,truth_object_detections): " << is_learning_problem(images,truth_object_detections)
                );
            for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
            {
                for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
                {
                    DLIB_ASSERT(truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template() &&
                                all_parts_in_rect(truth_object_detections[i][j]) == true,
                        "\t trained_function_type structural_object_detection_trainer::train()"
                        << "\n\t invalid inputs were given to this function"
                        << "\n\t truth_object_detections["<<i<<"]["<<j<<"].num_parts():                " << 
                            truth_object_detections[i][j].num_parts()
                        << "\n\t get_scanner().get_num_movable_components_per_detection_template(): " << 
                            get_scanner().get_num_movable_components_per_detection_template()
                        << "\n\t all_parts_in_rect(truth_object_detections["<<i<<"]["<<j<<"]): " << all_parts_in_rect(truth_object_detections[i][j])
                    );
                }
            }
#endif

            structural_svm_object_detection_problem<image_scanner_type,image_array_type > 
                svm_prob(scanner, overlap_tester, auto_overlap_tester, images,
                    truth_object_detections, ignore, ignore_overlap_tester, num_threads);

            if (verbose)
                svm_prob.be_verbose();

            svm_prob.set_c(C);
            svm_prob.set_epsilon(eps);
            svm_prob.set_max_cache_size(max_cache_size);
            svm_prob.set_match_eps(match_eps);
            svm_prob.set_loss_per_missed_target(loss_per_missed_target);
            svm_prob.set_loss_per_false_alarm(loss_per_false_alarm);
            configure_nuclear_norm_regularizer(scanner, svm_prob);
            matrix<double,0,1> w;

            // Run the optimizer to find the optimal w.
            solver(svm_prob,w);

            // report the results of the training.
            return object_detector<image_scanner_type>(scanner, svm_prob.get_overlap_tester(), w);
        }

        image_scanner_type scanner;
        test_box_overlap overlap_tester;

        double C;
        oca solver;
        double eps;
        double match_eps;
        bool verbose;
        unsigned long num_threads;
        unsigned long max_cache_size;
        double loss_per_missed_target;
        double loss_per_false_alarm;
        bool auto_overlap_tester;

    }; 

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_