"""JAX generalized Maxwell contact solver without Python loops. This variant removes the explicit Python time-stepping loop from `JAX_GMM.py` by relying on `jax.lax.scan`, which keeps all temporally coupled computations staged inside JAX's computation graph. The contact solver remains identical but is compatible with scanning so the entire transient solves as a single JIT-compiled program once the graph is traced. """ from __future__ import annotations import time import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation import jax import jax.numpy as jnp from jax import lax jax.config.update("jax_enable_x64", True) def build_fourier_kernel(n: int, m: int, L: float, E_star: float) -> jnp.ndarray: q_x = 2.0 * jnp.pi * jnp.fft.fftfreq(n, d=L / n) q_y = 2.0 * jnp.pi * jnp.fft.fftfreq(m, d=L / m) QX, QY = jnp.meshgrid(q_x, q_y, indexing="xy") q_norm = jnp.sqrt(QX**2 + QY**2) return jnp.where(q_norm > 0.0, 2.0 / (E_star * q_norm), 0.0) @jax.jit def displacement_from_pressure(kernel_fourier: jnp.ndarray, pressure: jnp.ndarray) -> jnp.ndarray: pressure_fft = jnp.fft.fft2(pressure, norm="ortho") displacement_fft = pressure_fft * kernel_fourier return jnp.fft.ifft2(displacement_fft, norm="ortho").real def elastic_energy(kernel_fourier: jnp.ndarray, h_profile: jnp.ndarray, pressure: jnp.ndarray) -> jnp.ndarray: displacement = displacement_from_pressure(kernel_fourier, pressure) stored = 0.5 * jnp.sum(pressure * displacement) work = jnp.sum(pressure * h_profile) return stored - work value_and_grad_energy = jax.jit(jax.value_and_grad(elastic_energy, argnums=2)) @jax.jit def project_total_load(pressure: jnp.ndarray, W: float, L: float) -> jnp.ndarray: mean_pressure = jnp.mean(pressure) target = W / (L**2) scale = jnp.where(mean_pressure > 0.0, target / mean_pressure, 0.0) return jnp.where(mean_pressure > 0.0, pressure * scale, jnp.full_like(pressure, target)) @jax.jit def masked_mean(values: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: count = jnp.sum(mask) total = jnp.sum(jnp.where(mask, values, 0.0)) return jnp.where(count > 0, total / count, 0.0) @jax.jit def compute_error(pressure: jnp.ndarray, gradient: jnp.ndarray, h_rms: float) -> jnp.ndarray: num = jnp.vdot(pressure.reshape(-1), gradient - jnp.min(gradient)) denom = jnp.sum(pressure) * h_rms + 1e-12 return jnp.abs(num / denom) @jax.jit def update_search_direction( gradient: jnp.ndarray, direction: jnp.ndarray, contact_mask: jnp.ndarray, delta: float, g_norm: float, g_old: float, ) -> jnp.ndarray: beta_cg = jnp.where(g_old > 0.0, delta * g_norm / (g_old + 1e-12), 0.0) updated = gradient + beta_cg * direction return jnp.where(contact_mask, updated, 0.0) def contact_solver_autodiff( kernel_fourier: jnp.ndarray, h_profile: jnp.ndarray, W: float, L: float, tol: float = 1e-6, iter_max: int = 200, ): h_rms = jnp.std(h_profile) initial_pressure = jnp.full_like(h_profile, W / (L**2)) initial_direction = jnp.zeros_like(initial_pressure) iter_max_jnp = jnp.array(iter_max) def cond_fun(state): _, _, _, _, k, error = state return jnp.logical_and(error > tol, k < iter_max_jnp) def body_fun(state): pressure, direction, g_old, delta, k, _ = state _, grad_energy = value_and_grad_energy(kernel_fourier, h_profile, pressure) contact_mask = pressure > 0.0 grad_mean = masked_mean(grad_energy, contact_mask) grad_centered = grad_energy - grad_mean grad_contact = jnp.where(contact_mask, grad_centered, 0.0) g_norm = jnp.sum(grad_contact * grad_contact) search_dir = update_search_direction(grad_contact, direction, contact_mask, delta, g_norm, g_old) displacement_dir = displacement_from_pressure(kernel_fourier, search_dir) disp_mean = masked_mean(displacement_dir, contact_mask) response = displacement_dir - disp_mean tau_num = jnp.sum(jnp.where(contact_mask, grad_centered * search_dir, 0.0)) tau_den = jnp.sum(jnp.where(contact_mask, response * search_dir, 0.0)) tau = tau_num / (tau_den + 1e-12) pressure_new = jnp.maximum(pressure - tau * search_dir, 0.0) inadmissible = jnp.logical_and(pressure_new == 0.0, grad_centered < 0.0) delta_new = jnp.where(jnp.sum(inadmissible) == 0, 1.0, 0.0) pressure_projected = project_total_load(pressure_new, W, L) error_new = compute_error(pressure_projected, grad_centered, h_rms) return ( pressure_projected, search_dir, jnp.where(g_norm > 0.0, g_norm, g_old), delta_new, k + 1, error_new, ) final_state = lax.while_loop( cond_fun, body_fun, ( initial_pressure, initial_direction, jnp.array(1.0), jnp.array(0.0), jnp.array(0), jnp.array(jnp.inf), ), ) pressure, _, _, _, iterations, error = final_state displacement = displacement_from_pressure(kernel_fourier, pressure) return displacement, pressure, iterations, error def run_simulation(): t0 = 0.0 t1 = 1.0 time_steps = 50 dt = (t1 - t0) / time_steps W = 1.0 L = 2.0 radius = 0.5 S = L**2 n = 300 m = 300 x_vals = jnp.linspace(0.0, L, n, endpoint=False) y_vals = jnp.linspace(0.0, L, m, endpoint=False) x, y = jnp.meshgrid(x_vals, y_vals, indexing="xy") x0 = 1.0 y0 = 1.0 E = 3.0 nu = 0.5 E_star = E / (1.0 - nu**2) r = jnp.sqrt((x - x0) ** 2 + (y - y0) ** 2) h_profile = -(r**2) / (2.0 * radius) kernel_fourier = build_fourier_kernel(n, m, L, E_star) G_inf = 2.75 G_branches = jnp.array([2.75, 2.75]) tau_branches = jnp.array([0.1, 1.0]) gamma = tau_branches / (tau_branches + dt) G_tilde = jnp.sum(gamma * G_branches) alpha = G_inf + G_tilde beta = G_tilde surface = h_profile U0 = jnp.zeros((n, m)) M0 = jnp.zeros((G_branches.shape[0], n, m)) def scan_step(carry, _): U, M = carry M_maxwell = jnp.tensordot(gamma, M, axes=1) H_new = alpha * surface - beta * U + M_maxwell displacement, pressure, iterations, residual = contact_solver_autodiff( kernel_fourier, H_new, W, L, tol=1e-6, iter_max=200, ) U_new = (displacement - M_maxwell + beta * U) / alpha delta_U = U_new - U M_new = gamma[:, None, None] * (M + G_branches[:, None, None] * delta_U) midline = pressure[n // 2] contact_area = jnp.mean(pressure > 0.0) * S return (U_new, M_new), (midline, contact_area, iterations, residual) (_, _), outputs = lax.scan( scan_step, (U0, M0), xs=None, length=time_steps, ) midlines, contact_areas, iterations, residuals = outputs G_maxwell_t0 = jnp.sum(G_branches) G_effective_t0 = G_inf + G_maxwell_t0 E_effective_t0 = 2.0 * G_effective_t0 * (1.0 + nu) / (1.0 - nu**2) p0_t0 = (6.0 * W * (E_effective_t0**2) / (jnp.pi**3 * radius**2)) ** (1.0 / 3.0) a_t0 = (3.0 * W * radius / (4.0 * E_effective_t0)) ** (1.0 / 3.0) E_effective_inf = 2.0 * G_inf * (1.0 + nu) / (1.0 - nu**2) p0_t_inf = (6.0 * W * (E_effective_inf**2) / (jnp.pi**3 * radius**2)) ** (1.0 / 3.0) a_t_inf = (3.0 * W * radius / (4.0 * E_effective_inf)) ** (1.0 / 3.0) return { "x": x, "midlines": midlines, "contact_areas": contact_areas, "iterations": iterations, "residuals": residuals, "params": { "t0": t0, "dt": dt, "L": L, "radius": radius, "x0": x0, "p0_t0": p0_t0, "p0_t_inf": p0_t_inf, "a_t0": a_t0, "a_t_inf": a_t_inf, "S": S, }, } def main(): start_time = time.perf_counter() results = run_simulation() total_time = time.perf_counter() - start_time print("Simulation time:", total_time, "seconds") x = jax.device_get(results["x"]) midlines = jax.device_get(results["midlines"]) contact_areas = jax.device_get(results["contact_areas"]) iterations = jax.device_get(results["iterations"]).astype(int) residuals = jax.device_get(results["residuals"]) params = results["params"] t0 = float(params["t0"]) dt = float(params["dt"]) L = float(params["L"]) x0 = float(params["x0"]) p0_t0 = float(params["p0_t0"]) p0_t_inf = float(params["p0_t_inf"]) a_t0 = float(params["a_t0"]) a_t_inf = float(params["a_t_inf"]) S = float(params["S"]) time_axis = t0 + dt * jnp.arange(midlines.shape[0]) time_axis_np = jax.device_get(time_axis) def update(frame): ax.clear() ax.set_xlim(0, L) ax.set_ylim(0, 1.1 * p0_t0) ax.grid(True) x_mid = x[int(x.shape[0] / 2)] ax.plot( x_mid, p0_t0 * np.sqrt(np.maximum(0.0, 1.0 - (x_mid - x0) ** 2 / a_t0**2)), "g--", label="Hertz t=0", ) ax.plot( x_mid, p0_t_inf * np.sqrt(np.maximum(0.0, 1.0 - (x_mid - x0) ** 2 / a_t_inf**2)), "b--", label="Hertz t=inf", ) ax.plot(x_mid, midlines[frame], "r-", label="Numerical") ax.set_title(f"Time = {t0 + frame * dt:.2f}s") ax.set_xlabel("x") ax.set_ylabel("Pressure distribution") ax.legend(loc="upper right") fig, ax = plt.subplots() ani = FuncAnimation(fig, update, frames=midlines.shape[0], repeat=False) plt.show() Ac_hertz_t0 = jnp.pi * a_t0**2 Ac_hertz_t_inf = jnp.pi * a_t_inf**2 print("Iterations and residuals per step:") for idx, (its, res) in enumerate(zip(iterations, residuals)): print(f" step {idx:02d}: {its:3d} iterations, residual={res:.3e}") print("Analytical contact area at t0:", float(Ac_hertz_t0)) print("Analytical contact area at t_inf:", float(Ac_hertz_t_inf)) print("Numerical contact area at t0:", float(contact_areas[0])) print("Numerical contact area at t_inf:", float(contact_areas[-1])) plt.figure() plt.plot(time_axis_np, contact_areas) plt.axhline(Ac_hertz_t0, color="red", linestyle="dotted") plt.axhline(Ac_hertz_t_inf, color="blue", linestyle="dotted") plt.xlabel("Time(s)") plt.ylabel("Contact area($m^2$)") plt.legend(["Numerical", "Hertz at t=0", "Hertz at t=inf"]) plt.title("Contact area vs time for multi-branch Generalized Maxwell model") plt.show() if __name__ == "__main__": main()