// Copyright (C) 2016  Davis E. King (davis@dlib.net)

#include "../algs.h"
#include "../serialize.h"
#include <cmath>
#include "../matrix.h"
#include <algorithm>

namespace dlib
{
{
public:

)
{
clear();
}

void clear(
)
{
n = 0;
R = identity_matrix<double>(2)*1e6;
w = 0;
residual_squared = 0;
}

double current_n (
) const
{
return n;
}

double y
)
{
matrix<double,2,1> x;
x = n, 1;

// Do recursive least squares computations
const double temp = 1 + trans(x)*R*x;
matrix<double,2,1> tmp = R*x;
R = R - (tmp*trans(tmp))/temp;
// R should always be symmetric.  This line improves numeric stability of this algorithm.
R = 0.5*(R + trans(R));
w = w + R*x*(y - trans(x)*w);

// Also, recursively keep track of the residual error between the given value
// and what our linear predictor outputs.
residual_squared = residual_squared + std::pow((y - trans(x)*w),2.0)*temp;

++n;
}

) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 1,
<< "\n\t You must add more values into this object before calling this function."
<< "\n\t this: " << this
);

return w(0);
}

double intercept (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
<< "\n\t You must add more values into this object before calling this function."
<< "\n\t this: " << this
);

return w(1);
}
double standard_error (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 2,
<< "\n\t You must add more values into this object before calling this function."
<< "\n\t this: " << this
);

const double s = residual_squared/(n-2);
const double adjust = 12.0/(std::pow(current_n(),3.0) - current_n());
}

double thresh
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 2,
<< "\n\t You must add more values into this object before calling this function."
<< "\n\t this: " << this
);

}

double thresh
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 2,
<< "\n\t You must add more values into this object before calling this function."
<< "\n\t this: " << this
);

}

friend void serialize (const running_gradient& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.n, out);
serialize(item.R, out);
serialize(item.w, out);
serialize(item.residual_squared, out);
}

friend void deserialize (running_gradient& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::running_gradient.");
deserialize(item.n, in);
deserialize(item.R, in);
deserialize(item.w, in);
deserialize(item.residual_squared, in);
}

private:

static double normal_cdf(double value, double mean, double stddev)
{
if (stddev == 0)
{
if (value < mean)
return 0;
else if (value > mean)
return 1;
else
return 0.5;
}
value = (value-mean)/stddev;
return 0.5 * std::erfc(-value / std::sqrt(2.0));
}

double n;
matrix<double,2,2> R;
matrix<double,2,1> w;
double residual_squared;
};

// ----------------------------------------------------------------------------------------

template <
typename T
>
const T& container,
double thresh
)
{
for(auto&& v : container)

// make sure requires clause is not broken
DLIB_ASSERT(g.current_n() > 2,
<< "\n\t You need more than 2 elements in the given container to call this function."
);
}

template <
typename T
>
const T& container,
double thresh
)
{
for(auto&& v : container)

// make sure requires clause is not broken
DLIB_ASSERT(g.current_n() > 2,
<< "\n\t You need more than 2 elements in the given container to call this function."
);
}

// ----------------------------------------------------------------------------------------

template <
typename T
>
double find_upper_quantile (
const T& container_,
double quantile
)
{
DLIB_CASSERT(0 <= quantile && quantile <= 1.0);

// copy container into a std::vector
std::vector<double> container(container_.begin(), container_.end());

DLIB_CASSERT(container.size() > 0);

size_t idx_upper = std::round((container.size()-1)*(1-quantile));

std::nth_element(container.begin(), container.begin()+idx_upper, container.end());
auto upper_q = *(container.begin()+idx_upper);
return upper_q;
}

// ----------------------------------------------------------------------------------------

template <
typename T
>
size_t count_steps_without_decrease (
const T& container,
double probability_of_decrease = 0.51
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0.5 < probability_of_decrease && probability_of_decrease < 1,
"\t size_t count_steps_without_decrease()"
<< "\n\t probability_of_decrease: "<< probability_of_decrease
);

size_t count = 0;
size_t j = 0;
for (auto i = container.rbegin(); i != container.rend(); ++i)
{
++j;
if (g.current_n() > 2)
{
// Note that this only looks backwards because we are looping over the
// container backwards.  So here we are really checking if the gradient isn't
// decreasing.
// If we aren't confident things are decreasing.
if (prob_decreasing < probability_of_decrease)
count = j;
}
}
return count;
}

// ----------------------------------------------------------------------------------------

template <
typename T
>
size_t count_steps_without_decrease_robust (
const T& container,
double probability_of_decrease = 0.51,
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0.5 < probability_of_decrease && probability_of_decrease < 1,
"\t size_t count_steps_without_decrease_robust()"
<< "\n\t probability_of_decrease: "<< probability_of_decrease
);

if (container.size() == 0)
return 0;

const auto quantile_thresh = find_upper_quantile(container, quantile_discard);

size_t count = 0;
size_t j = 0;
for (auto i = container.rbegin(); i != container.rend(); ++i)
{
++j;
// ignore values that are too large
if (*i <= quantile_thresh)

if (g.current_n() > 2)
{
// Note that this only looks backwards because we are looping over the
// container backwards.  So here we are really checking if the gradient isn't
// decreasing.
// If we aren't confident things are decreasing.
if (prob_decreasing < probability_of_decrease)
count = j;
}
}
return count;
}

// ----------------------------------------------------------------------------------------

template <
typename T
>
size_t count_steps_without_increase (
const T& container,
double probability_of_increase = 0.51
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0.5 < probability_of_increase && probability_of_increase < 1,
"\t size_t count_steps_without_increase()"
<< "\n\t probability_of_increase: "<< probability_of_increase
);

size_t count = 0;
size_t j = 0;
for (auto i = container.rbegin(); i != container.rend(); ++i)
{
++j;
if (g.current_n() > 2)
{
// Note that this only looks backwards because we are looping over the
// container backwards.  So here we are really checking if the gradient isn't
// increasing.