// Copyright (C) 2016 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_UTILITIES_H_
#define DLIB_DNn_UTILITIES_H_
#include "../cuda/tensor.h"
#include "utilities_abstract.h"
#include "../geometry.h"
#include <fstream>
namespace dlib
{
// ----------------------------------------------------------------------------------------
inline void randomize_parameters (
tensor& params,
unsigned long num_inputs_and_outputs,
dlib::rand& rnd
)
{
for (auto& val : params)
{
// Draw a random number to initialize the layer according to formula (16)
// from Understanding the difficulty of training deep feedforward neural
// networks by Xavier Glorot and Yoshua Bengio.
val = 2*rnd.get_random_float()-1;
val *= std::sqrt(6.0/(num_inputs_and_outputs));
}
}
// ----------------------------------------------------------------------------------------
template <typename label_type>
struct weighted_label
{
weighted_label()
{}
weighted_label(label_type label, float weight = 1.f)
: label(label), weight(weight)
{}
label_type label{};
float weight = 1.f;
};
// ----------------------------------------------------------------------------------------
inline double log1pexp(double x)
{
using std::exp;
if (x <= -37)
return exp(x);
else if (-37 < x && x <= 18)
return log1p(exp(x));
else if (18 < x && x <= 33.3)
return x + exp(-x);
else
return x;
}
// ----------------------------------------------------------------------------------------
template <typename T>
T safe_log(T input, T epsilon = 1e-10)
{
// Prevent trying to calculate the logarithm of a very small number (let alone zero)
return std::log(std::max(input, epsilon));
}
// ----------------------------------------------------------------------------------------
inline size_t tensor_index(
const tensor& t,
const long sample,
const long k,
const long r,
const long c
)
{
return ((sample * t.k() + k) * t.nr() + r) * t.nc() + c;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DNn_UTILITIES_H_