Effective use of enable_if with C++ templates to avoid class specialization

1.1k views Asked by At

I am having trouble getting my code to compile. clang, g++ and icpc all give different error messages,

A bit of background before getting to the question itself:

I am working now on a template class hierarchy for working with Matrices. There are template parameters for the data type (either float or double) and for "Implementation Policy" -- at present this includes regular C++ code with loops and Intel MKL versions. The following is an abridged summary (please disregard lack of forward references, etc. in this -- that is unrelated to my question):

// Matrix.h

template <typename Type, typename IP>
class Matrix : public Matrix_Base<Type, IP>;

template <typename Matrix_Type>
class Matrix_Base{
    /* ... */

    // Matrix / Scalar addition
    template <typename T>
    Matrix_Base& operator+=(const T value_) { 
      return Implementation<IP>::Plus_Equal(
          static_cast<Matrix_Type&>(*this), value_);

    /* More operators and rest of code... */
    };

struct CPP;
struct MKL;

template <typename IP>
struct Implementation{
/* This struct contains static methods that do the actual operations */

The trouble that I'm having right now is related to the implementation of the Implementation class (no pun intended). I know that I can use specialization of the Implementation template class to specialize template <> struct Implementation<MKL>{/* ... */}; however, this will result in a lot of code duplication as there are a number of operators (such as matrix-scalar addition, subtraction, ... ) for which both the generic and the specialized versions use the same code.

So, instead, I thought that I could get rid of the template specialization and just use enable_if to provide different implementations for those operators which have different implementations when using MKL (or CUDA, etc.).

This has proven to me to be more challenging than I had originally expected. The first -- for operator += (T value_) works fine. I added in a check just to make sure that the parameter is reasonable (this can be eliminated if it is the source of my troubles, which I doubt).

template <class Matrix_Type, typename Type, typename enable_if< 
    std::is_arithmetic<Type>::value  >::type* dummy = nullptr>
static Matrix_Type& Plus_Equal(Matrix_Type& matrix_, Type value_){
    uint64_t total_elements = matrix_.actual_dims.first * matrix_.actual_dims.second;
    //y := A + b

    #pragma parallel 
    for (uint64_t i = 0; i < total_elements; ++i)
        matrix_.Data[i] += value_; 

    return matrix_;
}

However, I am having a really hard time figuring out how to deal with operator *=(T value_). This is due to the fact that float and double have different implementations for MKL but not in the general case.

Here is the declaration. Note that the 3rd parameter is a dummy parameter and was my attempt at forcing function overloading, since I cannot use partial template function specialization:

template <class Matrix_Type, typename U, typename Type = 
    typename internal::Type_Traits< Matrix_Type>::type, typename  enable_if<
    std::is_arithmetic<Type>::value >::type* dummy = nullptr>

static Matrix_Type& Times_Equal(Matrix_Type& matrix_, U value_, Type dummy_ = 0.0);

Definition for general case. :

template<class IP>
template <class Matrix_Type, typename U, typename Type,  typename enable_if<
    std::is_arithmetic<Type>::value >::type* dummy>
Matrix_Type& Implementation<IP>::Times_Equal(Matrix_Type& matrix_, U value_, Type){

    uint64_t total_elements = matrix_.actual_dims.first * matrix_.actual_dims.second;

    //y := A - b
    #pragma parallel
    for (uint64_t i = 0; i < total_elements; ++i)
        matrix_.Data[i] *= value_;

    return matrix_;
}

The trouble starts when I try to implement a specialization for MKL:

template<>
template <class Matrix_Type, typename U, typename Type, typename enable_if<
    std::is_arithmetic<Type>::value >::type* dummy>
Matrix_Type& Implementation<implementation::MKL>::Times_Equal(
    Matrix_Type& matrix_, 
    U value_,
    typename enable_if<std::is_same<Type,float>::value,Type>::type)
{

    float value = value_;

    MKL_INT total_elements = matrix_.actual_dims.first * matrix_.actual_dims.second;
    MKL_INT const_one = 1;

    //y := a * b
    sscal(&total_elements, &value, matrix_.Data, &const_one);
    return matrix_;
}

This gives me an error in clang:

_error: out-of-line definition of 'Times_Equal' does not match any declaration in 'Implementation'_

and in g++ (shortened somewhat)

_error: template-id `Times_Equal<>' for 'Matrix_Type& Implementation::Times_Equal(...)' does not match any template declaration.

The code compiles perfectly fine if I change the 3rd parameter to be Type, rather than having the enable_if. But when I do that, I cannot see how to have separate implementations for float and double.

Any help would be greatly appreciated.

1

There are 1 answers

0
Banan On

I think this would be very tedious to implement using std::enable_if, as the general cases would have to be implemented with an enable_if that turned it on if it doesn't fit one of the specializations.

Specifically addressing your code, I don't think the compiler is able to deduce Type in your MKL specialization as it is hidden away in the std::enable_if, and thus this specialization would never get called.

Instead of using enable_if you could perhaps do something like this:

#include<iostream>

struct CPP {};
struct MKL {};

namespace Implementation
{
   //
   // general Plus_Equal
   //
   template<class Type, class IP>
   struct Plus_Equal
   {
      template<class Matrix_Type>
      static Matrix_Type& apply(Matrix_Type& matrix_, Type value_)
      {
         std::cout << " Matrix Plus Equal General Implementation " << std::endl;
         // ... do general Plus_Equal ...
         return matrix_;
      }
   };

   //
   // specialized Plus_Equal for MKL with Type double
   //
   template<>
   struct Plus_Equal<double,MKL>
   {
      template<class Matrix_Type>
      static Matrix_Type& apply(Matrix_Type& matrix_, double value_)
      {
         std::cout << " Matrix Plus Equal MKL double Implementation " << std::endl;
         // ... do MKL/double specialized Plus_Equal ...
         return matrix_;
      }
   };
} // namespace Implementation

template <typename Type, typename IP, typename Matrix_Type>
class Matrix_Base
{  
   public:
   // ... matrix base implementation ...

   // Matrix / Scalar addition
   template <typename T>
   Matrix_Base& operator+=(const T value_) 
   { 
      return Implementation::Plus_Equal<Type,IP>::apply(static_cast<Matrix_Type&>(*this), value_);
   }

   // ...More operators and rest of code...
};

template <typename Type, typename IP>
class Matrix : public Matrix_Base<Type, IP, Matrix<Type,IP> >
{
   // ... Matrix implementation ...
};

int main()
{
   Matrix<float ,MKL> f_mkl_mat;
   Matrix<double,MKL> d_mkl_mat;

   f_mkl_mat+=2.0; // will use general plus equal
   d_mkl_mat+=2.0; // will use specialized MKL/double version

   return 0;
}

Here I used class specialization instead of std::enable_if. I found that you were very inconsistent with the IP, Type, and Matrix_Type types in your examples, so I hope I use them correctly here.

As an aside in relation to the comments on std::enable_if. I would use the form

template<... , typename std::enable_if< some bool >::type* = nullptr> void func(...);

over

template<... , typename = std::enable_if< some bool >::type> void func(...);

as it enables you to do some function overloads you couldn't do with the other form.

Hope you can use it :)

EDIT 20/12-13: After re-reading my post i found that I should explicitly do CRTP (Curiously Recurring Template Pattern), which i added in the above code. I pass both Type and IP to Matrix_Base. If you find this tedious, one could instead provide a matrix traits class, from which Matrix_Base could take them out.

template<class A>
struct Matrix_Traits;

// Specialization for Matrix class
template<class Type, class IP>
struct Matrix_Traits<Matrix<Type,IP> >
{
   using type = Type;
   using ip   = IP;
};

Then Matrix_Basewould now only take one template argument, namely the matrix class itself, and get the types from the traits class

template<class Matrix_Type>
class Matrix_Base
{
   // Matrix / Scalar addition
   template <typename T>
   Matrix_Base& operator+=(const T value_) 
   { 
      // We now get Type and IP from Matrix_Traits
      return Implementation::Plus_Equal<typename Matrix_Traits<Matrix_Type>::type
                                      , typename Matrix_Traits<Matrix_Type>::ip
                                      >::apply(static_cast<Matrix_Type&>(*this), value_);
   }
};