// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_AnY_TRAINER_H_
#define DLIB_AnY_TRAINER_H_

#include "any.h"

#include "any_decision_function.h"

#include "any_trainer_abstract.h"
#include <vector>

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename sample_type_,
        typename scalar_type_ = double
        >
    class any_trainer
    {
    public:
        typedef sample_type_ sample_type;
        typedef scalar_type_ scalar_type;
        typedef default_memory_manager mem_manager_type;
        typedef any_decision_function<sample_type, scalar_type> trained_function_type;


        any_trainer()
        {
        }

        any_trainer (
            const any_trainer& item
        )
        {
            if (item.data)
            {
                item.data->copy_to(data);
            }
        }

        template <typename T>
        any_trainer (
            const T& item
        )
        {
            typedef typename basic_type<T>::type U;
            data.reset(new derived<U>(item));
        }

        void clear (
        )
        {
            data.reset();
        }

        template <typename T>
        bool contains (
        ) const
        {
            typedef typename basic_type<T>::type U;
            return dynamic_cast<derived<U>*>(data.get()) != 0;
        }

        bool is_empty(
        ) const
        {
            return data.get() == 0;
        }

        trained_function_type train (
            const std::vector<sample_type>& samples,
            const std::vector<scalar_type>& labels
        ) const
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(is_empty() == false,
                "\t trained_function_type any_trainer::train()"
                << "\n\t You can't call train() on an empty any_trainer"
                << "\n\t this: " << this
                );

            return data->train(samples, labels);
        }

        template <typename T>
        T& cast_to(
        ) 
        {
            typedef typename basic_type<T>::type U;
            derived<U>* d = dynamic_cast<derived<U>*>(data.get());
            if (d == 0)
            {
                throw bad_any_cast();
            }

            return d->item;
        }

        template <typename T>
        const T& cast_to(
        ) const
        {
            typedef typename basic_type<T>::type U;
            derived<U>* d = dynamic_cast<derived<U>*>(data.get());
            if (d == 0)
            {
                throw bad_any_cast();
            }

            return d->item;
        }

        template <typename T>
        T& get(
        ) 
        {
            typedef typename basic_type<T>::type U;
            derived<U>* d = dynamic_cast<derived<U>*>(data.get());
            if (d == 0)
            {
                d = new derived<U>();
                data.reset(d);
            }

            return d->item;
        }

        any_trainer& operator= (
            const any_trainer& item
        )
        {
            any_trainer(item).swap(*this);
            return *this;
        }

        void swap (
            any_trainer& item
        )
        {
            data.swap(item.data);
        }

    private:

        struct base
        {
            virtual ~base() {}

            virtual trained_function_type train (
                const std::vector<sample_type>& samples,
                const std::vector<scalar_type>& labels
            ) const = 0;

            virtual void copy_to (
                std::unique_ptr<base>& dest
            ) const = 0;
        };

        template <typename T>
        struct derived : public base
        {
            T item;
            derived() {}
            derived(const T& val) : item(val) {}

            virtual void copy_to (
                std::unique_ptr<base>& dest
            ) const
            {
                dest.reset(new derived<T>(item));
            }

            virtual trained_function_type train (
                const std::vector<sample_type>& samples,
                const std::vector<scalar_type>& labels
            ) const
            {
                return item.train(samples, labels);
            }
        };

        std::unique_ptr<base> data;
    };

// ----------------------------------------------------------------------------------------

    template <
        typename sample_type,
        typename scalar_type
        >
    inline void swap (
        any_trainer<sample_type,scalar_type>& a,
        any_trainer<sample_type,scalar_type>& b
    ) { a.swap(b); }

// ----------------------------------------------------------------------------------------

    template <typename T, typename U, typename V> 
    T& any_cast(any_trainer<U,V>& a) { return a.template cast_to<T>(); }

    template <typename T, typename U, typename V> 
    const T& any_cast(const any_trainer<U,V>& a) { return a.template cast_to<T>(); }

// ----------------------------------------------------------------------------------------

}


#endif // DLIB_AnY_TRAINER_H_