// Copyright (C) 2012  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_FULL_OBJECT_DeTECTION_Hh_
#define DLIB_FULL_OBJECT_DeTECTION_Hh_

#include "../geometry.h"
#include "full_object_detection_abstract.h"
#include <vector>
#include "../serialize.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    const static point OBJECT_PART_NOT_PRESENT(0x7FFFFFFF,
                                                0x7FFFFFFF);

// ----------------------------------------------------------------------------------------

    class full_object_detection
    {
    public:
        full_object_detection(
            const rectangle& rect_,
            const std::vector<point>& parts_
        ) : rect(rect_), parts(parts_) {}

        full_object_detection(){}

        explicit full_object_detection(
            const rectangle& rect_
        ) : rect(rect_) {}

        const rectangle& get_rect() const { return rect; }
        rectangle& get_rect() { return rect; }
        unsigned long num_parts() const { return parts.size(); }

        const point& part(
            unsigned long idx
        ) const 
        { 
            // make sure requires clause is not broken
            DLIB_ASSERT(idx < num_parts(),
                "\t point full_object_detection::part()"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t idx:         " << idx  
                << "\n\t num_parts(): " << num_parts()  
                << "\n\t this:        " << this
                );
            return parts[idx]; 
        }

        point& part(
            unsigned long idx
        )  
        { 
            // make sure requires clause is not broken
            DLIB_ASSERT(idx < num_parts(),
                "\t point full_object_detection::part()"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t idx:         " << idx  
                << "\n\t num_parts(): " << num_parts()  
                << "\n\t this:        " << this
                );
            return parts[idx]; 
        }

        friend void serialize (
            const full_object_detection& item,
            std::ostream& out
        )
        {
            int version = 1;
            serialize(version, out);
            serialize(item.rect, out);
            serialize(item.parts, out);
        }

        friend void deserialize (
            full_object_detection& item,
            std::istream& in
        )
        {
            int version = 0;
            deserialize(version, in);
            if (version != 1)
                throw serialization_error("Unexpected version encountered while deserializing dlib::full_object_detection.");

            deserialize(item.rect, in);
            deserialize(item.parts, in);
        }

        bool operator==(
            const full_object_detection& rhs
        ) const
        {
            if (rect != rhs.rect)
                return false;
            if (parts.size() != rhs.parts.size())
                return false;
            for (size_t i = 0; i < parts.size(); ++i)
            {
                if (parts[i] != rhs.parts[i])
                    return false;
            }
            return true;
        }

    private:
        rectangle rect;
        std::vector<point> parts;  
    };

// ----------------------------------------------------------------------------------------

    inline bool all_parts_in_rect (
        const full_object_detection& obj
    )
    {
        for (unsigned long i = 0; i < obj.num_parts(); ++i)
        {
            if (obj.get_rect().contains(obj.part(i)) == false &&
                obj.part(i) != OBJECT_PART_NOT_PRESENT)
                return false;
        }
        return true;
    }

// ----------------------------------------------------------------------------------------

    struct mmod_rect
    {
        mmod_rect() = default; 
        mmod_rect(const rectangle& r) : rect(r) {}
        mmod_rect(const rectangle& r, double score) : rect(r),detection_confidence(score) {}
        mmod_rect(const rectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score), label(label) {}

        rectangle rect;
        double detection_confidence = 0;
        bool ignore = false;
        std::string label;

        operator rectangle() const { return rect; }
        bool operator == (const mmod_rect& rhs) const
        { 
            return rect == rhs.rect 
                   && detection_confidence == rhs.detection_confidence
                   && ignore == rhs.ignore 
                   && label == rhs.label;
        }
    };

    inline mmod_rect ignored_mmod_rect(const rectangle& r)
    {
        mmod_rect temp(r);
        temp.ignore = true;
        return temp;
    }

    inline void serialize(const mmod_rect& item, std::ostream& out)
    {
        int version = 2;
        serialize(version, out);
        serialize(item.rect, out);
        serialize(item.detection_confidence, out);
        serialize(item.ignore, out);
        serialize(item.label, out);
    }

    inline void deserialize(mmod_rect& item, std::istream& in)
    {
        int version = 0;
        deserialize(version, in);
        if (version != 1 && version != 2)
            throw serialization_error("Unexpected version found while deserializing dlib::mmod_rect");
        deserialize(item.rect, in);
        deserialize(item.detection_confidence, in);
        deserialize(item.ignore, in);
        if (version == 2)
            deserialize(item.label, in);
        else
            item.label = "";
    }

// ----------------------------------------------------------------------------------------

    struct yolo_rect
    {
        yolo_rect() = default;
        yolo_rect(const drectangle& r) : rect(r) {}
        yolo_rect(const drectangle& r, double score) : rect(r),detection_confidence(score) {}
        yolo_rect(const drectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score), label(label) {}
        yolo_rect(const mmod_rect& r) : rect(r.rect), detection_confidence(r.detection_confidence), ignore(r.ignore), label(r.label) {}

        drectangle rect;
        double detection_confidence = 0;
        bool ignore = false;
        std::string label;
        std::vector<std::pair<double, std::string>> labels;

        operator rectangle() const { return rect; }
        bool operator == (const yolo_rect& rhs) const
        {
            return rect == rhs.rect
                   && detection_confidence == rhs.detection_confidence
                   && ignore == rhs.ignore
                   && label == rhs.label;
        }
        bool operator<(const yolo_rect& rhs) const
        {
            return detection_confidence < rhs.detection_confidence;
        }
    };

    inline void serialize(const yolo_rect& item, std::ostream& out)
    {
        int version = 1;
        serialize(version, out);
        serialize(item.rect, out);
        serialize(item.detection_confidence, out);
        serialize(item.ignore, out);
        serialize(item.label, out);
        serialize(item.labels, out);
    }

    inline void deserialize(yolo_rect& item, std::istream& in)
    {
        int version = 0;
        deserialize(version, in);
        if (version != 1)
            throw serialization_error("Unexpected version found while deserializing dlib::yolo_rect");
        deserialize(item.rect, in);
        deserialize(item.detection_confidence, in);
        deserialize(item.ignore, in);
        deserialize(item.label, in);
        deserialize(item.labels, in);
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_FULL_OBJECT_DeTECTION_H_