forked from lfrerot/good_simulation_practices
a simple script with jit and AD
This commit is contained in:
parent
602c480e49
commit
c600d30161
|
|
@ -0,0 +1,33 @@
|
||||||
|
'''
|
||||||
|
This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html)
|
||||||
|
'''
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import grad, jit
|
||||||
|
|
||||||
|
# Example: JIT compilation
|
||||||
|
def selu(x, alpha=1.67, lambda_=1.05):
|
||||||
|
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
|
||||||
|
|
||||||
|
x = jnp.arange(1000000)
|
||||||
|
|
||||||
|
import time
|
||||||
|
# measure a single execution and ensure JAX computation completes
|
||||||
|
start = time.perf_counter()
|
||||||
|
selu(x).block_until_ready()
|
||||||
|
end = time.perf_counter()
|
||||||
|
print("Elapsed:", end - start)
|
||||||
|
|
||||||
|
selu_jit = jit(selu)
|
||||||
|
# measure a single execution and ensure JAX computation completes
|
||||||
|
start = time.perf_counter()
|
||||||
|
selu_jit(x).block_until_ready()
|
||||||
|
end = time.perf_counter()
|
||||||
|
print("Elapsed with JIT:", end - start)
|
||||||
|
|
||||||
|
# Example: Automatic differentiation\
|
||||||
|
def f(x):
|
||||||
|
return jnp.sin(x) + 0.5 * x ** 2
|
||||||
|
df = grad(f)
|
||||||
|
x = 2.0
|
||||||
|
print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0
|
||||||
Loading…
Reference in New Issue