// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_ #define DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_ #include "structural_sequence_labeling_trainer_abstract.h" #include "../algs.h" #include "../optimization.h" #include "structural_svm_sequence_labeling_problem.h" #include "num_nonnegative_weights.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class structural_sequence_labeling_trainer { public: typedef typename feature_extractor::sequence_type sample_sequence_type; typedef std::vector<unsigned long> labeled_sequence_type; typedef sequence_labeler<feature_extractor> trained_function_type; explicit structural_sequence_labeling_trainer ( const feature_extractor& fe_ ) : fe(fe_) { set_defaults(); } structural_sequence_labeling_trainer ( ) { set_defaults(); } const feature_extractor& get_feature_extractor ( ) const { return fe; } unsigned long num_labels ( ) const { return fe.num_labels(); } void set_num_threads ( unsigned long num ) { num_threads = num; } unsigned long get_num_threads ( ) const { return num_threads; } void set_epsilon ( double eps_ ) { // make sure requires clause is not broken DLIB_ASSERT(eps_ > 0, "\t void structural_sequence_labeling_trainer::set_epsilon()" << "\n\t eps_ must be greater than 0" << "\n\t eps_: " << eps_ << "\n\t this: " << this ); eps = eps_; } double get_epsilon ( ) const { return eps; } unsigned long get_max_iterations ( ) const { return max_iterations; } void set_max_iterations ( unsigned long max_iter ) { max_iterations = max_iter; } void set_max_cache_size ( unsigned long max_size ) { max_cache_size = max_size; } unsigned long get_max_cache_size ( ) const { return max_cache_size; } void be_verbose ( ) { verbose = true; } void be_quiet ( ) { verbose = false; } void set_oca ( const oca& item ) { solver = item; } const oca get_oca ( ) const { return solver; } void set_c ( double C_ ) { // make sure requires clause is not broken DLIB_ASSERT(C_ > 0, "\t void structural_sequence_labeling_trainer::set_c()" << "\n\t C_ must be greater than 0" << "\n\t C_: " << C_ << "\n\t this: " << this ); C = C_; } double get_c ( ) const { return C; } double get_loss ( unsigned long label ) const { // make sure requires clause is not broken DLIB_ASSERT(label < num_labels(), "\t void structural_sequence_labeling_trainer::get_loss()" << "\n\t invalid inputs were given to this function" << "\n\t label: " << label << "\n\t num_labels(): " << num_labels() << "\n\t this: " << this ); return loss_values[label]; } void set_loss ( unsigned long label, double value ) { // make sure requires clause is not broken DLIB_ASSERT(label < num_labels() && value >= 0, "\t void structural_sequence_labeling_trainer::set_loss()" << "\n\t invalid inputs were given to this function" << "\n\t label: " << label << "\n\t num_labels(): " << num_labels() << "\n\t value: " << value << "\n\t this: " << this ); loss_values[label] = value; } const sequence_labeler<feature_extractor> train( const std::vector<sample_sequence_type>& x, const std::vector<labeled_sequence_type>& y ) const { // make sure requires clause is not broken DLIB_ASSERT(is_sequence_labeling_problem(x,y) == true && contains_invalid_labeling(get_feature_extractor(), x, y) == false, "\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)" << "\n\t invalid inputs were given to this function" << "\n\t x.size(): " << x.size() << "\n\t is_sequence_labeling_problem(x,y): " << is_sequence_labeling_problem(x,y) << "\n\t contains_invalid_labeling(get_feature_extractor(),x,y): " << contains_invalid_labeling(get_feature_extractor(),x,y) << "\n\t this: " << this ); #ifdef ENABLE_ASSERTS for (unsigned long i = 0; i < y.size(); ++i) { for (unsigned long j = 0; j < y[i].size(); ++j) { // make sure requires clause is not broken DLIB_ASSERT(y[i][j] < num_labels(), "\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)" << "\n\t The given labels in y are invalid." << "\n\t y[i][j]: " << y[i][j] << "\n\t num_labels(): " << num_labels() << "\n\t i: " << i << "\n\t j: " << j << "\n\t this: " << this ); } } #endif structural_svm_sequence_labeling_problem<feature_extractor> prob(x, y, fe, num_threads); matrix<double,0,1> weights; if (verbose) prob.be_verbose(); prob.set_epsilon(eps); prob.set_max_iterations(max_iterations); prob.set_c(C); prob.set_max_cache_size(max_cache_size); for (unsigned long i = 0; i < loss_values.size(); ++i) prob.set_loss(i,loss_values[i]); solver(prob, weights, num_nonnegative_weights(fe)); return sequence_labeler<feature_extractor>(weights,fe); } private: double C; oca solver; double eps; unsigned long max_iterations; bool verbose; unsigned long num_threads; unsigned long max_cache_size; std::vector<double> loss_values; void set_defaults () { C = 100; verbose = false; eps = 0.1; max_iterations = 10000; num_threads = 2; max_cache_size = 5; loss_values.assign(num_labels(), 1); } feature_extractor fe; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_