''' This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html) ''' import jax import jax.numpy as jnp from jax import grad, jit # Example: JIT compilation def selu(x, alpha=1.67, lambda_=1.05): return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) x = jnp.arange(1000000) import time # measure a single execution and ensure JAX computation completes start = time.perf_counter() selu(x).block_until_ready() end = time.perf_counter() print("Elapsed:", end - start) selu_jit = jit(selu) # measure a single execution and ensure JAX computation completes start = time.perf_counter() selu_jit(x).block_until_ready() end = time.perf_counter() print("Elapsed with JIT:", end - start) # Example: Automatic differentiation\ def f(x): return jnp.sin(x) + 0.5 * x ** 2 df = grad(f) x = 2.0 print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0