Load Amazon Sagemaker NTM model locally for inference

228 views Asked by At

I have trained a Sagemaker NTM model which is a neural topic model, directly on the AWS sagemaker platform. Once training is complete you are able to download the mxnet model files. Once unpacked the files contain:

  • params
  • symbol.json
  • meta.json

I have followed the docs on mxnet to load the model and have the following code:

sym, arg_params, aux_params = mx.model.load_checkpoint('model_algo-1', 0)
module_model = mx.mod.Module(symbol=sym, label_names=None, context=mx.cpu())

module_model.bind(
    for_training=False,
    data_shapes=[('data', (1, VOCAB_SIZE))]
)

module_model.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True) # must set allow missing true here or receive an error for a missing n_epoch var

I now try and use the model for inference using:

module_model.predict(x) # where x is a numpy array of size (1, VOCAB_SIZE)

The code runs, but the result is just a single value, where I expect a distribution over topics:

[11.060672]
<NDArray 1 @cpu(0)>

EDIT:

I have tried to load it using the Symbol API, but still no luck:

import warnings
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    deserialized_net = gluon.nn.SymbolBlock.imports('model_algo-1-symbol.json', ['data'], 'model_algo-1-0000.params', ctx=mx.cpu())

Error:

AssertionError: Parameter 'n_epoch' is missing in file: model_algo-1-0000.params, which contains parameters: 'logsigma_bias', 'enc_0_bias', 'projection_bias', ..., 'enc_1_weight', 'enc_0_weight', 'mean_bias', 'logsigma_weight'. Please make sure source and target networks have the same prefix.

Any help would be great!

1

There are 1 answers

1
Yury On

SageMaker does not support this use case. The model can be hosted on SageMaker for online inference or used to make predictions in batch with a transform job.

See more details:

  1. https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html
  2. https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html