good_simulation_practices/JAX/tests/JAX_GMM.py

360 lines
10 KiB
Python

"""JAX implementation of the generalized Maxwell contact solver.
This script mirrors the NumPy-based reference in
`Multi_branches_generalized_Maxwell.py`, but leverages JAX for automatic
differentiation and JIT compilation. The automatic gradient of the elastic
energy drives the constrained conjugate-gradient contact solver.
Running this file produces the same diagnostic plots as the reference
implementation while keeping all heavy lifting on the accelerator-enabled JAX
backend.
"""
from __future__ import annotations
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import jax
import jax.numpy as jnp
from jax import lax
import time
# Enable double precision for improved numerical stability.
jax.config.update("jax_enable_x64", True)
def build_fourier_kernel(n: int, m: int, L: float, E_star: float) -> jnp.ndarray:
"""Assemble the Fourier-domain kernel for the half-space Green's function."""
q_x = 2.0 * np.pi * jnp.fft.fftfreq(n, d=L / n)
q_y = 2.0 * np.pi * jnp.fft.fftfreq(m, d=L / m)
QX, QY = jnp.meshgrid(q_x, q_y, indexing="xy")
q_norm = jnp.sqrt(QX**2 + QY**2)
kernel = jnp.where(q_norm > 0.0, 2.0 / (E_star * q_norm), 0.0)
return kernel
@jax.jit
def displacement_from_pressure(
kernel_fourier: jnp.ndarray, pressure: jnp.ndarray
) -> jnp.ndarray:
"""Return the surface displacement induced by the supplied pressure field."""
pressure_fft = jnp.fft.fft2(pressure, norm="ortho")
displacement_fft = pressure_fft * kernel_fourier
displacement = jnp.fft.ifft2(displacement_fft, norm="ortho").real
return displacement
def elastic_energy(
kernel_fourier: jnp.ndarray, h_profile: jnp.ndarray, pressure: jnp.ndarray
) -> jnp.ndarray:
"""Elastic energy functional; its gradient yields the gap field."""
displacement = displacement_from_pressure(kernel_fourier, pressure)
stored = 0.5 * jnp.sum(pressure * displacement)
work = jnp.sum(pressure * h_profile)
return stored - work
value_and_grad_energy = jax.jit(jax.value_and_grad(elastic_energy, argnums=2))
@jax.jit
def project_total_load(pressure: jnp.ndarray, W: float, L: float) -> jnp.ndarray:
"""Project the pressure field onto the admissible set enforcing total load."""
mean_pressure = jnp.mean(pressure)
target = W / (L**2)
scale = jnp.where(mean_pressure > 0.0, target / mean_pressure, 0.0)
projected = jnp.where(mean_pressure > 0.0, pressure * scale, jnp.full_like(pressure, target))
return projected
@jax.jit
def masked_mean(values: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray:
"""Compute the mean over the masked region, guarding against empty sets."""
count = jnp.sum(mask)
total = jnp.sum(jnp.where(mask, values, 0.0))
return jnp.where(count > 0, total / count, 0.0)
@jax.jit
def compute_error(
pressure: jnp.ndarray,
gradient: jnp.ndarray,
h_rms: float,
) -> jnp.ndarray:
"""Scaled complementarity error used as stopping criterion."""
num = jnp.vdot(pressure.reshape(-1), gradient - jnp.min(gradient))
denom = jnp.sum(pressure) * h_rms + 1e-12
return jnp.abs(num / denom)
@jax.jit
def update_search_direction(
gradient: jnp.ndarray,
direction: jnp.ndarray,
contact_mask: jnp.ndarray,
delta: float,
g_norm: float,
g_old: float,
) -> jnp.ndarray:
"""Conjugate-gradient style update with projection onto the contact set."""
beta_cg = jnp.where(g_old > 0.0, delta * g_norm / (g_old + 1e-12), 0.0)
updated = gradient + beta_cg * direction
return jnp.where(contact_mask, updated, 0.0)
def contact_solver_autodiff(
kernel_fourier: jnp.ndarray,
h_profile: jnp.ndarray,
W: float,
L: float,
tol: float = 1e-6,
iter_max: int = 200,
):
"""Solve the constrained contact problem via autodiff-powered CG iterations."""
h_rms = jnp.std(h_profile)
initial_pressure = jnp.full_like(h_profile, W / (L**2))
initial_direction = jnp.zeros_like(initial_pressure)
iter_max_jnp = jnp.array(iter_max)
def cond_fun(state):
_, _, _, _, k, error = state
return jnp.logical_and(error > tol, k < iter_max_jnp)
def body_fun(state):
pressure, direction, g_old, delta, k, _ = state
_, grad_energy = value_and_grad_energy(kernel_fourier, h_profile, pressure)
contact_mask = pressure > 0.0
grad_mean = masked_mean(grad_energy, contact_mask)
grad_centered = grad_energy - grad_mean
grad_contact = jnp.where(contact_mask, grad_centered, 0.0)
g_norm = jnp.sum(grad_contact * grad_contact)
search_dir = update_search_direction(
grad_contact,
direction,
contact_mask,
delta,
g_norm,
g_old,
)
displacement_dir = displacement_from_pressure(kernel_fourier, search_dir)
disp_mean = masked_mean(displacement_dir, contact_mask)
response = displacement_dir - disp_mean
tau_num = jnp.sum(jnp.where(contact_mask, grad_centered * search_dir, 0.0))
tau_den = jnp.sum(jnp.where(contact_mask, response * search_dir, 0.0))
tau = tau_num / (tau_den + 1e-12)
pressure_new = pressure - tau * search_dir
pressure_new = jnp.where(pressure_new > 0.0, pressure_new, 0.0)
inadmissible = jnp.logical_and(pressure_new == 0.0, grad_centered < 0.0)
delta_new = jnp.where(jnp.sum(inadmissible) == 0, 1.0, 0.0)
pressure_projected = project_total_load(pressure_new, W, L)
error_new = compute_error(pressure_projected, grad_centered, h_rms)
return (
pressure_projected,
search_dir,
jnp.where(g_norm > 0.0, g_norm, g_old),
delta_new,
k + 1,
error_new,
)
final_state = lax.while_loop(
cond_fun,
body_fun,
(
initial_pressure,
initial_direction,
jnp.array(1.0),
jnp.array(0.0),
jnp.array(0),
jnp.array(jnp.inf),
),
)
pressure, _, _, _, iterations, error = final_state
displacement = displacement_from_pressure(kernel_fourier, pressure)
return displacement, pressure, int(iterations), float(error)
def main():
# Time discretization
t0 = 0.0
t1 = 1.0
time_steps = 50
dt = (t1 - t0) / time_steps
# Total load
W = 1.0
# Geometry
L = 2.0
radius = 0.5
S = L**2
# Grid
n = 300
m = 300
x_vals = jnp.linspace(0.0, L, n, endpoint=False)
y_vals = jnp.linspace(0.0, L, m, endpoint=False)
x, y = jnp.meshgrid(x_vals, y_vals, indexing="xy")
x0 = 1.0
y0 = 1.0
E = 3.0
nu = 0.5
E_star = E / (1.0 - nu**2)
r = jnp.sqrt((x - x0) ** 2 + (y - y0) ** 2)
h_profile = -(r**2) / (2.0 * radius)
kernel_fourier = build_fourier_kernel(n, m, L, E_star)
# Maxwell model parameters
G_inf = 2.75
G_branches = jnp.array([2.75, 2.75])
tau_branches = jnp.array([0.1, 1.0])
eta_branches = G_branches * tau_branches
gamma = tau_branches / (tau_branches + dt)
G_tilde = jnp.sum(gamma * G_branches)
alpha = G_inf + G_tilde
beta = G_tilde
surface = h_profile
U = jnp.zeros((n, m))
M = jnp.zeros((G_branches.shape[0], n, m))
# Hertzian references
G_maxwell_t0 = jnp.sum(G_branches)
G_effective_t0 = G_inf + G_maxwell_t0
E_effective_t0 = 2.0 * G_effective_t0 * (1.0 + nu) / (1.0 - nu**2)
p0_t0 = (6.0 * W * (E_effective_t0**2) / (np.pi**3 * radius**2)) ** (1.0 / 3.0)
a_t0 = (3.0 * W * radius / (4.0 * E_effective_t0)) ** (1.0 / 3.0)
E_effective_inf = 2.0 * G_inf * (1.0 + nu) / (1.0 - nu**2)
p0_t_inf = (6.0 * W * (E_effective_inf**2) / (np.pi**3 * radius**2)) ** (1.0 / 3.0)
a_t_inf = (3.0 * W * radius / (4.0 * E_effective_inf)) ** (1.0 / 3.0)
pressure_distributions = []
contact_areas = []
iteration_log = []
start_time = time.perf_counter()
# I think I should avoid using for loops in JAX
for step in range(time_steps):
M_maxwell = jnp.tensordot(gamma, M, axes=1)
H_new = alpha * surface - beta * U + M_maxwell
displacement, pressure, iterations, residual = contact_solver_autodiff(
kernel_fourier,
H_new,
W,
L,
tol=1e-6,
iter_max=200,
)
U_new = (displacement - M_maxwell + beta * U) / alpha
delta_U = U_new - U
M = gamma[:, None, None] * (M + G_branches[:, None, None] * delta_U)
area_ratio = jnp.mean(pressure > 0.0)
contact_area = float(area_ratio * S)
contact_areas.append(contact_area)
pressure_midline = np.array(jax.device_get(pressure[n // 2]))
pressure_distributions.append(pressure_midline)
iteration_log.append((iterations, residual))
U = U_new
end_time = time.perf_counter()
print("Simulation time:", end_time - start_time, "seconds")
x_np = np.array(jax.device_get(x))
def update(frame):
ax.clear()
ax.set_xlim(0, L)
ax.set_ylim(0, 1.1 * p0_t0)
ax.grid(True)
ax.plot(
x_np[n // 2],
p0_t0 * np.sqrt(np.maximum(0.0, 1.0 - (x_np[n // 2] - x0) ** 2 / a_t0**2)),
"g--",
label="Hertz t=0",
)
ax.plot(
x_np[n // 2],
p0_t_inf * np.sqrt(np.maximum(0.0, 1.0 - (x_np[n // 2] - x0) ** 2 / a_t_inf**2)),
"b--",
label="Hertz t=inf",
)
ax.plot(x_np[n // 2], pressure_distributions[frame], "r-", label="Numerical")
ax.set_title(f"Time = {t0 + frame * dt:.2f}s")
ax.set_xlabel("x")
ax.set_ylabel("Pressure distribution")
ax.legend(loc="upper right")
fig, ax = plt.subplots()
ani = FuncAnimation(fig, update, frames=len(pressure_distributions), repeat=False)
plt.show()
Ac_hertz_t0 = np.pi * a_t0**2
Ac_hertz_t_inf = np.pi * a_t_inf**2
print("Iterations and residuals per step:")
for idx, (iterations, residual) in enumerate(iteration_log):
print(f" step {idx:02d}: {iterations:3d} iterations, residual={residual:.3e}")
print("Analytical contact area radius at t0:", float(a_t0))
print("Analytical contact area radius at t_inf:", float(a_t_inf))
print("Analytical maximum pressure at t0:", float(p0_t0))
print("Analytical maximum pressure at t_inf:", float(p0_t_inf))
print("Numerical contact area at t0:", contact_areas[0])
print("Numerical contact area at t_inf:", contact_areas[-1])
print("Analytical contact area at t0:", float(Ac_hertz_t0))
print("Analytical contact area at t_inf:", float(Ac_hertz_t_inf))
time_axis = np.arange(t0, t1, dt)
plt.figure()
plt.plot(time_axis, contact_areas)
plt.axhline(Ac_hertz_t0, color="red", linestyle="dotted")
plt.axhline(Ac_hertz_t_inf, color="blue", linestyle="dotted")
plt.xlabel("Time(s)")
plt.ylabel("Contact area($m^2$)")
plt.legend(["Numerical", "Hertz at t=0", "Hertz at t=inf"])
plt.title("Contact area vs time for multi-branch Generalized Maxwell model")
plt.show()
if __name__ == "__main__":
main()