Calculating log-sum-exp function in c++

2.8k views Asked by At

Are there any functions in the c++ standard library that calculate the log of the sum of exponentials? If not, how should I go about writing my own? Do you have any suggestions not mentioned in this wiki article?

I am concerned specifically with the possibility that the summands will underflow. This is the situation where you are exponentiating negative numbers that have a large absolute value. I am using c++11.

2

There are 2 answers

0
MSalters On BEST ANSWER

(C++11 variant of shuvro's code, using the Standard Library as per the question.)

template <typename Iter>
std::iterator_traits<Iter>::value_type
log_sum_exp(Iter begin, Iter end)
{
  using VT = std::iterator_traits<Iter>::value_type{};
  if (begin==end) return VT{};
  using std::exp;
  using std::log;
  auto max_elem = *std::max_element(begin, end);
  auto sum = std::accumulate(begin, end, VT{}, 
     [max_elem](VT a, VT b) { return a + exp(b - max_elem); });
  return max_elem + log(sum);
}

This version is more generic - it will work on any type of value, in any type of container, as long as it has the relevant operators. In particular, it will use std::exp and std::log unless the value type has its own overloads.

To be really robust against underflow, even for unknown numeric types, it probably would be beneficial to sort the values. If you sort the inputs, the very first value will be max_elem, so the first term of sum will be exp(VT{0} which presumably is VT{1}. That's clearly free from underflow.

1
Nazmul Khan On

If you want smaller code, this implementation could do the job:

double log_sum_exp(double arr[], int count) 
{
   if(count > 0 ){
      double maxVal = arr[0];
      double sum = 0;

      for (int i = 1 ; i < count ; i++){
         if (arr[i] > maxVal){
            maxVal = arr[i];
         }
      }

      for (int i = 0; i < count ; i++){
         sum += exp(arr[i] - maxVal);
      }
      return log(sum) + maxVal;

   }
   else
   {
      return 0.0;
   }
}

You can see a more robust implementation on Takeda 25's blog (Japanese) (or, see it in English, via Google Translate).