Hmm, 3% market share framework with barely any ecosystem and single vendor accelerators (Jax on TPU) vs a 60% market share framework with insanely rich ecosystem and ability to debug code on your own workstation (PyTorch on GPU)? In my informed opinion most people should use the latter unless they like wasting time on shiny things
JAX is used by almost every large genAI player (Anthropic, Cohere, DeepMind, Midjourney, Character.ai, XAi, Apple, etc.). Its actual market share in foundation models development is something like 80%.
Are there any resources going into detail about why the big players prefer JAX? I've heard this before but have never seen explanations of why/how this happened.
The parts of your comment that have any truth in them could have been said of PyTorch when it came out. People wasting time on shiny things is how we get better tools.
Nope. When PyTorch came out it was the only option that was easy to use and debug. Your alternative was TF1 which sucked so bad people dropped it like it has syphilis, and Google had to add eager mode in TF2, ruining performance in the process, later. I would know, I was one of those people. It really was a watershed moment in AI research productivity
Yes, but its performance on GPU leaves much to be desired, and 20 times as much research comes out on PyTorch. Would you rather just build on that or laboriously port and debug the models and their weights, losses, dataset readers, training regimes etc etc?
Also JAX is not just for TPU. It's mainly for GPU. It's usually 2-3x faster than torch on GPU: https://keras.io/getting_started/benchmarks/
Far more industry users of JAX use it on GPU compared to TPU.