// Copyright (C) 2008  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.

#include "../matrix/matrix_abstract.h"
#include "../algs.h"
#include "../svm.h"

namespace dlib

// ----------------------------------------------------------------------------------------

    template <
        typename trainer_type,
        typename in_sample_vector_type,
        typename in_scalar_vector_type
    const matrix<double, 1, 2, typename trainer_type::mem_manager_type> 
    cross_validate_trainer_threaded (
        const trainer_type& trainer,
        const in_sample_vector_type& x,
        const in_scalar_vector_type& y,
        const long folds,
        const long num_threads
            - is_binary_classification_problem(x,y) == true
            - 1 < folds <= std::min(sum(y>0),sum(y<0))
              (e.g. There must be at least as many examples of each class as there are folds)
            - trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
            - num_threads > 0
            - It must be safe for multiple trainer objects to access the elements of x from
              multiple threads at the same time.  Note that all trainers and kernels in
              dlib are thread safe in this regard since they do not mutate the elements of x.
            - performs k-fold cross validation by using the given trainer to solve the
              given binary classification problem for the given number of folds.
              Each fold is tested using the output of the trainer and the average 
              classification accuracy from all folds is returned.  
            - uses num_threads threads of execution in doing the cross validation.  
            - The accuracy is returned in a row vector, let us call it R.  Both 
              quantities in R are numbers between 0 and 1 which represent the fraction 
              of examples correctly classified.  R(0) is the fraction of +1 examples 
              correctly classified and R(1) is the fraction of -1 examples correctly 
            - The number of folds used is given by the folds argument.
            - any exceptions thrown by trainer.train()
            - std::bad_alloc

// ----------------------------------------------------------------------------------------