I am following the Pyro introductory tutorial in forecasting, and trying to access the learned parameters after training the model, I get different results using different access methods for some of them (while getting identical results for others).
Here is the stripped-down reproducible code from the tutorial:
import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, Forecaster
pyro.enable_validation(True)
pyro.clear_param_store()
pyro.__version__
# '1.3.1'
torch.__version__
# '1.5.0+cu101'
# import & prepare the data
dataset = load_bart_od()
T, O, D = dataset["counts"].shape
data = dataset["counts"][:T // (24 * 7) * 24 * 7].reshape(T // (24 * 7), -1).sum(-1).log()
data = data.unsqueeze(-1)
T0 = 0 # begining
T2 = data.size(-2) # end
T1 = T2 - 52 # train/test split
# define the model class
class Model1(ForecastingModel):
def model(self, zero_data, covariates):
data_dim = zero_data.size(-1)
feature_dim = covariates.size(-1)
bias = pyro.sample("bias", dist.Normal(0, 10).expand([data_dim]).to_event(1))
weight = pyro.sample("weight", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))
prediction = bias + (weight * covariates).sum(-1, keepdim=True)
assert prediction.shape[-2:] == zero_data.shape
noise_scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5).expand([1]).to_event(1))
noise_dist = dist.Normal(0, noise_scale)
self.predict(noise_dist, prediction)
# fit the model
pyro.set_rng_seed(1)
pyro.clear_param_store()
time = torch.arange(float(T2)) / 365
covariates = torch.stack([time], dim=-1)
forecaster = Forecaster(Model1(), data[:T1], covariates[:T1], learning_rate=0.1)
So far so good; now, I want to inspect the learned latent parameters stored in Paramstore
. Seems there are more than one ways to do this; using the get_all_param_names()
method:
for name in pyro.get_param_store().get_all_param_names():
print(name, pyro.param(name).data.numpy())
I get
AutoNormal.locs.bias [14.585433]
AutoNormal.scales.bias [0.00631594]
AutoNormal.locs.weight [0.11947815]
AutoNormal.scales.weight [0.00922901]
AutoNormal.locs.noise_scale [-2.0719821]
AutoNormal.scales.noise_scale [0.03469057]
But using the named_parameters()
method:
pyro.get_param_store().named_parameters()
gives the same values for the location (locs
) parameters, but different values for all scales
ones:
dict_items([
('AutoNormal.locs.bias', Parameter containing: tensor([14.5854], requires_grad=True)),
('AutoNormal.scales.bias', Parameter containing: tensor([-5.0647], requires_grad=True)),
('AutoNormal.locs.weight', Parameter containing: tensor([0.1195], requires_grad=True)),
('AutoNormal.scales.weight', Parameter containing: tensor([-4.6854], requires_grad=True)),
('AutoNormal.locs.noise_scale', Parameter containing: tensor([-2.0720], requires_grad=True)),
('AutoNormal.scales.noise_scale', Parameter containing: tensor([-3.3613], requires_grad=True))
])
How is this possible? According to the documentation, Paramstore
is a simple key-value store; and there are only these six keys in it:
pyro.get_param_store().get_all_param_names() # .keys() method gives identical result
# result
dict_keys([
'AutoNormal.locs.bias',
'AutoNormal.scales.bias',
'AutoNormal.locs.weight',
'AutoNormal.scales.weight',
'AutoNormal.locs.noise_scale',
'AutoNormal.scales.noise_scale'])
so, there is no way that one method access one set of items and the other a different one.
Am I missing something here?
Here is the situation, as revealed in the Github thread I opened in parallel with this question...
Paramstore
is no more just a simple key-value store - it also performs constraint transformations; quoting a Pyro developer from the above link:As a consequence, it turns out that, while
pyro.param()
returns the results in the constrained (user-facing) space, the older methodnamed_parameters()
returns the unconstrained (i.e. for internal use only) values, hence the apparent discrepancy.It's not difficult to verify indeed that the
scales
values returned by the two methods above are related by a logarithmic transformation:Why does this discrepancy affect only
scales
parameters? That's becausescales
(i.e. essentially variances) are by definition constrained to be positive; that doesn't hold forlocs
(i.e. means), which are not constrained, hence the two representations coincide for them.As a result of the question above, a new bullet has now been added in the
Paramstore
documentation, giving a relevant hint:as well as in the documentation of the
named_parameters()
method of the old interface: