Compare commits
4 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
71f38ef56b | |
|
|
c600d30161 | |
|
|
602c480e49 | |
|
|
e989dc9927 |
|
|
@ -0,0 +1,143 @@
|
||||||
|
# Created by https://www.toptal.com/developers/gitignore/api/python
|
||||||
|
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Python
|
||||||
|
# --------------------
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
*.so
|
||||||
|
*$py.class
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Packaging / builds
|
||||||
|
# --------------------
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
develop-eggs/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
share/python-wheels/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
MANIFEST
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Logs / caches / temp
|
||||||
|
# --------------------
|
||||||
|
*.log
|
||||||
|
*.tmp
|
||||||
|
*.temp
|
||||||
|
*.out
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
.cache
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
cover/
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Databases / translations
|
||||||
|
# --------------------
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Framework / tool artifacts
|
||||||
|
# --------------------
|
||||||
|
local_settings.py
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
.scrapy
|
||||||
|
docs/_build/
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
/site
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Virtual environments
|
||||||
|
# --------------------
|
||||||
|
JAX-venv/
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
.env
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Python packaging helpers
|
||||||
|
# --------------------
|
||||||
|
.pdm.toml
|
||||||
|
poetry.toml
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Notebooks / REPL
|
||||||
|
# --------------------
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
*.nb.py
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Type checking / analysis
|
||||||
|
# --------------------
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
.pyre/
|
||||||
|
.pytype/
|
||||||
|
.ruff_cache/
|
||||||
|
pyrightconfig.json
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Editors / IDEs
|
||||||
|
# --------------------
|
||||||
|
.vscode/
|
||||||
|
.history/
|
||||||
|
*.code-workspace
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Operating system
|
||||||
|
# --------------------
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Editor temp files
|
||||||
|
# --------------------
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
|
||||||
1
AUTHORS
1
AUTHORS
|
|
@ -1 +1,2 @@
|
||||||
Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
||||||
|
Zichen Li <zichen.li@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
Here is a repo for beginners in JAX. We recommand to start with the documentation in [1].
|
||||||
|
|
||||||
|
We are motivated by the article of Mohit and David(2025)[2], especially the automatic differentiation[3] and just-in-time compilation[4].
|
||||||
|
|
||||||
|
### Install JAX
|
||||||
|
GPU programming is a future trend for our open-source project. The codes that we write in CPU and GPU version of JAX are the same. The difference is that pip will install a jaxlib wheel for GPU version depending on NVIDIA driver version and CUDA version.
|
||||||
|
|
||||||
|
First we can check our NVIDIA driver version and CUDA version
|
||||||
|
```bash
|
||||||
|
nvidia-smi
|
||||||
|
```
|
||||||
|
CUDA 12 requires driver version ≥ 525, which is already a mainstream and stable combination, supported by almost all frameworks. We will install the JAX GPU version suitable for CUDA 12.
|
||||||
|
```bash
|
||||||
|
python3 -m venv JAX-venv
|
||||||
|
source JAX-venv/bin/activate
|
||||||
|
(JAX-venv) pip install --upgrade pip
|
||||||
|
(JAX-venv) pip install ipython
|
||||||
|
(JAX-venv) pip install --upgrade "jax[cuda12]"
|
||||||
|
(JAX-venv) ipython # /path/to/JAX-venv/bin/ipython
|
||||||
|
```
|
||||||
|
Test script in Ipython:
|
||||||
|
```
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jaxlib
|
||||||
|
|
||||||
|
print("jax:", jax.__version__)
|
||||||
|
print("jaxlib:", jaxlib.__version__)
|
||||||
|
|
||||||
|
print("devices:", jax.devices())
|
||||||
|
|
||||||
|
x = jnp.arange(5.)
|
||||||
|
print("x.device:", x.device)
|
||||||
|
```
|
||||||
|
Reference output:
|
||||||
|
```
|
||||||
|
jax: 0.6.2
|
||||||
|
jaxlib: 0.6.2
|
||||||
|
devices: [CudaDevice(id=0)]
|
||||||
|
x.device: cuda:0
|
||||||
|
```
|
||||||
|
|
||||||
|
Structure of the project:
|
||||||
|
```
|
||||||
|
JAX/
|
||||||
|
├─ JAX-venv/
|
||||||
|
├─ src/ # Python codes
|
||||||
|
├─ notebooks/ # Experimental notebook
|
||||||
|
├─ tests/ # Unit tests
|
||||||
|
└─ pyproject.toml # Or requirements.txt
|
||||||
|
```
|
||||||
|
### Source
|
||||||
|
[1] https://uvadlc-notebooks.readthedocs.io/en/latest/
|
||||||
|
|
||||||
|
[2] https://www.sciencedirect.com/science/article/pii/S0045782524008260?via%3Dihub
|
||||||
|
|
||||||
|
[3] https://docs.jax.dev/en/latest/automatic-differentiation.html
|
||||||
|
|
||||||
|
[4] https://docs.jax.dev/en/latest/jit-compilation.html
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
'''
|
||||||
|
This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html)
|
||||||
|
'''
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import grad, jit
|
||||||
|
|
||||||
|
# Example: JIT compilation
|
||||||
|
def selu(x, alpha=1.67, lambda_=1.05):
|
||||||
|
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
|
||||||
|
|
||||||
|
x = jnp.arange(1000000)
|
||||||
|
|
||||||
|
import time
|
||||||
|
# measure a single execution and ensure JAX computation completes
|
||||||
|
start = time.perf_counter()
|
||||||
|
selu(x).block_until_ready()
|
||||||
|
end = time.perf_counter()
|
||||||
|
print("Elapsed:", end - start)
|
||||||
|
|
||||||
|
selu_jit = jit(selu)
|
||||||
|
# measure a single execution and ensure JAX computation completes
|
||||||
|
start = time.perf_counter()
|
||||||
|
selu_jit(x).block_until_ready()
|
||||||
|
end = time.perf_counter()
|
||||||
|
print("Elapsed with JIT:", end - start)
|
||||||
|
|
||||||
|
# Example: Automatic differentiation\
|
||||||
|
def f(x):
|
||||||
|
return jnp.sin(x) + 0.5 * x ** 2
|
||||||
|
df = grad(f)
|
||||||
|
x = 2.0
|
||||||
|
print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0
|
||||||
Loading…
Reference in New Issue