Readit News logoReadit News
srush · 2 years ago
PyTorch is a generationally important project. I've never seen a tool that is so inline with how researchers learn and internalize a subject. Teaching Machine Learning before and after its adoption has been a completely different experience. Never can be said enough how cool it is that Meta fosters and supports it.

Viva PyTorch! (Jax rocks too)

deepsquirrelnet · 2 years ago
This is exactly why I gravitated to it so quickly. The first time I looked at pytorch code it was immediately obvious what the abstractions meant and how to use them to write a model architecture.

Jax looks like something completely different to me. Maybe I’m dumb and probably not the target audience, but it occurs to me that very few people are. When I read about using Jax, I find recommendations for a handful of other libraries that make it more useable. Which of those I choose to learn is not entirely obvious because they all seem to create a very fragmented ecosystem with code that isn’t portable.

I’m still not sure why I’d spend my time learning Jax, especially when it seems like most of the complaints from the author don’t really separate out training and inference, which don’t necessarily need to occur from the same framework.

6gvONxR4sf7o · 2 years ago
Honestly, when I turn to JAX, I generally do it without a framework. It’s like asking for a framework to wrap numpy to me. Just JAX plus optax is sufficient for me in the cases I turn to it.
PostOnce · 2 years ago
Torch was originally a Lua project, hence why pytorch is called pytorch and not just torch.

In another timeline AI would have made Lua popular.

The best part is it trampled TensorFlow which I personally find obtuse.

n7g · 2 years ago
> In another timeline AI would have made Lua popular.

I wonder if it'd have been hated more than Python is - especially with the 1-based indexing...

pjmlp · 2 years ago
Additionally, nowadays it also has Java and C++ bindings to the same native libraries, so others can enjoy performance without having to rewrite their research afterwards.
smhx · 2 years ago
the author got a couple of things wrong, that are worth pointing out:

1. PyTorch is going all-in on torch.compile -- Dynamo is the frontend, Inductor is the backend -- with a strong default Inductor codegen powered by OpenAI Triton (which now has CPU, NVIDIA GPU and AMD GPU backends). The author's view that PyTorch is building towards a multi-backend future isn't really where things are going. PyTorch supports extensibility of backends (including XLA), but there's disproportionate effort into the default path. torch.compile is 2 years old, XLA is 7 years old. Compilers take a few years to mature. torch.compile will get there (and we have reasonable measures that the compiler is on track to maturity).

2. PyTorch/XLA exists, mainly to drive a TPU backend for PyTorch, as Google gives no other real way to access the TPU. It's not great to try shoe-in XLA as a backend into PyTorch -- as XLA fundamentally doesn't have the flexibility that PyTorch supports by default (especially dynamic shapes). PyTorch on TPUs is unlikely to ever have the experience of JAX on TPUs, almost by definition.

3. JAX was developed at Google, not at Deepmind.

n7g · 2 years ago
Hey, thanks for actually engaging with the blog's points instead of "Google kills everything it touches" :)

1. I'm well aware of the PyTorch stack, but this point:

> PyTorch is building towards a multi-backend future isn't really where things are going

>PyTorch supports extensibility of backends (including XLA)

Is my problem. Those backends just never integrate well as I mentioned in the blogpost. I'm not sure if you've ever gone into the weeds, but there are so many (often undocumented) sharp edges when using different backends that they never really work well. For example, how bad Torch:XLA is and the nightmare inducing bugs & errors with it.

> torch.compile is 2 years old, XLA is 7 years old. Compilers take a few years to mature

That was one of my major points - I don't think leaning on torch.compile is the best idea. A compiler would inherently place restrictions that you have to work-around.

This is not dynamic, nor flexible - and it flies in the face of torch's core philosophies just so they can offer more performance to the big labs using PyTorch. For various reasons, I dislike pandering to the rich guy instead of being an independent, open-source entity.

2. Torch/XLA is indeed primarily meant for TPUs - like the quoted announcement, where they declare to be ditching TF:XLA in favour of OpenXLA. But there's still a very real effort to get it working on GPUs - infact, a lab on twitter declared that they're using Torch/XLA on GPUs and will soon™ release details.

XLA's GPU support is great, its compatible across different hardware, its optimized and mature. In short, its a great alternative to the often buggy torch.compile stack - if you fix the torch integration.

So I won't be surprised if in the long-term they lean on XLA. Whether that's a good direction or not is upto the devs to decide unfortunately - not the community.

3. Thank you for pointing that out. I'm not sure about the history of JAX (maybe might make for a good blogpost for JAX devs to write someday), but it seems that it was indeed developed at Google research, though also heavily supported + maintained by DeepMind.

Appreciate you giving the time to comment here though :)

smhx · 2 years ago
If you're the author, unfortunately I have to say that the blog is not well-written -- misinformed about some of the claims and has a repugnant click-baity title. you're getting the attention and clicks, but probably losing a lot of trust among people. I didn't engage out of choice, but because of a duty to respond to FUD.

> > torch.compile is 2 years old, XLA is 7 years old. Compilers take a few years to mature

> That was one of my major points - I don't think leaning on torch.compile is the best idea. A compiler would inherently place restrictions that you have to work-around.

There are plenty of compilers that place restrictions that you barely notice. gcc, clang, nvcc -- they're fairly flexible, and "dynamic". Adding constraints doesn't mean you have to give up on important flexibility.

> This is not dynamic, nor flexible - and it flies in the face of torch's core philosophies just so they can offer more performance to the big labs using PyTorch. For various reasons, I dislike pandering to the rich guy instead of being an independent, open-source entity.

I think this is an assumption you've made largely without evidence. I'm not entirely sure what your point is. The way torch.compile is measured for success publicly (even in the announcement blogpost and Conference Keynote, link https://pytorch.org/get-started/pytorch-2.0/ ) is by measuring on a bunch of popular PyTorch-based github repos in the wild + popular HuggingFace models + the TIMM vision benchmark. They're curated here https://github.com/pytorch/benchmark . Your claim that its to mainly favor large labs is pretty puzzling.

torch.compile is both dynamic and flexible because: 1. it supports dynamic shapes, 2. it allows incremental compilation (you dont need to compile the parts that you wish to keep in uncompilable python -- probably using random arbitrary python packages, etc.). there is a trade-off between dynamic, flexible and performance, i.e. more dynamic and flexible means we don't have enough information to extract better performance, but that's an acceptable trade-off when you need the flexibility to express your ideas more than you need the speed.

> XLA's GPU support is great, its compatible across different hardware, its optimized and mature. In short, its a great alternative to the often buggy torch.compile stack - if you fix the torch integration.

If you are an XLA maximalist, that's fine. I am not. There isn't evidence to prove out either opinions. PyTorch will never be nicely compatible with XLA until XLA has significant constraints that are incompatible with PyTorch's User Experience model. The PyTorch devs have given clear written-down feedback to the XLA project on what it takes for XLA+PyTorch to get better, and its been a few years and the XLA project prioritizes other things.

lunaticd · 2 years ago
3. The project started under a Harvard affiliated Github org during the course of PhDs. These same people later joined Google where it continued to be developed and over time adopted more and more in place of TensorFlow.
logicchains · 2 years ago
PyTorch beat Tensorflow because it was much easier to use for research. Jax is much harder to use for exploratory research than PyTorch, due to requiring a fixed shape computation graph, which makes implementing many custom model architectures very difficult.

Jax's advantages shine when it comes to parallelizing a new architecture across multiple GPU/TPUs, which it makes much easier than PyTorch (no need for custom cuda/networking code). Needing to scale up a new architecture across many GPUs is however not a common use-case, and most teams that have the resources for large-scale multi-gpu training also have the resources for specialised engineers to do it in PyTorch.

big-chungus4 · 2 years ago
Doesn't JAX support dynamic graphs as well?
logicchains · 2 years ago
Nope, changing graph shape requires recompilation: https://github.com/google/jax/discussions/17191
ianbutler · 2 years ago
From an eng/industry perspective, back in 2016/2017 I watched the realtime decline of Tensorflow towards Pytorch.

The issue was TF had too many interfaces to accomplish the same thing and each one was rough in its own way. Along with some complexity for using serving and experiment logging via Tensorboard, but this wasn’t as bad at least for me.

Keras was integrated in an attempt to help, but ultimately it wasn’t enough and people started using Torch more and more even against the perception that TF was for prod workloads and Torch was for research.

TFA mentions the interface complexity as starting to be a problem with Torch, but I don’t think we’re anywhere near the critical point that would cause people to abandon it in favor of JAX.

Additionally with JAX you’re just shoving the portability problems mentioned down to XLA which brings its own issues and gotchas even if it hides the immediate reality of said problems from the end user.

I think the Torch maintainers should watch not to repeat the mistakes of TF, but I think theres a long way to go before JAX is a serious contender. It’s been years and JAX has stayed in relatively small usage.

markeroon · 2 years ago
Imo the biggest issue (from memory) was that Tensorflow used a static computation graph. PyTorch was so much easier to work with.
lostmsu · 2 years ago
I honestly think the static graph was much better in the similar way Vulkan/DX12 are better than OpenGL/DX9. It is harder to program, but gives more explicit control of important things. E.g. who would use PyTorch if they new that for optimal performance they'd need to record CUDA graphs?
ianbutler · 2 years ago
Yup that was also definitely one among the many issues
wafngar · 2 years ago
PyTorch is developed by multiple companies / stake holders while jax is google only with internal tooling they don’t share with the world. This alone is a major reason not to use jax. Also I think it is more the other way around: with torch.compile the main advantage of jax is disappearing.
dauertewigkeit · 2 years ago
It's the old age question in programming: Do you use a highly constrained paradigm that allows easy automatic optimization or do you use a very flexible and more user intuitive paradigm that makes automatic optimization harder?

If the future is going to be better more intelligent compilers, then that settles the question in my opinion.

n7g · 2 years ago
> with torch.compile the main advantage of jax is disappearing.

Interesting take - I agree here somewhat.

But also, wouldn't you think a framework that has been from the ground-up designed around a specific, mature compiler stack be better able to integrate compilers in a more stable fashion than just shoe-horning static compilers into a very dynamic framework? ;)

wafngar · 2 years ago
Depends. PyTorch on the other hand has a large user base and well defined and tested api. So should be doable; and is already progressing and rapid speed..
anon389r58r58 · 2 years ago
So the answer is not Jax?

Because JAX is not designed around a mature compiler stack. The history of Jax is more so that it matured alongside the compiler...

0x19241217 · 2 years ago
Pushback notwithstanding, this article is 100% correct in all PyTorch criticisms. PyTorch was a platform for fast experimentation with eager evaluation, now they shoehorn "compilers" into it. "compilers", because a lot of the work is done by g++ and Triton.

It is a messy and quickly expanding codebase with many surprises like segfaults and leaks.

Is scientific experimentation really sped up by these frameworks? Everyone uses the Transformer model and uses the same algorithms over and over again.

If researchers wrote directly in C or Fortran, perhaps they'd get new ideas. The core inference (see Karparthy's llama.c) is ridiculously small. Core training does not seem much larger either.

dunefox · 2 years ago
> If researchers wrote directly in C or Fortran...

... then they would get nothing done.

TheRealKing · 2 years ago
Fortran cannot be placed with C in the same category of low programming productivity.
pjmlp · 2 years ago
Apparently we could do research work before Python was invented.
pklausler · 2 years ago
Why not?
sundarurfriend · 2 years ago
Can we get the title changed to the actual title of the post? "The future of Deep Learning frameworks" sounds like a neutral and far wider-reaching article, and ends up being clickbait here (even if unintentionally).

"PyTorch is dead. Long live JAX." conveys exactly what the article about, and is a much better title.

funks_ · 2 years ago
I wish dex-lang [1] had gotten more traction. It’s JAX without the limitations that come from being a Python DSL. But ML researchers apparently don’t want to touch anything that doesn’t look exactly like Python.

[1]: https://github.com/google-research/dex-lang

cherryteastain · 2 years ago
It's very rare that an ML project is _only_ the ML parts. A significant chunk of the engineering effort goes into data pipelines and other plumbing. Having access to a widely used general purpose language with plenty of libraries in addition to all the ML libraries is the real reason why everyone goes for Python for ML.
hatmatrix · 2 years ago
It seems like an experimental research language.

Julia also competes in this domain from a more practical standpoint and has less limitations than JAX as I understand it, but is less mature and still working on getting wider traction.

funks_ · 2 years ago
The Julia AD ecosystem is very interesting in that the community is trying to make the entire language differentiable, which is much broader in scope than what Torch and JAX are doing. But unlike Dex, Julia is not a language built from the ground up for automatic differentiation.

Shameless plug for one of my talks at JuliaCon 2024: https://www.youtube.com/live/ZKt0tiG5ajw?t=19747s. The comparison between Python and Julia starts at 5:31:44.

mccoyb · 2 years ago
Dex is also missing user authored composable program transformations, which is one of JAX’s hidden superpowers.

So not quite “JAX without limitations” — but certainly without some of the limitations.

6gvONxR4sf7o · 2 years ago
This is both its strength and its weakness. As soon as you write a jaxpr interpreter, you lose all the tooling that makes the python interpreter so mature. For example stack traces and debugging become black holes. If jax made it easy to write these transformations without losing python’s benefits it would be incredible.
funks_ · 2 years ago
Are you talking about custom VJPs/JVPs?
hedgehog · 2 years ago
It's not about the syntax, it's all the knowledge, tools, existing code, etc that make Python so attractive.
funks_ · 2 years ago
I don't doubt that, but I'm specifically talking about new languages. I've seen far more enthusiasm from ML researchers for Mojo, which doesn't even do automatic differentiation, than for Dex. And to recycle an old HN comment of mine, people are much more eager to learn a functional programming language if it looks like NumPy (I'm talking about JAX here).