// Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuBLAS_CPP_ #define DLIB_DNN_CuBLAS_CPP_ #ifdef DLIB_USE_CUDA #include "cublas_dlibapi.h" #include "cuda_utils.h" #include <cublas_v2.h> #include <vector> static const char* cublas_get_error_string(cublasStatus_t s) { switch(s) { case CUBLAS_STATUS_NOT_INITIALIZED: return "CUDA Runtime API initialization failed."; case CUBLAS_STATUS_ALLOC_FAILED: return "CUDA Resources could not be allocated."; default: return "A call to cuBLAS failed"; } } // Check the return value of a call to the cuBLAS runtime for an error condition. #define CHECK_CUBLAS(call) \ do{ \ const cublasStatus_t error = call; \ if (error != CUBLAS_STATUS_SUCCESS) \ { \ std::ostringstream sout; \ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ sout << "code: " << error << ", reason: " << cublas_get_error_string(error);\ throw dlib::cublas_error(sout.str()); \ } \ }while(false) namespace dlib { namespace cuda { // ----------------------------------------------------------------------------------- class cublas_context { public: // not copyable cublas_context(const cublas_context&) = delete; cublas_context& operator=(const cublas_context&) = delete; cublas_context() { handles.resize(16); } ~cublas_context() { for (auto h : handles) { if (h) cublasDestroy(h); } } cublasHandle_t get_handle ( ) { int new_device_id; CHECK_CUDA(cudaGetDevice(&new_device_id)); // make room for more devices if needed if (new_device_id >= (long)handles.size()) handles.resize(new_device_id+16); // If we don't have a handle already for this device then make one if (!handles[new_device_id]) CHECK_CUBLAS(cublasCreate(&handles[new_device_id])); // Finally, return the handle for the current device return handles[new_device_id]; } private: std::vector<cublasHandle_t> handles; }; static cublasHandle_t context() { thread_local cublas_context c; return c.get_handle(); } // ----------------------------------------------------------------------------------- void gemm ( float beta, tensor& dest, float alpha, const tensor& lhs, bool trans_lhs, const tensor& rhs, bool trans_rhs, operation_mode mode ) { if (mode == operation_mode::CHANNEL_WISE) { // Recall that BLAS uses column major order so to deal with that we flip the // order of the lhs and rhs arguments. const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N; const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N; const int dest_nr = dest.num_samples(); const int dest_nc = dest.size() / dest_nr; const int lhs_nr = lhs.num_samples(); const int lhs_nc = lhs.size() / lhs_nr; const int rhs_nr = rhs.num_samples(); const int rhs_nc = rhs.size() / rhs_nr; if (trans_lhs && trans_rhs) { DLIB_ASSERT(dest_nr == lhs_nc && dest_nc == rhs_nr && lhs_nr == rhs_nc) } else if (!trans_lhs && trans_rhs) { DLIB_ASSERT(dest_nr == lhs_nr && dest_nc == rhs_nr && lhs_nc == rhs_nc) } else if (trans_lhs && !trans_rhs) { DLIB_ASSERT(dest_nr == lhs_nc && dest_nc == rhs_nc && lhs_nr == rhs_nr) } else { DLIB_ASSERT(dest_nr == lhs_nr && dest_nc == rhs_nc && lhs_nc == rhs_nr) } const int k = trans_rhs ? rhs_nc : rhs_nr; CHECK_CUBLAS(cublasSgemm(context(), transb, transa, dest_nc, dest_nr, k, &alpha, rhs.device(), rhs_nc, lhs.device(), lhs_nc, &beta, dest.device(), dest_nc)); } else if (mode == operation_mode::PLANE_WISE) { const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N; const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N; long num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() }); long num_channels = std::min({ lhs.k(), rhs.k(), dest.k() }); auto is_matrix = [](const auto& tensor) { return ((tensor.num_samples() * tensor.k() == 1 && tensor.nr() * tensor.nc() > 1) || (tensor.num_samples() * tensor.k() > 1 && tensor.nr() * tensor.nc() == 1)); }; const bool lhs_is_matrix = is_matrix(lhs), rhs_is_matrix = is_matrix(rhs), dest_is_matrix = is_matrix(dest); if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) num_samples = num_channels = 1; size_t lhs_rows = lhs.nr(); size_t lhs_cols = lhs.nc(); if (lhs_is_matrix && (lhs.num_samples() > 1 || lhs.k() > 1)) { lhs_rows = lhs.num_samples(); lhs_cols = lhs.k(); } size_t rhs_rows = rhs.nr(); size_t rhs_cols = rhs.nc(); if (rhs_is_matrix && (rhs.num_samples() > 1 || rhs.k() > 1)) { rhs_rows = rhs.num_samples(); rhs_cols = rhs.k(); } size_t dest_rows = dest.nr(); size_t dest_cols = dest.nc(); if (dest_is_matrix && (dest.num_samples() > 1 || dest.k() > 1)) { dest_rows = dest.num_samples(); dest_cols = dest.k(); } const size_t lhs_plane_size = lhs_rows * lhs_cols; const size_t rhs_plane_size = rhs_rows * rhs_cols; const size_t dest_plane_size = dest_rows * dest_cols; for (long b = 0; b < num_samples; ++b) { for (long c = 0; c < num_channels; ++c) { auto lhs_slice = lhs_is_matrix ? lhs.device() : lhs.device() + (b * num_channels + c) * lhs_plane_size; auto rhs_slice = rhs_is_matrix ? rhs.device() : rhs.device() + (b * num_channels + c) * rhs_plane_size; auto dest_slice = dest_is_matrix ? dest.device() : dest.device() + (b * num_channels + c) * dest_plane_size; const int k = trans_rhs ? rhs_cols : rhs_rows; CHECK_CUBLAS(cublasSgemm( context(), transb, transa, dest_cols, dest_rows, k, &alpha, rhs_slice, rhs_cols, lhs_slice, lhs_cols, &beta, dest_slice, dest_cols )); } } } } // ------------------------------------------------------------------------------------ } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuBLAS_CPP_