// Copyright (C) 2006  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_SOCKETS_EXTENSIONs_CPP
#define DLIB_SOCKETS_EXTENSIONs_CPP

#include <string>
#include <sstream>
#include "../sockets.h"
#include "../error.h"
#include "sockets_extensions.h"
#include "../timer.h"
#include "../algs.h"
#include "../timeout.h"
#include "../misc_api.h"
#include "../serialize.h"
#include "../string.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    network_address::
    network_address(
        const std::string& full_address
    )
    {
        std::istringstream sin(full_address);
        sin >> *this;
        if (!sin || sin.peek() != EOF)
            throw invalid_network_address("invalid network address: " + full_address);
    }

// ----------------------------------------------------------------------------------------

    void serialize(
        const network_address& item,
        std::ostream& out
    )
    {
        serialize(item.host_address, out);
        serialize(item.port, out);
    }

// ----------------------------------------------------------------------------------------

    void deserialize(
        network_address& item,
        std::istream& in 
    )
    {
        deserialize(item.host_address, in);
        deserialize(item.port, in);
    }

// ----------------------------------------------------------------------------------------

    std::ostream& operator<< (
        std::ostream& out,
        const network_address& item
    )
    {
        out << item.host_address << ":" << item.port;
        return out;
    }

// ----------------------------------------------------------------------------------------

    std::istream& operator>> (
        std::istream& in,
        network_address& item
    )
    {
        std::string temp;
        in >> temp;

        std::string::size_type pos = temp.find_last_of(":");
        if (pos == std::string::npos)
        {
            in.setstate(std::ios::badbit);
            return in;
        }

        item.host_address = temp.substr(0, pos);
        try
        {
            item.port = sa = temp.substr(pos+1);
        } catch (std::exception& )
        {
            in.setstate(std::ios::badbit);
            return in;
        }


        return in;
    }

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

    connection* connect (
        const std::string& host_or_ip,
        unsigned short port
    )
    {
        std::string ip;
        connection* con;
        if (is_ip_address(host_or_ip))
        {
            ip = host_or_ip;
        }
        else
        {
            if( hostname_to_ip(host_or_ip,ip))
                throw socket_error(ERESOLVE,"unable to resolve '" + host_or_ip + "' in connect()");
        }

        if(create_connection(con,port,ip))
        {
            std::ostringstream sout;
            sout << "unable to connect to '" << host_or_ip << ":" << port << "'";
            throw socket_error(sout.str()); 
        }

        return con;
    }

// ----------------------------------------------------------------------------------------

    connection* connect (
        const network_address& addr
    )
    {
        return connect(addr.host_address, addr.port);
    }

// ----------------------------------------------------------------------------------------

    namespace connect_timeout_helpers
    {
        mutex connect_mutex;
        signaler connect_signaler(connect_mutex);
        timestamper ts;
        long outstanding_connects = 0;

        struct thread_data
        {
            std::string host_or_ip;
            unsigned short port;
            connection* con;
            bool connect_ended;
            bool error_occurred;
        };

        void thread(void* param)
        {
            thread_data p = *static_cast<thread_data*>(param);
            try
            {
                p.con = connect(p.host_or_ip, p.port); 
            }
            catch (...)
            {
                p.error_occurred = true;
            }

            auto_mutex M(connect_mutex);
            // report the results back to the connect() call that spawned this
            // thread.
            static_cast<thread_data*>(param)->con = p.con;
            static_cast<thread_data*>(param)->error_occurred = p.error_occurred;
            connect_signaler.broadcast();

            // wait for the call to connect() that spawned this thread to terminate
            // before we delete the thread_data struct.
            while (static_cast<thread_data*>(param)->connect_ended == false)
                connect_signaler.wait();

            connect_signaler.broadcast();
            --outstanding_connects;
            delete static_cast<thread_data*>(param);
        }
    }

    connection* connect (
        const std::string& host_or_ip,
        unsigned short port,
        unsigned long timeout
    )
    {
        using namespace connect_timeout_helpers;

        auto_mutex M(connect_mutex);

        const uint64 end_time = ts.get_timestamp() + timeout*1000;


        // wait until there are less than 100 outstanding connections
        while (outstanding_connects > 100)
        {
            uint64 cur_time = ts.get_timestamp();
            if (end_time > cur_time)
            {
                timeout = static_cast<unsigned long>((end_time - cur_time)/1000);
            }
            else
            {
                throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out"); 
            }
            
            connect_signaler.wait_or_timeout(timeout);
        }

        
        thread_data* data = new thread_data;
        data->host_or_ip = host_or_ip.c_str();
        data->port = port;
        data->con = 0;
        data->connect_ended = false;
        data->error_occurred = false;


        if (create_new_thread(thread, data) == false)
        {
            delete data;
            throw socket_error("unable to connect to '" + host_or_ip); 
        }

        ++outstanding_connects;

        // wait until we have a connection object 
        while (data->con == 0)
        {
            uint64 cur_time = ts.get_timestamp();
            if (end_time > cur_time && data->error_occurred == false)
            {
                timeout = static_cast<unsigned long>((end_time - cur_time)/1000);
            }
            else
            {
                // let the thread know that it should terminate
                data->connect_ended = true;
                connect_signaler.broadcast();
                if (data->error_occurred)
                    throw socket_error("unable to connect to '" + host_or_ip); 
                else
                    throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out"); 
            }

            connect_signaler.wait_or_timeout(timeout);
        }

        // let the thread know that it should terminate
        data->connect_ended = true;
        connect_signaler.broadcast();
        return data->con;
    }

// ----------------------------------------------------------------------------------------

    bool is_ip_address (
        std::string ip
    )
    {
        for (std::string::size_type i = 0; i < ip.size(); ++i)
        {
            if (ip[i] == '.')
                ip[i] = ' ';
        }
        std::istringstream sin(ip);
        
        bool bad = false;
        int num;
        for (int i = 0; i < 4; ++i)
        {
            sin >> num;
            if (!sin || num < 0 || num > 255)
            {
                bad = true;
                break;
            }
        }

        if (sin.get() != EOF)
            bad = true;
        
        return !bad;
    }

// ----------------------------------------------------------------------------------------

    void close_gracefully (
        connection* con,
        unsigned long timeout 
    )
    {
        std::unique_ptr<connection> ptr(con);
        close_gracefully(ptr,timeout);
    }

// ----------------------------------------------------------------------------------------

    void close_gracefully (
        std::unique_ptr<connection>& con,
        unsigned long timeout 
    )
    {
        if (!con)
            return;

        if(con->shutdown_outgoing())
        {
            // there was an error so just close it now and return
            con.reset();
            return;
        }

        try
        {
            dlib::timeout t(*con,&connection::shutdown,timeout);

            char junk[100];
            // wait for the other end to close their side
            while (con->read(junk,sizeof(junk)) > 0) ;
        }
        catch (...)
        {
            con.reset();
            throw;
        }

        con.reset();
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_SOCKETS_EXTENSIONs_CPP