// Copyright (C) 2012  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_RLS_FiLTER_Hh_
#define DLIB_RLS_FiLTER_Hh_

#include "rls_filter_abstract.h"
#include "../svm/rls.h"
#include <vector>
#include "../matrix.h"
#include "../sliding_buffer.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    class rls_filter
    {
        /*!
            CONVENTION
                - data.size() == the number of variables in a measurement 
                - data[i].size() == data[j].size() for all i and j.  
                - data[i].size() == get_window_size() 
                - data[i][0] == most recent measurement of i-th variable given to update.
                - data[i].back() == oldest measurement of i-th variable given to update 
                  (or zero if we haven't seen this much data yet).

                - if (count <= 2) then
                    - count == number of times update(z) has been called
        !*/
    public:

        rls_filter()
        {
            size = 5;
            count = 0;
            filter = rls(0.8, 100);
        }

        explicit rls_filter (
            unsigned long size_,
            double forget_factor = 0.8,
            double C = 100
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(0 < forget_factor && forget_factor <= 1 &&
                        0 < C && size_ >= 2,
                "\t rls_filter::rls_filter()"
                << "\n\t invalid arguments were given to this function"
                << "\n\t forget_factor: " << forget_factor 
                << "\n\t C:     " << C 
                << "\n\t size_: " << size_
                << "\n\t this: " << this
                );

            size = size_;
            count = 0;
            filter = rls(forget_factor, C);
        }

        double get_c(
        ) const
        {
            return filter.get_c();
        }

        double get_forget_factor(
        ) const
        {
            return filter.get_forget_factor();
        }

        unsigned long get_window_size (
        ) const
        {
            return size;
        }

        void update (
        )
        {
            if (filter.get_w().size() == 0)
                return;

            for (unsigned long i = 0; i < data.size(); ++i)
            {
                // Put old predicted value into the circular buffer as if it was 
                // the measurement we just observed.  But don't update the rls filter.
                data[i].push_front(next(i));
            }

            // predict next state
            for (long i = 0; i < next.size(); ++i)
                next(i) = filter(mat(data[i]));
        }

        template <typename EXP>
        void update (
            const matrix_exp<EXP>& z
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(is_col_vector(z) == true &&
                        z.size() != 0 &&
                        (get_predicted_next_state().size()==0 || z.size()==get_predicted_next_state().size()),
                "\t void rls_filter::update(z)"
                << "\n\t invalid arguments were given to this function"
                << "\n\t is_col_vector(z): " << is_col_vector(z) 
                << "\n\t z.size():         " << z.size()
                << "\n\t get_predicted_next_state().size(): " << get_predicted_next_state().size()
                << "\n\t this: " << this
                );

            // initialize data if necessary 
            if (data.size() == 0)
            {
                data.resize(z.size());
                for (long i = 0; i < z.size(); ++i)
                    data[i].assign(size, 0);
            }


            for (unsigned long i = 0; i < data.size(); ++i)
            {
                // Once there is some stuff in the circular buffer, start
                // showing it to the rls filter so it can do its thing.
                if (count >= 2)
                {
                    filter.train(mat(data[i]), z(i));
                }

                // keep track of the measurements in our circular buffer
                data[i].push_front(z(i));
            }

            // Don't bother with the filter until we have seen two samples
            if (count >= 2)
            {
                // predict next state
                for (long i = 0; i < z.size(); ++i)
                    next(i) = filter(mat(data[i]));
            }
            else
            {
                // Use current measurement as the next state prediction
                // since we don't know any better at this point.
                ++count;
                next = matrix_cast<double>(z);
            }
        }

        const matrix<double,0,1>& get_predicted_next_state(
        ) const
        {
            return next;
        }

        friend inline void serialize(const rls_filter& item, std::ostream& out)
        {
            int version = 1;
            serialize(version, out);
            serialize(item.count, out);
            serialize(item.size, out);
            serialize(item.filter, out);
            serialize(item.next, out);
            serialize(item.data, out);
        }

        friend inline void deserialize(rls_filter& item, std::istream& in)
        {
            int version = 0;
            deserialize(version, in);
            if (version != 1)
                throw dlib::serialization_error("Unknown version number found while deserializing rls_filter object.");

            deserialize(item.count, in);
            deserialize(item.size, in);
            deserialize(item.filter, in);
            deserialize(item.next, in);
            deserialize(item.data, in);
        }

    private:

        unsigned long count;
        unsigned long size;
        rls filter;
        matrix<double,0,1> next;
        std::vector<circular_buffer<double> > data;
    };

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_RLS_FiLTER_Hh_