XLA allocates 4G of memory to this tensor. The size of which seems to scale with the batch size. Which doesn't make sense to me, it doesn't seem to be part of the model graph to be stored in HBM. I use a TPUv3.
I don't use any random operation apart from initialization of the model. Moreover I declared bfloat16 for all weights, but this is a u32 tensor.
Largest program allocations in hbm:
1. Size: 4.00G
Shape: u32[128,8,1024,1024]{3,2,1,0:T(8,128)}
Unpadded size: 4.00G
XLA label: %rng-bit-generator = (u32[2,128]{1,0:T(2,128)}, u32[128,8,1024,1024]{3,2,1,0:T(8,128)}) rng-bit-generator(u32[2,128]{1,0:T(2,128)} %fusion.2446), algorithm=rng_default
Allocation type: HLO temp
==========================
What could be the reason for the above allocation? I use pixelsnail from: https://github.com/kamenbliznashki/pixel_models
Questions:
- Why does this tensor have a type of
u32
, when all my weight / model definitions (including a global enviromental flag) use BF16? - Why is rng-bit-generator used?