// Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include <dlib/filtering.h> #include <sstream> #include <string> #include <cstdlib> #include <ctime> #include <dlib/matrix.h> #include <dlib/rand.h> #include "tester.h" namespace { using namespace test; using namespace dlib; using namespace std; logger dlog("test.filtering"); // ---------------------------------------------------------------------------------------- template <typename filter_type> double test_filter ( filter_type kf, int size ) { // This test has a point moving in a circle around the origin. The point // also gets a random bump in a random direction at each time step. running_stats<double> rs; dlib::rand rnd; int count = 0; const dlib::vector<double,3> z(0,0,1); dlib::vector<double,2> p(10,10), temp; for (int i = 0; i < size; ++i) { // move the point around in a circle p += z.cross(p).normalize()/0.5; // randomly drop measurements if (rnd.get_random_double() < 0.7 || count < 4) { // make a random bump dlib::vector<double,2> pp; pp.x() = rnd.get_random_gaussian()/3; pp.y() = rnd.get_random_gaussian()/3; ++count; kf.update(p+pp); } else { kf.update(); dlog << LTRACE << "MISSED MEASUREMENT"; } // figure out the next position temp = (p+z.cross(p).normalize()/0.5); const double error = length(temp - rowm(kf.get_predicted_next_state(),range(0,1))); rs.add(error); dlog << LTRACE << temp << "("<< error << "): " << trans(kf.get_predicted_next_state()); // test the serialization a few times. if (count < 10) { ostringstream sout; serialize(kf, sout); istringstream sin(sout.str()); filter_type temp; deserialize(temp, sin); kf = temp; } } return rs.mean(); } // ---------------------------------------------------------------------------------------- void test_kalman_filter() { matrix<double,2,2> R; R = 0.3, 0, 0, 0.3; // the variables in the state are // x,y, x velocity, y velocity, x acceleration, and y acceleration matrix<double,6,6> A; A = 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1; // the measurements only tell us the positions matrix<double,2,6> H; H = 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0; kalman_filter<6,2> kf; kf.set_measurement_noise(R); matrix<double> pn = 0.01*identity_matrix<double,6>(); kf.set_process_noise(pn); kf.set_observation_model(H); kf.set_transition_model(A); DLIB_TEST(equal(kf.get_observation_model() , H)); DLIB_TEST(equal(kf.get_transition_model() , A)); DLIB_TEST(equal(kf.get_measurement_noise() , R)); DLIB_TEST(equal(kf.get_process_noise() , pn)); DLIB_TEST(equal(kf.get_current_estimation_error_covariance() , identity_matrix(pn))); double kf_error = test_filter(kf, 300); dlog << LINFO << "kf error: "<< kf_error; DLIB_TEST_MSG(kf_error < 0.75, kf_error); } // ---------------------------------------------------------------------------------------- void test_rls_filter() { rls_filter rls(10, 0.99, 0.1); DLIB_TEST(rls.get_window_size() == 10); DLIB_TEST(rls.get_forget_factor() == 0.99); DLIB_TEST(rls.get_c() == 0.1); double rls_error = test_filter(rls, 1000); dlog << LINFO << "rls error: "<< rls_error; DLIB_TEST_MSG(rls_error < 0.75, rls_error); } // ---------------------------------------------------------------------------------------- class filtering_tester : public tester { public: filtering_tester ( ) : tester ("test_filtering", "Runs tests on the filtering stuff (rls and kalman filters).") {} void perform_test ( ) { test_rls_filter(); test_kalman_filter(); } } a; }