a simple script with jit and AD

This commit is contained in:
Zichen LI 2025-11-28 14:08:18 +01:00
parent 602c480e49
commit c600d30161
1 changed files with 33 additions and 0 deletions

33
JAX/tests/test.py Normal file
View File

@ -0,0 +1,33 @@
'''
This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html)
'''
import jax
import jax.numpy as jnp
from jax import grad, jit
# Example: JIT compilation
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
import time
# measure a single execution and ensure JAX computation completes
start = time.perf_counter()
selu(x).block_until_ready()
end = time.perf_counter()
print("Elapsed:", end - start)
selu_jit = jit(selu)
# measure a single execution and ensure JAX computation completes
start = time.perf_counter()
selu_jit(x).block_until_ready()
end = time.perf_counter()
print("Elapsed with JIT:", end - start)
# Example: Automatic differentiation\
def f(x):
return jnp.sin(x) + 0.5 * x ** 2
df = grad(f)
x = 2.0
print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0