I made pytorch model for machine learning and I want to convert this to tensorflow model. I think I converted pytorch model to onnx. so now I want to convert onnx to tensorflow(pb) model.
Here is my code.
import onnx
from onnx_tf.backend import prepare
onnx_model = onnx.load("./sales_predict_model.onnx") # load onnx model
tf_rep = prepare(onnx_model) # prepare tf representation
tf_rep.export_graph("sales_predict_model.pb") # export the model
And I got error here.
AssertionError: Tried to export a function which references untracked object Tensor("1076:0", shape=(), dtype=resource)`.
TensorFlow objects (e.g. `tf.Variable`) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.
I am using TensorFlow version 1.14.0.
Maybe Onnx version 1.7.0 (I checked this pip show onnx
)
onnx-tf version 1.6.0 (pip show onnx-tf
)
Here is the code below when I converted pytorch model to ONNX.
class LSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
super(LSTM, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_()
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_()
x = torch.tensor(x, dtype = torch.float32)
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
out = out[:, -1, :]
out = self.fc(out)
return out
PATH = './model/file300input5lb1ep100drop0W_V7'
model = torch.load(PATH + 'model.pt')
model.eval()
x = torch.randn(1, 1, 5, requires_grad=True)
torch_out = model(x)
torch.onnx.export(model,
x,
"sales_predict_model.onnx",
export_params=True,
opset_version=10,
do_constant_folding=True,
input_names = ['input'],
output_names = ['x'],
dynamic_axes={'input' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}})