// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#undef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_
#ifdef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_

#include <vector>
#include "../matrix.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename reg_funct_type,
        typename sample_type,
        typename label_type
        >
    matrix<double,1,4>
    test_regression_function (
        reg_funct_type& reg_funct,
        const std::vector<sample_type>& x_test,
        const std::vector<label_type>& y_test
    );
    /*!
        requires
            - is_learning_problem(x_test, y_test)
            - reg_funct_type == some kind of regression function object 
              (e.g. a decision_function created by the svr_trainer )
        ensures
            - Tests reg_funct against the given samples in x_test and target values in 
              y_test and returns a matrix M summarizing the results.  Specifically:
                - M(0) == the mean squared error.  
                  The MSE is given by: sum over i: pow(reg_funct(x_test[i]) - y_test[i], 2.0)
                - M(1) == the correlation between reg_funct(x_test[i]) and y_test[i].
                  This is a number between -1 and 1.
                - M(2) == the mean absolute error.  
                  This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i])
                - M(3) == the standard deviation of the absolute error.
    !*/

// ----------------------------------------------------------------------------------------

    template <
        typename trainer_type,
        typename sample_type,
        typename label_type 
        >
    matrix<double,1,4>
    cross_validate_regression_trainer (
        const trainer_type& trainer,
        const std::vector<sample_type>& x,
        const std::vector<label_type>& y,
        const long folds
    );
    /*!
        requires
            - is_learning_problem(x,y)
            - 1 < folds <= x.size()
            - trainer_type == some kind of regression trainer object (e.g. svr_trainer)
        ensures
            - Performs k-fold cross validation by using the given trainer to solve a 
              regression problem for the given number of folds.  Each fold is tested using 
              the output of the trainer.  A matrix M summarizing the results is returned.  
              Specifically:
                - M(0) == the mean squared error.  
                  The MSE is given by: sum over i: pow(reg_funct(x[i]) - y[i], 2.0)
                - M(1) == the correlation between a predicted y value and its true value.
                  This is a number between -1 and 1.
                - M(2) == the mean absolute error.  
                  This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i])
                - M(3) == the standard deviation of the absolute error.
    !*/

}

// ----------------------------------------------------------------------------------------

#endif // DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_