// Copyright (C) 2008  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_MATRIx_DEFAULT_MULTIPLY_
#define DLIB_MATRIx_DEFAULT_MULTIPLY_

#include "../geometry/rectangle.h"
#include "matrix.h"
#include "matrix_utilities.h"
#include "../enable_if.h"

namespace dlib
{

// ------------------------------------------------------------------------------------

    namespace ma
    {
        template < typename EXP, typename enable = void >
        struct matrix_is_vector { static const bool value = false; };
        template < typename EXP >
        struct matrix_is_vector<EXP, typename enable_if_c<EXP::NR==1 || EXP::NC==1>::type > { static const bool value = true; };
    }

// ------------------------------------------------------------------------------------

    /*!  This file defines the default_matrix_multiply() function.  It is a function 
         that conforms to the following definition:

        template <
            typename matrix_dest_type,
            typename EXP1,
            typename EXP2
            >
        void default_matrix_multiply (
            matrix_dest_type& dest,
            const EXP1& lhs,
            const EXP2& rhs
        );
            requires
                - (lhs*rhs).destructively_aliases(dest) == false
                - dest.nr() == (lhs*rhs).nr()
                - dest.nc() == (lhs*rhs).nc()
            ensures
                - #dest == dest + lhs*rhs
    !*/

// ------------------------------------------------------------------------------------

    template <
        typename matrix_dest_type,
        typename EXP1,
        typename EXP2
        >
    typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true || ma::matrix_is_vector<EXP2>::value == true>::type 
    default_matrix_multiply (
        matrix_dest_type& dest,
        const EXP1& lhs,
        const EXP2& rhs
    )
    {
        matrix_assign_default(dest, lhs*rhs, 1, true);
    }

// ------------------------------------------------------------------------------------

    template <
        typename matrix_dest_type,
        typename EXP1,
        typename EXP2
        >
    typename enable_if_c<ma::matrix_is_vector<EXP1>::value == false && ma::matrix_is_vector<EXP2>::value == false>::type 
    default_matrix_multiply (
        matrix_dest_type& dest,
        const EXP1& lhs,
        const EXP2& rhs
    )
    {
        const long bs = 90;

        // if the matrices are small enough then just use the simple multiply algorithm
        if (lhs.nc() <= 2 || rhs.nc() <= 2 || lhs.nr() <= 2 || rhs.nr() <= 2 || (lhs.size() <= bs*10 && rhs.size() <= bs*10) )
        {
            matrix_assign_default(dest, lhs*rhs, 1, true);
        }
        else
        {
            // if the lhs and rhs matrices are big enough we should use a cache friendly
            // algorithm that computes the matrix multiply in blocks.  


            // Loop over all the blocks in the lhs matrix
            for (long r = 0; r < lhs.nr(); r+=bs)
            {
                for (long c = 0; c < lhs.nc(); c+=bs)
                {
                    // make a rect for the block from lhs 
                    rectangle lhs_block(c, r, std::min(c+bs-1,lhs.nc()-1), std::min(r+bs-1,lhs.nr()-1));

                    // now loop over all the rhs blocks we have to multiply with the current lhs block
                    for (long i = 0; i < rhs.nc(); i += bs)
                    {
                        // make a rect for the block from rhs 
                        rectangle rhs_block(i, c, std::min(i+bs-1,rhs.nc()-1), std::min(c+bs-1,rhs.nr()-1));

                        // make a target rect in res
                        rectangle res_block(rhs_block.left(),lhs_block.top(), rhs_block.right(), lhs_block.bottom());

                        // This loop is optimized assuming that the data is laid out in 
                        // row major order in memory.
                        for (long r = lhs_block.top(); r <= lhs_block.bottom(); ++r)
                        {
                            for (long c = lhs_block.left(); c<= lhs_block.right(); ++c)
                            {
                                const typename EXP2::type temp = lhs(r,c);
                                for (long i = rhs_block.left(); i <= rhs_block.right(); ++i)
                                {
                                    dest(r,i) += rhs(c,i)*temp;
                                }
                            }
                        }
                    }
                }
            }
        }


    }

// ------------------------------------------------------------------------------------

}

#endif // DLIB_MATRIx_DEFAULT_MULTIPLY_