Here is a repo for beginners in JAX. We recommand to start with the documentation in [1]. We are motivated by the article of Mohit and David(2025)[2], especially the automatic differentiation[3] and just-in-time compilation[4]. ### Install JAX GPU programming is a future trend for our open-source project. The codes that we write in CPU and GPU version of JAX are the same. The difference is that pip will install a jaxlib wheel for GPU version depending on NVIDIA driver version and CUDA version. First we can check our NVIDIA driver version and CUDA version ```bash nvidia-smi ``` CUDA 12 requires driver version ≥ 525, which is already a mainstream and stable combination, supported by almost all frameworks. We will install the JAX GPU version suitable for CUDA 12. ```bash python3 -m venv JAX-venv source JAX-venv/bin/activate (JAX-venv) pip install --upgrade pip (JAX-venv) pip install ipython (JAX-venv) pip install --upgrade "jax[cuda12]" (JAX-venv) ipython # /path/to/JAX-venv/bin/ipython ``` Test script in Ipython: ``` import jax import jax.numpy as jnp import jaxlib print("jax:", jax.__version__) print("jaxlib:", jaxlib.__version__) print("devices:", jax.devices()) x = jnp.arange(5.) print("x.device:", x.device) ``` Reference output: ``` jax: 0.6.2 jaxlib: 0.6.2 devices: [CudaDevice(id=0)] x.device: cuda:0 ``` Structure of the project: ``` JAX/ ├─ JAX-venv/ ├─ src/ # Python codes ├─ notebooks/ # Experimental notebook ├─ tests/ # Unit tests └─ pyproject.toml # Or requirements.txt ``` ### Source [1] https://uvadlc-notebooks.readthedocs.io/en/latest/ [2] https://www.sciencedirect.com/science/article/pii/S0045782524008260?via%3Dihub [3] https://docs.jax.dev/en/latest/automatic-differentiation.html [4] https://docs.jax.dev/en/latest/jit-compilation.html