Tensor format issue from converting Pytorch -> Onnx -> Tensorflow

1.5k views Asked by At

I have an issue with Tensorflow model that is converted from Pytorch -> Onnx -> Tensorflow. The issue is the converted Tensorflow model expects the input in Pytorch format that is (batch size, number channels, height, width) but not in Tensorflow format (batch size, height, width, number channel). Therefore, I cannot use the model to process further with Vitis AI.

So I would like to ask is there is any ways to convert this Pytorch input format to Tensorflow format by using tools from Onnx, Tensorflow 1, or others?

My code is as below:

Pytorch -> Onnx

from hardnet import hardnet
import torch
import onnx

ckpt = torch.load('../hardnet.pth')
model_state_dict = ckpt['model_state_dict']
optimizer_state_dict = ckpt['optimizer_state_dict']

model = hardnet(11)
model.load_state_dict(model_state_dict)
model.eval()     

dummy_input = torch.randn(1, 3, 1080, 1920)
input_names = ['input0']
output_names = ['output0']

output_file = 'hardnet.onnx'
torch.onnx.export(model, dummy_input, output_file, verbose=True,
    input_names=input_names, output_names=output_names,
    opset_version=11, keep_initializers_as_inputs=True)

onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
print('Passed Onnx')

Onnx -> Tensorflow 1 (using Tensorflow 1.15)

import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import onnx
from onnx_tf.backend import prepare

output_file = 'hardnet.onnx'
onnx_model = onnx.load(output_file)
output = prepare(onnx_model)
output.export_graph('hardnet.pb')
tf.compat.v1.disable_eager_execution()

def load_pb(path_to_pb: str):
    """From: https://stackoverflow.com/questions/51278213/what-is-the-use-of-a-pb-file-in-tensorflow-and-how-does-it-work
    """
    with tf.gfile.GFile(path_to_pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


graph = load_pb('hardnet.pb')
input = graph.get_tensor_by_name('input0:0')
output = graph.get_tensor_by_name('output0:0')
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img = cv2.imread('train_0.jpg', cv2.IMREAD_COLOR)
img = cv2.resize(img, (1920,  1080))

img = img/255
img = img - mean
img = img/std
img = np.expand_dims(img, -1)
# To Pytorch format.
img = np.transpose(img, (3, 2, 0, 1))
img = img

with tf.Session(graph=graph) as sess:
    pred = sess.run(output, {input: img})
1

There are 1 answers

1
Proko On BEST ANSWER

You could wrap your Pytorch model into another one that would do the transpose you want to have in TensorFlow. See the following example:

Let's say you have the following toy NN:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.rnn = nn.LSTM(10, 20, 2)

    def forward(self, x):
        h0 = torch.zeros(2, 3, 20)
        c0 = torch.zeros(2, 3, 20)
        return self.rnn(x, (h0, c0))

the exemplary pytorch/tensorflow input shape would be :

>> pytorch_input  = torch.randn(5, 3, 10)
>> tf_input  = torch.transpose(pytorch_input, 1, 2)

>> print("PyTorch input shape: ", pytorch_input.shape)
>> print("TensorFlow input shape: ", tf_input.shape)

PyTorch input shape:  torch.Size([5, 3, 10])
TensorFlow input shape:  torch.Size([5, 10, 3])

Now, the wrapper which will first transpose input and then pass transposed input to some model:

class NetTensorFlowWrapper(nn.Module):
    def __init__(self, main_module: nn.Module):
        super(NetTensorFlowWrapper, self).__init__()
        self.main_module = main_module
        
    def forward(self, x):
        x = torch.transpose(x, 1, 2)
        return self.main_module(x)

Then, this is possible:

net = Net()
net_wrapper = NetTensorFlowWrapper(net)

net(pytorch_input)
net_wrapper(tf_input)

and then, when you finally save your models like you did previously via torch.onnx.export and read their graph via onnx package (not torch.onnx) you will have...

  • for Net- input 5x3x10 and no transpose layer
graph torch-jit-export (
  %input0[FLOAT, 5x3x10]
 {
  %76 = Shape(%input0)
  %77 = Constant[value = <Scalar Tensor []>]()
  • for NetTensorFlowWrapper- input 5x10x3 and transpose layer
graph torch-jit-export (
  %input0[FLOAT, 5x10x3]
{
  %9 = Transpose[perm = [0, 2, 1]](%input0)
  %77 = Shape(%9)
  %78 = Constant[value = <Scalar Tensor []>]()
...