// Copyright (C) 2009  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_TYPE_SAFE_UNIOn_h_ 
#define DLIB_TYPE_SAFE_UNIOn_h_

#include "type_safe_union_kernel_abstract.h"
#include <new>
#include <iostream>
#include <functional>
#include "../serialize.h"
#include "../type_traits.h"
#include "../overloaded.h"

namespace dlib
{
    // ---------------------------------------------------------------------

    class bad_type_safe_union_cast : public std::bad_cast 
    {
    public:
          virtual const char * what() const throw()
          {
              return "bad_type_safe_union_cast";
          }
    };

    // ---------------------------------------------------------------------

    template<typename T>
    struct in_place_tag { using type = T;};

    // ---------------------------------------------------------------------

    template <typename... Types> class type_safe_union;

    template<typename Tsu>
    struct type_safe_union_size;

    template<typename... Types>
    struct type_safe_union_size<type_safe_union<Types...>> : std::integral_constant<size_t, sizeof...(Types)> {};

    template<typename Tsu> struct type_safe_union_size<const Tsu>           : type_safe_union_size<Tsu> {};
    template<typename Tsu> struct type_safe_union_size<volatile Tsu>        : type_safe_union_size<Tsu> {};
    template<typename Tsu> struct type_safe_union_size<const volatile Tsu>  : type_safe_union_size<Tsu> {};

    // ---------------------------------------------------------------------

    template <size_t I, typename TSU>
    struct type_safe_union_alternative;

    template <size_t I, typename... Types>
    struct type_safe_union_alternative<I, type_safe_union<Types...>> : nth_type<I, Types...>{};

    template<size_t I, typename TSU>
    using type_safe_union_alternative_t = typename type_safe_union_alternative<I, TSU>::type;

    template <size_t I, typename TSU>
    struct type_safe_union_alternative<I, const TSU>
    { using type = typename std::add_const<type_safe_union_alternative_t<I, TSU>>::type; };

    template <size_t I, typename TSU>
    struct type_safe_union_alternative<I, volatile TSU>
    { using type = typename std::add_volatile<type_safe_union_alternative_t<I, TSU>>::type; };

    template <size_t I, typename TSU>
    struct type_safe_union_alternative<I, const volatile TSU>
    { using type = typename std::add_cv<type_safe_union_alternative_t<I, TSU>>::type; };

    // ---------------------------------------------------------------------

    namespace detail
    {
        // ---------------------------------------------------------------------

        template <int nTs, typename T, typename... Ts>
        struct type_safe_union_type_id_impl
                : std::integral_constant<int, -1 - nTs> {};

        template <int nTs, typename T, typename T0, typename... Ts>
        struct type_safe_union_type_id_impl<nTs, T, T0, Ts...>
                : std::integral_constant<int, std::is_same<T,T0>::value ? 1 : type_safe_union_type_id_impl<nTs, T,Ts...>::value + 1> {};

        template <typename T, typename... Ts>
        struct type_safe_union_type_id : type_safe_union_type_id_impl<sizeof...(Ts),T,Ts...>{};

        template <typename T, typename... Ts>
        struct type_safe_union_type_id<in_place_tag<T>, Ts...> : type_safe_union_type_id<T,Ts...>{};

        // ---------------------------------------------------------------------
    }

    template <typename... Types>
    class type_safe_union
    {
        /*!
            CONVENTION
                - is_empty() ==  (type_identity == 0)
                - contains<T>() == (type_identity == get_type_id<T>())
                - mem == the aligned block of memory on the stack which is
                  where objects in the union are stored
        !*/
    public:
        template <typename T>
        static constexpr int get_type_id ()
        {
            return detail::type_safe_union_type_id<T,Types...>::value;
        }

        template <typename T>
        static constexpr int get_type_id (in_place_tag<T>)
        {
            return get_type_id<T>();
        }

    private:

        template<typename T>
        using is_valid_check = std::enable_if_t<is_any_type<T,Types...>::value, bool>;

        template <size_t I>
        using get_type_t = type_safe_union_alternative_t<I, type_safe_union>;

        typename std::aligned_union<0, Types...>::type mem;
        int type_identity = 0;

        template<typename F, typename TSU>
        struct dispatcher
        {
            constexpr static const std::size_t N = sizeof...(Types);
            using R = decltype(std::declval<F>()(std::declval<TSU>().template unchecked_get<get_type_t<0>>()));

            constexpr static const bool is_noexcept =
                And<std::is_default_constructible<R>::value &&
                    noexcept(std::declval<F>()(std::declval<TSU>().template unchecked_get<Types>()))...>::value;

            template<size_t I, typename std::enable_if<I == N, bool>::type = true>
            inline R operator()(F&&, TSU&&, size_<I>)
            noexcept(is_noexcept) { return R(); }

            template<size_t I, typename std::enable_if<I < N, bool>::type = true>
            inline R operator()(F&& f, TSU&& me, size_<I>)
            noexcept(is_noexcept)
            {
                if (me.is_empty())
                    return R();
                else if (me.get_current_type_id() == (I+1))
                    return std::forward<F>(f)(me.template unchecked_get<get_type_t<I>>());
                else
                    return (*this)(std::forward<F>(f), std::forward<TSU>(me), size_<I+1>{});
            }
        };

        template<typename F, typename TSU>
        static inline decltype(auto) dispatch(F&& f, TSU&& me)
        noexcept(noexcept(dispatcher<F&&,TSU&&>{}(std::forward<F>(f), std::forward<TSU>(me), size_<0>{}))) {
            return dispatcher<F&&,TSU&&>{}(std::forward<F>(f), std::forward<TSU>(me), size_<0>{});
        }

        template <typename T>
        const T& unchecked_get() const noexcept
        {
            return *reinterpret_cast<const T*>(&mem);
        }

        template <typename T>
        T& unchecked_get() noexcept
        {
            return *reinterpret_cast<T*>(&mem);
        }

        struct destruct_helper
        {
            template <typename T>
            void operator() (T& item) const
            {
                item.~T();
            }
        };

        void destruct ()
        {
            apply_to_contents(destruct_helper{});
            type_identity = 0;
        }

        template <typename T, typename... Args>
        void construct (
            Args&&... args
        )
        {
            destruct();
            new(&mem) T(std::forward<Args>(args)...);
            type_identity = get_type_id<T>();
        }

        struct assign_to
        {
            /*!
                This class assigns an object to `me` using std::forward.
            !*/
            assign_to(type_safe_union& me) : _me(me) {}

            template<typename T>
            void operator()(T&& x)
            {
                using U = std::decay_t<T>;

                if (_me.type_identity != get_type_id<U>())
                {
                    _me.construct<U>(std::forward<T>(x));
                }
                else
                {
                    _me.template unchecked_get<U>() = std::forward<T>(x);
                }
            }

            type_safe_union& _me;
        };

        struct move_to
        {
            /*!
                This class move assigns an object to `me`.
            !*/
            move_to(type_safe_union& me) : _me(me) {}

            template<typename T>
            void operator()(T& x)
            {
                if (_me.type_identity != get_type_id<T>())
                {
                    _me.construct<T>(std::move(x));
                }
                else
                {
                    _me.template unchecked_get<T>() = std::move(x);
                }
            }

            type_safe_union& _me;
        };

        struct swap_to
        {
            /*!
                This class swaps an object with `me`.
            !*/
            swap_to(type_safe_union& me) : _me(me) {}

            template<typename T>
            void operator()(T& x)
            /*!
                requires
                    - _me.contains<T>() == true
            !*/
            {
                using std::swap;
                swap(_me.unchecked_get<T>(), x);
            }

            type_safe_union& _me;
        };

    public:

        type_safe_union() = default;

        type_safe_union (
            const type_safe_union& item
        )
        noexcept(are_nothrow_copy_constructible<Types...>::value)
        : type_safe_union()
        {
            item.apply_to_contents(assign_to{*this});
        }

        type_safe_union& operator=(
            const type_safe_union& item
        )
        noexcept(are_nothrow_copy_constructible<Types...>::value &&
                 are_nothrow_copy_assignable<Types...>::value)
        {
            if (item.is_empty())
                destruct();
            else
                item.apply_to_contents(assign_to{*this});
            return *this;
        }

        type_safe_union (
            type_safe_union&& item
        )
        noexcept(are_nothrow_move_constructible<Types...>::value)
        : type_safe_union()
        {
            item.apply_to_contents(move_to{*this});
            item.destruct();
        }

        type_safe_union& operator= (
            type_safe_union&& item
        )
        noexcept(are_nothrow_move_constructible<Types...>::value &&
                 are_nothrow_move_assignable<Types...>::value)
        {
            if (item.is_empty())
            {
                destruct();
            }
            else
            {
                item.apply_to_contents(move_to{*this});
                item.destruct();
            }
            return *this;
        }

        template <
            typename T,
            is_valid_check<std::decay_t<T>> = true
        >
        type_safe_union (
            T&& item
        )
        noexcept(std::is_nothrow_constructible<std::decay_t<T>, T>::value)
        : type_safe_union()
        {
            assign_to{*this}(std::forward<T>(item));
        }

        template <
            typename T,
            is_valid_check<std::decay_t<T>> = true
        >
        type_safe_union& operator= (
            T&& item
        )
        noexcept(std::is_nothrow_constructible<std::decay_t<T>, T>::value &&
                 std::is_nothrow_assignable<std::decay_t<T>, T>::value)
        {
            assign_to{*this}(std::forward<T>(item));
            return *this;
        }

        template <
            typename T,
            typename... Args,
            is_valid_check<T> = true
        >
        type_safe_union (
            in_place_tag<T>,
            Args&&... args
        )
        noexcept(std::is_nothrow_constructible<T, Args...>::value)
        : type_safe_union()
        {
            construct<T>(std::forward<Args>(args)...);
        }

        ~type_safe_union()
        {
            destruct();
        }

        void clear()
        {
            destruct();
        }

        template <
            typename T,
            typename... Args,
            is_valid_check<T> = true
        >
        void emplace(
            Args&&... args
        )
        noexcept(std::is_nothrow_constructible<T, Args...>::value)
        {
            construct<T>(std::forward<Args>(args)...);
        }

        template <typename F>
        decltype(auto) apply_to_contents(
            F&& f
        ) noexcept(noexcept(dispatch(std::forward<F>(f), std::declval<type_safe_union&>()))) {
            return dispatch(std::forward<F>(f), *this);
        }

        template <typename F>
        decltype(auto) apply_to_contents(
            F&& f
        ) const noexcept(noexcept(dispatch(std::forward<F>(f), std::declval<const type_safe_union&>()))) {
            return dispatch(std::forward<F>(f), *this);
        }

        template <typename T>
        bool contains (
        ) const noexcept
        {
            return type_identity == get_type_id<T>();
        }

        bool is_empty (
        ) const noexcept
        {
            return type_identity == 0;
        }

        int get_current_type_id() const noexcept
        {
            return type_identity;
        }

        template <
            typename T,
            is_valid_check<T> = true
        >
        T& get(
        )
        noexcept(std::is_nothrow_default_constructible<T>::value)
        {
            if (type_identity != get_type_id<T>())
                construct<T>();
            return unchecked_get<T>();
        }

        template <
            typename T
        >
        T& get(
            in_place_tag<T>
        )
        noexcept(std::is_nothrow_default_constructible<T>::value)
        {
            return get<T>();
        }

        template <
            typename T,
            is_valid_check<T> = true
        >
        const T& cast_to (
        ) const
        {
            if (contains<T>())
                return unchecked_get<T>();
            else
                throw bad_type_safe_union_cast();
        }

        template <
            typename T,
            is_valid_check<T> = true
        >
        T& cast_to (
        )
        {
            if (contains<T>())
                return unchecked_get<T>();
            else
                throw bad_type_safe_union_cast();
        }

        void swap(
            type_safe_union& item
        ) noexcept(std::is_nothrow_move_constructible<type_safe_union>::value &&
                   are_nothrow_swappable<Types...>::value)
        {
            if (type_identity == item.type_identity)
            {
                apply_to_contents(swap_to{item});
            }
            else if (is_empty())
            {
                *this = std::move(item);
            }
            else if (item.is_empty())
            {
                item = std::move(*this);
            }
            else
            {
                type_safe_union tmp{std::move(*this)};
                *this = std::move(item);
                item  = std::move(tmp);
            }
        }
    };

    template <typename ...Types>
    inline void swap (
        type_safe_union<Types...>& a,
        type_safe_union<Types...>& b
    ) noexcept(noexcept(a.swap(b)))
    { a.swap(b); }

    namespace detail
    {
        template<
            typename F,
            typename TSU,
            std::size_t... I
        >
        void for_each_type_impl(
            F&& f,
            TSU&& tsu,
            std::index_sequence<I...>
        )
        {
            using Tsu = std::decay_t<TSU>;

#ifdef __cpp_fold_expressions
            (std::forward<F>(f)(
                in_place_tag<type_safe_union_alternative_t<I, Tsu>>{},
                std::forward<TSU>(tsu)),
            ...);
#else
            (void)std::initializer_list<int>{
                (std::forward<F>(f)(
                        in_place_tag<type_safe_union_alternative_t<I, Tsu>>{},
                        std::forward<TSU>(tsu)),
                 0
                )...
            };
#endif            
        }
    }

    template<
        typename TSU,
        typename F
    >
    void for_each_type(
        F&& f,
        TSU&& tsu
    )
    {
        using Tsu = std::decay_t<TSU>;
        static constexpr std::size_t Size = type_safe_union_size<Tsu>::value;
        detail::for_each_type_impl(std::forward<F>(f), std::forward<TSU>(tsu), std::make_index_sequence<Size>{});
    }

    template<typename F, typename TSU>
    decltype(auto) visit(
        F&& f,
        TSU&& tsu
    ) noexcept(noexcept(tsu.apply_to_contents(std::forward<F>(f)))) {
        return tsu.apply_to_contents(std::forward<F>(f));
    }

    template<typename... Types>
    inline void serialize (
        const type_safe_union<Types...>& item,
        std::ostream& out
    )
    {
        try
        {
            serialize(item.get_current_type_id(), out);
            item.apply_to_contents([&](auto&& x) {
                serialize(x, out);
            });
        }
        catch (serialization_error& e)
        {
            throw serialization_error(e.info + "\n   while serializing an object of type type_safe_union");
        }
    }

    template<typename... Types>
    inline void deserialize (
        type_safe_union<Types...>& item,
        std::istream& in
    )
    {
        try
        {
            int index = -1;
            deserialize(index, in);

            if (index == 0)
                item.clear();
            else if (index > 0 && index <= (int)sizeof...(Types))
                for_each_type([&](auto tag, auto&& me) {
                    if (index == me.get_type_id(tag))
                        deserialize(me.get(tag), in);
                }, item);
            else
                throw serialization_error("bad index value. Should be in range [0,sizeof...(Types))");
        }
        catch(serialization_error& e)
        {
            throw serialization_error(e.info + "\n   while deserializing an object of type type_safe_union");
        }
    }
}

#endif // DLIB_TYPE_SAFE_UNIOn_h_