I am making a procedurally generated terrain for which I used the Classic Perlin Noise give here. Now to calculate the normal to the terrain I need the differential of this function, so I rewrote the function in python and used jax.grad to differentiated it. I then created a computation graph as shown here but it was too complicated to turn into code manually.
The closest solution I have found is the jax2tex library but its depricated and doesn't work anymore.
Since I need it for a compute shader, I can't run it using an XLA runtime in C++, I need to write glsl code from it.
My Question is : Is there something like Jax2Tex that I can use that can help me understand the XLA code in an easy to understand way so I can write glsl code from it.