Mixed Precision Training using Jax

636 views Asked by At

I'm trying to understand how did Haiku achieve 2x speedup when training ResNet50 on ImageNet https://github.com/deepmind/dm-haiku/tree/main/examples/imagenet using the Deepmind JMP lib https://github.com/deepmind/jmp, and how to replicate this with other networks.

1- If we compare the time needed for a matrix multiplication in float32 and float16 on a GPU, we can barely see a 5% speedup, how can this become a 2x speedup as we scale the number of operations ?

2- Why do we need to apply mixed precision also on the network ? If you data and parameters are in float16 then aren't all ops inside the neural network in float16 too ?

3- Can I hope to see any speedup with a small fully connected network ? deep fully connected network ? Or only big vision-related neural network optimized specifically for that ?

0

There are 0 answers