good_simulation_practices/JAX
Zichen LI c600d30161 a simple script with jit and AD 2025-11-28 14:08:18 +01:00
..
tests a simple script with jit and AD 2025-11-28 14:08:18 +01:00
README.md gitignore and update README.md 2025-11-28 13:36:07 +01:00

README.md

Here is a repo for beginners in JAX. We recommand to start with the documentation in [1].

We are motivated by the article of Mohit and David(2025)[2], especially the automatic differentiation[3] and just-in-time compilation[4].

Install JAX

GPU programming is a future trend for our open-source project. The codes that we write in CPU and GPU version of JAX are the same. The difference is that pip will install a jaxlib wheel for GPU version depending on NVIDIA driver version and CUDA version.

First we can check our NVIDIA driver version and CUDA version

nvidia-smi

CUDA 12 requires driver version ≥ 525, which is already a mainstream and stable combination, supported by almost all frameworks. We will install the JAX GPU version suitable for CUDA 12.

python3 -m venv JAX-venv
source JAX-venv/bin/activate
(JAX-venv) pip install --upgrade pip
(JAX-venv) pip install ipython
(JAX-venv) pip install --upgrade "jax[cuda12]"
(JAX-venv) ipython # /path/to/JAX-venv/bin/ipython

Test script in Ipython:

import jax
import jax.numpy as jnp
import jaxlib

print("jax:", jax.__version__)
print("jaxlib:", jaxlib.__version__)

print("devices:", jax.devices())

x = jnp.arange(5.)
print("x.device:", x.device)

Reference output:

jax: 0.6.2
jaxlib: 0.6.2
devices: [CudaDevice(id=0)]
x.device: cuda:0

Structure of the project:

JAX/
  ├─ JAX-venv/        
  ├─ src/             # Python codes
  ├─ notebooks/       # Experimental notebook
  ├─ tests/           # Unit tests
  └─ pyproject.toml   # Or requirements.txt

Source

[1] https://uvadlc-notebooks.readthedocs.io/en/latest/

[2] https://www.sciencedirect.com/science/article/pii/S0045782524008260?via%3Dihub

[3] https://docs.jax.dev/en/latest/automatic-differentiation.html

[4] https://docs.jax.dev/en/latest/jit-compilation.html