Once you have the matrix implementation in Step 2 (Implementation 3) it's rather straightforward to extend your N-body simulator to run on a GPU with Jax --- you can just add `import jax.numpy as jnp` and replace all the `np.`s with `jnp`s.
For a few-body system (e.g., the Solar System) this probably won't provide any speedup. But once you get to ~100 bodies you should start to see substantial speedups by running the simulator on a GPU.
Not true for regular GPUs like RTX5090. They have atrocious float64 performance compared to CPU. You need a special GPU designed for scientific computations (many float64 cores)
For a few-body system (e.g., the Solar System) this probably won't provide any speedup. But once you get to ~100 bodies you should start to see substantial speedups by running the simulator on a GPU.