good_simulation_practices/JAX/tests/test.py

34 lines
938 B
Python

'''
This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html)
'''
import jax
import jax.numpy as jnp
from jax import grad, jit
# Example: JIT compilation
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
import time
# measure a single execution and ensure JAX computation completes
start = time.perf_counter()
selu(x).block_until_ready()
end = time.perf_counter()
print("Elapsed:", end - start)
selu_jit = jit(selu)
# measure a single execution and ensure JAX computation completes
start = time.perf_counter()
selu_jit(x).block_until_ready()
end = time.perf_counter()
print("Elapsed with JIT:", end - start)
# Example: Automatic differentiation\
def f(x):
return jnp.sin(x) + 0.5 * x ** 2
df = grad(f)
x = 2.0
print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0