// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_AnY_TRAINER_H_
#define DLIB_AnY_TRAINER_H_
#include "any.h"
#include "any_decision_function.h"
#include "any_trainer_abstract.h"
#include <vector>
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename sample_type_,
typename scalar_type_ = double
>
class any_trainer
{
public:
using sample_type = sample_type_;
using scalar_type = scalar_type_;
using mem_manager_type = default_memory_manager;
using trained_function_type = any_decision_function<sample_type, scalar_type>;
any_trainer() = default;
any_trainer(const any_trainer& other) = default;
any_trainer& operator=(const any_trainer& other) = default;
any_trainer(any_trainer&& other) = default;
any_trainer& operator=(any_trainer&& other) = default;
template <
class T,
class T_ = std::decay_t<T>,
std::enable_if_t<!std::is_same<T_,any_trainer>::value, bool> = true
>
any_trainer (
T&& item
) : storage{std::forward<T>(item)},
train_func{[](
const void* ptr,
const std::vector<sample_type>& samples,
const std::vector<scalar_type>& labels
) -> trained_function_type {
const T_& f = *reinterpret_cast<const T_*>(ptr);
return f.train(samples, labels);
}}
{
}
template <
class T,
class T_ = std::decay_t<T>,
std::enable_if_t<!std::is_same<T_,any_trainer>::value, bool> = true
>
any_trainer& operator= (
T&& item
)
{
if (contains<T_>())
storage.unsafe_get<T_>() = std::forward<T>(item);
else
*this = std::move(any_trainer{std::forward<T>(item)});
return *this;
}
trained_function_type train (
const std::vector<sample_type>& samples,
const std::vector<scalar_type>& labels
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_empty() == false,
"\t trained_function_type any_trainer::train()"
<< "\n\t You can't call train() on an empty any_trainer"
<< "\n\t this: " << this
);
return train_func(storage.get_ptr(), samples, labels);
}
bool is_empty() const { return storage.is_empty(); }
void clear() { storage.clear(); }
void swap (any_trainer& item) { std::swap(*this, item); }
template <typename T> bool contains() const { return storage.contains<T>();}
template <typename T> T& cast_to() { return storage.cast_to<T>(); }
template <typename T> const T& cast_to() const { return storage.cast_to<T>(); }
template <typename T> T& get() { return storage.get<T>(); }
private:
te::storage_heap storage;
trained_function_type (*train_func) (
const void* self,
const std::vector<sample_type>& samples,
const std::vector<scalar_type>& labels
) = nullptr;
};
// ----------------------------------------------------------------------------------------
template <typename T, typename U, typename V>
T& any_cast(any_trainer<U,V>& a) { return a.template cast_to<T>(); }
template <typename T, typename U, typename V>
const T& any_cast(const any_trainer<U,V>& a) { return a.template cast_to<T>(); }
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_AnY_TRAINER_H_