From c600d3016130937ed5c67094d27fbf753404ee9a Mon Sep 17 00:00:00 2001 From: Zichen LI Date: Fri, 28 Nov 2025 14:08:18 +0100 Subject: [PATCH] a simple script with jit and AD --- JAX/tests/test.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 JAX/tests/test.py diff --git a/JAX/tests/test.py b/JAX/tests/test.py new file mode 100644 index 0000000..0ef4c3a --- /dev/null +++ b/JAX/tests/test.py @@ -0,0 +1,33 @@ +''' +This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html) +''' +import jax +import jax.numpy as jnp +from jax import grad, jit + +# Example: JIT compilation +def selu(x, alpha=1.67, lambda_=1.05): + return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) + +x = jnp.arange(1000000) + +import time +# measure a single execution and ensure JAX computation completes +start = time.perf_counter() +selu(x).block_until_ready() +end = time.perf_counter() +print("Elapsed:", end - start) + +selu_jit = jit(selu) +# measure a single execution and ensure JAX computation completes +start = time.perf_counter() +selu_jit(x).block_until_ready() +end = time.perf_counter() +print("Elapsed with JIT:", end - start) + +# Example: Automatic differentiation\ +def f(x): + return jnp.sin(x) + 0.5 * x ** 2 +df = grad(f) +x = 2.0 +print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0