// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*

    This is an example illustrating the use of the dlib machine learning tools for 
    learning to solve the assignment problem. 
    
    Many tasks in computer vision or natural language processing can be thought of 
    as assignment problems.  For example, in a computer vision application where 
    you are trying to track objects moving around in video, you likely need to solve
    an association problem every time you get a new video frame.  That is, each new 
    frame will contain objects (e.g. people, cars, etc.) and you will want to 
    determine which of these objects are actually things you have seen in previous 
    frames.  
   
    The assignment problem can be optimally solved using the well known Hungarian 
    algorithm.  However, this algorithm requires the user to supply some function 
    which measures the "goodness" of an individual association.  In many cases the 
    best way to measure this goodness isn't obvious and therefore machine learning 
    methods are used.  
    
    The remainder of this example program will show you how to learn a goodness 
    function which is optimal, in a certain sense, for use with the Hungarian 
    algorithm.  To do this, we will make a simple dataset of example associations 
    and use them to train a supervised machine learning method. 

    Finally, note that there is a whole example program dedicated to assignment learning
    problems where you are trying to make an object tracker.  So if that is what you are
    interested in then read the learning_to_track_ex.cpp example program.
*/


#include <iostream>
#include <dlib/svm_threaded.h>

using namespace std;
using namespace dlib;


// ----------------------------------------------------------------------------------------

/*
    In an association problem, we will talk about the "Left Hand Set" (LHS) and the
    "Right Hand Set" (RHS).  The task will be to learn to map all elements of LHS to 
    unique elements of RHS.  If an element of LHS can't be mapped to a unique element of 
    RHS for some reason (e.g. LHS is bigger than RHS) then it can also be mapped to the 
    special -1 output, indicating no mapping to RHS.

    So the first step is to define the type of elements in each of these sets.  In the
    code below we will use column vectors in both LHS and RHS.  However, in general,
    they can each contain any type you like.  LHS can even contain a different type 
    than RHS.
*/

typedef dlib::matrix<double,0,1> column_vector;

// This type represents a pair of LHS and RHS.  That is, sample_type::first
// contains a left hand set and sample_type::second contains a right hand set.
typedef std::pair<std::vector<column_vector>, std::vector<column_vector> > sample_type;

// This type will contain the association information between LHS and RHS.  That is,
// it will determine which elements of LHS map to which elements of RHS.
typedef std::vector<long> label_type;

// In this example, all our LHS and RHS elements will be 3-dimensional vectors.
const unsigned long num_dims = 3;

void make_data (
    std::vector<sample_type>& samples,
    std::vector<label_type>& labels
);
/*!
    ensures
        - This function creates a training dataset of 5 example associations.  
        - #samples.size() == 5
        - #labels.size() == 5
        - for all valid i:
            - #samples[i].first == a left hand set
            - #samples[i].second == a right hand set
            - #labels[i] == a set of integers indicating how to map LHS to RHS.  To be
              precise:  
                - #samples[i].first.size() == #labels[i].size()
                - for all valid j:
                    -1 <= #labels[i][j] < #samples[i].second.size()
                    (A value of -1 indicates that #samples[i].first[j] isn't associated with anything.
                    All other values indicate the associating element of #samples[i].second)
                - All elements of #labels[i] which are not equal to -1 are unique.  That is,
                  multiple elements of #samples[i].first can't associate to the same element
                  in #samples[i].second.
!*/

// ----------------------------------------------------------------------------------------

struct feature_extractor
{
    /*!
        Recall that our task is to learn the "goodness of assignment" function for
        use with the Hungarian algorithm.  The dlib tools assume this function
        can be written as:
            match_score(l,r) == dot(w, PSI(l,r))
        where l is an element of LHS, r is an element of RHS, w is a parameter vector,
        and PSI() is a user supplied feature extractor.

        This feature_extractor is where we implement PSI().  How you implement this
        is highly problem dependent.  
    !*/

    // The type of feature vector returned from get_features().  This must be either
    // a dlib::matrix or a sparse vector.
    typedef column_vector feature_vector_type;

    // The types of elements in the LHS and RHS sets
    typedef column_vector lhs_element;
    typedef column_vector rhs_element;


    unsigned long num_features() const
    {
        // Return the dimensionality of feature vectors produced by get_features()
        return num_dims + 1;
    }

    void get_features (
        const lhs_element& left,
        const rhs_element& right,
        feature_vector_type& feats
    ) const
    /*!
        ensures
            - #feats == PSI(left,right)
              (i.e. This function computes a feature vector which, in some sense, 
              captures information useful for deciding if matching left to right 
              is "good").
    !*/
    {
        // We will have: 
        //   - feats(i) == std::pow(left(i) - right(i), 2.0)
        // Except for the last element of feats which will be equal to 1 and
        // therefore function as a bias term.  Again, how you define this feature
        // extractor is highly problem dependent.    
        feats = join_cols(squared(left - right), ones_matrix<double>(1,1));
    }

};

// We need to define serialize() and deserialize() for our feature extractor if we want 
// to be able to serialize and deserialize our learned models.  In this case the 
// implementation is empty since our feature_extractor doesn't have any state.  But you 
// might define more complex feature extractors which have state that needs to be saved.
void serialize   (const feature_extractor& , std::ostream& ) {}
void deserialize (feature_extractor&       , std::istream& ) {}

// ----------------------------------------------------------------------------------------

int main()
{
    try
    {
        // Get a small bit of training data.
        std::vector<sample_type> samples;
        std::vector<label_type> labels;
        make_data(samples, labels);


        structural_assignment_trainer<feature_extractor> trainer;
        // This is the common SVM C parameter.  Larger values encourage the
        // trainer to attempt to fit the data exactly but might overfit. 
        // In general, you determine this parameter by cross-validation.
        trainer.set_c(10);
        // This trainer can use multiple CPU cores to speed up the training.  
        // So set this to the number of available CPU cores. 
        trainer.set_num_threads(4);

        // Do the training and save the results in assigner.
        assignment_function<feature_extractor> assigner = trainer.train(samples, labels);


        // Test the assigner on our data.  The output will indicate that it makes the
        // correct associations on all samples.
        cout << "Test the learned assignment function: " << endl;
        for (unsigned long i = 0; i < samples.size(); ++i)
        {
            // Predict the assignments for the LHS and RHS in samples[i].   
            std::vector<long> predicted_assignments = assigner(samples[i]);
            cout << "true labels:      " << trans(mat(labels[i]));
            cout << "predicted labels: " << trans(mat(predicted_assignments)) << endl;
        }

        // We can also use this tool to compute the percentage of assignments predicted correctly.
        cout << "training accuracy: " << test_assignment_function(assigner, samples, labels) << endl;


        // Since testing on your training data is a really bad idea, we can also do 5-fold cross validation.
        // Happily, this also indicates that all associations were made correctly.
        randomize_samples(samples, labels);
        cout << "cv accuracy:       " << cross_validate_assignment_trainer(trainer, samples, labels, 5) << endl;



        // Finally, the assigner can be serialized to disk just like most dlib objects.
        ofstream fout("assigner.dat", ios::binary);
        serialize(assigner, fout);
        fout.close();

        // recall from disk
        ifstream fin("assigner.dat", ios::binary);
        deserialize(assigner, fin);
    }
    catch (std::exception& e)
    {
        cout << "EXCEPTION THROWN" << endl;
        cout << e.what() << endl;
    }
}

// ----------------------------------------------------------------------------------------

void make_data (
    std::vector<sample_type>& samples,
    std::vector<label_type>& labels
)
{
    // Make four different vectors.  We will use them to make example assignments.
    column_vector A(num_dims), B(num_dims), C(num_dims), D(num_dims);
    A = 1,0,0;
    B = 0,1,0;
    C = 0,0,1;
    D = 0,1,1;

    std::vector<column_vector> lhs;
    std::vector<column_vector> rhs;
    label_type mapping;

    // In all the assignments to follow, we will only say an element of the LHS 
    // matches an element of the RHS if the two are equal.  So A matches with A, 
    // B with B, etc.  But never A with C, for example.
    // ------------------------

    lhs.resize(3);
    lhs[0] = A;
    lhs[1] = B;
    lhs[2] = C;

    rhs.resize(3);
    rhs[0] = B;
    rhs[1] = A;
    rhs[2] = C;

    mapping.resize(3);
    mapping[0] = 1;  // lhs[0] matches rhs[1]
    mapping[1] = 0;  // lhs[1] matches rhs[0]
    mapping[2] = 2;  // lhs[2] matches rhs[2]

    samples.push_back(make_pair(lhs,rhs));
    labels.push_back(mapping);

    // ------------------------

    lhs[0] = C;
    lhs[1] = A;
    lhs[2] = B;

    rhs[0] = A;
    rhs[1] = B;
    rhs[2] = D;

    mapping[0] = -1;  // The -1 indicates that lhs[0] doesn't match anything in rhs.
    mapping[1] = 0;   // lhs[1] matches rhs[0]
    mapping[2] = 1;   // lhs[2] matches rhs[1]

    samples.push_back(make_pair(lhs,rhs));
    labels.push_back(mapping);

    // ------------------------

    lhs[0] = A;
    lhs[1] = B;
    lhs[2] = C;

    rhs.resize(4);
    rhs[0] = C;
    rhs[1] = B;
    rhs[2] = A;
    rhs[3] = D;

    mapping[0] = 2;
    mapping[1] = 1;
    mapping[2] = 0;

    samples.push_back(make_pair(lhs,rhs));
    labels.push_back(mapping);

    // ------------------------

    lhs.resize(2);
    lhs[0] = B;
    lhs[1] = C;

    rhs.resize(3);
    rhs[0] = C;
    rhs[1] = A;
    rhs[2] = D;

    mapping.resize(2);
    mapping[0] = -1;
    mapping[1] = 0;

    samples.push_back(make_pair(lhs,rhs));
    labels.push_back(mapping);

    // ------------------------

    lhs.resize(3);
    lhs[0] = D;
    lhs[1] = B;
    lhs[2] = C;

    // rhs will be empty.  So none of the items in lhs can match anything.
    rhs.resize(0);

    mapping.resize(3);
    mapping[0] = -1;
    mapping[1] = -1;
    mapping[2] = -1;

    samples.push_back(make_pair(lhs,rhs));
    labels.push_back(mapping);

}