// Copyright (C) 2012  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_BSP_CPph_
#define DLIB_BSP_CPph_

#include "bsp.h"
#include <memory>
#include <stack>

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

namespace dlib
{

    namespace impl1
    {

        void connect_all (
            map_id_to_con& cons,
            const std::vector<network_address>& hosts,
            unsigned long node_id
        )
        {
            cons.clear();
            for (unsigned long i = 0; i < hosts.size(); ++i)
            {
                std::unique_ptr<bsp_con> con(new bsp_con(hosts[i]));
                dlib::serialize(node_id, con->stream); // tell the other end our node_id
                unsigned long id = i+1;
                cons.add(id, con);
            }
        }

        void connect_all_hostinfo (
            map_id_to_con& cons,
            const std::vector<hostinfo>& hosts,
            unsigned long node_id,
            std::string& error_string 
        )
        {
            cons.clear();
            for (unsigned long i = 0; i < hosts.size(); ++i)
            {
                try
                {
                    std::unique_ptr<bsp_con> con(new bsp_con(hosts[i].addr));
                    dlib::serialize(node_id, con->stream); // tell the other end our node_id
                    con->stream.flush();
                    unsigned long id = hosts[i].node_id;
                    cons.add(id, con);
                }
                catch (std::exception&)
                {
                    std::ostringstream sout;
                    sout << "Could not connect to " << hosts[i].addr;
                    error_string = sout.str();
                    break;
                }
            }
        }


        void send_out_connection_orders (
            map_id_to_con& cons,
            const std::vector<network_address>& hosts
        )
        {
            // tell everyone their node ids
            cons.reset();
            while (cons.move_next())
            {
                dlib::serialize(cons.element().key(), cons.element().value()->stream);
            }

            // now tell them who to connect to
            std::vector<hostinfo> targets; 
            for (unsigned long i = 0; i < hosts.size(); ++i)
            {
                hostinfo info(hosts[i], i+1);

                dlib::serialize(targets, cons[info.node_id]->stream);
                targets.push_back(info);

                // let the other host know how many incoming connections to expect
                const unsigned long num = hosts.size()-targets.size();
                dlib::serialize(num, cons[info.node_id]->stream);
                cons[info.node_id]->stream.flush();
            }
        }

    // ------------------------------------------------------------------------------------


    }

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

    namespace impl2
    {
        // These control bytes are sent before each message between nodes.  Note that many
        // of these are only sent between the control node (node 0) and the other nodes.
        // This is because the controller node is responsible for handling the
        // synchronization that needs to happen when all nodes block on calls to
        // receive_data()
        // at the same time.

        // denotes a normal content message.
        const static char MESSAGE_HEADER            = 0; 

        // sent to the controller node when someone receives a message via receive_data().
        const static char GOT_MESSAGE               = 1; 

        // sent to the controller node when someone sends a message via send().
        const static char SENT_MESSAGE              = 2; 

        // sent to the controller node when someone enters a call to receive_data()
        const static char IN_WAITING_STATE          = 3; 

        // broadcast when a node terminates itself. 
        const static char NODE_TERMINATE            = 5; 

        // broadcast by the controller node when it determines that all nodes are blocked
        // on calls to receive_data() and there aren't any messages in flight.  This is also
        // what makes us go to the next epoch.
        const static char SEE_ALL_IN_WAITING_STATE  = 6; 

        // This isn't ever transmitted between nodes.  It is used internally to indicate
        // that an error occurred.
        const static char READ_ERROR                = 7;

    // ------------------------------------------------------------------------------------

        void read_thread (
            impl1::bsp_con* con,
            unsigned long node_id,
            unsigned long sender_id,
            impl1::thread_safe_message_queue& msg_buffer
        )
        {
            try
            {
                while(true)
                {
                    impl1::msg_data msg;
                    deserialize(msg.msg_type, con->stream);
                    msg.sender_id = sender_id;

                    if (msg.msg_type == MESSAGE_HEADER)
                    {
                        msg.data.reset(new std::vector<char>);
                        deserialize(msg.epoch, con->stream);
                        deserialize(*msg.data, con->stream);
                    }

                    msg_buffer.push_and_consume(msg);

                    if (msg.msg_type == NODE_TERMINATE)
                        break;
                }
            }
            catch (std::exception& e)
            {
                impl1::msg_data msg;
                msg.data.reset(new std::vector<char>);
                vectorstream sout(*msg.data);
                sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
                sout << "  Sending processing node address:   " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
                sout << "  Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
                sout << "  Receiving processing node id:      " << node_id << std::endl;
                sout << "  Error message in the exception:    " << e.what() << std::endl;

                msg.sender_id = sender_id;
                msg.msg_type = READ_ERROR;

                msg_buffer.push_and_consume(msg);
            }
            catch (...)
            {
                impl1::msg_data msg;
                msg.data.reset(new std::vector<char>);
                vectorstream sout(*msg.data);
                sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
                sout << "  Sending processing node address:   " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
                sout << "  Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
                sout << "  Receiving processing node id:      " << node_id << std::endl;

                msg.sender_id = sender_id;
                msg.msg_type = READ_ERROR;

                msg_buffer.push_and_consume(msg);
            }
        }

    // ------------------------------------------------------------------------------------

    }
    
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
//                          IMPLEMENTATION OF bsp_context OBJECT MEMBERS
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

    void bsp_context::
    close_all_connections_gracefully(
    )
    {
        if (node_id() != 0)
        {
            _cons.reset();
            while (_cons.move_next())
            {
                // tell the other end that we are intentionally dropping the connection
                serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
                _cons.element().value()->stream.flush();
            }
        }

        impl1::msg_data msg;
        // now wait for all the other nodes to terminate
        while (num_terminated_nodes < _cons.size() )
        {
            if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
            {
                num_waiting_nodes = 0;
                broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
                ++current_epoch;
            }

            if (!msg_buffer.pop(msg))
                throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");

            if (msg.msg_type == impl2::NODE_TERMINATE)
            {
                ++num_terminated_nodes;
                _cons[msg.sender_id]->terminated = true;
            }
            else if (msg.msg_type == impl2::READ_ERROR)
            {
                throw dlib::socket_error(msg.data_to_string());
            }
            else if (msg.msg_type == impl2::MESSAGE_HEADER)
            {
                throw dlib::socket_error("A BSP node received a message after it has terminated.");
            }
            else if (msg.msg_type == impl2::GOT_MESSAGE)
            {
                --num_waiting_nodes;
                --outstanding_messages;
            }
            else if (msg.msg_type == impl2::SENT_MESSAGE)
            {
                ++outstanding_messages;
            }
            else if (msg.msg_type == impl2::IN_WAITING_STATE)
            {
                ++num_waiting_nodes;
            }
        }

        if (node_id() == 0)
        {
            _cons.reset();
            while (_cons.move_next())
            {
                // tell the other end that we are intentionally dropping the connection
                serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
                _cons.element().value()->stream.flush();
            }

            if (outstanding_messages != 0)
            {
                std::ostringstream sout;
                sout << "A BSP job was allowed to terminate before all sent messages have been received.\n";
                sout << "There are at least " << outstanding_messages << " messages still in flight.   Make sure all sent messages\n";
                sout << "have a corresponding call to receive().";
                throw dlib::socket_error(sout.str());
            }
        }
    }

// ----------------------------------------------------------------------------------------

    bsp_context::
    ~bsp_context()
    {
        _cons.reset();
        while (_cons.move_next())
        {
            _cons.element().value()->con->shutdown();
        }

        msg_buffer.disable();

        // this will wait for all the threads to terminate
        threads.clear();
    }

// ----------------------------------------------------------------------------------------

    bsp_context::
    bsp_context(
        unsigned long node_id_,
        impl1::map_id_to_con& cons_
    ) :
        outstanding_messages(0),
        num_waiting_nodes(0),
        num_terminated_nodes(0),
        current_epoch(1),
        _cons(cons_),
        _node_id(node_id_)
    {
        // spawn a bunch of read threads, one for each connection
        _cons.reset();
        while (_cons.move_next())
        {
            std::unique_ptr<thread_function> ptr(new thread_function(&impl2::read_thread,
                                                                _cons.element().value().get(),
                                                                _node_id,
                                                                _cons.element().key(),
                                                                ref(msg_buffer)));
            threads.push_back(ptr);
        }

    }

// ----------------------------------------------------------------------------------------

    bool bsp_context::
    receive_data (
        std::shared_ptr<std::vector<char> >& item,
        unsigned long& sending_node_id
    ) 
    {
        notify_control_node(impl2::IN_WAITING_STATE);

        while (true)
        {
            // If there aren't any nodes left to give us messages then return right now.
            // We need to check the msg_buffer size to make sure there aren't any
            // unprocessed message there.  Recall that this can happen because status
            // messages always jump to the front of the message buffer.  So we might have
            // learned about the node terminations before processing their messages for us.
            if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0)
            {
                return false;
            }

            // if all running nodes are currently blocking forever on receive_data()
            if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
            {
                num_waiting_nodes = 0;
                broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);

                // Note that the reason we have this epoch counter is so we can tell if a
                // sent message is from before or after one of these "all nodes waiting"
                // synchronization events.  If we didn't have the epoch count we would have
                // a race condition where one node gets the SEE_ALL_IN_WAITING_STATE
                // message before others and then sends out a message to another node
                // before that node got the SEE_ALL_IN_WAITING_STATE message.  Then that
                // node would think the normal message came before SEE_ALL_IN_WAITING_STATE
                // which would be bad.
                ++current_epoch;
                return false;
            }

            impl1::msg_data data;
            if (!msg_buffer.pop(data, current_epoch))
                throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");


            switch(data.msg_type)
            {
                case impl2::MESSAGE_HEADER: {
                    item = data.data;
                    sending_node_id = data.sender_id;
                    notify_control_node(impl2::GOT_MESSAGE);
                    return true;
                } break;

                case impl2::IN_WAITING_STATE: {
                    ++num_waiting_nodes;
                } break;

                case impl2::GOT_MESSAGE: {
                    --outstanding_messages;
                    --num_waiting_nodes;
                } break;

                case impl2::SENT_MESSAGE: {
                    ++outstanding_messages;
                } break;

                case impl2::NODE_TERMINATE: {
                    ++num_terminated_nodes;
                    _cons[data.sender_id]->terminated = true;
                } break;

                case impl2::SEE_ALL_IN_WAITING_STATE: {
                    ++current_epoch;
                    return false;
                } break;

                case impl2::READ_ERROR: {
                    throw dlib::socket_error(data.data_to_string());
                } break;

                default: {
                    throw dlib::socket_error("Unknown message received by dlib::bsp_context");
                } break;
            } // end switch()
        } // end while (true)
    }

// ----------------------------------------------------------------------------------------

    void bsp_context::
    notify_control_node (
        char val
    )
    {
        if (node_id() == 0)
        {
            using namespace impl2;
            switch(val)
            {
                case SENT_MESSAGE: {
                    ++outstanding_messages;
                } break;

                case GOT_MESSAGE: {
                    --outstanding_messages;
                } break;

                case IN_WAITING_STATE: {
                    // nothing to do in this case
                } break;

                default:
                    DLIB_CASSERT(false,"This should never happen");
            }
        }
        else
        {
            serialize(val, _cons[0]->stream);
            _cons[0]->stream.flush();
        }
    }

// ----------------------------------------------------------------------------------------

    void bsp_context::
    broadcast_byte (
        char val
    )
    {
        for (unsigned long i = 0; i < number_of_nodes(); ++i)
        {
            // don't send to yourself or to terminated nodes
            if (i == node_id() || _cons[i]->terminated)
                continue;

            serialize(val, _cons[i]->stream);
            _cons[i]->stream.flush();
        }
    }

// ----------------------------------------------------------------------------------------

    void bsp_context::
    send_data(
        const std::vector<char>& item,
        unsigned long target_node_id
    ) 
    {
        using namespace impl2;
        if (_cons[target_node_id]->terminated)
            throw socket_error("Attempt to send a message to a node that has terminated.");

        serialize(MESSAGE_HEADER, _cons[target_node_id]->stream);
        serialize(current_epoch, _cons[target_node_id]->stream);
        serialize(item, _cons[target_node_id]->stream);
        _cons[target_node_id]->stream.flush();

        notify_control_node(SENT_MESSAGE);
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_BSP_CPph_