pytorch model -> onnx -> tensorflow

1.9k views Asked by At

I made pytorch model for machine learning and I want to convert this to tensorflow model. I think I converted pytorch model to onnx. so now I want to convert onnx to tensorflow(pb) model.

Here is my code.

import onnx
from onnx_tf.backend import prepare
onnx_model = onnx.load("./sales_predict_model.onnx")  # load onnx model
tf_rep = prepare(onnx_model)  # prepare tf representation
tf_rep.export_graph("sales_predict_model.pb")  # export the model

And I got error here.

AssertionError: Tried to export a function which references untracked object Tensor("1076:0", shape=(), dtype=resource)`.
TensorFlow objects (e.g. `tf.Variable`) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.

I am using TensorFlow version 1.14.0.

Maybe Onnx version 1.7.0 (I checked this pip show onnx)

onnx-tf version 1.6.0 (pip show onnx-tf)

Here is the code below when I converted pytorch model to ONNX.

class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)

        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_()


        x = torch.tensor(x, dtype = torch.float32)
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))

        out = out[:, -1, :]
        out = self.fc(out) 
        return out

    
PATH = './model/file300input5lb1ep100drop0W_V7' 

model = torch.load(PATH + 'model.pt')
model.eval()

x = torch.randn(1, 1, 5, requires_grad=True)
torch_out = model(x)

torch.onnx.export(model,
                 x,
                 "sales_predict_model.onnx",
                 export_params=True,
                 opset_version=10,
                 do_constant_folding=True,
                 input_names = ['input'],
                 output_names =  ['x'],
                 dynamic_axes={'input' : {0 : 'batch_size'},
                              'output' : {0 : 'batch_size'}})

0

There are 0 answers