Compare commits

..

4 Commits
main ... main

4 changed files with 236 additions and 0 deletions

143
.gitignore vendored Normal file
View File

@ -0,0 +1,143 @@
# Created by https://www.toptal.com/developers/gitignore/api/python
# Edit at https://www.toptal.com/developers/gitignore?templates=python
# --------------------
# Python
# --------------------
__pycache__/
*.py[cod]
*.pyo
*.pyd
*.so
*$py.class
*.sage.py
# --------------------
# Packaging / builds
# --------------------
.Python
build/
dist/
develop-eggs/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
share/python-wheels/
var/
wheels/
*.egg
*.egg-info/
.installed.cfg
MANIFEST
*.manifest
*.spec
# --------------------
# Logs / caches / temp
# --------------------
*.log
*.tmp
*.temp
*.out
pip-log.txt
pip-delete-this-directory.txt
.cache
.hypothesis/
.pytest_cache/
.tox/
.nox/
.coverage
.coverage.*
*.cover
*.py,cover
cover/
htmlcov/
# --------------------
# Databases / translations
# --------------------
*.mo
*.pot
db.sqlite3
db.sqlite3-journal
# --------------------
# Framework / tool artifacts
# --------------------
local_settings.py
instance/
.webassets-cache
.scrapy
docs/_build/
.pybuilder/
target/
celerybeat-schedule
celerybeat.pid
/site
# --------------------
# Virtual environments
# --------------------
JAX-venv/
.venv/
venv/
env/
ENV/
.env
env.bak/
venv.bak/
# --------------------
# Python packaging helpers
# --------------------
.pdm.toml
poetry.toml
__pypackages__/
# --------------------
# Notebooks / REPL
# --------------------
.ipynb_checkpoints/
*.nb.py
profile_default/
ipython_config.py
# --------------------
# Type checking / analysis
# --------------------
.mypy_cache/
.dmypy.json
dmypy.json
.pyre/
.pytype/
.ruff_cache/
pyrightconfig.json
# --------------------
# Editors / IDEs
# --------------------
.vscode/
.history/
*.code-workspace
.spyderproject
.spyproject
.ropeproject
# --------------------
# Operating system
# --------------------
.DS_Store
Thumbs.db
# --------------------
# Editor temp files
# --------------------
*.swp
*.swo
*~

View File

@ -1 +1,2 @@
Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
Zichen Li <zichen.li@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France

59
JAX/README.md Normal file
View File

@ -0,0 +1,59 @@
Here is a repo for beginners in JAX. We recommand to start with the documentation in [1].
We are motivated by the article of Mohit and David(2025)[2], especially the automatic differentiation[3] and just-in-time compilation[4].
### Install JAX
GPU programming is a future trend for our open-source project. The codes that we write in CPU and GPU version of JAX are the same. The difference is that pip will install a jaxlib wheel for GPU version depending on NVIDIA driver version and CUDA version.
First we can check our NVIDIA driver version and CUDA version
```bash
nvidia-smi
```
CUDA 12 requires driver version ≥ 525, which is already a mainstream and stable combination, supported by almost all frameworks. We will install the JAX GPU version suitable for CUDA 12.
```bash
python3 -m venv JAX-venv
source JAX-venv/bin/activate
(JAX-venv) pip install --upgrade pip
(JAX-venv) pip install ipython
(JAX-venv) pip install --upgrade "jax[cuda12]"
(JAX-venv) ipython # /path/to/JAX-venv/bin/ipython
```
Test script in Ipython:
```
import jax
import jax.numpy as jnp
import jaxlib
print("jax:", jax.__version__)
print("jaxlib:", jaxlib.__version__)
print("devices:", jax.devices())
x = jnp.arange(5.)
print("x.device:", x.device)
```
Reference output:
```
jax: 0.6.2
jaxlib: 0.6.2
devices: [CudaDevice(id=0)]
x.device: cuda:0
```
Structure of the project:
```
JAX/
├─ JAX-venv/
├─ src/ # Python codes
├─ notebooks/ # Experimental notebook
├─ tests/ # Unit tests
└─ pyproject.toml # Or requirements.txt
```
### Source
[1] https://uvadlc-notebooks.readthedocs.io/en/latest/
[2] https://www.sciencedirect.com/science/article/pii/S0045782524008260?via%3Dihub
[3] https://docs.jax.dev/en/latest/automatic-differentiation.html
[4] https://docs.jax.dev/en/latest/jit-compilation.html

33
JAX/tests/test.py Normal file
View File

@ -0,0 +1,33 @@
'''
This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html)
'''
import jax
import jax.numpy as jnp
from jax import grad, jit
# Example: JIT compilation
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
import time
# measure a single execution and ensure JAX computation completes
start = time.perf_counter()
selu(x).block_until_ready()
end = time.perf_counter()
print("Elapsed:", end - start)
selu_jit = jit(selu)
# measure a single execution and ensure JAX computation completes
start = time.perf_counter()
selu_jit(x).block_until_ready()
end = time.perf_counter()
print("Elapsed with JIT:", end - start)
# Example: Automatic differentiation\
def f(x):
return jnp.sin(x) + 0.5 * x ** 2
df = grad(f)
x = 2.0
print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0