```// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <dlib/optimization.h>
#include <dlib/unordered_pair.h>
#include <dlib/rand.h>

#include "tester.h"

namespace
{
using namespace test;
using namespace dlib;
using namespace std;

logger dlog("test.find_max_factor_graph_nmplp");

// ----------------------------------------------------------------------------------------

dlib::rand rnd;

template <bool fully_connected>
class map_problem
{
/*
This is a simple 8 node problem with two cycles in it unless fully_connected is true
and then it's a fully connected 8 note graph.
*/

public:

mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights;
map_problem()
{
for (int i = 0; i < 8; ++i)
{
for (int j = i; j < 8; ++j)
{
weights[make_unordered_pair(i,j)][make_pair(0,0)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,j)][make_pair(0,1)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,j)][make_pair(1,0)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,j)][make_pair(1,1)] = rnd.get_random_gaussian();
}
}
}

struct node_iterator
{
node_iterator() {}
node_iterator(unsigned long nid_): nid(nid_) {}
bool operator== (const node_iterator& item) const { return item.nid == nid; }
bool operator!= (const node_iterator& item) const { return item.nid != nid; }

node_iterator& operator++()
{
++nid;
return *this;
}

unsigned long nid;
};

struct neighbor_iterator
{
neighbor_iterator() : count(0) {}

bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
neighbor_iterator& operator++()
{
++count;
return *this;
}

unsigned long node_id () const
{
if (fully_connected)
{
if (count < home_node)
return count;
else
return count+1;
}

if (home_node < 4)
{
if (count == 0)
return (home_node + 4 + 1)%4;
else if (count == 1)
return (home_node + 4 - 1)%4;
else
return 8; // one past the end
}
else
{
if (count == 0)
return (home_node + 4 + 1)%4 + 4;
else if (count == 1)
return (home_node + 4 - 1)%4 + 4;
else
return 8; // one past the end
}
}

unsigned long home_node;
unsigned long count;
};

unsigned long number_of_nodes (
) const
{
return 8;
}

node_iterator begin(
) const
{
node_iterator temp;
temp.nid = 0;
return temp;
}

node_iterator end(
) const
{
node_iterator temp;
temp.nid = 8;
return temp;
}

neighbor_iterator begin(
const node_iterator& it
) const
{
neighbor_iterator temp;
temp.home_node = it.nid;
return temp;
}

neighbor_iterator begin(
const neighbor_iterator& it
) const
{
neighbor_iterator temp;
temp.home_node = it.node_id();
return temp;
}

neighbor_iterator end(
const node_iterator&
) const
{
neighbor_iterator temp;
temp.home_node = 9;
temp.count = 8;
return temp;
}

neighbor_iterator end(
const neighbor_iterator&
) const
{
neighbor_iterator temp;
temp.home_node = 9;
temp.count = 8;
return temp;
}

unsigned long node_id (
const node_iterator& it
) const
{
return it.nid;
}

unsigned long node_id (
const neighbor_iterator& it
) const
{
return it.node_id();
}

unsigned long num_states (
const node_iterator&
) const
{
return 2;
}

unsigned long num_states (
const neighbor_iterator&
) const
{
return 2;
}

double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.nid, it2.nid, s1, s2); }
double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }

private:

double basic_factor_value (
unsigned long n1,
unsigned long n2,
unsigned long s1,
unsigned long s2
) const
{
if (n1 > n2)
{
swap(n1,n2);
swap(s1,s2);
}
return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)];
}

};

// ----------------------------------------------------------------------------------------

class map_problem_chain
{
/*
This is a chain structured 8 node graph (so no cycles).
*/

public:

mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights;
map_problem_chain()
{
for (int i = 0; i < 7; ++i)
{
weights[make_unordered_pair(i,i+1)][make_pair(0,0)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,i+1)][make_pair(0,1)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,i+1)][make_pair(1,0)] = rnd.get_random_gaussian();
weights[make_unordered_pair(i,i+1)][make_pair(1,1)] = rnd.get_random_gaussian();
}
}

struct node_iterator
{
node_iterator() {}
node_iterator(unsigned long nid_): nid(nid_) {}
bool operator== (const node_iterator& item) const { return item.nid == nid; }
bool operator!= (const node_iterator& item) const { return item.nid != nid; }

node_iterator& operator++()
{
++nid;
return *this;
}

unsigned long nid;
};

struct neighbor_iterator
{
neighbor_iterator() : count(0) {}

bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
neighbor_iterator& operator++()
{
++count;
return *this;
}

unsigned long node_id () const
{
if (count >= 2)
return 8;
return nid[count];
}

unsigned long nid[2];
unsigned int count;
};

unsigned long number_of_nodes (
) const
{
return 8;
}

node_iterator begin(
) const
{
node_iterator temp;
temp.nid = 0;
return temp;
}

node_iterator end(
) const
{
node_iterator temp;
temp.nid = 8;
return temp;
}

neighbor_iterator begin(
const node_iterator& it
) const
{
neighbor_iterator temp;
if (it.nid == 0)
{
temp.nid[0] = it.nid+1;
temp.nid[1] = 8;
}
else if (it.nid == 7)
{
temp.nid[0] = it.nid-1;
temp.nid[1] = 8;
}
else
{
temp.nid[0] = it.nid-1;
temp.nid[1] = it.nid+1;
}
return temp;
}

neighbor_iterator begin(
const neighbor_iterator& it
) const
{
const unsigned long nid = it.node_id();
neighbor_iterator temp;
if (nid == 0)
{
temp.nid[0] = nid+1;
temp.nid[1] = 8;
}
else if (nid == 7)
{
temp.nid[0] = nid-1;
temp.nid[1] = 8;
}
else
{
temp.nid[0] = nid-1;
temp.nid[1] = nid+1;
}
return temp;
}

neighbor_iterator end(
const node_iterator&
) const
{
neighbor_iterator temp;
temp.nid[0] = 8;
temp.nid[1] = 8;
return temp;
}

neighbor_iterator end(
const neighbor_iterator&
) const
{
neighbor_iterator temp;
temp.nid[0] = 8;
temp.nid[1] = 8;
return temp;
}

unsigned long node_id (
const node_iterator& it
) const
{
return it.nid;
}

unsigned long node_id (
const neighbor_iterator& it
) const
{
return it.node_id();
}

unsigned long num_states (
const node_iterator&
) const
{
return 2;
}

unsigned long num_states (
const neighbor_iterator&
) const
{
return 2;
}

double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.nid, it2.nid, s1, s2); }
double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }

private:

double basic_factor_value (
unsigned long n1,
unsigned long n2,
unsigned long s1,
unsigned long s2
) const
{
if (n1 > n2)
{
swap(n1,n2);
swap(s1,s2);
}
return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)];
}

};

// ----------------------------------------------------------------------------------------

class map_problem2
{
/*
This is a simple tree structured graph.  In particular, it is a star made
up of 6 nodes.
*/
public:
matrix<double> numbers;

map_problem2()
{
numbers = randm(5,3,rnd);
}

struct node_iterator
{
node_iterator() {}
node_iterator(unsigned long nid_): nid(nid_) {}
bool operator== (const node_iterator& item) const { return item.nid == nid; }
bool operator!= (const node_iterator& item) const { return item.nid != nid; }

node_iterator& operator++()
{
++nid;
return *this;
}

unsigned long nid;
};

struct neighbor_iterator
{
neighbor_iterator() : count(0) {}

bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
neighbor_iterator& operator++()
{
++count;
return *this;
}

unsigned long node_id () const
{
if (home_node == 6)
return 6;

if (home_node < 5)
{
// all the nodes are connected to node 5 and nothing else
if (count == 0)
return 5;
else
return 6; // the number returned by the end() functions.
}
else if (count < 5)
{
return count;
}
else
{
return 6;
}

}

unsigned long home_node;
unsigned long count;
};

unsigned long number_of_nodes (
) const
{
return 6;
}

node_iterator begin(
) const
{
node_iterator temp;
temp.nid = 0;
return temp;
}

node_iterator end(
) const
{
node_iterator temp;
temp.nid = 6;
return temp;
}

neighbor_iterator begin(
const node_iterator& it
) const
{
neighbor_iterator temp;
temp.home_node = it.nid;
return temp;
}

neighbor_iterator begin(
const neighbor_iterator& it
) const
{
neighbor_iterator temp;
temp.home_node = it.node_id();
return temp;
}

neighbor_iterator end(
const node_iterator&
) const
{
neighbor_iterator temp;
temp.home_node = 6;
return temp;
}

neighbor_iterator end(
const neighbor_iterator&
) const
{
neighbor_iterator temp;
temp.home_node = 6;
return temp;
}

unsigned long node_id (
const node_iterator& it
) const
{
return it.nid;
}

unsigned long node_id (
const neighbor_iterator& it
) const
{
return it.node_id();
}

unsigned long num_states (
const node_iterator&
) const
{
return 3;
}

unsigned long num_states (
const neighbor_iterator&
) const
{
return 3;
}

double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.nid, it2.nid, s1, s2); }
double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
{ return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }

private:

double basic_factor_value (
unsigned long n1,
unsigned long n2,
unsigned long s1,
unsigned long s2
) const
{
if (n1 > n2)
{
swap(n1,n2);
swap(s1,s2);
}

// basically ignore the other node in this factor.  The node we
// are ignoring is the center node of this star graph.  So we basically
// let it always have a value of 1.
if (s2 == 1)
return numbers(n1,s1) + 1;
else
return numbers(n1,s1);
}

};

// ----------------------------------------------------------------------------------------

template <typename map_problem>
double find_total_score (
const map_problem& prob,
const std::vector<unsigned long>& map_assignment
)
{
typedef typename map_problem::node_iterator node_iterator;
typedef typename map_problem::neighbor_iterator neighbor_iterator;

double score = 0;
for (node_iterator i = prob.begin(); i != prob.end(); ++i)
{
const unsigned long id_i = prob.node_id(i);
for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
{
const unsigned long id_j = prob.node_id(j);
score += prob.factor_value(i,j, map_assignment[id_i], map_assignment[id_j]);
}
}

return score;
}

// ----------------------------------------------------------------------------------------

template <
typename map_problem
>
void brute_force_find_max_factor_graph_nmplp (
const map_problem& prob,
std::vector<unsigned long>& map_assignment
)
{
std::vector<unsigned long> temp_assignment;
temp_assignment.resize(prob.number_of_nodes(),0);

double best_score = -std::numeric_limits<double>::infinity();

for (unsigned long i = 0; i < 255; ++i)
{
temp_assignment[0] = (i&0x01)!=0;
temp_assignment[1] = (i&0x02)!=0;
temp_assignment[2] = (i&0x04)!=0;
temp_assignment[3] = (i&0x08)!=0;
temp_assignment[4] = (i&0x10)!=0;
temp_assignment[5] = (i&0x20)!=0;
temp_assignment[6] = (i&0x40)!=0;
temp_assignment[7] = (i&0x80)!=0;

double score = find_total_score(prob,temp_assignment);
if (score > best_score)
{
best_score = score;
map_assignment = temp_assignment;
}
}
}

// ----------------------------------------------------------------------------------------

template <typename map_problem>
void do_test(
)
{
print_spinner();
std::vector<unsigned long> map_assignment1, map_assignment2;
map_problem prob;
find_max_factor_graph_nmplp(prob, map_assignment1, 1000, 1e-8);

const double score1 = find_total_score(prob, map_assignment1);

brute_force_find_max_factor_graph_nmplp(prob, map_assignment2);
const double score2 = find_total_score(prob, map_assignment2);

dlog << LINFO << "score NMPLP: " << score1;
dlog << LINFO << "score MAP:   " << score2;

DLIB_TEST(std::abs(score1 - score2) < 1e-10);
DLIB_TEST(mat(map_assignment1) == mat(map_assignment2));
}

// ----------------------------------------------------------------------------------------

template <typename map_problem>
void do_test2(
)
{
print_spinner();
std::vector<unsigned long> map_assignment1, map_assignment2;
map_problem prob;
find_max_factor_graph_nmplp(prob, map_assignment1, 10, 1e-8);

const double score1 = find_total_score(prob, map_assignment1);

map_assignment2.resize(6);
map_assignment2[0] = index_of_max(rowm(prob.numbers,0));
map_assignment2[1] = index_of_max(rowm(prob.numbers,1));
map_assignment2[2] = index_of_max(rowm(prob.numbers,2));
map_assignment2[3] = index_of_max(rowm(prob.numbers,3));
map_assignment2[4] = index_of_max(rowm(prob.numbers,4));
map_assignment2[5] = 1;
const double score2 = find_total_score(prob, map_assignment2);

dlog << LINFO << "score NMPLP: " << score1;
dlog << LINFO << "score MAP:   " << score2;
dlog << LINFO << "MAP assignment: "<< trans(mat(map_assignment1));

DLIB_TEST(std::abs(score1 - score2) < 1e-10);
DLIB_TEST(mat(map_assignment1) == mat(map_assignment2));
}

// ----------------------------------------------------------------------------------------

class test_find_max_factor_graph_nmplp : public tester
{
public:
test_find_max_factor_graph_nmplp (
) :
tester ("test_find_max_factor_graph_nmplp",
"Runs tests on the find_max_factor_graph_nmplp routine.")
{}

void perform_test (
)
{
rnd.clear();

dlog << LINFO << "test on a chain structured graph";
for (int i = 0; i < 30; ++i)
do_test<map_problem_chain>();

dlog << LINFO << "test on a 2 cycle graph";
for (int i = 0; i < 30; ++i)
do_test<map_problem<false> >();

dlog << LINFO << "test on a fully connected graph";
for (int i = 0; i < 5; ++i)
do_test<map_problem<true> >();

dlog << LINFO << "test on a tree structured graph";
for (int i = 0; i < 10; ++i)
do_test2<map_problem2>();
}
} a;

}

```