forked from lfrerot/good_simulation_practices
34 lines
938 B
Python
34 lines
938 B
Python
'''
|
|
This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html)
|
|
'''
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import grad, jit
|
|
|
|
# Example: JIT compilation
|
|
def selu(x, alpha=1.67, lambda_=1.05):
|
|
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
|
|
|
|
x = jnp.arange(1000000)
|
|
|
|
import time
|
|
# measure a single execution and ensure JAX computation completes
|
|
start = time.perf_counter()
|
|
selu(x).block_until_ready()
|
|
end = time.perf_counter()
|
|
print("Elapsed:", end - start)
|
|
|
|
selu_jit = jit(selu)
|
|
# measure a single execution and ensure JAX computation completes
|
|
start = time.perf_counter()
|
|
selu_jit(x).block_until_ready()
|
|
end = time.perf_counter()
|
|
print("Elapsed with JIT:", end - start)
|
|
|
|
# Example: Automatic differentiation\
|
|
def f(x):
|
|
return jnp.sin(x) + 0.5 * x ** 2
|
|
df = grad(f)
|
|
x = 2.0
|
|
print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0
|