Clarification in understanding TorchScripts and JIT on PyTorch

844 views Asked by At

Just wanted to clarify my understanding of the way JIT and TorchScripts work and clarify a particular example.

So if im not wrong torch.jit.script converts my method or module to TorchScript. I can use my TorchScript compiled module in an environment outside python but can also just use it within python with supposed improvements and optimizations. A similar case with torch.jit.trace where the weights and operations are traced instead but follows roughly a similar idea.

If this is the case, the TorchScripted module should, in general, be at least as fast as the python interpreter typical inference time. On experimenting a bit I observed that it is most often slower then the typical interpreter inference times and on reading up a bit found out that apparently the TorchScripted module needs to be "warmed up" a bit, to achieve its best performance. On doing so I saw no changes as such to the inference times, it got better but not enough to call an improvement over the typical way of doing things(python interpreter). Furthermore, I used a third party library called torch_tvm, which when enabled supposedly halves the inference times for any way of jit-ing the module.

None of this has happened until now and I am not really able to say why.

The following is my sample code in case I have done something wrongly -

class TrialC(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(1024, 2048)
        self.l2 = nn.Linear(2048, 4096)
        self.l3 = nn.Linear(4096, 4096)
        self.l4 = nn.Linear(4096, 2048)
        self.l5 = nn.Linear(2048, 1024)

    def forward(self, input):
        out = self.l1(input)
        out = self.l2(out)
        out = self.l3(out)
        out = self.l4(out)
        out = self.l5(out)
        return out 

if __name__ == '__main__':
    # Trial inference input 
    TrialC_input = torch.randn(1, 1024)
    warmup = 10

    # Record time for typical inference 
    model = TrialC()
    start = time.time()
    model_out = model(TrialC_input)
    elapsed = time.time() - start 

    # Record the 10th inference time (10 warmup) for the optimized model in TorchScript 
    script_model = torch.jit.script(TrialC())
    for i in range(warmup):
        start_2 = time.time()
        model_out_check_2 = script_model(TrialC_input)
        elapsed_2 = time.time() - start_2

    # Record the 10th inference time (10 warmup) for the optimized model in TorchScript + TVM optimization
    torch_tvm.enable()
    script_model_2 = torch.jit.trace(TrialC(), torch.randn(1, 1024))
    for i in range(warmup):
        start_3 = time.time()
        model_out_check_3 = script_model_2(TrialC_input)
        elapsed_3 = time.time() - start_3 
    
    print("Regular model inference time: {}s\nJIT compiler inference time: {}s\nJIT Compiler with TVM: {}s".format(elapsed, elapsed_2, elapsed_3))

And the following are the results of the above code on my CPU -

Regular model inference time: 0.10335588455200195s
JIT compiler inference time: 0.11449170112609863s
JIT Compiler with TVM: 0.10834860801696777s

Any help or clarity on this would really be appreciated!

0

There are 0 answers