forked from lfrerot/good_simulation_practices
Compare commits
No commits in common. "main" and "main" have entirely different histories.
|
|
@ -1,143 +0,0 @@
|
||||||
# Created by https://www.toptal.com/developers/gitignore/api/python
|
|
||||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Python
|
|
||||||
# --------------------
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*.pyo
|
|
||||||
*.pyd
|
|
||||||
*.so
|
|
||||||
*$py.class
|
|
||||||
*.sage.py
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Packaging / builds
|
|
||||||
# --------------------
|
|
||||||
.Python
|
|
||||||
build/
|
|
||||||
dist/
|
|
||||||
develop-eggs/
|
|
||||||
downloads/
|
|
||||||
eggs/
|
|
||||||
.eggs/
|
|
||||||
lib/
|
|
||||||
lib64/
|
|
||||||
parts/
|
|
||||||
sdist/
|
|
||||||
share/python-wheels/
|
|
||||||
var/
|
|
||||||
wheels/
|
|
||||||
*.egg
|
|
||||||
*.egg-info/
|
|
||||||
.installed.cfg
|
|
||||||
MANIFEST
|
|
||||||
*.manifest
|
|
||||||
*.spec
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Logs / caches / temp
|
|
||||||
# --------------------
|
|
||||||
*.log
|
|
||||||
*.tmp
|
|
||||||
*.temp
|
|
||||||
*.out
|
|
||||||
pip-log.txt
|
|
||||||
pip-delete-this-directory.txt
|
|
||||||
.cache
|
|
||||||
.hypothesis/
|
|
||||||
.pytest_cache/
|
|
||||||
.tox/
|
|
||||||
.nox/
|
|
||||||
.coverage
|
|
||||||
.coverage.*
|
|
||||||
*.cover
|
|
||||||
*.py,cover
|
|
||||||
cover/
|
|
||||||
htmlcov/
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Databases / translations
|
|
||||||
# --------------------
|
|
||||||
*.mo
|
|
||||||
*.pot
|
|
||||||
db.sqlite3
|
|
||||||
db.sqlite3-journal
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Framework / tool artifacts
|
|
||||||
# --------------------
|
|
||||||
local_settings.py
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
.scrapy
|
|
||||||
docs/_build/
|
|
||||||
.pybuilder/
|
|
||||||
target/
|
|
||||||
celerybeat-schedule
|
|
||||||
celerybeat.pid
|
|
||||||
/site
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Virtual environments
|
|
||||||
# --------------------
|
|
||||||
JAX-venv/
|
|
||||||
.venv/
|
|
||||||
venv/
|
|
||||||
env/
|
|
||||||
ENV/
|
|
||||||
.env
|
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Python packaging helpers
|
|
||||||
# --------------------
|
|
||||||
.pdm.toml
|
|
||||||
poetry.toml
|
|
||||||
__pypackages__/
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Notebooks / REPL
|
|
||||||
# --------------------
|
|
||||||
.ipynb_checkpoints/
|
|
||||||
*.nb.py
|
|
||||||
profile_default/
|
|
||||||
ipython_config.py
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Type checking / analysis
|
|
||||||
# --------------------
|
|
||||||
.mypy_cache/
|
|
||||||
.dmypy.json
|
|
||||||
dmypy.json
|
|
||||||
.pyre/
|
|
||||||
.pytype/
|
|
||||||
.ruff_cache/
|
|
||||||
pyrightconfig.json
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Editors / IDEs
|
|
||||||
# --------------------
|
|
||||||
.vscode/
|
|
||||||
.history/
|
|
||||||
*.code-workspace
|
|
||||||
.spyderproject
|
|
||||||
.spyproject
|
|
||||||
.ropeproject
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Operating system
|
|
||||||
# --------------------
|
|
||||||
.DS_Store
|
|
||||||
Thumbs.db
|
|
||||||
|
|
||||||
# --------------------
|
|
||||||
# Editor temp files
|
|
||||||
# --------------------
|
|
||||||
*.swp
|
|
||||||
*.swo
|
|
||||||
*~
|
|
||||||
|
|
||||||
|
|
||||||
1
AUTHORS
1
AUTHORS
|
|
@ -1,2 +1 @@
|
||||||
Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
||||||
Zichen Li <zichen.li@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
|
||||||
|
|
|
||||||
|
|
@ -1,59 +0,0 @@
|
||||||
Here is a repo for beginners in JAX. We recommand to start with the documentation in [1].
|
|
||||||
|
|
||||||
We are motivated by the article of Mohit and David(2025)[2], especially the automatic differentiation[3] and just-in-time compilation[4].
|
|
||||||
|
|
||||||
### Install JAX
|
|
||||||
GPU programming is a future trend for our open-source project. The codes that we write in CPU and GPU version of JAX are the same. The difference is that pip will install a jaxlib wheel for GPU version depending on NVIDIA driver version and CUDA version.
|
|
||||||
|
|
||||||
First we can check our NVIDIA driver version and CUDA version
|
|
||||||
```bash
|
|
||||||
nvidia-smi
|
|
||||||
```
|
|
||||||
CUDA 12 requires driver version ≥ 525, which is already a mainstream and stable combination, supported by almost all frameworks. We will install the JAX GPU version suitable for CUDA 12.
|
|
||||||
```bash
|
|
||||||
python3 -m venv JAX-venv
|
|
||||||
source JAX-venv/bin/activate
|
|
||||||
(JAX-venv) pip install --upgrade pip
|
|
||||||
(JAX-venv) pip install ipython
|
|
||||||
(JAX-venv) pip install --upgrade "jax[cuda12]"
|
|
||||||
(JAX-venv) ipython # /path/to/JAX-venv/bin/ipython
|
|
||||||
```
|
|
||||||
Test script in Ipython:
|
|
||||||
```
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import jaxlib
|
|
||||||
|
|
||||||
print("jax:", jax.__version__)
|
|
||||||
print("jaxlib:", jaxlib.__version__)
|
|
||||||
|
|
||||||
print("devices:", jax.devices())
|
|
||||||
|
|
||||||
x = jnp.arange(5.)
|
|
||||||
print("x.device:", x.device)
|
|
||||||
```
|
|
||||||
Reference output:
|
|
||||||
```
|
|
||||||
jax: 0.6.2
|
|
||||||
jaxlib: 0.6.2
|
|
||||||
devices: [CudaDevice(id=0)]
|
|
||||||
x.device: cuda:0
|
|
||||||
```
|
|
||||||
|
|
||||||
Structure of the project:
|
|
||||||
```
|
|
||||||
JAX/
|
|
||||||
├─ JAX-venv/
|
|
||||||
├─ src/ # Python codes
|
|
||||||
├─ notebooks/ # Experimental notebook
|
|
||||||
├─ tests/ # Unit tests
|
|
||||||
└─ pyproject.toml # Or requirements.txt
|
|
||||||
```
|
|
||||||
### Source
|
|
||||||
[1] https://uvadlc-notebooks.readthedocs.io/en/latest/
|
|
||||||
|
|
||||||
[2] https://www.sciencedirect.com/science/article/pii/S0045782524008260?via%3Dihub
|
|
||||||
|
|
||||||
[3] https://docs.jax.dev/en/latest/automatic-differentiation.html
|
|
||||||
|
|
||||||
[4] https://docs.jax.dev/en/latest/jit-compilation.html
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
||||||
'''
|
|
||||||
This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html)
|
|
||||||
'''
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
from jax import grad, jit
|
|
||||||
|
|
||||||
# Example: JIT compilation
|
|
||||||
def selu(x, alpha=1.67, lambda_=1.05):
|
|
||||||
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
|
|
||||||
|
|
||||||
x = jnp.arange(1000000)
|
|
||||||
|
|
||||||
import time
|
|
||||||
# measure a single execution and ensure JAX computation completes
|
|
||||||
start = time.perf_counter()
|
|
||||||
selu(x).block_until_ready()
|
|
||||||
end = time.perf_counter()
|
|
||||||
print("Elapsed:", end - start)
|
|
||||||
|
|
||||||
selu_jit = jit(selu)
|
|
||||||
# measure a single execution and ensure JAX computation completes
|
|
||||||
start = time.perf_counter()
|
|
||||||
selu_jit(x).block_until_ready()
|
|
||||||
end = time.perf_counter()
|
|
||||||
print("Elapsed with JIT:", end - start)
|
|
||||||
|
|
||||||
# Example: Automatic differentiation\
|
|
||||||
def f(x):
|
|
||||||
return jnp.sin(x) + 0.5 * x ** 2
|
|
||||||
df = grad(f)
|
|
||||||
x = 2.0
|
|
||||||
print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0
|
|
||||||
Loading…
Reference in New Issue