Variational Inference with Normalizing Flows in Tensorflow Probability

340 views Asked by At

In the last time I've read a little bit about using normalizing flows to improve variational inference f.e. Link1 Link2.

Tensorflow probability already offers RealNVP and MaskedAutoregressiveFlow in the bijector submodule as well as an AutoregressiveTransform layer in the layers submodule. Therefore I thought it would be easy and straight forward to build a Bayesian Neural Network trained with variational inference and a posterior given by a normalizing flow using Tensorflow Probability.

Starting from one of the tutorials (Link) I was able to build a BNN with a mean_field_posterior.

Then things started to get complicated. I wrote the following function, adapted from this example (Link), to produce a posterior where a normal distribution is transformed using masked autoregressive flows.

def posterior_vi_maf(kernel_size, bias_size=0, dtype=None):
    n = kernel_size + bias_size
    return tf.keras.Sequential(
        [
            tfk.layers.InputLayer(input_shape=(0,), dtype=tf.float32),
            tfpl.DistributionLambda(
                lambda t: tfd.MultivariateNormalDiag(
                    loc=tf.zeros(tf.concat([tf.shape(t)[:-1], [4]], axis=0)),
                    scale_diag=tf.ones(4),
                ),
            ),
            tfp.layers.AutoregressiveTransform(
                tfb.AutoregressiveNetwork(
                    params=2, hidden_units=[10, 10], activation="relu"
                )
            ),
        ]
    )

Comparing shape and output of posterior_vi_maf and posterior_mean_field it seems as if everything should, from a technical point of view, work.

p1 = posterior_vi_maf(16, 4, dtype=tf.float32)
p2 = posterior_mean_field(16, 4, dtype=tf.float32)
assert p1(x).shape == p2(x).shape
assert isinstance(p1(x), tfd.Distribution)
assert isinstance(p2(x), tfd.Distribution)

Unfortunately running the training script (see bottom) raises the following error message:

ValueError: Shape must be rank 1 but is rank 2 for '{{node dense_variational/BiasAdd}} = BiasAdd[T=DT_FLOAT, data_format="NHWC"](dense_variational/MatMul, dense_variational/split:1)' with input shapes: [?,?,16], [?,16].

Any suggestions why this happens and / or how I can fix this?

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
tfk = tf.keras
tfl = tf.keras.layers
tfpl = tfp.layers
tfb = tfp.bijectors

np.random.seed(2)
tf.random.set_seed(2)
N = 100

x = tfd.Normal(loc=0, scale=1).sample(N)
y = tfd.Normal(loc=x * 0.5, scale=0.3).sample()


def negloglik(y_true, y_pred):
    nll = -tf.reduce_mean(y_pred.log_prob(y_true))
    return nll


def prior(kernel_size, bias_size=0, dtype=None):
    n = kernel_size + bias_size
    return lambda t: tfd.Independent(
        tfd.Normal(loc=tf.zeros(n, dtype=dtype), scale=1),
        reinterpreted_batch_ndims=1,
    )


def posterior_mean_field(kernel_size, bias_size=0, dtype=None):
    n = kernel_size + bias_size
    c = np.log(np.expm1(1.0))
    return tf.keras.Sequential(
        [
            tfp.layers.VariableLayer(2 * n, dtype=dtype),
            tfp.layers.DistributionLambda(
                lambda t: tfd.Independent(
                    tfd.Normal(
                        loc=t[..., :n],
                        scale=1e-5 + 0.01 * tf.nn.softplus(c + t[..., n:]),
                    ),
                    reinterpreted_batch_ndims=1,
                )
            ),
        ]
    )


model = tf.keras.Sequential(name="small_vi_nn")
model.add(tfl.Input(1))
model.add(
    tfpl.DenseVariational(
        units=16,
        make_posterior_fn=posterior_vi_maf,
        make_prior_fn=prior,
        kl_weight=1 / N,
        kl_use_exact=False,
        activation="relu",
    )
)
model.add(
    tfpl.DenseVariational(
        units=2,
        make_posterior_fn=posterior_mean_field,
        make_prior_fn=prior,
        kl_weight=1 / N,
        kl_use_exact=False,
    )
)
model.add(
    tfpl.DistributionLambda(
        make_distribution_fn=lambda t: tfd.Normal(
            loc=t[:, 0],
            scale=1e-3 + tf.math.softplus(0.05 * t[:, 0]),
        )
    )
)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.01, amsgrad=True)


model.compile(optimizer=optimizer, loss=negloglik)
model.fit(
    x,
    y,
    epochs=2,
    shuffle=True,
    verbose=True,
)
0

There are 0 answers