From cc14b7b2b13b21c512d2650f9d010c532a5d515b Mon Sep 17 00:00:00 2001 From: Zichen LI Date: Fri, 5 Dec 2025 11:39:38 +0100 Subject: [PATCH] With an example of Hertz viscoelastic contact, we compare the computation efficiency of JAX and Tamaas - more tests need to be done - rough surfaces need to be considered --- JAX/tests/JAX_GMM.py | 359 ++++++++++++++++++ JAX/tests/JAX_GMM_without_for.py | 348 +++++++++++++++++ .../Multi_branches_generalized_Maxwell.py | 294 ++++++++++++++ JAX/tests/Tamaas_GMM.py | 70 ++++ 4 files changed, 1071 insertions(+) create mode 100644 JAX/tests/JAX_GMM.py create mode 100644 JAX/tests/JAX_GMM_without_for.py create mode 100644 JAX/tests/Multi_branches_generalized_Maxwell.py create mode 100644 JAX/tests/Tamaas_GMM.py diff --git a/JAX/tests/JAX_GMM.py b/JAX/tests/JAX_GMM.py new file mode 100644 index 0000000..d1b30bd --- /dev/null +++ b/JAX/tests/JAX_GMM.py @@ -0,0 +1,359 @@ +"""JAX implementation of the generalized Maxwell contact solver. + +This script mirrors the NumPy-based reference in +`Multi_branches_generalized_Maxwell.py`, but leverages JAX for automatic +differentiation and JIT compilation. The automatic gradient of the elastic +energy drives the constrained conjugate-gradient contact solver. + +Running this file produces the same diagnostic plots as the reference +implementation while keeping all heavy lifting on the accelerator-enabled JAX +backend. +""" + +from __future__ import annotations + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation + +import jax +import jax.numpy as jnp +from jax import lax + +import time + + +# Enable double precision for improved numerical stability. +jax.config.update("jax_enable_x64", True) + + +def build_fourier_kernel(n: int, m: int, L: float, E_star: float) -> jnp.ndarray: + """Assemble the Fourier-domain kernel for the half-space Green's function.""" + + q_x = 2.0 * np.pi * jnp.fft.fftfreq(n, d=L / n) + q_y = 2.0 * np.pi * jnp.fft.fftfreq(m, d=L / m) + QX, QY = jnp.meshgrid(q_x, q_y, indexing="xy") + q_norm = jnp.sqrt(QX**2 + QY**2) + kernel = jnp.where(q_norm > 0.0, 2.0 / (E_star * q_norm), 0.0) + return kernel + + +@jax.jit +def displacement_from_pressure( + kernel_fourier: jnp.ndarray, pressure: jnp.ndarray +) -> jnp.ndarray: + """Return the surface displacement induced by the supplied pressure field.""" + + pressure_fft = jnp.fft.fft2(pressure, norm="ortho") + displacement_fft = pressure_fft * kernel_fourier + displacement = jnp.fft.ifft2(displacement_fft, norm="ortho").real + return displacement + + +def elastic_energy( + kernel_fourier: jnp.ndarray, h_profile: jnp.ndarray, pressure: jnp.ndarray +) -> jnp.ndarray: + """Elastic energy functional; its gradient yields the gap field.""" + + displacement = displacement_from_pressure(kernel_fourier, pressure) + stored = 0.5 * jnp.sum(pressure * displacement) + work = jnp.sum(pressure * h_profile) + return stored - work + + +value_and_grad_energy = jax.jit(jax.value_and_grad(elastic_energy, argnums=2)) + + +@jax.jit +def project_total_load(pressure: jnp.ndarray, W: float, L: float) -> jnp.ndarray: + """Project the pressure field onto the admissible set enforcing total load.""" + + mean_pressure = jnp.mean(pressure) + target = W / (L**2) + scale = jnp.where(mean_pressure > 0.0, target / mean_pressure, 0.0) + projected = jnp.where(mean_pressure > 0.0, pressure * scale, jnp.full_like(pressure, target)) + return projected + + +@jax.jit +def masked_mean(values: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: + """Compute the mean over the masked region, guarding against empty sets.""" + + count = jnp.sum(mask) + total = jnp.sum(jnp.where(mask, values, 0.0)) + return jnp.where(count > 0, total / count, 0.0) + + +@jax.jit +def compute_error( + pressure: jnp.ndarray, + gradient: jnp.ndarray, + h_rms: float, +) -> jnp.ndarray: + """Scaled complementarity error used as stopping criterion.""" + + num = jnp.vdot(pressure.reshape(-1), gradient - jnp.min(gradient)) + denom = jnp.sum(pressure) * h_rms + 1e-12 + return jnp.abs(num / denom) + + +@jax.jit +def update_search_direction( + gradient: jnp.ndarray, + direction: jnp.ndarray, + contact_mask: jnp.ndarray, + delta: float, + g_norm: float, + g_old: float, +) -> jnp.ndarray: + """Conjugate-gradient style update with projection onto the contact set.""" + + beta_cg = jnp.where(g_old > 0.0, delta * g_norm / (g_old + 1e-12), 0.0) + updated = gradient + beta_cg * direction + return jnp.where(contact_mask, updated, 0.0) + + +def contact_solver_autodiff( + kernel_fourier: jnp.ndarray, + h_profile: jnp.ndarray, + W: float, + L: float, + tol: float = 1e-6, + iter_max: int = 200, +): + """Solve the constrained contact problem via autodiff-powered CG iterations.""" + + h_rms = jnp.std(h_profile) + initial_pressure = jnp.full_like(h_profile, W / (L**2)) + initial_direction = jnp.zeros_like(initial_pressure) + + iter_max_jnp = jnp.array(iter_max) + + def cond_fun(state): + _, _, _, _, k, error = state + return jnp.logical_and(error > tol, k < iter_max_jnp) + + def body_fun(state): + pressure, direction, g_old, delta, k, _ = state + + _, grad_energy = value_and_grad_energy(kernel_fourier, h_profile, pressure) + contact_mask = pressure > 0.0 + + grad_mean = masked_mean(grad_energy, contact_mask) + grad_centered = grad_energy - grad_mean + grad_contact = jnp.where(contact_mask, grad_centered, 0.0) + + g_norm = jnp.sum(grad_contact * grad_contact) + + search_dir = update_search_direction( + grad_contact, + direction, + contact_mask, + delta, + g_norm, + g_old, + ) + + displacement_dir = displacement_from_pressure(kernel_fourier, search_dir) + disp_mean = masked_mean(displacement_dir, contact_mask) + response = displacement_dir - disp_mean + + tau_num = jnp.sum(jnp.where(contact_mask, grad_centered * search_dir, 0.0)) + tau_den = jnp.sum(jnp.where(contact_mask, response * search_dir, 0.0)) + tau = tau_num / (tau_den + 1e-12) + + pressure_new = pressure - tau * search_dir + pressure_new = jnp.where(pressure_new > 0.0, pressure_new, 0.0) + + inadmissible = jnp.logical_and(pressure_new == 0.0, grad_centered < 0.0) + delta_new = jnp.where(jnp.sum(inadmissible) == 0, 1.0, 0.0) + + pressure_projected = project_total_load(pressure_new, W, L) + error_new = compute_error(pressure_projected, grad_centered, h_rms) + + return ( + pressure_projected, + search_dir, + jnp.where(g_norm > 0.0, g_norm, g_old), + delta_new, + k + 1, + error_new, + ) + + final_state = lax.while_loop( + cond_fun, + body_fun, + ( + initial_pressure, + initial_direction, + jnp.array(1.0), + jnp.array(0.0), + jnp.array(0), + jnp.array(jnp.inf), + ), + ) + + pressure, _, _, _, iterations, error = final_state + displacement = displacement_from_pressure(kernel_fourier, pressure) + return displacement, pressure, int(iterations), float(error) + + +def main(): + # Time discretization + t0 = 0.0 + t1 = 1.0 + time_steps = 50 + dt = (t1 - t0) / time_steps + + # Total load + W = 1.0 + + # Geometry + L = 2.0 + radius = 0.5 + S = L**2 + + # Grid + n = 300 + m = 300 + x_vals = jnp.linspace(0.0, L, n, endpoint=False) + y_vals = jnp.linspace(0.0, L, m, endpoint=False) + x, y = jnp.meshgrid(x_vals, y_vals, indexing="xy") + + x0 = 1.0 + y0 = 1.0 + + E = 3.0 + nu = 0.5 + E_star = E / (1.0 - nu**2) + + r = jnp.sqrt((x - x0) ** 2 + (y - y0) ** 2) + h_profile = -(r**2) / (2.0 * radius) + + kernel_fourier = build_fourier_kernel(n, m, L, E_star) + + # Maxwell model parameters + G_inf = 2.75 + G_branches = jnp.array([2.75, 2.75]) + tau_branches = jnp.array([0.1, 1.0]) + eta_branches = G_branches * tau_branches + + gamma = tau_branches / (tau_branches + dt) + G_tilde = jnp.sum(gamma * G_branches) + alpha = G_inf + G_tilde + beta = G_tilde + + surface = h_profile + U = jnp.zeros((n, m)) + M = jnp.zeros((G_branches.shape[0], n, m)) + + # Hertzian references + G_maxwell_t0 = jnp.sum(G_branches) + G_effective_t0 = G_inf + G_maxwell_t0 + E_effective_t0 = 2.0 * G_effective_t0 * (1.0 + nu) / (1.0 - nu**2) + p0_t0 = (6.0 * W * (E_effective_t0**2) / (np.pi**3 * radius**2)) ** (1.0 / 3.0) + a_t0 = (3.0 * W * radius / (4.0 * E_effective_t0)) ** (1.0 / 3.0) + + E_effective_inf = 2.0 * G_inf * (1.0 + nu) / (1.0 - nu**2) + p0_t_inf = (6.0 * W * (E_effective_inf**2) / (np.pi**3 * radius**2)) ** (1.0 / 3.0) + a_t_inf = (3.0 * W * radius / (4.0 * E_effective_inf)) ** (1.0 / 3.0) + + pressure_distributions = [] + contact_areas = [] + iteration_log = [] + + start_time = time.perf_counter() + + # I think I should avoid using for loops in JAX + for step in range(time_steps): + M_maxwell = jnp.tensordot(gamma, M, axes=1) + H_new = alpha * surface - beta * U + M_maxwell + + displacement, pressure, iterations, residual = contact_solver_autodiff( + kernel_fourier, + H_new, + W, + L, + tol=1e-6, + iter_max=200, + ) + + U_new = (displacement - M_maxwell + beta * U) / alpha + delta_U = U_new - U + M = gamma[:, None, None] * (M + G_branches[:, None, None] * delta_U) + + area_ratio = jnp.mean(pressure > 0.0) + contact_area = float(area_ratio * S) + contact_areas.append(contact_area) + + pressure_midline = np.array(jax.device_get(pressure[n // 2])) + pressure_distributions.append(pressure_midline) + + iteration_log.append((iterations, residual)) + U = U_new + + end_time = time.perf_counter() + print("Simulation time:", end_time - start_time, "seconds") + + x_np = np.array(jax.device_get(x)) + + def update(frame): + ax.clear() + ax.set_xlim(0, L) + ax.set_ylim(0, 1.1 * p0_t0) + ax.grid(True) + + ax.plot( + x_np[n // 2], + p0_t0 * np.sqrt(np.maximum(0.0, 1.0 - (x_np[n // 2] - x0) ** 2 / a_t0**2)), + "g--", + label="Hertz t=0", + ) + ax.plot( + x_np[n // 2], + p0_t_inf * np.sqrt(np.maximum(0.0, 1.0 - (x_np[n // 2] - x0) ** 2 / a_t_inf**2)), + "b--", + label="Hertz t=inf", + ) + ax.plot(x_np[n // 2], pressure_distributions[frame], "r-", label="Numerical") + ax.set_title(f"Time = {t0 + frame * dt:.2f}s") + ax.set_xlabel("x") + ax.set_ylabel("Pressure distribution") + ax.legend(loc="upper right") + + fig, ax = plt.subplots() + ani = FuncAnimation(fig, update, frames=len(pressure_distributions), repeat=False) + + plt.show() + + Ac_hertz_t0 = np.pi * a_t0**2 + Ac_hertz_t_inf = np.pi * a_t_inf**2 + + print("Iterations and residuals per step:") + for idx, (iterations, residual) in enumerate(iteration_log): + print(f" step {idx:02d}: {iterations:3d} iterations, residual={residual:.3e}") + + print("Analytical contact area radius at t0:", float(a_t0)) + print("Analytical contact area radius at t_inf:", float(a_t_inf)) + print("Analytical maximum pressure at t0:", float(p0_t0)) + print("Analytical maximum pressure at t_inf:", float(p0_t_inf)) + print("Numerical contact area at t0:", contact_areas[0]) + print("Numerical contact area at t_inf:", contact_areas[-1]) + print("Analytical contact area at t0:", float(Ac_hertz_t0)) + print("Analytical contact area at t_inf:", float(Ac_hertz_t_inf)) + + time_axis = np.arange(t0, t1, dt) + plt.figure() + plt.plot(time_axis, contact_areas) + plt.axhline(Ac_hertz_t0, color="red", linestyle="dotted") + plt.axhline(Ac_hertz_t_inf, color="blue", linestyle="dotted") + plt.xlabel("Time(s)") + plt.ylabel("Contact area($m^2$)") + plt.legend(["Numerical", "Hertz at t=0", "Hertz at t=inf"]) + plt.title("Contact area vs time for multi-branch Generalized Maxwell model") + plt.show() + + +if __name__ == "__main__": + main() + diff --git a/JAX/tests/JAX_GMM_without_for.py b/JAX/tests/JAX_GMM_without_for.py new file mode 100644 index 0000000..8802461 --- /dev/null +++ b/JAX/tests/JAX_GMM_without_for.py @@ -0,0 +1,348 @@ +"""JAX generalized Maxwell contact solver without Python loops. + +This variant removes the explicit Python time-stepping loop from +`JAX_GMM.py` by relying on `jax.lax.scan`, which keeps all temporally +coupled computations staged inside JAX's computation graph. The contact +solver remains identical but is compatible with scanning so the entire +transient solves as a single JIT-compiled program once the graph is +traced. +""" + +from __future__ import annotations + +import time + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation + +import jax +import jax.numpy as jnp +from jax import lax + + +jax.config.update("jax_enable_x64", True) + + +def build_fourier_kernel(n: int, m: int, L: float, E_star: float) -> jnp.ndarray: + q_x = 2.0 * jnp.pi * jnp.fft.fftfreq(n, d=L / n) + q_y = 2.0 * jnp.pi * jnp.fft.fftfreq(m, d=L / m) + QX, QY = jnp.meshgrid(q_x, q_y, indexing="xy") + q_norm = jnp.sqrt(QX**2 + QY**2) + return jnp.where(q_norm > 0.0, 2.0 / (E_star * q_norm), 0.0) + + +@jax.jit +def displacement_from_pressure(kernel_fourier: jnp.ndarray, pressure: jnp.ndarray) -> jnp.ndarray: + pressure_fft = jnp.fft.fft2(pressure, norm="ortho") + displacement_fft = pressure_fft * kernel_fourier + return jnp.fft.ifft2(displacement_fft, norm="ortho").real + + +def elastic_energy(kernel_fourier: jnp.ndarray, h_profile: jnp.ndarray, pressure: jnp.ndarray) -> jnp.ndarray: + displacement = displacement_from_pressure(kernel_fourier, pressure) + stored = 0.5 * jnp.sum(pressure * displacement) + work = jnp.sum(pressure * h_profile) + return stored - work + + +value_and_grad_energy = jax.jit(jax.value_and_grad(elastic_energy, argnums=2)) + + +@jax.jit +def project_total_load(pressure: jnp.ndarray, W: float, L: float) -> jnp.ndarray: + mean_pressure = jnp.mean(pressure) + target = W / (L**2) + scale = jnp.where(mean_pressure > 0.0, target / mean_pressure, 0.0) + return jnp.where(mean_pressure > 0.0, pressure * scale, jnp.full_like(pressure, target)) + + +@jax.jit +def masked_mean(values: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: + count = jnp.sum(mask) + total = jnp.sum(jnp.where(mask, values, 0.0)) + return jnp.where(count > 0, total / count, 0.0) + + +@jax.jit +def compute_error(pressure: jnp.ndarray, gradient: jnp.ndarray, h_rms: float) -> jnp.ndarray: + num = jnp.vdot(pressure.reshape(-1), gradient - jnp.min(gradient)) + denom = jnp.sum(pressure) * h_rms + 1e-12 + return jnp.abs(num / denom) + + +@jax.jit +def update_search_direction( + gradient: jnp.ndarray, + direction: jnp.ndarray, + contact_mask: jnp.ndarray, + delta: float, + g_norm: float, + g_old: float, +) -> jnp.ndarray: + beta_cg = jnp.where(g_old > 0.0, delta * g_norm / (g_old + 1e-12), 0.0) + updated = gradient + beta_cg * direction + return jnp.where(contact_mask, updated, 0.0) + + +def contact_solver_autodiff( + kernel_fourier: jnp.ndarray, + h_profile: jnp.ndarray, + W: float, + L: float, + tol: float = 1e-6, + iter_max: int = 200, +): + h_rms = jnp.std(h_profile) + initial_pressure = jnp.full_like(h_profile, W / (L**2)) + initial_direction = jnp.zeros_like(initial_pressure) + iter_max_jnp = jnp.array(iter_max) + + def cond_fun(state): + _, _, _, _, k, error = state + return jnp.logical_and(error > tol, k < iter_max_jnp) + + def body_fun(state): + pressure, direction, g_old, delta, k, _ = state + + _, grad_energy = value_and_grad_energy(kernel_fourier, h_profile, pressure) + contact_mask = pressure > 0.0 + + grad_mean = masked_mean(grad_energy, contact_mask) + grad_centered = grad_energy - grad_mean + grad_contact = jnp.where(contact_mask, grad_centered, 0.0) + + g_norm = jnp.sum(grad_contact * grad_contact) + search_dir = update_search_direction(grad_contact, direction, contact_mask, delta, g_norm, g_old) + + displacement_dir = displacement_from_pressure(kernel_fourier, search_dir) + disp_mean = masked_mean(displacement_dir, contact_mask) + response = displacement_dir - disp_mean + + tau_num = jnp.sum(jnp.where(contact_mask, grad_centered * search_dir, 0.0)) + tau_den = jnp.sum(jnp.where(contact_mask, response * search_dir, 0.0)) + tau = tau_num / (tau_den + 1e-12) + + pressure_new = jnp.maximum(pressure - tau * search_dir, 0.0) + + inadmissible = jnp.logical_and(pressure_new == 0.0, grad_centered < 0.0) + delta_new = jnp.where(jnp.sum(inadmissible) == 0, 1.0, 0.0) + + pressure_projected = project_total_load(pressure_new, W, L) + error_new = compute_error(pressure_projected, grad_centered, h_rms) + + return ( + pressure_projected, + search_dir, + jnp.where(g_norm > 0.0, g_norm, g_old), + delta_new, + k + 1, + error_new, + ) + + final_state = lax.while_loop( + cond_fun, + body_fun, + ( + initial_pressure, + initial_direction, + jnp.array(1.0), + jnp.array(0.0), + jnp.array(0), + jnp.array(jnp.inf), + ), + ) + + pressure, _, _, _, iterations, error = final_state + displacement = displacement_from_pressure(kernel_fourier, pressure) + return displacement, pressure, iterations, error + + +def run_simulation(): + t0 = 0.0 + t1 = 1.0 + time_steps = 50 + dt = (t1 - t0) / time_steps + + W = 1.0 + + L = 2.0 + radius = 0.5 + S = L**2 + + n = 300 + m = 300 + x_vals = jnp.linspace(0.0, L, n, endpoint=False) + y_vals = jnp.linspace(0.0, L, m, endpoint=False) + x, y = jnp.meshgrid(x_vals, y_vals, indexing="xy") + + x0 = 1.0 + y0 = 1.0 + + E = 3.0 + nu = 0.5 + E_star = E / (1.0 - nu**2) + + r = jnp.sqrt((x - x0) ** 2 + (y - y0) ** 2) + h_profile = -(r**2) / (2.0 * radius) + + kernel_fourier = build_fourier_kernel(n, m, L, E_star) + + G_inf = 2.75 + G_branches = jnp.array([2.75, 2.75]) + tau_branches = jnp.array([0.1, 1.0]) + + gamma = tau_branches / (tau_branches + dt) + G_tilde = jnp.sum(gamma * G_branches) + alpha = G_inf + G_tilde + beta = G_tilde + + surface = h_profile + U0 = jnp.zeros((n, m)) + M0 = jnp.zeros((G_branches.shape[0], n, m)) + + def scan_step(carry, _): + U, M = carry + + M_maxwell = jnp.tensordot(gamma, M, axes=1) + H_new = alpha * surface - beta * U + M_maxwell + + displacement, pressure, iterations, residual = contact_solver_autodiff( + kernel_fourier, + H_new, + W, + L, + tol=1e-6, + iter_max=200, + ) + + U_new = (displacement - M_maxwell + beta * U) / alpha + delta_U = U_new - U + M_new = gamma[:, None, None] * (M + G_branches[:, None, None] * delta_U) + + midline = pressure[n // 2] + contact_area = jnp.mean(pressure > 0.0) * S + + return (U_new, M_new), (midline, contact_area, iterations, residual) + + (_, _), outputs = lax.scan( + scan_step, + (U0, M0), + xs=None, + length=time_steps, + ) + + midlines, contact_areas, iterations, residuals = outputs + + G_maxwell_t0 = jnp.sum(G_branches) + G_effective_t0 = G_inf + G_maxwell_t0 + E_effective_t0 = 2.0 * G_effective_t0 * (1.0 + nu) / (1.0 - nu**2) + p0_t0 = (6.0 * W * (E_effective_t0**2) / (jnp.pi**3 * radius**2)) ** (1.0 / 3.0) + a_t0 = (3.0 * W * radius / (4.0 * E_effective_t0)) ** (1.0 / 3.0) + + E_effective_inf = 2.0 * G_inf * (1.0 + nu) / (1.0 - nu**2) + p0_t_inf = (6.0 * W * (E_effective_inf**2) / (jnp.pi**3 * radius**2)) ** (1.0 / 3.0) + a_t_inf = (3.0 * W * radius / (4.0 * E_effective_inf)) ** (1.0 / 3.0) + + return { + "x": x, + "midlines": midlines, + "contact_areas": contact_areas, + "iterations": iterations, + "residuals": residuals, + "params": { + "t0": t0, + "dt": dt, + "L": L, + "radius": radius, + "x0": x0, + "p0_t0": p0_t0, + "p0_t_inf": p0_t_inf, + "a_t0": a_t0, + "a_t_inf": a_t_inf, + "S": S, + }, + } + + +def main(): + start_time = time.perf_counter() + results = run_simulation() + total_time = time.perf_counter() - start_time + print("Simulation time:", total_time, "seconds") + + x = jax.device_get(results["x"]) + midlines = jax.device_get(results["midlines"]) + contact_areas = jax.device_get(results["contact_areas"]) + iterations = jax.device_get(results["iterations"]).astype(int) + residuals = jax.device_get(results["residuals"]) + + params = results["params"] + t0 = float(params["t0"]) + dt = float(params["dt"]) + L = float(params["L"]) + x0 = float(params["x0"]) + p0_t0 = float(params["p0_t0"]) + p0_t_inf = float(params["p0_t_inf"]) + a_t0 = float(params["a_t0"]) + a_t_inf = float(params["a_t_inf"]) + S = float(params["S"]) + + time_axis = t0 + dt * jnp.arange(midlines.shape[0]) + time_axis_np = jax.device_get(time_axis) + + def update(frame): + ax.clear() + ax.set_xlim(0, L) + ax.set_ylim(0, 1.1 * p0_t0) + ax.grid(True) + + x_mid = x[int(x.shape[0] / 2)] + ax.plot( + x_mid, + p0_t0 * np.sqrt(np.maximum(0.0, 1.0 - (x_mid - x0) ** 2 / a_t0**2)), + "g--", + label="Hertz t=0", + ) + ax.plot( + x_mid, + p0_t_inf * np.sqrt(np.maximum(0.0, 1.0 - (x_mid - x0) ** 2 / a_t_inf**2)), + "b--", + label="Hertz t=inf", + ) + ax.plot(x_mid, midlines[frame], "r-", label="Numerical") + ax.set_title(f"Time = {t0 + frame * dt:.2f}s") + ax.set_xlabel("x") + ax.set_ylabel("Pressure distribution") + ax.legend(loc="upper right") + + fig, ax = plt.subplots() + ani = FuncAnimation(fig, update, frames=midlines.shape[0], repeat=False) + plt.show() + + Ac_hertz_t0 = jnp.pi * a_t0**2 + Ac_hertz_t_inf = jnp.pi * a_t_inf**2 + + print("Iterations and residuals per step:") + for idx, (its, res) in enumerate(zip(iterations, residuals)): + print(f" step {idx:02d}: {its:3d} iterations, residual={res:.3e}") + + print("Analytical contact area at t0:", float(Ac_hertz_t0)) + print("Analytical contact area at t_inf:", float(Ac_hertz_t_inf)) + print("Numerical contact area at t0:", float(contact_areas[0])) + print("Numerical contact area at t_inf:", float(contact_areas[-1])) + + plt.figure() + plt.plot(time_axis_np, contact_areas) + plt.axhline(Ac_hertz_t0, color="red", linestyle="dotted") + plt.axhline(Ac_hertz_t_inf, color="blue", linestyle="dotted") + plt.xlabel("Time(s)") + plt.ylabel("Contact area($m^2$)") + plt.legend(["Numerical", "Hertz at t=0", "Hertz at t=inf"]) + plt.title("Contact area vs time for multi-branch Generalized Maxwell model") + plt.show() + + +if __name__ == "__main__": + main() + diff --git a/JAX/tests/Multi_branches_generalized_Maxwell.py b/JAX/tests/Multi_branches_generalized_Maxwell.py new file mode 100644 index 0000000..5af0ff6 --- /dev/null +++ b/JAX/tests/Multi_branches_generalized_Maxwell.py @@ -0,0 +1,294 @@ +### This script is for the Maxwell multi-branch model. +### Deduce process is in generalized_Maxwell_backward_Euler.ipynb + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation + +import time + +#define input parameters +##time +t0 = 0 +t1 = 1 +time_steps = 50 +dt = (t1 - t0)/time_steps +##load(constant) +W = 1e0 # Total load + +#domain size +#R = 1 # Radius of demi-sphere +L = 2 # Domain size +Radius = 0.5 +S = L**2 # Domain area + +# Generate a 2D coordinate space +n = 300 +m = 300 + +x, y = np.meshgrid(np.linspace(0, L, n, endpoint=False), np.linspace(0, L, m, endpoint=False)) + +x0 = 1 +y0 = 1 + +E = 3 # Young's modulus +nu = 0.5 +E_star = E / (1 - nu**2) # Plane strain modulus + +################################################################## +#####First just apply for demi-sphere and compare with Hertz###### +################################################################## + +# We define the distance from the center of the sphere +r = np.sqrt((x-x0)**2 + (y-y0)**2) + +# Define the kernel in the Fourier domain +q_x = 2 * np.pi * np.fft.fftfreq(n, d=L/n) +q_y = 2 * np.pi * np.fft.fftfreq(m, d=L/m) +QX, QY = np.meshgrid(q_x, q_y) + +kernel_fourier = np.zeros_like(QX) +kernel_fourier = 2 / (E_star * np.sqrt(QX**2 + QY**2)) +kernel_fourier[0, 0] = 0 # Avoid division by zero at the zero frequency + +h_profile = -(r**2)/(2*Radius) + +def apply_integration_operator(Origin, kernel_fourier, h_profile): + # Compute the Fourier transform of the input image + Origin2fourier = np.fft.fft2(Origin, norm='ortho') + + Middle_fourier = Origin2fourier * kernel_fourier + + Middle = np.fft.ifft2(Middle_fourier, norm='ortho').real + + Gradient = Middle - h_profile + + return Gradient, Origin2fourier#true gradient + +##define our elastic solver with constrained conjuagte gradient method +def contact_solver(n, m, W, S, h_profile, tol=1e-6, iter_max=200): + + + # Initial pressure distribution + P = np.full((n, m), W / S) # Initial guess for the pressure + + #initialize the search direction + T = np.zeros((n, m)) + + #set the norm of surface(to normalze the error) + h_rms = np.std(h_profile) + + #initialize G_norm and G_old + G_norm = 0 + G_old = 1 + + #initialize delta + delta = 0 + + # Initialize variables for the iteration + k = 0 # Iteration counter + error = np.inf # Initialize error + h_rms = np.std(h_profile) + + while np.abs(error) > tol and k < iter_max: + S = P > 0 + + G, P_fourier = apply_integration_operator(P, kernel_fourier, h_profile) + + G -= G[S].mean() + + G_norm = np.linalg.norm(G[S])**2 + + # Calculate the search direction + T[S] = G[S] + delta * G_norm / G_old * T[S] + T[~S] = 0 ## out of contact area, dont need to update + + # Update G_old + G_old = G_norm + + # Set R + R, T_fourier = apply_integration_operator(T, kernel_fourier, h_profile) + R += h_profile + R -= R[S].mean() + + # Calculate the step size tau + tau = np.vdot(G[S], T[S]) / np.vdot(R[S], T[S]) + + # Update P + P -= tau * T + P *= P > 0 + + # identify the inadmissible points + R = (P == 0) & (G < 0) + + if R.sum() == 0: + delta = 1 + else: + delta = 0#change the contact point set and need to do conjugate gradient again + + # Enforce the applied force constraint + P = W * P / np.mean(P) / L**2 + + # Calculate the error for convergence checking + error = np.vdot(P, (G - np.min(G))) / (P.sum()*h_rms) + # print(delta, error, k, np.mean(P), np.mean(P>0), tau) + + k += 1 # Increment the iteration counter + + # Ensure a positive gap by updating G + G = G - np.min(G) + + displacement_fourier = P_fourier * kernel_fourier + displacement = np.fft.ifft2(displacement_fourier, norm='ortho').real + + return displacement, P + +################################################################## +#####shear modulus for multi-branch Maxwell model################### +################################################################## +G_inf = 2.75 #elastic branch +#G = [2.75, 2, 0.25, 10, 2.5] #viscoelastic branch +G = [2.75, 2.75] + +print('G_inf:', G_inf, ' G: ' + str(G)) + +# Define the relaxation times +#tau = [0.1, 0.5, 1, 2, 10] # relaxation times +tau = [0.1, 1] +#tau = [0, 0, 0, 0, 0] +#tau = [1e6,1e6,1e6,1e6,1e6] +eta = [g * t for g, t in zip(G, tau)] + +print('tau:', tau, ' eta:', eta) + +################################################################## +#####define G_tilde for one-branch Maxwell model ################# +################################################################## +G_tilde = 0 +for k in range(len(G)): + G_tilde += tau[k] / (tau[k] + dt) * G[k] + + +# Define parameters for updating the surface profile +alpha = G_inf + G_tilde +beta = G_tilde + +gamma = [] +for k in range(len(G)): + gamma.append(tau[k]/(tau[k] + dt)) + +Surface = h_profile + +U = np.zeros((n, m)) +M = np.zeros((len(G), n, m)) + +Ac=[] +M_maxwell = np.zeros_like(U) + +####################################### +###Hertzian contact theory reference +####################################### +##Hertz solution at t0 +G_maxwell_t0 = 0 +for k in range(len(G)): + G_maxwell_t0 += G[k] +G_effective_t0 = G_inf + G_maxwell_t0 +E_effective_t0 = 2*G_effective_t0*(1+nu)/(1-nu**2) + +p0_t0 = (6*W*(E_effective_t0)**2/(np.pi**3*Radius**2))**(1/3) +a_t0 = (3*W*Radius/(4*(E_effective_t0)))**(1/3) +##Hertz solution at t_inf +E_effective_inf = 2*G_inf*(1+nu)/(1-nu**2) + +p0_t_inf = (6*W*(E_effective_inf)**2/(np.pi**3*Radius**2))**(1/3) +a_t_inf = (3*W*Radius/(4*(E_effective_inf)))**(1/3) + + +# define the update function for the animation +def update(frame): + ax.clear() + ax.set_xlim(0, L) + ax.set_ylim(0, 1.1*p0_t0) + ax.grid() + + # draw Hertzian contact theory reference + ax.plot(x[n//2], p0_t0*np.sqrt(1 - (x[n//2]-x0)**2 / a_t0**2), 'g--', label='Hertz at t=0') + ax.plot(x[n//2], p0_t_inf*np.sqrt(1 - (x[n//2]-x0)**2 / a_t_inf**2), 'b--', label='Hertz at t=inf') + + # draw numerical solution at current time step + ax.plot(x[n//2], pressure_distributions[frame], 'r-', label='Numerical') + ax.set_title(f"Time = {t0 + frame * dt:.2f}s") + plt.xlabel("x") + plt.ylabel("Pressure distribution") + plt.legend() + + +start = time.perf_counter() +# collect pressure distributions at each time step +pressure_distributions = [] +for t in np.arange(t0, t1, dt): + #Update the surface profile + M_maxwell[:] = 0 + for k in range(len(G)): + M_maxwell += gamma[k]*M[k] + H_new = alpha*Surface - beta*U + M_maxwell + + #main step1: Compute $P_{t+\Delta t}^{\prime}$ + #M_new, P = contact_solver(n, m, W, S, H_new, tol=1e-6, iter_max=200) + M_new, P = contact_solver(n, m, W, S, H_new, tol=1e-6, iter_max=200) + + ##Sanity check?? + + + + ##main step2: Update global displacement + U_new = (1/alpha)*(M_new - M_maxwell + beta*U) + + + + #main step3: Update the pressure + for k in range(len(G)): + M[k] = gamma[k]*(M[k] + G[k]*(U_new - U)) + #only maxwell branch, see algorithm formula 1 in the notebook + + + Ac.append(np.mean(P > 0)*S) + + #main step4: Update the total displacement field + U = U_new + + pressure_distributions.append(P[n//2].copy()) # store the pressure distribution at each time step + +end = time.perf_counter() +print("Simulation time:", end - start, "seconds") + +# create a figure and axis +fig, ax = plt.subplots() + +# create an animation +ani = FuncAnimation(fig, update, frames=len(pressure_distributions), repeat=False) + +plt.show() + + +Ac_hertz_t0 = np.pi*a_t0**2 +Ac_hertz_t_inf = np.pi*a_t_inf**2 + +print("Analytical contact area radius at t0:", a_t0) +print("Analytical contact area radius at t_inf:", a_t_inf) +print("Analytical maximum pressure at t0:", p0_t0) +print("Analytical maximum pressure at t_inf:", p0_t_inf) +print("Numerical contact area at t0:", Ac[0]) +print("Numerical contact area at t_inf", Ac[-1]) +print("Analyical contact area at t0:", Ac_hertz_t0) +print("Analyical contact area at t_inf:", Ac_hertz_t_inf) +plt.plot(np.arange(t0, t1, dt), Ac) +plt.axhline(Ac_hertz_t0, color='red', linestyle='dotted') +plt.axhline(Ac_hertz_t_inf, color='blue', linestyle='dotted') +plt.xlabel("Time(s)") +plt.ylabel("Contact area($m^2$)") +plt.legend(["Numerical", "Hertz at t=0", "Hertz at t=inf"]) +#define a title that can read parameter tau_0 +plt.title("Contact area vs time for multi-branch Generalized Maxwell model") +#plt.axhline(Ac_hertz_t_inf, color='blue') +plt.show() \ No newline at end of file diff --git a/JAX/tests/Tamaas_GMM.py b/JAX/tests/Tamaas_GMM.py new file mode 100644 index 0000000..547227d --- /dev/null +++ b/JAX/tests/Tamaas_GMM.py @@ -0,0 +1,70 @@ +''' +Here we test a Hertzian contact on a generalized Maxwell material using Tamaas. +Contact with rough surfaces needs to be tested. +''' +import tamaas as tm +import time +import numpy as np + + +# Set-up of the model +L = 2 +Radius = 0.5 +S = L**2 + +# discretization +n = m = 300 +x = np.linspace(0, L, n, endpoint=False, dtype=tm.dtype) +y = np.linspace(0, L, m, endpoint=False, dtype=tm.dtype) +xx, yy = np.meshgrid(x, y, indexing="ij") +# Define the surface +surface = surface = -((xx - L / 2) ** 2 + (yy - L / 2) ** 2) / (2 * Radius) +# Create the model +model = tm.Model(tm.model_type.basic_2d, [L, L], [n, m]) + +# Defining the elastic branch (i.e. the behavior at t = ∞) +model.E = 3 +model.nu = 0.5 + +# Characteristic times of the relaxation function +times = [0.1, 1] + +# Shear moduli for each branch of the model +shear_moduli = [2.75, 2.75] + +t0 = 0 +t1 = 1 +time_steps = 50 +# Time step +Δt = (t1 - t0) / time_steps + +# Applied load +W = 1.0 +load = W / S + +# Solver instanciation +solver = tm.MaxwellViscoelastic(model, surface, 1e-10, + time_step=Δt, + shear_moduli=shear_moduli, + characteristic_times=times) + +# Solve one timestep with given load +start = time.perf_counter() +solver.solve(load) +end = time.perf_counter() +print(f'Simulation time for one step: {end - start} seconds') + +# plot like ub Multi_branches_generalized_Maxwell.py +import matplotlib.pyplot as plt +displacement = model.displacement[:] +pressure = model.traction[:] +plt.figure(figsize=(12, 5)) +plt.subplot(1, 2, 1) +plt.imshow(displacement, extent=(0, L, 0, L), origin='lower') +plt.title('Displacement field') +plt.colorbar() +plt.subplot(1, 2, 2) +plt.imshow(pressure, extent=(0, L, 0, L), origin='lower') +plt.title('Pressure field') +plt.colorbar() +plt.show() \ No newline at end of file