Readit News logoReadit News
mattjjatgoogle commented on JAX – NumPy on the CPU, GPU, and TPU   jax.readthedocs.io/en/lat... · Posted by u/peter_d_sherman
sampo · 2 years ago
> With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy code.

Nice. When did they make this change?

Here is the old way in the docs, where you needed to define functions for the if-true branch and the if-false branch, and feed them to a conditional function, to get the normal if-then-else conditional.

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotcha...

mattjjatgoogle · 2 years ago
Actually that never changed. The README has always had an example of differentiating through native Python control flow:

https://github.com/google/jax/commit/948a8db0adf233f333f3e5f...

The constraints on control flow expressions come from jax.jit (because Python control flow can't be staged out) and jax.vmap (because we can't take multiple branches of Python control flow, which we might need to do for different batch elements). But autodiff of Python-native control flow works fine!

mattjjatgoogle commented on JAX – NumPy on the CPU, GPU, and TPU   jax.readthedocs.io/en/lat... · Posted by u/peter_d_sherman
cl3misch · 2 years ago
This is still the case afaik.

For vanilla "if", the condition must be known at compile time. For runtime, you have to use "cond", "where", or "select" (which may be analogous).

mattjjatgoogle · 2 years ago
Actually, that's never been a constraint for JAX autodiff. JAX grew out of the original Autograd (https://github.com/hips/autograd), so differentiating through Python control flow always worked. It's jax.jit and jax.vmap which place constraints on control flow, requiring structured control flow combinators like those.
mattjjatgoogle commented on JAX – NumPy on the CPU, GPU, and TPU   jax.readthedocs.io/en/lat... · Posted by u/peter_d_sherman
matrss · 2 years ago
Meanwhile, the first sentence in their readme is this:

> JAX is Autograd and XLA, brought together for high-performance machine learning research.

That does not really convey the generality of it that well.

mattjjatgoogle · 2 years ago
You're right! Maybe we should revise that... I made https://github.com/google/jax/pull/17851, comments welcome!
mattjjatgoogle commented on JAX – NumPy on the CPU, GPU, and TPU   jax.readthedocs.io/en/lat... · Posted by u/peter_d_sherman
tehsauce · 2 years ago
Jax is super useful for scientific computing. Although nbody sims might not be the best application. A naive nbody sim is very easy to implement and accelerate in jax (here’s my version: https://github.com/PWhiddy/jax-experiments/blob/main/nbody.i...), but it can be tricky to scale it. This is because efficient nbody sims usually either rely on trees or spatial hashing/sorting which are tricky to efficiently implement with jax.
mattjjatgoogle · 2 years ago
Have you seen JAX MD? https://github.com/jax-md/jax-md
mattjjatgoogle commented on Training Deep Networks with Data Parallelism in Jax   mishalaskin.com/posts/dat... · Posted by u/sebg
6gvONxR4sf7o · 3 years ago
Thanks! One last thing, since I have your ear. The function transformation aspects of jax seem to make their way into downstream libraries like haiku, resulting in a lot of "magic" that can be difficult to examine and debug. Are there any utils you made to make jax's own transformations more transparent, which you think might be helpful to third party transformations?

Higher order functions are difficult in general, and it would be fantastic to have core patterns or tools for breaking them open.

mattjjatgoogle · 3 years ago
You're right that downstream libraries have often tended to introduce magic (some more than others), and moreover one library's magic is typically incompatible with other libraries'. It's something that we're working on but we don't have much to show for it yet. Two avenues are:

1. as you say, exposing patterns and tools for library authors to implement transformations/higher-order primitives using JAX's machinery rather than requiring each library to introduce bespoke magic to do the same;

2. adding JAX core infrastructure which directly solves the common problems that libraries tend to solve independently (and with bespoke magic).

mattjjatgoogle commented on Training Deep Networks with Data Parallelism in Jax   mishalaskin.com/posts/dat... · Posted by u/sebg
6gvONxR4sf7o · 3 years ago
> with jax.disable_jit(): ...

That's handy, and I hadn't seen it before, thanks.

It's been a bit, but I think the most frustrating errors were around mapping pytrees (like this issue https://github.com/google/jax/issues/9928). I'm not sure the exact solution, but the axis juggling and specifications were where I remember a lot of pain, and the docs (though extensive) were unclear. At times it feels like improvements are punted on in the hopes that xmap eventually fixes everything (and xmap has been in experimental for far longer than I expected).

Also the barriers where I couldn't disable jit. IIRC pmap automatically jits, so there was no way to avoid staging that part out. When it came to doing some complex jax.lax.ppermute, it felt more difficult than it needed to be to debug.

Next time I encounter something particularly opaque, I'll share on the github issue tracker.

mattjjatgoogle · 3 years ago
Thanks for taking the time to explain these.

> It's been a bit, but I think the most frustrating errors were around mapping pytrees (like this issue https://github.com/google/jax/issues/9928).

We've improved some of these pytree error messages but it seems that vmap one is still not great. Thanks for the ping on it.

> Also the barriers where I couldn't disable jit. IIRC pmap automatically jits, so there was no way to avoid staging that part out.

That was indeed a longstanding issue in pmap's implementation. And since people came to expect jit to be "built in" to pmap, it wasn't easy to revise.

However, we recently (https://github.com/google/jax/pull/11854) made `jax.disable_jit()` work with pmap, in the sense that it makes pmap execute eagerly, so that you can print/pdb/etc to your heart's content. (The pmap successor, shard_map (https://jax.readthedocs.io/en/latest/jep/14273-shard-map.htm...), is eager by default. Also it has uniformly good error messages from the start!)

> Next time I encounter something particularly opaque, I'll share on the github issue tracker.

Thank you for the constructive feedback!

mattjjatgoogle commented on Training Deep Networks with Data Parallelism in Jax   mishalaskin.com/posts/dat... · Posted by u/sebg
6gvONxR4sf7o · 3 years ago
That's true, and is a massive part of what I love about JAX, but they also form barriers in weird parts of your code, preventing standard introspection tools, which is the single thing I hate about JAX. The errors are amazingly opaque.
mattjjatgoogle · 3 years ago
If you have any particular examples in mind, and time to share them on https://github.com/google/jax/issues, we'd love to try to improve them. Improving error messages is a priority.

About introspection tools, at least for runtime value debugging there is to some extent a fundamental challenge: since jax.jit stages computation out of Python (though jax.grad and jax.vmap don't), it means standard Python runtime value inspection tools, like printing and pdb, can't work under a jax.jit as the values aren't available as the Python code is executing. You can always remove the jax.jit while debugging (or use `with jax.disable_jit(): ...`), but that's not always convenient, and we need jax.jit for good performance.

We recently added some runtime value debugging tools which work even with jax.jit-staged-out code (even in automatically parallelized code!), though they're not the standard introspection tools: see `jax.debug.print` and `jax.debug.breakpoint` on https://jax.readthedocs.io/en/latest/debugging/index.html and https://jax.readthedocs.io/en/latest/debugging/print_breakpo....

If you were thinking about other kinds of introspection tooling, I'd love to hear about it!

mattjjatgoogle commented on Training Deep Networks with Data Parallelism in Jax   mishalaskin.com/posts/dat... · Posted by u/sebg
jdeaton · 3 years ago
The abstractions provided by JAX for parallelism are beautiful. JAX is an absolute master-class in programming-interface design and a lesson in the power of providing composable primitive operations and FP inspired design. An astounding amount of complexity is hidden from the user behind primitives like pmap, and the power is exposed in such a simple interface.
mattjjatgoogle · 3 years ago
Thanks for the kind words! We've been doing a lot more work in this direction too, for both compiler-based automatic parallelization [0] and a work-in-progress pmap successor for 'manual' parallelism (per-device code and explicit collectives) [1] which composes seamlessly with the compiler-based stuff.

[0] https://jax.readthedocs.io/en/latest/notebooks/Distributed_a...

[1] https://jax.readthedocs.io/en/latest/jep/14273-shard-map.htm...

u/mattjjatgoogle

KarmaCake day77February 24, 2019View Original