Readit News logoReadit News
marmaduke · 7 months ago
looks like a nice overview. i’ve implemented neural ODEs in Jax for low dimensional problems and it works well, but I keep looking for a good, fast, CPU-first implementation that is good for models that fit in cache and don’t require a GPU or big Torch/TF machinery.
sitkack · 7 months ago
yberreby · 7 months ago
Anecdotally, I used diffrax (and equinox) throughout last year after jumping between a few differential equation solvers in Python, for a project based on Dynamic Field Theory [1]. I only scratched the surface, but so far, it's been a pleasure to use, and it's quite fast. It also introduced me to equinox [2], by the same author, which I'm using to get the JAX-friendly equivalent of dataclasses.

`vmap`-able differential equation solving is really cool.

[1]: https://dynamicfieldtheory.org/ [2]: https://github.com/patrick-kidger/equinox

marmaduke · 7 months ago
no, wrote it by hand for use with my own Heun implementation, since it’s for use within stochastic delayed systems.

jax is fun but as effective as i’d like for CPU

barrenko · 7 months ago
How would you describe what a neural ODE is in the simplest possible terms? Let's say I know what an NN and a DE are :).
kk58 · 7 months ago
classic NN takes a vector of data through layers to make a prediction. Backprop adjusts network weights till predictions are right. These network weights form a vector, and training changes this vector till it hits values that mean "trained network".

Neural ODE reframes this: instead of focusing on the weights, focus on how they change. It sees training as finding a path from untrained to trained state. At each step, it uses ODE solvers to compute the next state, continuing for N steps till it reaches values matching training data. This gives you the solution for the trained network.