Readit News logoReadit News
ein0p · 2 years ago
Hmm, 3% market share framework with barely any ecosystem and single vendor accelerators (Jax on TPU) vs a 60% market share framework with insanely rich ecosystem and ability to debug code on your own workstation (PyTorch on GPU)? In my informed opinion most people should use the latter unless they like wasting time on shiny things
_ntka · 2 years ago
JAX is used by almost every large genAI player (Anthropic, Cohere, DeepMind, Midjourney, Character.ai, XAi, Apple, etc.). Its actual market share in foundation models development is something like 80%.

Also JAX is not just for TPU. It's mainly for GPU. It's usually 2-3x faster than torch on GPU: https://keras.io/getting_started/benchmarks/

Far more industry users of JAX use it on GPU compared to TPU.

kaycebasques · 2 years ago
Are there any resources going into detail about why the big players prefer JAX? I've heard this before but have never seen explanations of why/how this happened.
ein0p · 2 years ago
Are you on one of those (usually small) teams? No? Then it’s probably not a good choice for you.
phyalow · 2 years ago
For smaller scale projects, its basically a no brainer still to use pytorch.
lern_too_spel · 2 years ago
The parts of your comment that have any truth in them could have been said of PyTorch when it came out. People wasting time on shiny things is how we get better tools.
ein0p · 2 years ago
Nope. When PyTorch came out it was the only option that was easy to use and debug. Your alternative was TF1 which sucked so bad people dropped it like it has syphilis, and Google had to add eager mode in TF2, ruining performance in the process, later. I would know, I was one of those people. It really was a watershed moment in AI research productivity
drdirk · 2 years ago
Jax uses the XLA compiler which is compatible with GPU and CPU.
ein0p · 2 years ago
Yes, but its performance on GPU leaves much to be desired, and 20 times as much research comes out on PyTorch. Would you rather just build on that or laboriously port and debug the models and their weights, losses, dataset readers, training regimes etc etc?