How to speed-up Log-sum-exp function over multidimensional R Arrays?

422 views Asked by At

I am working on speeding up a program I wrote in R. The code involves repeatedly computing LogSumExp over multidimensional arrays, i.e computing s_lnj = exp(u_lnj) / (1 + sum_k exp(u_lnk)). The base R version of the code I am trying to increase the speed of is the following:

log_sum_exp_func <- function(vec){
  max_vec <- max(vec)
  return(max_vec + log(sum(exp(vec-max_vec))))
}

compute_share_from_utils_func <- function(u_lnj){
  ### get dimensions
  L <- dim(u_lnj)[1]; n_poly <- dim(u_lnj)[2]; J <- dim(u_lnj)[3]
  
  ### compute denominator of share, 1 + sum exp utils
  den_ln <- 1 + exp(apply(u_lnj, c(1,2), log_sum_exp_func))
  den_lnj <- array(rep(den_ln, J), dim = c(L, n_poly, J))
  
  ### take ratio of utils and denominator
  s_lnj <- exp(u_lnj) / den_lnj
  return(s_lnj)
}

I tried to use xtensor and Rcpp to speed things up, but ran into several issues. The Rcpp code I wrote is the following

// [[Rcpp::depends(xtensor)]]
// [[Rcpp::plugins(cpp14)]]
#include <numeric>                    // Standard library import for std::accumulate
#define STRICT_R_HEADERS              // Otherwise a PI macro is defined in R
#include "xtensor/xmath.hpp"          // xtensor import for the C++ universal functions
#include "xtensor/xarray.hpp"
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"
#include "xtensor-r/rarray.hpp"       // R bindings
#include <Rcpp.h>

using namespace Rcpp;

// [[Rcpp::export]]
double cxxlog_sum_exp_vec(xt::rarray<double>& m)
{
  auto shape_m = m.shape();
  double maxvec = xt::amax(m)[0];
  xt::rarray<double> arr_maxvec = maxvec * xt::ones<double>(shape_m);
  xt::rarray<double> vec_min_max = m - arr_maxvec;
  xt::rarray<double> exp_vec_min_max = xt::exp(vec_min_max);
  double sum_exp = xt::sum(exp_vec_min_max)[0];
  double log_sum_exp = std::log(sum_exp);
  return log_sum_exp + maxvec; 
}

// [[Rcpp::export]]
xt::rarray<double> cxxshare_from_utils(xt::rarray<double>& u_lnj)
{
  int L = u_lnj.shape(0);
  int N = u_lnj.shape(1);
  int J = u_lnj.shape(2);
  xt::rarray<double> res = xt::ones<double>({L,N,J});
  for (std::size_t l = 0; l < u_lnj.shape()[0]; ++l)
  {
    for (std::size_t n = 0; n < u_lnj.shape()[1]; ++n)
    {
      xt::rarray<double> utils_j = xt::view(u_lnj, l, n, xt::all());
      double inv_lse = 1 / (1 + std::exp(cxxlog_sum_exp_vec(utils_j)));
      for (std::size_t j = 0; j < J; ++j)
      {
        res(l, n, j) = std::exp(u_lnj(l, n, j)) * inv_lse;
      }
    }
  }
  return res;
}

The Rcpp implementation does seem to yield the same results as the base R code, however it seems to encounter problems whenever the dimensions of the input array increase. My R Session fails if I run

L <- 100
n <- 100
J <- 200
u_lnj <- array(rnorm(L*n*J,0,2), dim = c(L, n, J))
test <- cxxshare_from_utils(u_lnj)

But the code runs fine for L, n, J = 10,10,20 for instance. Moreover, the C++ implementation of log_sum_exp does not seem to outperform the base R version that much.

EDIT: I could not figure out what was the issue with the way I am using xtensor. But I did get some speed up with the following RcppArmadillo code. The drawback of this version is that is likely not as robust to overflow as the base R function relying on Log Sum Exp.

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::export]]
arma::cube cxxarma_share_from_utils(arma::cube u_lnj) {
  
  // Extract the different dimensions
  
  // Normal Matrix dimensions
  unsigned int L = u_lnj.n_rows;
  unsigned int N = u_lnj.n_cols;
  
  // Depth of Array
  unsigned int J = u_lnj.n_slices;
  
  //resulting cube
  arma::cube s_lnj = arma::exp(u_lnj);
  for (unsigned int l = 0; l < L; l++) {
    
    for (unsigned int n = 0; n < N; n++) {
      
      double den = 1 / (1 + arma::accu(s_lnj.subcube(arma::span(l), arma::span(n), arma::span())));
      
      for (unsigned int j = 0; j < J; j++) {
        
        s_lnj(l, n, j) = s_lnj(l, n, j) * den;
      }
    }
  }
  return s_lnj;
}
0

There are 0 answers