What is the main difference between flax (google) and dm-haiku (deepmind)?

7.8k views Asked by At

What are main differences between flax and dm-haiku?

From theirs descriptions:

  • Flax, a neural network library for JAX
  • Haiku, a neural network library for JAX inspired by Sonnet

Question:

Which one jax-based library should I pick to implement, let's say DeepSpeech model (consists of CNN layers + LSTM layers + FC) and ctc-loss?


UPD.

Found the explanation about differences from the developer of dm-haiku:

Flax is a bit more batteries included, and comes with optimizers, mixed precision and some training loops (I am told these are decoupled and you can use as much or as little as you want). Haiku aims to just solve NN modules and state management, it leaves other parts of the problem to other libraries (e.g. optax for optimization).

Haiku is designed to be a port of Sonnet (a TF NN library) to JAX. So Haiku is a better choice if (like DeepMind) you have a significant amount of Sonnet+TF code that you might want to use in JAX and you want migrating that code (in either direction) to be as easy as possible.

I think otherwise it comes down to personal preference. Within Alphabet there are 100s of researchers using each library so I don't think you can go wrong either way. At DeepMind we have standardised on Haiku because it makes sense for us. I would suggest taking a look at the example code provided by both libraries and seeing which matches your preferences for structuring experiments. I think you'll find that moving code from one library to another is not very complicated if you change your mind in the future.


The original question is still relevant.

1

There are 1 answers

2
Robin On

I recently encountered the same question, and I favored Haiku since I think their implementation (see Flax Dense() versus Haiku Linear() ) is closer to the original JAX spirit (i.e. chaining init and predict functions and keeping trace of the parameters in Pytrees) which makes it easier for me to modify things.

But if you do not wish to modify things in depth, the best way to choose is to find a nice blog post on CNNs + LSTMs with Flax/Haiku and to stick with it. My general opinion is that both libraries are really close even if I prefer the more modular way Haiku ( + Optax + Rlax + Chex + ...) is built.