// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_AnY_TRAINER_H_
#define DLIB_AnY_TRAINER_H_

#include "any.h"
#include "any_decision_function.h"
#include "any_trainer_abstract.h"
#include <vector>

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename sample_type_,
        typename scalar_type_ = double
    >
    class any_trainer
    {
    public:
        using sample_type = sample_type_;
        using scalar_type = scalar_type_;
        using mem_manager_type = default_memory_manager;
        using trained_function_type = any_decision_function<sample_type, scalar_type>;

        any_trainer()                                       = default;
        any_trainer(const any_trainer& other)               = default;
        any_trainer& operator=(const any_trainer& other)    = default;
        any_trainer(any_trainer&& other)                    = default;
        any_trainer& operator=(any_trainer&& other)         = default;

        template <
            class T,
            class T_ = std::decay_t<T>,
            std::enable_if_t<!std::is_same<T_,any_trainer>::value, bool> = true
        >
        any_trainer (
            T&& item
        ) : storage{std::forward<T>(item)},
            train_func{[](
                const void* ptr,
                const std::vector<sample_type>& samples,
                const std::vector<scalar_type>& labels
            ) -> trained_function_type {
                const T_& f = *reinterpret_cast<const T_*>(ptr);
                return f.train(samples, labels);
            }}
        {
        }

        template <
            class T,
            class T_ = std::decay_t<T>,
            std::enable_if_t<!std::is_same<T_,any_trainer>::value, bool> = true
        >
        any_trainer& operator= (
            T&& item
        )
        {
            if (contains<T_>())
                storage.unsafe_get<T_>() = std::forward<T>(item);
            else
                *this = std::move(any_trainer{std::forward<T>(item)});
            return *this;
        }

        trained_function_type train (
            const std::vector<sample_type>& samples,
            const std::vector<scalar_type>& labels
        ) const
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(is_empty() == false,
                "\t trained_function_type any_trainer::train()"
                << "\n\t You can't call train() on an empty any_trainer"
                << "\n\t this: " << this
                );

            return train_func(storage.get_ptr(), samples, labels);
        }

        bool is_empty() const { return storage.is_empty(); }
        void clear()          { storage.clear(); }
        void swap (any_trainer& item) { std::swap(*this, item); }

        template <typename T>     bool contains() const { return storage.contains<T>();}
        template <typename T>       T& cast_to()        { return storage.cast_to<T>(); }
        template <typename T> const T& cast_to() const  { return storage.cast_to<T>(); }
        template <typename T>       T& get()            { return storage.get<T>(); }

    private:
        te::storage_heap storage;
        trained_function_type (*train_func) (
            const void* self,
            const std::vector<sample_type>& samples,
            const std::vector<scalar_type>& labels
        ) = nullptr;
    };

// ----------------------------------------------------------------------------------------

    template <typename T, typename U, typename V> 
    T& any_cast(any_trainer<U,V>& a) { return a.template cast_to<T>(); }

    template <typename T, typename U, typename V> 
    const T& any_cast(const any_trainer<U,V>& a) { return a.template cast_to<T>(); }

// ----------------------------------------------------------------------------------------

}


#endif // DLIB_AnY_TRAINER_H_