From e989dc9927e9cbb217a021959eeaab3971d1a5cb Mon Sep 17 00:00:00 2001 From: Zichen LI Date: Fri, 28 Nov 2025 11:27:43 +0100 Subject: [PATCH] install JAX guide --- AUTHORS | 1 + JAX/README.md | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 JAX/README.md diff --git a/AUTHORS b/AUTHORS index 422498e..39624bd 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1 +1,2 @@ Lucas Frérot Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France +Zichen Li Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France diff --git a/JAX/README.md b/JAX/README.md new file mode 100644 index 0000000..3a9c5c8 --- /dev/null +++ b/JAX/README.md @@ -0,0 +1,47 @@ +Here is a repo for beginners in JAX. We recommand to start with the documentation in [1]. + +We are motivated by the article of Mohit and David(2025)[2], especially the automatic differentiation[3] and just-in-time compilation[4]. + +### Install JAX +GPU programming is a future trend for our open-source project. The codes that we write in CPU and GPU version of JAX are the same. The difference is that pip will install a jaxlib wheel for GPU version depending on NVIDIA driver version and CUDA version. + +First we can check our NVIDIA driver version and CUDA version +```bash +nvidia-smi +``` +CUDA 12 requires driver version ≥ 525, which is already a mainstream and stable combination, supported by almost all frameworks. We will install the JAX GPU version suitable for CUDA 12. +```bash +python3 -m venv JAX-venv +source JAX-venv/bin/activate +(JAX-venv) pip install --upgrade pip +(JAX-venv) pip install ipython +(JAX-venv) pip install --upgrade "jax[cuda12]" +(JAX-venv) ipython # /path/to/JAX-venv/bin/ipython +``` +Test script in Ipython: +``` +import jax +import jax.numpy as jnp +import jaxlib + +print("jax:", jax.__version__) +print("jaxlib:", jaxlib.__version__) + +print("devices:", jax.devices()) + +x = jnp.arange(5.) +print("x.device:", x.device) +``` +Reference output: +``` +jax: 0.6.2 +jaxlib: 0.6.2 +devices: [CudaDevice(id=0)] +x.device: cuda:0 +``` + +### Source +[1] https://uvadlc-notebooks.readthedocs.io/en/latest/ +[2] https://www.sciencedirect.com/science/article/pii/S0045782524008260?via%3Dihub +[3] https://docs.jax.dev/en/latest/automatic-differentiation.html +[4] https://docs.jax.dev/en/latest/jit-compilation.html