I'm trying to understand how did Haiku achieve 2x speedup when training ResNet50 on ImageNet https://github.com/deepmind/dm-haiku/tree/main/examples/imagenet using the Deepmind JMP lib https://github.com/deepmind/jmp, and how to replicate this with other networks.
1- If we compare the time needed for a matrix multiplication in float32
and float16
on a GPU, we can barely see a 5% speedup, how can this become a 2x speedup as we scale the number of operations ?
2- Why do we need to apply mixed precision also on the network ? If you data and parameters are in float16
then aren't all ops inside the neural network in float16
too ?
3- Can I hope to see any speedup with a small fully connected network ? deep fully connected network ? Or only big vision-related neural network optimized specifically for that ?