forked from lfrerot/good_simulation_practices
Compare commits
No commits in common. "main" and "main" have entirely different histories.
|
|
@ -1,143 +0,0 @@
|
|||
# Created by https://www.toptal.com/developers/gitignore/api/python
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
||||
|
||||
# --------------------
|
||||
# Python
|
||||
# --------------------
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.so
|
||||
*$py.class
|
||||
*.sage.py
|
||||
|
||||
# --------------------
|
||||
# Packaging / builds
|
||||
# --------------------
|
||||
.Python
|
||||
build/
|
||||
dist/
|
||||
develop-eggs/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
share/python-wheels/
|
||||
var/
|
||||
wheels/
|
||||
*.egg
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
MANIFEST
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# --------------------
|
||||
# Logs / caches / temp
|
||||
# --------------------
|
||||
*.log
|
||||
*.tmp
|
||||
*.temp
|
||||
*.out
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
.cache
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
*.cover
|
||||
*.py,cover
|
||||
cover/
|
||||
htmlcov/
|
||||
|
||||
# --------------------
|
||||
# Databases / translations
|
||||
# --------------------
|
||||
*.mo
|
||||
*.pot
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# --------------------
|
||||
# Framework / tool artifacts
|
||||
# --------------------
|
||||
local_settings.py
|
||||
instance/
|
||||
.webassets-cache
|
||||
.scrapy
|
||||
docs/_build/
|
||||
.pybuilder/
|
||||
target/
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
/site
|
||||
|
||||
# --------------------
|
||||
# Virtual environments
|
||||
# --------------------
|
||||
JAX-venv/
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.env
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# --------------------
|
||||
# Python packaging helpers
|
||||
# --------------------
|
||||
.pdm.toml
|
||||
poetry.toml
|
||||
__pypackages__/
|
||||
|
||||
# --------------------
|
||||
# Notebooks / REPL
|
||||
# --------------------
|
||||
.ipynb_checkpoints/
|
||||
*.nb.py
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# --------------------
|
||||
# Type checking / analysis
|
||||
# --------------------
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
.pyre/
|
||||
.pytype/
|
||||
.ruff_cache/
|
||||
pyrightconfig.json
|
||||
|
||||
# --------------------
|
||||
# Editors / IDEs
|
||||
# --------------------
|
||||
.vscode/
|
||||
.history/
|
||||
*.code-workspace
|
||||
.spyderproject
|
||||
.spyproject
|
||||
.ropeproject
|
||||
|
||||
# --------------------
|
||||
# Operating system
|
||||
# --------------------
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# --------------------
|
||||
# Editor temp files
|
||||
# --------------------
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
|
||||
1
AUTHORS
1
AUTHORS
|
|
@ -1,2 +1 @@
|
|||
Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
||||
Zichen Li <zichen.li@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
||||
|
|
|
|||
|
|
@ -1,59 +0,0 @@
|
|||
Here is a repo for beginners in JAX. We recommand to start with the documentation in [1].
|
||||
|
||||
We are motivated by the article of Mohit and David(2025)[2], especially the automatic differentiation[3] and just-in-time compilation[4].
|
||||
|
||||
### Install JAX
|
||||
GPU programming is a future trend for our open-source project. The codes that we write in CPU and GPU version of JAX are the same. The difference is that pip will install a jaxlib wheel for GPU version depending on NVIDIA driver version and CUDA version.
|
||||
|
||||
First we can check our NVIDIA driver version and CUDA version
|
||||
```bash
|
||||
nvidia-smi
|
||||
```
|
||||
CUDA 12 requires driver version ≥ 525, which is already a mainstream and stable combination, supported by almost all frameworks. We will install the JAX GPU version suitable for CUDA 12.
|
||||
```bash
|
||||
python3 -m venv JAX-venv
|
||||
source JAX-venv/bin/activate
|
||||
(JAX-venv) pip install --upgrade pip
|
||||
(JAX-venv) pip install ipython
|
||||
(JAX-venv) pip install --upgrade "jax[cuda12]"
|
||||
(JAX-venv) ipython # /path/to/JAX-venv/bin/ipython
|
||||
```
|
||||
Test script in Ipython:
|
||||
```
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxlib
|
||||
|
||||
print("jax:", jax.__version__)
|
||||
print("jaxlib:", jaxlib.__version__)
|
||||
|
||||
print("devices:", jax.devices())
|
||||
|
||||
x = jnp.arange(5.)
|
||||
print("x.device:", x.device)
|
||||
```
|
||||
Reference output:
|
||||
```
|
||||
jax: 0.6.2
|
||||
jaxlib: 0.6.2
|
||||
devices: [CudaDevice(id=0)]
|
||||
x.device: cuda:0
|
||||
```
|
||||
|
||||
Structure of the project:
|
||||
```
|
||||
JAX/
|
||||
├─ JAX-venv/
|
||||
├─ src/ # Python codes
|
||||
├─ notebooks/ # Experimental notebook
|
||||
├─ tests/ # Unit tests
|
||||
└─ pyproject.toml # Or requirements.txt
|
||||
```
|
||||
### Source
|
||||
[1] https://uvadlc-notebooks.readthedocs.io/en/latest/
|
||||
|
||||
[2] https://www.sciencedirect.com/science/article/pii/S0045782524008260?via%3Dihub
|
||||
|
||||
[3] https://docs.jax.dev/en/latest/automatic-differentiation.html
|
||||
|
||||
[4] https://docs.jax.dev/en/latest/jit-compilation.html
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
'''
|
||||
This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html)
|
||||
'''
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import grad, jit
|
||||
|
||||
# Example: JIT compilation
|
||||
def selu(x, alpha=1.67, lambda_=1.05):
|
||||
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
|
||||
|
||||
x = jnp.arange(1000000)
|
||||
|
||||
import time
|
||||
# measure a single execution and ensure JAX computation completes
|
||||
start = time.perf_counter()
|
||||
selu(x).block_until_ready()
|
||||
end = time.perf_counter()
|
||||
print("Elapsed:", end - start)
|
||||
|
||||
selu_jit = jit(selu)
|
||||
# measure a single execution and ensure JAX computation completes
|
||||
start = time.perf_counter()
|
||||
selu_jit(x).block_until_ready()
|
||||
end = time.perf_counter()
|
||||
print("Elapsed with JIT:", end - start)
|
||||
|
||||
# Example: Automatic differentiation\
|
||||
def f(x):
|
||||
return jnp.sin(x) + 0.5 * x ** 2
|
||||
df = grad(f)
|
||||
x = 2.0
|
||||
print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0
|
||||
Loading…
Reference in New Issue