forked from lfrerot/good_simulation_practices
install JAX guide
This commit is contained in:
parent
bfe4edaf62
commit
e989dc9927
1
AUTHORS
1
AUTHORS
|
|
@ -1 +1,2 @@
|
||||||
Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
||||||
|
Zichen Li <zichen.li@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
Here is a repo for beginners in JAX. We recommand to start with the documentation in [1].
|
||||||
|
|
||||||
|
We are motivated by the article of Mohit and David(2025)[2], especially the automatic differentiation[3] and just-in-time compilation[4].
|
||||||
|
|
||||||
|
### Install JAX
|
||||||
|
GPU programming is a future trend for our open-source project. The codes that we write in CPU and GPU version of JAX are the same. The difference is that pip will install a jaxlib wheel for GPU version depending on NVIDIA driver version and CUDA version.
|
||||||
|
|
||||||
|
First we can check our NVIDIA driver version and CUDA version
|
||||||
|
```bash
|
||||||
|
nvidia-smi
|
||||||
|
```
|
||||||
|
CUDA 12 requires driver version ≥ 525, which is already a mainstream and stable combination, supported by almost all frameworks. We will install the JAX GPU version suitable for CUDA 12.
|
||||||
|
```bash
|
||||||
|
python3 -m venv JAX-venv
|
||||||
|
source JAX-venv/bin/activate
|
||||||
|
(JAX-venv) pip install --upgrade pip
|
||||||
|
(JAX-venv) pip install ipython
|
||||||
|
(JAX-venv) pip install --upgrade "jax[cuda12]"
|
||||||
|
(JAX-venv) ipython # /path/to/JAX-venv/bin/ipython
|
||||||
|
```
|
||||||
|
Test script in Ipython:
|
||||||
|
```
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jaxlib
|
||||||
|
|
||||||
|
print("jax:", jax.__version__)
|
||||||
|
print("jaxlib:", jaxlib.__version__)
|
||||||
|
|
||||||
|
print("devices:", jax.devices())
|
||||||
|
|
||||||
|
x = jnp.arange(5.)
|
||||||
|
print("x.device:", x.device)
|
||||||
|
```
|
||||||
|
Reference output:
|
||||||
|
```
|
||||||
|
jax: 0.6.2
|
||||||
|
jaxlib: 0.6.2
|
||||||
|
devices: [CudaDevice(id=0)]
|
||||||
|
x.device: cuda:0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Source
|
||||||
|
[1] https://uvadlc-notebooks.readthedocs.io/en/latest/
|
||||||
|
[2] https://www.sciencedirect.com/science/article/pii/S0045782524008260?via%3Dihub
|
||||||
|
[3] https://docs.jax.dev/en/latest/automatic-differentiation.html
|
||||||
|
[4] https://docs.jax.dev/en/latest/jit-compilation.html
|
||||||
Loading…
Reference in New Issue