Coregionalization in SVGP model in GPflow 2?

34 views Asked by At

I am using GPflow 2.8.1 and following the coregionalization demo here. I want to apply this technique to a sparse GP model (SVGP, as opposed to VGP as used in the linked demo).

In my data set, I am trying to estimate 6 processes (corresponding to 6 conditions) defined over trial numbers. On each trial, only a single process is measured, resulting in a binary response. I can successfully apply these techniques to a VGP model that matches this scenario:

import numpy as np
import gpflow as gp
import tensorflow as tf
from gpflow.ci_utils import reduce_in_tests

# X data are trial numbers. Only one condition is measured on each trial
num_trials = 200
num_cond = 6
order_cond = np.random.choice(
    np.arange(0, num_cond),
    size=(num_trials, 1),
    replace=True
)
x = np.hstack((
    np.arange(1, num_trials + 1).reshape((num_trials, 1)),
    order_cond
))

# Y data are binary responses
y = np.hstack((
    np.random.binomial(n=1, p=0.5, size=(num_trials, 1)),
    order_cond
))

# Force both `x` and `y` to have float64 type
x = tf.constant(x, dtype=tf.float64)
y = tf.constant(y, dtype=tf.float64)

# Base kernel
k = gp.kernels.Matern32(active_dims=[0])

# Coregion kernel
coreg = gp.kernels.Coregion(
    output_dim=num_cond, rank=num_cond, active_dims=[1]
)
kern = k * coreg

# Switched likelihood similar to that used in co-regionalization example:
# https://gpflow.github.io/GPflow/2.9.0/notebooks/advanced/coregionalisation.html
lik = gp.likelihoods.SwitchedLikelihood(
    [gp.likelihoods.Bernoulli()] * num_cond
)

# Build the VGP model
m = gp.models.VGP((x, y), kernel=kern, likelihood=lik)

# Fit the covariance function parameters
maxiter = reduce_in_tests(10000)
gp.optimizers.Scipy().minimize(
    m.training_loss,
    m.trainable_variables,
    options=dict(maxiter=maxiter),
    method="L-BFGS-B",
)

However, when I change the model implementation (and loss specification to the optimizer accordingly) so that I am using an SVGP, I get an error reading Node: 'GatherV2_4' indices[0] = 1 is not in [0, 1) occurring in line 106 in slice in gpflow/kernels/base.py. Here is code to reproduce that error:

import numpy as np
import gpflow as gp
import tensorflow as tf
from gpflow.ci_utils import reduce_in_tests

# X data are trial numbers. Only one condition is measured on each trial
num_trials = 200
num_cond = 6
order_cond = np.random.choice(
    np.arange(0, num_cond),
    size=(num_trials, 1),
    replace=True
)
x = np.hstack((
    np.arange(1, num_trials + 1).reshape((num_trials, 1)),
    order_cond
))

# Y data are binary responses
y = np.hstack((
    np.random.binomial(n=1, p=0.5, size=(num_trials, 1)),
    order_cond
))

# Force both `x` and `y` to have float64 type
x = tf.constant(x, dtype=tf.float64)
y = tf.constant(y, dtype=tf.float64)

# Base kernel
k = gp.kernels.Matern32(active_dims=[0])

# Coregion kernel
coreg = gp.kernels.Coregion(
    output_dim=num_cond, rank=num_cond, active_dims=[1]
)
kern = k * coreg

# Switched likelihood similar to that used in co-regionalization example:
# https://gpflow.github.io/GPflow/2.9.0/notebooks/advanced/coregionalisation.html
lik = gp.likelihoods.SwitchedLikelihood(
    [gp.likelihoods.Bernoulli()] * num_cond
)

# Build the SVGP model
m_sparse = gp.models.SVGP(
    kernel=kern,
    likelihood=lik,
    inducing_variable=np.linspace(1, num_trials, 40)[:, None]
)

# Specify loss function
data = (x, y)
loss_fn = m_sparse.training_loss_closure(data)

maxiter = reduce_in_tests(10000)
gp.optimizers.Scipy().minimize(
    loss_fn,
    m_sparse.trainable_variables,
    options=dict(maxiter=maxiter),
    method="L-BFGS-B",
)

I have a sneaking suspicion that this is not possible by default because of some difference in the SVGP and VGP implementation in GPflow 2, but I've seen some old comments for GPflow 1 that suggest it is possible, e.g., GitHub issue #563. On the other hand, GitHub issue #985 reports a similar error message as mine for GPflow 1.3.0 and it appears to be unresolved; it mentions an error in how active_dims is specified in Coregion which I believe does not apply here.

I'm using tensorflow 2.10.1 and numpy 1.25.1. Here's a full error trace:

Traceback (most recent call last):
  File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\tensorflow\python\util\traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\tensorflow\python\eager\execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'GatherV2_4' defined at (most recent call last):
    File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.3.2\plugins\python-ce\helpers\pydev\pydevd.py", line 2195, in <module>
      main()
    File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.3.2\plugins\python-ce\helpers\pydev\pydevd.py", line 2177, in main
      globals = debugger.run(setup['file'], None, None, is_module)
    File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.3.2\plugins\python-ce\helpers\pydev\pydevd.py", line 1489, in run
      return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
    File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.3.2\plugins\python-ce\helpers\pydev\pydevd.py", line 1496, in _exec
      pydev_imports.execfile(file, globals, locals)  # execute the script
    File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.3.2\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
      exec(compile(contents+"\n", file, 'exec'), glob, loc)
    File "C:\Users\scw8734\Desktop\pl_gp\try_coreg.py", line 78, in <module>
      gp.optimizers.Scipy().minimize(
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\optimizers\scipy.py", line 108, in minimize
      opt_result = scipy.optimize.minimize(
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\scipy\optimize\_minimize.py", line 696, in minimize
      res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\scipy\optimize\_lbfgsb_py.py", line 305, in _minimize_lbfgsb
      sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\scipy\optimize\_optimize.py", line 332, in _prepare_scalar_function
      sf = ScalarFunction(fun, x0, args, grad, hess,
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\scipy\optimize\_differentiable_functions.py", line 158, in __init__
      self._update_fun()
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\scipy\optimize\_differentiable_functions.py", line 251, in _update_fun
      self._update_fun_impl()
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\scipy\optimize\_differentiable_functions.py", line 155, in update_fun
      self.f = fun_wrapped(self.x)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\scipy\optimize\_differentiable_functions.py", line 137, in fun_wrapped
      fx = fun(np.copy(x), *args)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\scipy\optimize\_optimize.py", line 76, in __call__
      self._compute_if_needed(x, *args)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\scipy\optimize\_optimize.py", line 70, in _compute_if_needed
      fg = self.fun(x, *args)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\optimizers\scipy.py", line 154, in _eval
      loss, grad = _tf_eval(tf.convert_to_tensor(x))
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\optimizers\scipy.py", line 136, in _tf_eval
      if first_call:
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\optimizers\scipy.py", line 138, in _tf_eval
      loss, grads = _compute_loss_and_gradients(
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\optimizers\scipy.py", line 241, in _compute_loss_and_gradients
      loss = loss_closure()
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\models\training_mixins.py", line 142, in closure
      return training_loss(data)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\integration\tf.py", line 96, in wrapped_method
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 119, in wrapped_function
      if not get_enable_check_shapes():
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 120, in wrapped_function
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\models\training_mixins.py", line 107, in training_loss
      return self._training_loss(data)  # type: ignore[attr-defined]
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\integration\tf.py", line 96, in wrapped_method
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 119, in wrapped_function
      if not get_enable_check_shapes():
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 120, in wrapped_function
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\models\model.py", line 76, in _training_loss
      return -(self.maximum_log_likelihood_objective(*args, **kwargs) + self.log_prior_density())
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\integration\tf.py", line 96, in wrapped_method
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 119, in wrapped_function
      if not get_enable_check_shapes():
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 120, in wrapped_function
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\models\svgp.py", line 161, in maximum_log_likelihood_objective
      return self.elbo(data)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\integration\tf.py", line 96, in wrapped_method
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 119, in wrapped_function
      if not get_enable_check_shapes():
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 120, in wrapped_function
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\models\svgp.py", line 173, in elbo
      f_mean, f_var = self.predict_f(X, full_cov=False, full_output_cov=False)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\integration\tf.py", line 96, in wrapped_method
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 119, in wrapped_function
      if not get_enable_check_shapes():
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 120, in wrapped_function
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\models\svgp.py", line 254, in predict_f
      Xnew, full_cov=full_cov, full_output_cov=full_output_cov
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\integration\tf.py", line 96, in wrapped_method
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 119, in wrapped_function
      if not get_enable_check_shapes():
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 120, in wrapped_function
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\posteriors.py", line 255, in fused_predict_f
      mean, cov = self._conditional_fused(
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\integration\tf.py", line 96, in wrapped_method
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 119, in wrapped_function
      if not get_enable_check_shapes():
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 120, in wrapped_function
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\posteriors.py", line 835, in _conditional_fused
      Kmm = covariances.Kuu(self.X_data, self.kernel, jitter=default_jitter())  # [M, M]
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\multipledispatch\dispatcher.py", line 278, in __call__
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 119, in wrapped_function
      if not get_enable_check_shapes():
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 120, in wrapped_function
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\covariances\kuus.py", line 32, in Kuu_kernel_inducingpoints
      Kzz = kernel(inducing_variable.Z)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\kernels\base.py", line 290, in __call__
      [k(X, X2, full_cov=full_cov, presliced=presliced) for k in self.kernels]
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\kernels\base.py", line 290, in __call__
      [k(X, X2, full_cov=full_cov, presliced=presliced) for k in self.kernels]
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\integration\tf.py", line 96, in wrapped_method
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 119, in wrapped_function
      if not get_enable_check_shapes():
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 120, in wrapped_function
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\kernels\base.py", line 206, in __call__
      if not presliced:
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\kernels\base.py", line 207, in __call__
      X, X2 = self.slice(X, X2)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\integration\tf.py", line 96, in wrapped_method
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 119, in wrapped_function
      if not get_enable_check_shapes():
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\check_shapes\decorator.py", line 120, in wrapped_function
      return func(*args, **kwargs)
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\kernels\base.py", line 101, in slice
      if isinstance(dims, slice):
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\kernels\base.py", line 105, in slice
      elif dims is not None:
    File "C:\ProgramData\miniconda3\envs\tf\lib\site-packages\gpflow\kernels\base.py", line 106, in slice
      X = tf.gather(X, dims, axis=-1)
Node: 'GatherV2_4'
indices[0] = 1 is not in [0, 1)
     [[{{node GatherV2_4}}]] [Op:__inference__tf_eval_3682]

Process finished with exit code 1

I have tried changing the optimizers, changing the size and structure of the data (e.g., specifying only two conditions, like in the linked demo), changing whether condition labels are included as columns in x and/or y, changing the arguments of Coregion (output_dim, rank, and active_dims), and using Gaussian distributed data instead of Bernoulli distributed (and changing the switched likelihood accordingly). It does not appear that these changes had any influence over the cause of the above error message. In general, as noted above, the VGP model works across these manipulations whereas the SVGP model does not. I'm hoping to benefit from the expertise of someone familiar with the SVGP implementation.

0

There are 0 answers