"""JAX implementation of the generalized Maxwell contact solver. This script mirrors the NumPy-based reference in `Multi_branches_generalized_Maxwell.py`, but leverages JAX for automatic differentiation and JIT compilation. The automatic gradient of the elastic energy drives the constrained conjugate-gradient contact solver. Running this file produces the same diagnostic plots as the reference implementation while keeping all heavy lifting on the accelerator-enabled JAX backend. """ from __future__ import annotations import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation import jax import jax.numpy as jnp from jax import lax import time # Enable double precision for improved numerical stability. jax.config.update("jax_enable_x64", True) def build_fourier_kernel(n: int, m: int, L: float, E_star: float) -> jnp.ndarray: """Assemble the Fourier-domain kernel for the half-space Green's function.""" q_x = 2.0 * np.pi * jnp.fft.fftfreq(n, d=L / n) q_y = 2.0 * np.pi * jnp.fft.fftfreq(m, d=L / m) QX, QY = jnp.meshgrid(q_x, q_y, indexing="xy") q_norm = jnp.sqrt(QX**2 + QY**2) kernel = jnp.where(q_norm > 0.0, 2.0 / (E_star * q_norm), 0.0) return kernel @jax.jit def displacement_from_pressure( kernel_fourier: jnp.ndarray, pressure: jnp.ndarray ) -> jnp.ndarray: """Return the surface displacement induced by the supplied pressure field.""" pressure_fft = jnp.fft.fft2(pressure, norm="ortho") displacement_fft = pressure_fft * kernel_fourier displacement = jnp.fft.ifft2(displacement_fft, norm="ortho").real return displacement def elastic_energy( kernel_fourier: jnp.ndarray, h_profile: jnp.ndarray, pressure: jnp.ndarray ) -> jnp.ndarray: """Elastic energy functional; its gradient yields the gap field.""" displacement = displacement_from_pressure(kernel_fourier, pressure) stored = 0.5 * jnp.sum(pressure * displacement) work = jnp.sum(pressure * h_profile) return stored - work value_and_grad_energy = jax.jit(jax.value_and_grad(elastic_energy, argnums=2)) @jax.jit def project_total_load(pressure: jnp.ndarray, W: float, L: float) -> jnp.ndarray: """Project the pressure field onto the admissible set enforcing total load.""" mean_pressure = jnp.mean(pressure) target = W / (L**2) scale = jnp.where(mean_pressure > 0.0, target / mean_pressure, 0.0) projected = jnp.where(mean_pressure > 0.0, pressure * scale, jnp.full_like(pressure, target)) return projected @jax.jit def masked_mean(values: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: """Compute the mean over the masked region, guarding against empty sets.""" count = jnp.sum(mask) total = jnp.sum(jnp.where(mask, values, 0.0)) return jnp.where(count > 0, total / count, 0.0) @jax.jit def compute_error( pressure: jnp.ndarray, gradient: jnp.ndarray, h_rms: float, ) -> jnp.ndarray: """Scaled complementarity error used as stopping criterion.""" num = jnp.vdot(pressure.reshape(-1), gradient - jnp.min(gradient)) denom = jnp.sum(pressure) * h_rms + 1e-12 return jnp.abs(num / denom) @jax.jit def update_search_direction( gradient: jnp.ndarray, direction: jnp.ndarray, contact_mask: jnp.ndarray, delta: float, g_norm: float, g_old: float, ) -> jnp.ndarray: """Conjugate-gradient style update with projection onto the contact set.""" beta_cg = jnp.where(g_old > 0.0, delta * g_norm / (g_old + 1e-12), 0.0) updated = gradient + beta_cg * direction return jnp.where(contact_mask, updated, 0.0) def contact_solver_autodiff( kernel_fourier: jnp.ndarray, h_profile: jnp.ndarray, W: float, L: float, tol: float = 1e-6, iter_max: int = 200, ): """Solve the constrained contact problem via autodiff-powered CG iterations.""" h_rms = jnp.std(h_profile) initial_pressure = jnp.full_like(h_profile, W / (L**2)) initial_direction = jnp.zeros_like(initial_pressure) iter_max_jnp = jnp.array(iter_max) def cond_fun(state): _, _, _, _, k, error = state return jnp.logical_and(error > tol, k < iter_max_jnp) def body_fun(state): pressure, direction, g_old, delta, k, _ = state _, grad_energy = value_and_grad_energy(kernel_fourier, h_profile, pressure) contact_mask = pressure > 0.0 grad_mean = masked_mean(grad_energy, contact_mask) grad_centered = grad_energy - grad_mean grad_contact = jnp.where(contact_mask, grad_centered, 0.0) g_norm = jnp.sum(grad_contact * grad_contact) search_dir = update_search_direction( grad_contact, direction, contact_mask, delta, g_norm, g_old, ) displacement_dir = displacement_from_pressure(kernel_fourier, search_dir) disp_mean = masked_mean(displacement_dir, contact_mask) response = displacement_dir - disp_mean tau_num = jnp.sum(jnp.where(contact_mask, grad_centered * search_dir, 0.0)) tau_den = jnp.sum(jnp.where(contact_mask, response * search_dir, 0.0)) tau = tau_num / (tau_den + 1e-12) pressure_new = pressure - tau * search_dir pressure_new = jnp.where(pressure_new > 0.0, pressure_new, 0.0) inadmissible = jnp.logical_and(pressure_new == 0.0, grad_centered < 0.0) delta_new = jnp.where(jnp.sum(inadmissible) == 0, 1.0, 0.0) pressure_projected = project_total_load(pressure_new, W, L) error_new = compute_error(pressure_projected, grad_centered, h_rms) return ( pressure_projected, search_dir, jnp.where(g_norm > 0.0, g_norm, g_old), delta_new, k + 1, error_new, ) final_state = lax.while_loop( cond_fun, body_fun, ( initial_pressure, initial_direction, jnp.array(1.0), jnp.array(0.0), jnp.array(0), jnp.array(jnp.inf), ), ) pressure, _, _, _, iterations, error = final_state displacement = displacement_from_pressure(kernel_fourier, pressure) return displacement, pressure, int(iterations), float(error) def main(): # Time discretization t0 = 0.0 t1 = 1.0 time_steps = 50 dt = (t1 - t0) / time_steps # Total load W = 1.0 # Geometry L = 2.0 radius = 0.5 S = L**2 # Grid n = 300 m = 300 x_vals = jnp.linspace(0.0, L, n, endpoint=False) y_vals = jnp.linspace(0.0, L, m, endpoint=False) x, y = jnp.meshgrid(x_vals, y_vals, indexing="xy") x0 = 1.0 y0 = 1.0 E = 3.0 nu = 0.5 E_star = E / (1.0 - nu**2) r = jnp.sqrt((x - x0) ** 2 + (y - y0) ** 2) h_profile = -(r**2) / (2.0 * radius) kernel_fourier = build_fourier_kernel(n, m, L, E_star) # Maxwell model parameters G_inf = 2.75 G_branches = jnp.array([2.75, 2.75]) tau_branches = jnp.array([0.1, 1.0]) eta_branches = G_branches * tau_branches gamma = tau_branches / (tau_branches + dt) G_tilde = jnp.sum(gamma * G_branches) alpha = G_inf + G_tilde beta = G_tilde surface = h_profile U = jnp.zeros((n, m)) M = jnp.zeros((G_branches.shape[0], n, m)) # Hertzian references G_maxwell_t0 = jnp.sum(G_branches) G_effective_t0 = G_inf + G_maxwell_t0 E_effective_t0 = 2.0 * G_effective_t0 * (1.0 + nu) / (1.0 - nu**2) p0_t0 = (6.0 * W * (E_effective_t0**2) / (np.pi**3 * radius**2)) ** (1.0 / 3.0) a_t0 = (3.0 * W * radius / (4.0 * E_effective_t0)) ** (1.0 / 3.0) E_effective_inf = 2.0 * G_inf * (1.0 + nu) / (1.0 - nu**2) p0_t_inf = (6.0 * W * (E_effective_inf**2) / (np.pi**3 * radius**2)) ** (1.0 / 3.0) a_t_inf = (3.0 * W * radius / (4.0 * E_effective_inf)) ** (1.0 / 3.0) pressure_distributions = [] contact_areas = [] iteration_log = [] start_time = time.perf_counter() # I think I should avoid using for loops in JAX for step in range(time_steps): M_maxwell = jnp.tensordot(gamma, M, axes=1) H_new = alpha * surface - beta * U + M_maxwell displacement, pressure, iterations, residual = contact_solver_autodiff( kernel_fourier, H_new, W, L, tol=1e-6, iter_max=200, ) U_new = (displacement - M_maxwell + beta * U) / alpha delta_U = U_new - U M = gamma[:, None, None] * (M + G_branches[:, None, None] * delta_U) area_ratio = jnp.mean(pressure > 0.0) contact_area = float(area_ratio * S) contact_areas.append(contact_area) pressure_midline = np.array(jax.device_get(pressure[n // 2])) pressure_distributions.append(pressure_midline) iteration_log.append((iterations, residual)) U = U_new end_time = time.perf_counter() print("Simulation time:", end_time - start_time, "seconds") x_np = np.array(jax.device_get(x)) def update(frame): ax.clear() ax.set_xlim(0, L) ax.set_ylim(0, 1.1 * p0_t0) ax.grid(True) ax.plot( x_np[n // 2], p0_t0 * np.sqrt(np.maximum(0.0, 1.0 - (x_np[n // 2] - x0) ** 2 / a_t0**2)), "g--", label="Hertz t=0", ) ax.plot( x_np[n // 2], p0_t_inf * np.sqrt(np.maximum(0.0, 1.0 - (x_np[n // 2] - x0) ** 2 / a_t_inf**2)), "b--", label="Hertz t=inf", ) ax.plot(x_np[n // 2], pressure_distributions[frame], "r-", label="Numerical") ax.set_title(f"Time = {t0 + frame * dt:.2f}s") ax.set_xlabel("x") ax.set_ylabel("Pressure distribution") ax.legend(loc="upper right") fig, ax = plt.subplots() ani = FuncAnimation(fig, update, frames=len(pressure_distributions), repeat=False) plt.show() Ac_hertz_t0 = np.pi * a_t0**2 Ac_hertz_t_inf = np.pi * a_t_inf**2 print("Iterations and residuals per step:") for idx, (iterations, residual) in enumerate(iteration_log): print(f" step {idx:02d}: {iterations:3d} iterations, residual={residual:.3e}") print("Analytical contact area radius at t0:", float(a_t0)) print("Analytical contact area radius at t_inf:", float(a_t_inf)) print("Analytical maximum pressure at t0:", float(p0_t0)) print("Analytical maximum pressure at t_inf:", float(p0_t_inf)) print("Numerical contact area at t0:", contact_areas[0]) print("Numerical contact area at t_inf:", contact_areas[-1]) print("Analytical contact area at t0:", float(Ac_hertz_t0)) print("Analytical contact area at t_inf:", float(Ac_hertz_t_inf)) time_axis = np.arange(t0, t1, dt) plt.figure() plt.plot(time_axis, contact_areas) plt.axhline(Ac_hertz_t0, color="red", linestyle="dotted") plt.axhline(Ac_hertz_t_inf, color="blue", linestyle="dotted") plt.xlabel("Time(s)") plt.ylabel("Contact area($m^2$)") plt.legend(["Numerical", "Hertz at t=0", "Hertz at t=inf"]) plt.title("Contact area vs time for multi-branch Generalized Maxwell model") plt.show() if __name__ == "__main__": main()