forked from lfrerot/good_simulation_practices
349 lines
9.5 KiB
Python
349 lines
9.5 KiB
Python
"""JAX generalized Maxwell contact solver without Python loops.
|
|
|
|
This variant removes the explicit Python time-stepping loop from
|
|
`JAX_GMM.py` by relying on `jax.lax.scan`, which keeps all temporally
|
|
coupled computations staged inside JAX's computation graph. The contact
|
|
solver remains identical but is compatible with scanning so the entire
|
|
transient solves as a single JIT-compiled program once the graph is
|
|
traced.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.animation import FuncAnimation
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import lax
|
|
|
|
|
|
jax.config.update("jax_enable_x64", True)
|
|
|
|
|
|
def build_fourier_kernel(n: int, m: int, L: float, E_star: float) -> jnp.ndarray:
|
|
q_x = 2.0 * jnp.pi * jnp.fft.fftfreq(n, d=L / n)
|
|
q_y = 2.0 * jnp.pi * jnp.fft.fftfreq(m, d=L / m)
|
|
QX, QY = jnp.meshgrid(q_x, q_y, indexing="xy")
|
|
q_norm = jnp.sqrt(QX**2 + QY**2)
|
|
return jnp.where(q_norm > 0.0, 2.0 / (E_star * q_norm), 0.0)
|
|
|
|
|
|
@jax.jit
|
|
def displacement_from_pressure(kernel_fourier: jnp.ndarray, pressure: jnp.ndarray) -> jnp.ndarray:
|
|
pressure_fft = jnp.fft.fft2(pressure, norm="ortho")
|
|
displacement_fft = pressure_fft * kernel_fourier
|
|
return jnp.fft.ifft2(displacement_fft, norm="ortho").real
|
|
|
|
|
|
def elastic_energy(kernel_fourier: jnp.ndarray, h_profile: jnp.ndarray, pressure: jnp.ndarray) -> jnp.ndarray:
|
|
displacement = displacement_from_pressure(kernel_fourier, pressure)
|
|
stored = 0.5 * jnp.sum(pressure * displacement)
|
|
work = jnp.sum(pressure * h_profile)
|
|
return stored - work
|
|
|
|
|
|
value_and_grad_energy = jax.jit(jax.value_and_grad(elastic_energy, argnums=2))
|
|
|
|
|
|
@jax.jit
|
|
def project_total_load(pressure: jnp.ndarray, W: float, L: float) -> jnp.ndarray:
|
|
mean_pressure = jnp.mean(pressure)
|
|
target = W / (L**2)
|
|
scale = jnp.where(mean_pressure > 0.0, target / mean_pressure, 0.0)
|
|
return jnp.where(mean_pressure > 0.0, pressure * scale, jnp.full_like(pressure, target))
|
|
|
|
|
|
@jax.jit
|
|
def masked_mean(values: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray:
|
|
count = jnp.sum(mask)
|
|
total = jnp.sum(jnp.where(mask, values, 0.0))
|
|
return jnp.where(count > 0, total / count, 0.0)
|
|
|
|
|
|
@jax.jit
|
|
def compute_error(pressure: jnp.ndarray, gradient: jnp.ndarray, h_rms: float) -> jnp.ndarray:
|
|
num = jnp.vdot(pressure.reshape(-1), gradient - jnp.min(gradient))
|
|
denom = jnp.sum(pressure) * h_rms + 1e-12
|
|
return jnp.abs(num / denom)
|
|
|
|
|
|
@jax.jit
|
|
def update_search_direction(
|
|
gradient: jnp.ndarray,
|
|
direction: jnp.ndarray,
|
|
contact_mask: jnp.ndarray,
|
|
delta: float,
|
|
g_norm: float,
|
|
g_old: float,
|
|
) -> jnp.ndarray:
|
|
beta_cg = jnp.where(g_old > 0.0, delta * g_norm / (g_old + 1e-12), 0.0)
|
|
updated = gradient + beta_cg * direction
|
|
return jnp.where(contact_mask, updated, 0.0)
|
|
|
|
|
|
def contact_solver_autodiff(
|
|
kernel_fourier: jnp.ndarray,
|
|
h_profile: jnp.ndarray,
|
|
W: float,
|
|
L: float,
|
|
tol: float = 1e-6,
|
|
iter_max: int = 200,
|
|
):
|
|
h_rms = jnp.std(h_profile)
|
|
initial_pressure = jnp.full_like(h_profile, W / (L**2))
|
|
initial_direction = jnp.zeros_like(initial_pressure)
|
|
iter_max_jnp = jnp.array(iter_max)
|
|
|
|
def cond_fun(state):
|
|
_, _, _, _, k, error = state
|
|
return jnp.logical_and(error > tol, k < iter_max_jnp)
|
|
|
|
def body_fun(state):
|
|
pressure, direction, g_old, delta, k, _ = state
|
|
|
|
_, grad_energy = value_and_grad_energy(kernel_fourier, h_profile, pressure)
|
|
contact_mask = pressure > 0.0
|
|
|
|
grad_mean = masked_mean(grad_energy, contact_mask)
|
|
grad_centered = grad_energy - grad_mean
|
|
grad_contact = jnp.where(contact_mask, grad_centered, 0.0)
|
|
|
|
g_norm = jnp.sum(grad_contact * grad_contact)
|
|
search_dir = update_search_direction(grad_contact, direction, contact_mask, delta, g_norm, g_old)
|
|
|
|
displacement_dir = displacement_from_pressure(kernel_fourier, search_dir)
|
|
disp_mean = masked_mean(displacement_dir, contact_mask)
|
|
response = displacement_dir - disp_mean
|
|
|
|
tau_num = jnp.sum(jnp.where(contact_mask, grad_centered * search_dir, 0.0))
|
|
tau_den = jnp.sum(jnp.where(contact_mask, response * search_dir, 0.0))
|
|
tau = tau_num / (tau_den + 1e-12)
|
|
|
|
pressure_new = jnp.maximum(pressure - tau * search_dir, 0.0)
|
|
|
|
inadmissible = jnp.logical_and(pressure_new == 0.0, grad_centered < 0.0)
|
|
delta_new = jnp.where(jnp.sum(inadmissible) == 0, 1.0, 0.0)
|
|
|
|
pressure_projected = project_total_load(pressure_new, W, L)
|
|
error_new = compute_error(pressure_projected, grad_centered, h_rms)
|
|
|
|
return (
|
|
pressure_projected,
|
|
search_dir,
|
|
jnp.where(g_norm > 0.0, g_norm, g_old),
|
|
delta_new,
|
|
k + 1,
|
|
error_new,
|
|
)
|
|
|
|
final_state = lax.while_loop(
|
|
cond_fun,
|
|
body_fun,
|
|
(
|
|
initial_pressure,
|
|
initial_direction,
|
|
jnp.array(1.0),
|
|
jnp.array(0.0),
|
|
jnp.array(0),
|
|
jnp.array(jnp.inf),
|
|
),
|
|
)
|
|
|
|
pressure, _, _, _, iterations, error = final_state
|
|
displacement = displacement_from_pressure(kernel_fourier, pressure)
|
|
return displacement, pressure, iterations, error
|
|
|
|
|
|
def run_simulation():
|
|
t0 = 0.0
|
|
t1 = 1.0
|
|
time_steps = 50
|
|
dt = (t1 - t0) / time_steps
|
|
|
|
W = 1.0
|
|
|
|
L = 2.0
|
|
radius = 0.5
|
|
S = L**2
|
|
|
|
n = 300
|
|
m = 300
|
|
x_vals = jnp.linspace(0.0, L, n, endpoint=False)
|
|
y_vals = jnp.linspace(0.0, L, m, endpoint=False)
|
|
x, y = jnp.meshgrid(x_vals, y_vals, indexing="xy")
|
|
|
|
x0 = 1.0
|
|
y0 = 1.0
|
|
|
|
E = 3.0
|
|
nu = 0.5
|
|
E_star = E / (1.0 - nu**2)
|
|
|
|
r = jnp.sqrt((x - x0) ** 2 + (y - y0) ** 2)
|
|
h_profile = -(r**2) / (2.0 * radius)
|
|
|
|
kernel_fourier = build_fourier_kernel(n, m, L, E_star)
|
|
|
|
G_inf = 2.75
|
|
G_branches = jnp.array([2.75, 2.75])
|
|
tau_branches = jnp.array([0.1, 1.0])
|
|
|
|
gamma = tau_branches / (tau_branches + dt)
|
|
G_tilde = jnp.sum(gamma * G_branches)
|
|
alpha = G_inf + G_tilde
|
|
beta = G_tilde
|
|
|
|
surface = h_profile
|
|
U0 = jnp.zeros((n, m))
|
|
M0 = jnp.zeros((G_branches.shape[0], n, m))
|
|
|
|
def scan_step(carry, _):
|
|
U, M = carry
|
|
|
|
M_maxwell = jnp.tensordot(gamma, M, axes=1)
|
|
H_new = alpha * surface - beta * U + M_maxwell
|
|
|
|
displacement, pressure, iterations, residual = contact_solver_autodiff(
|
|
kernel_fourier,
|
|
H_new,
|
|
W,
|
|
L,
|
|
tol=1e-6,
|
|
iter_max=200,
|
|
)
|
|
|
|
U_new = (displacement - M_maxwell + beta * U) / alpha
|
|
delta_U = U_new - U
|
|
M_new = gamma[:, None, None] * (M + G_branches[:, None, None] * delta_U)
|
|
|
|
midline = pressure[n // 2]
|
|
contact_area = jnp.mean(pressure > 0.0) * S
|
|
|
|
return (U_new, M_new), (midline, contact_area, iterations, residual)
|
|
|
|
(_, _), outputs = lax.scan(
|
|
scan_step,
|
|
(U0, M0),
|
|
xs=None,
|
|
length=time_steps,
|
|
)
|
|
|
|
midlines, contact_areas, iterations, residuals = outputs
|
|
|
|
G_maxwell_t0 = jnp.sum(G_branches)
|
|
G_effective_t0 = G_inf + G_maxwell_t0
|
|
E_effective_t0 = 2.0 * G_effective_t0 * (1.0 + nu) / (1.0 - nu**2)
|
|
p0_t0 = (6.0 * W * (E_effective_t0**2) / (jnp.pi**3 * radius**2)) ** (1.0 / 3.0)
|
|
a_t0 = (3.0 * W * radius / (4.0 * E_effective_t0)) ** (1.0 / 3.0)
|
|
|
|
E_effective_inf = 2.0 * G_inf * (1.0 + nu) / (1.0 - nu**2)
|
|
p0_t_inf = (6.0 * W * (E_effective_inf**2) / (jnp.pi**3 * radius**2)) ** (1.0 / 3.0)
|
|
a_t_inf = (3.0 * W * radius / (4.0 * E_effective_inf)) ** (1.0 / 3.0)
|
|
|
|
return {
|
|
"x": x,
|
|
"midlines": midlines,
|
|
"contact_areas": contact_areas,
|
|
"iterations": iterations,
|
|
"residuals": residuals,
|
|
"params": {
|
|
"t0": t0,
|
|
"dt": dt,
|
|
"L": L,
|
|
"radius": radius,
|
|
"x0": x0,
|
|
"p0_t0": p0_t0,
|
|
"p0_t_inf": p0_t_inf,
|
|
"a_t0": a_t0,
|
|
"a_t_inf": a_t_inf,
|
|
"S": S,
|
|
},
|
|
}
|
|
|
|
|
|
def main():
|
|
start_time = time.perf_counter()
|
|
results = run_simulation()
|
|
total_time = time.perf_counter() - start_time
|
|
print("Simulation time:", total_time, "seconds")
|
|
|
|
x = jax.device_get(results["x"])
|
|
midlines = jax.device_get(results["midlines"])
|
|
contact_areas = jax.device_get(results["contact_areas"])
|
|
iterations = jax.device_get(results["iterations"]).astype(int)
|
|
residuals = jax.device_get(results["residuals"])
|
|
|
|
params = results["params"]
|
|
t0 = float(params["t0"])
|
|
dt = float(params["dt"])
|
|
L = float(params["L"])
|
|
x0 = float(params["x0"])
|
|
p0_t0 = float(params["p0_t0"])
|
|
p0_t_inf = float(params["p0_t_inf"])
|
|
a_t0 = float(params["a_t0"])
|
|
a_t_inf = float(params["a_t_inf"])
|
|
S = float(params["S"])
|
|
|
|
time_axis = t0 + dt * jnp.arange(midlines.shape[0])
|
|
time_axis_np = jax.device_get(time_axis)
|
|
|
|
def update(frame):
|
|
ax.clear()
|
|
ax.set_xlim(0, L)
|
|
ax.set_ylim(0, 1.1 * p0_t0)
|
|
ax.grid(True)
|
|
|
|
x_mid = x[int(x.shape[0] / 2)]
|
|
ax.plot(
|
|
x_mid,
|
|
p0_t0 * np.sqrt(np.maximum(0.0, 1.0 - (x_mid - x0) ** 2 / a_t0**2)),
|
|
"g--",
|
|
label="Hertz t=0",
|
|
)
|
|
ax.plot(
|
|
x_mid,
|
|
p0_t_inf * np.sqrt(np.maximum(0.0, 1.0 - (x_mid - x0) ** 2 / a_t_inf**2)),
|
|
"b--",
|
|
label="Hertz t=inf",
|
|
)
|
|
ax.plot(x_mid, midlines[frame], "r-", label="Numerical")
|
|
ax.set_title(f"Time = {t0 + frame * dt:.2f}s")
|
|
ax.set_xlabel("x")
|
|
ax.set_ylabel("Pressure distribution")
|
|
ax.legend(loc="upper right")
|
|
|
|
fig, ax = plt.subplots()
|
|
ani = FuncAnimation(fig, update, frames=midlines.shape[0], repeat=False)
|
|
plt.show()
|
|
|
|
Ac_hertz_t0 = jnp.pi * a_t0**2
|
|
Ac_hertz_t_inf = jnp.pi * a_t_inf**2
|
|
|
|
print("Iterations and residuals per step:")
|
|
for idx, (its, res) in enumerate(zip(iterations, residuals)):
|
|
print(f" step {idx:02d}: {its:3d} iterations, residual={res:.3e}")
|
|
|
|
print("Analytical contact area at t0:", float(Ac_hertz_t0))
|
|
print("Analytical contact area at t_inf:", float(Ac_hertz_t_inf))
|
|
print("Numerical contact area at t0:", float(contact_areas[0]))
|
|
print("Numerical contact area at t_inf:", float(contact_areas[-1]))
|
|
|
|
plt.figure()
|
|
plt.plot(time_axis_np, contact_areas)
|
|
plt.axhline(Ac_hertz_t0, color="red", linestyle="dotted")
|
|
plt.axhline(Ac_hertz_t_inf, color="blue", linestyle="dotted")
|
|
plt.xlabel("Time(s)")
|
|
plt.ylabel("Contact area($m^2$)")
|
|
plt.legend(["Numerical", "Hertz at t=0", "Hertz at t=inf"])
|
|
plt.title("Contact area vs time for multi-branch Generalized Maxwell model")
|
|
plt.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|