```// Copyright (C) 2008  Davis E. King (davis@dlib.net)
#ifndef DLIB_OPTIMIZATIOn_LINE_SEARCH_H_
#define DLIB_OPTIMIZATIOn_LINE_SEARCH_H_

#include <cmath>
#include <limits>
#include "../matrix.h"
#include "../algs.h"
#include "optimization_line_search_abstract.h"
#include <utility>

namespace dlib
{

// ----------------------------------------------------------------------------------------

template <typename funct, typename T>
class line_search_funct
{
public:
line_search_funct(const funct& f_, const T& start_, const T& direction_)
: f(f_),start(start_), direction(direction_), matrix_r(0), scalar_r(0)
{}

line_search_funct(const funct& f_, const T& start_, const T& direction_, T& r)
: f(f_),start(start_), direction(direction_), matrix_r(&r), scalar_r(0)
{}

line_search_funct(const funct& f_, const T& start_, const T& direction_, double& r)
: f(f_),start(start_), direction(direction_), matrix_r(0), scalar_r(&r)
{}

double operator()(const double& x) const
{
return get_value(f(start + x*direction));
}

private:

double get_value (const double& r) const
{
// save a copy of this value for later
if (scalar_r)
*scalar_r = r;

return r;
}

template <typename U>
double get_value (const U& r) const
{
// U should be a matrix type
COMPILE_TIME_ASSERT(is_matrix<U>::value);

// save a copy of this value for later
if (matrix_r)
*matrix_r = r;

return dot(r,direction);
}

const funct& f;
const T& start;
const T& direction;
T* matrix_r;
double* scalar_r;
};

template <typename funct, typename T>
const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction)
{
COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT (
is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
"\tline_search_funct make_line_search_function(f,start,direction)"
<< "\n\tYou have to supply column vectors to this function"
<< "\n\tstart.nc():     " << start.nc()
<< "\n\tdirection.nc(): " << direction.nc()
<< "\n\tstart.nr():     " << start.nr()
<< "\n\tdirection.nr(): " << direction.nr()
);
return line_search_funct<funct,T>(f,start,direction);
}

// ----------------------------------------------------------------------------------------

template <typename funct, typename T>
const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction, double& f_out)
{
COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT (
is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
"\tline_search_funct make_line_search_function(f,start,direction)"
<< "\n\tYou have to supply column vectors to this function"
<< "\n\tstart.nc():     " << start.nc()
<< "\n\tdirection.nc(): " << direction.nc()
<< "\n\tstart.nr():     " << start.nr()
<< "\n\tdirection.nr(): " << direction.nr()
);
return line_search_funct<funct,T>(f,start,direction, f_out);
}

// ----------------------------------------------------------------------------------------

template <typename funct, typename T>
const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction, T& grad_out)
{
COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT (
is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
"\tline_search_funct make_line_search_function(f,start,direction)"
<< "\n\tYou have to supply column vectors to this function"
<< "\n\tstart.nc():     " << start.nc()
<< "\n\tdirection.nc(): " << direction.nc()
<< "\n\tstart.nr():     " << start.nr()
<< "\n\tdirection.nr(): " << direction.nr()
);
}

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

inline double poly_min_extrap (
double f0,
double d0,
double f1,
double d1,
double limit = 1
)
{
const double n = 3*(f1 - f0) - 2*d0 - d1;
const double e = d0 + d1 - 2*(f1 - f0);

// find the minimum of the derivative of the polynomial

double temp = std::max(n*n - 3*e*d0,0.0);

if (temp < 0)
return 0.5;

temp = std::sqrt(temp);

if (std::abs(e) <= std::numeric_limits<double>::epsilon())
return 0.5;

// figure out the two possible min values
double x1 = (temp - n)/(3*e);
double x2 = -(temp + n)/(3*e);

// compute the value of the interpolating polynomial at these two points
double y1 = f0 + d0*x1 + n*x1*x1 + e*x1*x1*x1;
double y2 = f0 + d0*x2 + n*x2*x2 + e*x2*x2*x2;

// pick the best point
double x;
if (y1 < y2)
x = x1;
else
x = x2;

// now make sure the minimum is within the allowed range of [0,limit]
return put_in_range(0,limit,x);
}

// ----------------------------------------------------------------------------------------

inline double poly_min_extrap (
double f0,
double d0,
double f1
)
{
const double temp = 2*(f1 - f0 - d0);
if (std::abs(temp) <= d0*std::numeric_limits<double>::epsilon())
return 0.5;

const double alpha = -d0/temp;

// now make sure the minimum is within the allowed range of (0,1)
return put_in_range(0,1,alpha);
}

// ----------------------------------------------------------------------------------------

inline double poly_min_extrap (
double f0,
double d0,
double x1,
double f_x1,
double x2,
double f_x2
)
{
DLIB_ASSERT(0 < x1 && x1 < x2,"Invalid inputs were given to this function.\n"
<< "x1: " << x1
<< "    x2: " << x2
);
// The contents of this function follow the equations described on page 58 of the
// book Numerical Optimization by Nocedal and Wright, second edition.
matrix<double,2,2> m;
matrix<double,2,1> v;

const double aa2 = x2*x2;
const double aa1 = x1*x1;
m =  aa2,       -aa1,
-aa2*x2, aa1*x1;
v = f_x1 - f0 - d0*x1,
f_x2 - f0 - d0*x2;

double temp = aa2*aa1*(x1-x2);

// just take a guess if this happens
if (temp == 0 || std::fpclassify(temp) == FP_SUBNORMAL)
{
return x1/2.0;
}

matrix<double,2,1> temp2;
temp2 = m*v/temp;
const double a = temp2(0);
const double b = temp2(1);

temp = b*b - 3*a*d0;
if (temp < 0 || a == 0)
{
// This is probably a line so just pick the lowest point
if (f0 < f_x2)
return 0;
else
return x2;
}
temp = (-b + std::sqrt(temp))/(3*a);
return put_in_range(0, x2, temp);
}

// ----------------------------------------------------------------------------------------

inline double lagrange_poly_min_extrap (
double p1,
double p2,
double p3,
double f1,
double f2,
double f3
)
{
DLIB_ASSERT(p1 < p2 && p2 < p3 && f1 >= f2 && f2 <= f3,
"   p1: " << p1
<< "   p2: " << p2
<< "   p3: " << p3
<< "   f1: " << f1
<< "   f2: " << f2
<< "   f3: " << f3);

// This formula is out of the book Nonlinear Optimization by Andrzej Ruszczynski.  See section 5.2.
double temp1 =    f1*(p3*p3 - p2*p2) + f2*(p1*p1 - p3*p3) + f3*(p2*p2 - p1*p1);
double temp2 = 2*(f1*(p3 - p2)       + f2*(p1 - p3)       + f3*(p2 - p1) );

if (temp2 == 0)
{
return p2;
}

const double result = temp1/temp2;

// do a final sanity check to make sure the result is in the right range
if (p1 <= result && result <= p3)
{
return result;
}
else
{
return std::min(std::max(p1,result),p3);
}
}

// ----------------------------------------------------------------------------------------

template <
typename funct,
typename funct_der
>
double line_search (
const funct& f,
const double f0,
const funct_der& der,
const double d0,
double rho,
double sigma,
double min_f,
unsigned long max_iter
)
{
DLIB_ASSERT (
0 < rho && rho < sigma && sigma < 1 && max_iter > 0,
"\tdouble line_search()"
<< "\n\tYou have given invalid arguments to this function"
<< "\n\t sigma:    " << sigma
<< "\n\t rho:      " << rho
<< "\n\t max_iter: " << max_iter
);

// The bracketing phase of this function is implemented according to block 2.6.2 from
// the book Practical Methods of Optimization by R. Fletcher.   The sectioning
// phase is an implementation of 2.6.4 from the same book.

// 1 <= tau1a < tau1b. Controls the alpha jump size during the bracketing phase of
// the search.
const double tau1a = 1.4;
const double tau1b = 9;

// it must be the case that 0 < tau2 < tau3 <= 1/2 for the algorithm to function
// correctly but the specific values of tau2 and tau3 aren't super important.
const double tau2 = 1.0/10.0;
const double tau3 = 1.0/2.0;

// Stop right away and return a step size of 0 if the gradient is 0 at the starting point
if (std::abs(d0) <= std::abs(f0)*std::numeric_limits<double>::epsilon())
return 0;

// Stop right away if the current value is good enough according to min_f
if (f0 <= min_f)
return 0;

// Figure out a reasonable upper bound on how large alpha can get.
const double mu = (min_f-f0)/(rho*d0);

double alpha = 1;
if (mu < 0)
alpha = -alpha;
alpha = put_in_range(0, 0.65*mu, alpha);

double last_alpha = 0;
double last_val = f0;
double last_val_der = d0;

// The bracketing stage will find a range of points [a,b]
// that contains a reasonable solution to the line search
double a, b;

// These variables will hold the values and derivatives of f(a) and f(b)
double a_val, b_val, a_val_der, b_val_der;

// This thresh value represents the Wolfe curvature condition
const double thresh = std::abs(sigma*d0);

unsigned long itr = 0;
// do the bracketing stage to find the bracket range [a,b]
while (true)
{
++itr;
const double val = f(alpha);
const double val_der = der(alpha);

// we are done with the line search since we found a value smaller
// than the minimum f value
if (val <= min_f)
return alpha;

if (val > f0 + rho*alpha*d0 || val >= last_val)
{
a_val = last_val;
a_val_der = last_val_der;
b_val = val;
b_val_der = val_der;

a = last_alpha;
b = alpha;
break;
}

if (std::abs(val_der) <= thresh)
return alpha;

// if we are stuck not making progress then quit with the current alpha
if (last_alpha == alpha || itr >= max_iter)
return alpha;

if (val_der >= 0)
{
a_val = val;
a_val_der = val_der;
b_val = last_val;
b_val_der = last_val_der;

a = alpha;
b = last_alpha;
break;
}

const double temp = alpha;
// Pick a larger range [first, last].  We will pick the next alpha in that
// range.
double first, last;
if (mu > 0)
{
first = std::min(mu, alpha + tau1a*(alpha - last_alpha));
last  = std::min(mu, alpha + tau1b*(alpha - last_alpha));
}
else
{
first = std::max(mu, alpha + tau1a*(alpha - last_alpha));
last  = std::max(mu, alpha + tau1b*(alpha - last_alpha));
}

// pick a point between first and last by doing some kind of interpolation
if (last_alpha < alpha)
alpha = last_alpha + (alpha-last_alpha)*poly_min_extrap(last_val, last_val_der, val, val_der, 1e10);
else
alpha = alpha + (last_alpha-alpha)*poly_min_extrap(val, val_der, last_val, last_val_der, 1e10);

alpha = put_in_range(first,last,alpha);

last_alpha = temp;

last_val = val;
last_val_der = val_der;

}

// Now do the sectioning phase from 2.6.4
while (true)
{
++itr;
double first = a + tau2*(b-a);
double last = b - tau3*(b-a);

// use interpolation to pick alpha between first and last
alpha = a + (b-a)*poly_min_extrap(a_val, a_val_der, b_val, b_val_der);
alpha = put_in_range(first,last,alpha);

const double val = f(alpha);
const double val_der = der(alpha);

// we are done with the line search since we found a value smaller
// than the minimum f value or we ran out of iterations.
if (val <= min_f || itr >= max_iter)
return alpha;

// stop if the interval gets so small that it isn't shrinking any more due to rounding error
if (a == first || b == last)
{
return b;
}

// If alpha has basically become zero then just stop.  Think of it like this,
// if we take the largest possible alpha step will the objective function
// change at all?  If not then there isn't any point looking for a better
// alpha.
const double max_possible_alpha = std::max(std::abs(a),std::abs(b));
if (std::abs(max_possible_alpha*d0) <= std::abs(f0)*std::numeric_limits<double>::epsilon())
return alpha;

if (val > f0 + rho*alpha*d0 || val >= a_val)
{
b = alpha;
b_val = val;
b_val_der = val_der;
}
else
{
if (std::abs(val_der) <= thresh)
return alpha;

if ( (b-a)*val_der >= 0)
{
b = a;
b_val = a_val;
b_val_der = a_val_der;
}

a = alpha;
a_val = val;
a_val_der = val_der;
}
}
}

// ----------------------------------------------------------------------------------------

template <
typename funct
>
double backtracking_line_search (
const funct& f,
double f0,
double d0,
double alpha,
double rho,
unsigned long max_iter
)
{
DLIB_ASSERT (
0 < rho && rho < 1 && max_iter > 0,
"\tdouble backtracking_line_search()"
<< "\n\tYou have given invalid arguments to this function"
<< "\n\t rho:      " << rho
<< "\n\t max_iter: " << max_iter
);

// make sure alpha is going in the right direction.  That is, it should be opposite
// the direction of the gradient.
if ((d0 > 0 && alpha > 0) ||
(d0 < 0 && alpha < 0))
{
alpha *= -1;
}

bool have_prev_alpha = false;
double prev_alpha = 0;
double prev_val = 0;
unsigned long iter = 0;
while (true)
{
++iter;
const double val = f(alpha);
if (val <= f0 + alpha*rho*d0 || iter >= max_iter)
{
return alpha;
}
else
{
// Interpolate a new alpha.  We also make sure the step by which we
// reduce alpha is not super small.
double step;
if (!have_prev_alpha)
{
if (d0 < 0)
step = alpha*put_in_range(0.1,0.9, poly_min_extrap(f0, d0, val));
else
step = alpha*put_in_range(0.1,0.9, poly_min_extrap(f0, -d0, val));
have_prev_alpha = true;
}
else
{
if (d0 < 0)
step = put_in_range(0.1*alpha,0.9*alpha, poly_min_extrap(f0, d0, alpha, val, prev_alpha, prev_val));
else
step = put_in_range(0.1*alpha,0.9*alpha, -poly_min_extrap(f0, -d0, -alpha, val, -prev_alpha, prev_val));
}

prev_alpha = alpha;
prev_val = val;

alpha = step;
}
}
}

// ----------------------------------------------------------------------------------------

class optimize_single_variable_failure : public error {
public: optimize_single_variable_failure(const std::string& s):error(s){}
};

// ----------------------------------------------------------------------------------------

template <typename funct>
double find_min_single_variable (
const funct& f,
double& starting_point,
const double begin = -1e200,
const double end = 1e200,
const double eps = 1e-3,
const long max_iter = 100,
)
{
DLIB_CASSERT( eps > 0 &&
max_iter > 1 &&
begin <= starting_point && starting_point <= end &&
"eps: " << eps
<< "\n max_iter: "<< max_iter
<< "\n begin: "<< begin
<< "\n end:   "<< end
<< "\n starting_point: "<< starting_point
);

double p1=0, p2=0, p3=0, f1=0, f2=0, f3=0;
long f_evals = 1;

if (begin == end)
{
return f(starting_point);
}

using std::abs;
using std::min;
using std::max;

// find three bracketing points such that f1 > f2 < f3.   Do this by generating a sequence
// of points expanding away from 0.   Also note that, in the following code, it is always the
// case that p1 < p2 < p3.

// The first thing we do is get a starting set of 3 points that are inside the [begin,end] bounds
f1 = f(p1);
f3 = f(p3);

if (starting_point == p1 || starting_point == p3)
{
p2 = (p1+p3)/2;
f2 = f(p2);
}
else
{
p2 = starting_point;
f2 = f(starting_point);
}

f_evals += 2;

// Now we have 3 points on the function.  Start looking for a bracketing set such that
// f1 > f2 < f3 is the case.
while ( !(f1 > f2 && f2 < f3))
{
// check for hitting max_iter or if the interval is now too small
if (f_evals >= max_iter)
{
throw optimize_single_variable_failure(
"The max number of iterations of single variable optimization have been reached\n"
"without converging.");
}
if (p3-p1 < eps)
{
if (f1 < min(f2,f3))
{
starting_point = p1;
return f1;
}

if (f2 < min(f1,f3))
{
starting_point = p2;
return f2;
}

starting_point = p3;
return f3;
}

// If the left most points are identical in function value then expand out the
// left a bit, unless it's already at bound or we would drop that left most
// point anyway because it's bad.
if (f1==f2 && f1<f3 && p1!=begin)
{
p1 = max(p1 - search_radius, begin);
f1 = f(p1);
++f_evals;
continue;
}
if (f2==f3 && f3<f1 && p3!=end)
{
p3 = min(p3 + search_radius, end);
f3 = f(p3);
++f_evals;
continue;
}

// if f1 is small then take a step to the left
if (f1 <= f3)
{
// check if the minimum is butting up against the bounds and if so then pick
// a point between p1 and p2 in the hopes that shrinking the interval will
// be a good thing to do.  Or if p1 and p2 aren't differentiated then try and
// get them to obtain different values.
if (p1 == begin || (f1 == f2 && (end-begin) < search_radius ))
{
p3 = p2;
f3 = f2;

p2 = (p1+p2)/2.0;
f2 = f(p2);
}
else
{
// pick a new point to the left of our current bracket
p3 = p2;
f3 = f2;

p2 = p1;
f2 = f1;

p1 = max(p1 - search_radius, begin);
f1 = f(p1);

}

}
// otherwise f3 is small and we should take a step to the right
else
{
// check if the minimum is butting up against the bounds and if so then pick
// a point between p2 and p3 in the hopes that shrinking the interval will
// be a good thing to do.  Or if p2 and p3 aren't differentiated then try and
// get them to obtain different values.
if (p3 == end || (f2 == f3 && (end-begin) < search_radius))
{
p1 = p2;
f1 = f2;

p2 = (p3+p2)/2.0;
f2 = f(p2);
}
else
{
// pick a new point to the right of our current bracket
p1 = p2;
f1 = f2;

p2 = p3;
f2 = f3;

p3 = min(p3 + search_radius, end);
f3 = f(p3);

}
}

++f_evals;
}

// Loop until we have done the max allowable number of iterations or
// the bracketing window is smaller than eps.
// Within this loop we maintain the invariant that: f1 > f2 < f3 and p1 < p2 < p3
const double tau = 0.1;
while( f_evals < max_iter && p3-p1 > eps)
{
double p_min = lagrange_poly_min_extrap(p1,p2,p3, f1,f2,f3);

// make sure p_min isn't too close to the three points we already have
if (p_min < p2)
{
const double min_dist = (p2-p1)*tau;
if (abs(p1-p_min) < min_dist)
{
p_min = p1 + min_dist;
}
else if (abs(p2-p_min) < min_dist)
{
p_min = p2 - min_dist;
}
}
else
{
const double min_dist = (p3-p2)*tau;
if (abs(p2-p_min) < min_dist)
{
p_min = p2 + min_dist;
}
else if (abs(p3-p_min) < min_dist)
{
p_min = p3 - min_dist;
}
}

// make sure one side of the bracket isn't super huge compared to the other
// side.  If it is then contract it.
const double bracket_ratio = abs(p1-p2)/abs(p2-p3);
// Force p_min to be on a reasonable side.  But only if lagrange_poly_min_extrap()
// didn't put it on a good side already.
if (bracket_ratio >= 10)
{
if (p_min > p2)
p_min = (p1+p2)/2;
}
else if (bracket_ratio <= 0.1)
{
if (p_min < p2)
p_min = (p2+p3)/2;
}

const double f_min = f(p_min);

// Remove one of the endpoints of our bracket depending on where the new point falls.
if (p_min < p2)
{
if (f1 > f_min && f_min < f2)
{
p3 = p2;
f3 = f2;
p2 = p_min;
f2 = f_min;
}
else
{
p1 = p_min;
f1 = f_min;
}
}
else
{
if (f2 > f_min && f_min < f3)
{
p1 = p2;
f1 = f2;
p2 = p_min;
f2 = f_min;
}
else
{
p3 = p_min;
f3 = f_min;
}
}

++f_evals;
}

if (f_evals >= max_iter)
{
throw optimize_single_variable_failure(
"The max number of iterations of single variable optimization have been reached\n"
"without converging.");
}

starting_point = p2;
return f2;
}

// ----------------------------------------------------------------------------------------

template <typename funct>
class negate_function_object
{
public:
negate_function_object(const funct& f_) : f(f_){}

template <typename T>
double operator()(const T& x) const
{
return -f(x);
}

private:
const funct& f;
};

template <typename funct>
const negate_function_object<funct> negate_function(const funct& f) { return negate_function_object<funct>(f); }

// ----------------------------------------------------------------------------------------

template <typename funct>
double find_max_single_variable (
const funct& f,
double& starting_point,
const double begin = -1e200,
const double end = 1e200,
const double eps = 1e-3,
const long max_iter = 100,