// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
#define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_

#include "approximate_linear_models_abstract.h"
#include "../matrix.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename feature_extractor
        >
    struct process_sample
    {
        typedef feature_extractor feature_extractor_type;
        typedef typename feature_extractor::state_type state_type;
        typedef typename feature_extractor::action_type action_type;

        process_sample(){}

        process_sample(
            const state_type& s,
            const action_type& a,
            const state_type& n,
            const double& r
        ) : state(s), action(a), next_state(n), reward(r) {}

        state_type  state;
        action_type action;
        state_type  next_state;
        double reward;
    };

    template < typename feature_extractor >
    void serialize (const process_sample<feature_extractor>& item, std::ostream& out)
    {
        serialize(item.state, out);
        serialize(item.action, out);
        serialize(item.next_state, out);
        serialize(item.reward, out);
    }

    template < typename feature_extractor >
    void deserialize (process_sample<feature_extractor>& item, std::istream& in)
    {
        deserialize(item.state, in);
        deserialize(item.action, in);
        deserialize(item.next_state, in);
        deserialize(item.reward, in);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename feature_extractor
        >
    class policy
    {
    public:

        typedef feature_extractor feature_extractor_type;
        typedef typename feature_extractor::state_type state_type;
        typedef typename feature_extractor::action_type action_type;


        policy (
        )
        {
            w.set_size(fe.num_features());
            w = 0;
        }

        policy (
            const matrix<double,0,1>& weights_,
            const feature_extractor& fe_
        ) : w(weights_), fe(fe_) {}

        action_type operator() (
            const state_type& state
        ) const
        {
            return fe.find_best_action(state,w);
        }

        const feature_extractor& get_feature_extractor (
        ) const { return fe; }

        const matrix<double,0,1>& get_weights (
        ) const { return w; }


    private:
        matrix<double,0,1> w;
        feature_extractor fe;
    };

    template < typename feature_extractor >
    inline void serialize(const policy<feature_extractor>& item, std::ostream& out)
    {
        int version = 1;
        serialize(version, out);
        serialize(item.get_feature_extractor(), out);
        serialize(item.get_weights(), out);
    }
    template < typename feature_extractor >
    inline void deserialize(policy<feature_extractor>& item, std::istream& in)
    {
        int version = 0;
        deserialize(version, in);
        if (version != 1)
            throw serialization_error("Unexpected version found while deserializing dlib::policy object.");
        feature_extractor fe;
        matrix<double,0,1> w;
        deserialize(fe, in);
        deserialize(w, in);
        item = policy<feature_extractor>(w,fe);
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_