// Copyright (C) 2014 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_ #define DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_ #include "cross_validate_track_association_trainer_abstract.h" #include "structural_track_association_trainer.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace impl { template < typename track_association_function, typename detection_type, typename label_type > void test_track_association_function ( const track_association_function& assoc, const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& samples, unsigned long& total_dets, unsigned long& correctly_associated_dets ) { const typename track_association_function::association_function_type& f = assoc.get_assignment_function(); typedef typename detection_type::track_type track_type; using namespace impl; dlib::rand rnd; std::vector<track_type> tracks; std::map<label_type,long> track_idx; // tracks[track_idx[id]] == track with ID id. for (unsigned long j = 0; j < samples.size(); ++j) { std::vector<labeled_detection<detection_type,label_type> > dets = samples[j]; // Shuffle the order of the detections so we can be sure that there isn't // anything funny going on like the detections always coming in the same // order relative to their labels and the association function just gets // lucky by picking the same assignment ordering every time. So this way // we know the assignment function really is doing something rather than // just being lucky. randomize_samples(dets, rnd); total_dets += dets.size(); std::vector<long> assignments = f(get_unlabeled_dets(dets), tracks); std::vector<bool> updated_track(tracks.size(), false); // now update all the tracks with the detections that associated to them. for (unsigned long k = 0; k < assignments.size(); ++k) { // If the detection is associated to tracks[assignments[k]] if (assignments[k] != -1) { tracks[assignments[k]].update_track(dets[k].det); updated_track[assignments[k]] = true; // if this detection was supposed to go to this track if (track_idx.count(dets[k].label) && track_idx[dets[k].label]==assignments[k]) ++correctly_associated_dets; track_idx[dets[k].label] = assignments[k]; } else { track_type new_track; new_track.update_track(dets[k].det); tracks.push_back(new_track); // if this detection was supposed to go to a new track if (track_idx.count(dets[k].label) == 0) ++correctly_associated_dets; track_idx[dets[k].label] = tracks.size()-1; } } // Now propagate all the tracks that didn't get any detections. for (unsigned long k = 0; k < updated_track.size(); ++k) { if (!updated_track[k]) tracks[k].propagate_track(); } } } } // ---------------------------------------------------------------------------------------- template < typename track_association_function, typename detection_type, typename label_type > double test_track_association_function ( const track_association_function& assoc, const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples ) { unsigned long total_dets = 0; unsigned long correctly_associated_dets = 0; for (unsigned long i = 0; i < samples.size(); ++i) { impl::test_track_association_function(assoc, samples[i], total_dets, correctly_associated_dets); } return (double)correctly_associated_dets/(double)total_dets; } // ---------------------------------------------------------------------------------------- template < typename trainer_type, typename detection_type, typename label_type > double cross_validate_track_association_trainer ( const trainer_type& trainer, const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples, const long folds ) { const long num_in_test = samples.size()/folds; const long num_in_train = samples.size() - num_in_test; std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > > samples_train; long next_test_idx = 0; unsigned long total_dets = 0; unsigned long correctly_associated_dets = 0; for (long i = 0; i < folds; ++i) { samples_train.clear(); // load up the training samples long next = (next_test_idx + num_in_test)%samples.size(); for (long cnt = 0; cnt < num_in_train; ++cnt) { samples_train.push_back(samples[next]); next = (next + 1)%samples.size(); } const track_association_function<detection_type>& df = trainer.train(samples_train); for (long cnt = 0; cnt < num_in_test; ++cnt) { impl::test_track_association_function(df, samples[next_test_idx], total_dets, correctly_associated_dets); next_test_idx = (next_test_idx + 1)%samples.size(); } } return (double)correctly_associated_dets/(double)total_dets; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_