Compare commits
5 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
cc14b7b2b1 | |
|
|
71f38ef56b | |
|
|
c600d30161 | |
|
|
602c480e49 | |
|
|
e989dc9927 |
|
|
@ -0,0 +1,143 @@
|
||||||
|
# Created by https://www.toptal.com/developers/gitignore/api/python
|
||||||
|
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Python
|
||||||
|
# --------------------
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
*.so
|
||||||
|
*$py.class
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Packaging / builds
|
||||||
|
# --------------------
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
develop-eggs/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
share/python-wheels/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
MANIFEST
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Logs / caches / temp
|
||||||
|
# --------------------
|
||||||
|
*.log
|
||||||
|
*.tmp
|
||||||
|
*.temp
|
||||||
|
*.out
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
.cache
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
cover/
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Databases / translations
|
||||||
|
# --------------------
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Framework / tool artifacts
|
||||||
|
# --------------------
|
||||||
|
local_settings.py
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
.scrapy
|
||||||
|
docs/_build/
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
/site
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Virtual environments
|
||||||
|
# --------------------
|
||||||
|
JAX-venv/
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
.env
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Python packaging helpers
|
||||||
|
# --------------------
|
||||||
|
.pdm.toml
|
||||||
|
poetry.toml
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Notebooks / REPL
|
||||||
|
# --------------------
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
*.nb.py
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Type checking / analysis
|
||||||
|
# --------------------
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
.pyre/
|
||||||
|
.pytype/
|
||||||
|
.ruff_cache/
|
||||||
|
pyrightconfig.json
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Editors / IDEs
|
||||||
|
# --------------------
|
||||||
|
.vscode/
|
||||||
|
.history/
|
||||||
|
*.code-workspace
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Operating system
|
||||||
|
# --------------------
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# --------------------
|
||||||
|
# Editor temp files
|
||||||
|
# --------------------
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
|
||||||
1
AUTHORS
1
AUTHORS
|
|
@ -1 +1,2 @@
|
||||||
Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
Lucas Frérot <lucas.frerot@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
||||||
|
Zichen Li <zichen.li@sorbonne-universite.fr> Sorbonne Université, CNRS, Institut Jean Le Rond d'Alembert, F-75005 Paris, France
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
Here is a repo for beginners in JAX. We recommand to start with the documentation in [1].
|
||||||
|
|
||||||
|
We are motivated by the article of Mohit and David(2025)[2], especially the automatic differentiation[3] and just-in-time compilation[4].
|
||||||
|
|
||||||
|
### Install JAX
|
||||||
|
GPU programming is a future trend for our open-source project. The codes that we write in CPU and GPU version of JAX are the same. The difference is that pip will install a jaxlib wheel for GPU version depending on NVIDIA driver version and CUDA version.
|
||||||
|
|
||||||
|
First we can check our NVIDIA driver version and CUDA version
|
||||||
|
```bash
|
||||||
|
nvidia-smi
|
||||||
|
```
|
||||||
|
CUDA 12 requires driver version ≥ 525, which is already a mainstream and stable combination, supported by almost all frameworks. We will install the JAX GPU version suitable for CUDA 12.
|
||||||
|
```bash
|
||||||
|
python3 -m venv JAX-venv
|
||||||
|
source JAX-venv/bin/activate
|
||||||
|
(JAX-venv) pip install --upgrade pip
|
||||||
|
(JAX-venv) pip install ipython
|
||||||
|
(JAX-venv) pip install --upgrade "jax[cuda12]"
|
||||||
|
(JAX-venv) ipython # /path/to/JAX-venv/bin/ipython
|
||||||
|
```
|
||||||
|
Test script in Ipython:
|
||||||
|
```
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jaxlib
|
||||||
|
|
||||||
|
print("jax:", jax.__version__)
|
||||||
|
print("jaxlib:", jaxlib.__version__)
|
||||||
|
|
||||||
|
print("devices:", jax.devices())
|
||||||
|
|
||||||
|
x = jnp.arange(5.)
|
||||||
|
print("x.device:", x.device)
|
||||||
|
```
|
||||||
|
Reference output:
|
||||||
|
```
|
||||||
|
jax: 0.6.2
|
||||||
|
jaxlib: 0.6.2
|
||||||
|
devices: [CudaDevice(id=0)]
|
||||||
|
x.device: cuda:0
|
||||||
|
```
|
||||||
|
|
||||||
|
Structure of the project:
|
||||||
|
```
|
||||||
|
JAX/
|
||||||
|
├─ JAX-venv/
|
||||||
|
├─ src/ # Python codes
|
||||||
|
├─ notebooks/ # Experimental notebook
|
||||||
|
├─ tests/ # Unit tests
|
||||||
|
└─ pyproject.toml # Or requirements.txt
|
||||||
|
```
|
||||||
|
### Source
|
||||||
|
[1] https://uvadlc-notebooks.readthedocs.io/en/latest/
|
||||||
|
|
||||||
|
[2] https://www.sciencedirect.com/science/article/pii/S0045782524008260?via%3Dihub
|
||||||
|
|
||||||
|
[3] https://docs.jax.dev/en/latest/automatic-differentiation.html
|
||||||
|
|
||||||
|
[4] https://docs.jax.dev/en/latest/jit-compilation.html
|
||||||
|
|
@ -0,0 +1,359 @@
|
||||||
|
"""JAX implementation of the generalized Maxwell contact solver.
|
||||||
|
|
||||||
|
This script mirrors the NumPy-based reference in
|
||||||
|
`Multi_branches_generalized_Maxwell.py`, but leverages JAX for automatic
|
||||||
|
differentiation and JIT compilation. The automatic gradient of the elastic
|
||||||
|
energy drives the constrained conjugate-gradient contact solver.
|
||||||
|
|
||||||
|
Running this file produces the same diagnostic plots as the reference
|
||||||
|
implementation while keeping all heavy lifting on the accelerator-enabled JAX
|
||||||
|
backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.animation import FuncAnimation
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import lax
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
# Enable double precision for improved numerical stability.
|
||||||
|
jax.config.update("jax_enable_x64", True)
|
||||||
|
|
||||||
|
|
||||||
|
def build_fourier_kernel(n: int, m: int, L: float, E_star: float) -> jnp.ndarray:
|
||||||
|
"""Assemble the Fourier-domain kernel for the half-space Green's function."""
|
||||||
|
|
||||||
|
q_x = 2.0 * np.pi * jnp.fft.fftfreq(n, d=L / n)
|
||||||
|
q_y = 2.0 * np.pi * jnp.fft.fftfreq(m, d=L / m)
|
||||||
|
QX, QY = jnp.meshgrid(q_x, q_y, indexing="xy")
|
||||||
|
q_norm = jnp.sqrt(QX**2 + QY**2)
|
||||||
|
kernel = jnp.where(q_norm > 0.0, 2.0 / (E_star * q_norm), 0.0)
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def displacement_from_pressure(
|
||||||
|
kernel_fourier: jnp.ndarray, pressure: jnp.ndarray
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Return the surface displacement induced by the supplied pressure field."""
|
||||||
|
|
||||||
|
pressure_fft = jnp.fft.fft2(pressure, norm="ortho")
|
||||||
|
displacement_fft = pressure_fft * kernel_fourier
|
||||||
|
displacement = jnp.fft.ifft2(displacement_fft, norm="ortho").real
|
||||||
|
return displacement
|
||||||
|
|
||||||
|
|
||||||
|
def elastic_energy(
|
||||||
|
kernel_fourier: jnp.ndarray, h_profile: jnp.ndarray, pressure: jnp.ndarray
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Elastic energy functional; its gradient yields the gap field."""
|
||||||
|
|
||||||
|
displacement = displacement_from_pressure(kernel_fourier, pressure)
|
||||||
|
stored = 0.5 * jnp.sum(pressure * displacement)
|
||||||
|
work = jnp.sum(pressure * h_profile)
|
||||||
|
return stored - work
|
||||||
|
|
||||||
|
|
||||||
|
value_and_grad_energy = jax.jit(jax.value_and_grad(elastic_energy, argnums=2))
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def project_total_load(pressure: jnp.ndarray, W: float, L: float) -> jnp.ndarray:
|
||||||
|
"""Project the pressure field onto the admissible set enforcing total load."""
|
||||||
|
|
||||||
|
mean_pressure = jnp.mean(pressure)
|
||||||
|
target = W / (L**2)
|
||||||
|
scale = jnp.where(mean_pressure > 0.0, target / mean_pressure, 0.0)
|
||||||
|
projected = jnp.where(mean_pressure > 0.0, pressure * scale, jnp.full_like(pressure, target))
|
||||||
|
return projected
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def masked_mean(values: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
"""Compute the mean over the masked region, guarding against empty sets."""
|
||||||
|
|
||||||
|
count = jnp.sum(mask)
|
||||||
|
total = jnp.sum(jnp.where(mask, values, 0.0))
|
||||||
|
return jnp.where(count > 0, total / count, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def compute_error(
|
||||||
|
pressure: jnp.ndarray,
|
||||||
|
gradient: jnp.ndarray,
|
||||||
|
h_rms: float,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Scaled complementarity error used as stopping criterion."""
|
||||||
|
|
||||||
|
num = jnp.vdot(pressure.reshape(-1), gradient - jnp.min(gradient))
|
||||||
|
denom = jnp.sum(pressure) * h_rms + 1e-12
|
||||||
|
return jnp.abs(num / denom)
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def update_search_direction(
|
||||||
|
gradient: jnp.ndarray,
|
||||||
|
direction: jnp.ndarray,
|
||||||
|
contact_mask: jnp.ndarray,
|
||||||
|
delta: float,
|
||||||
|
g_norm: float,
|
||||||
|
g_old: float,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Conjugate-gradient style update with projection onto the contact set."""
|
||||||
|
|
||||||
|
beta_cg = jnp.where(g_old > 0.0, delta * g_norm / (g_old + 1e-12), 0.0)
|
||||||
|
updated = gradient + beta_cg * direction
|
||||||
|
return jnp.where(contact_mask, updated, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def contact_solver_autodiff(
|
||||||
|
kernel_fourier: jnp.ndarray,
|
||||||
|
h_profile: jnp.ndarray,
|
||||||
|
W: float,
|
||||||
|
L: float,
|
||||||
|
tol: float = 1e-6,
|
||||||
|
iter_max: int = 200,
|
||||||
|
):
|
||||||
|
"""Solve the constrained contact problem via autodiff-powered CG iterations."""
|
||||||
|
|
||||||
|
h_rms = jnp.std(h_profile)
|
||||||
|
initial_pressure = jnp.full_like(h_profile, W / (L**2))
|
||||||
|
initial_direction = jnp.zeros_like(initial_pressure)
|
||||||
|
|
||||||
|
iter_max_jnp = jnp.array(iter_max)
|
||||||
|
|
||||||
|
def cond_fun(state):
|
||||||
|
_, _, _, _, k, error = state
|
||||||
|
return jnp.logical_and(error > tol, k < iter_max_jnp)
|
||||||
|
|
||||||
|
def body_fun(state):
|
||||||
|
pressure, direction, g_old, delta, k, _ = state
|
||||||
|
|
||||||
|
_, grad_energy = value_and_grad_energy(kernel_fourier, h_profile, pressure)
|
||||||
|
contact_mask = pressure > 0.0
|
||||||
|
|
||||||
|
grad_mean = masked_mean(grad_energy, contact_mask)
|
||||||
|
grad_centered = grad_energy - grad_mean
|
||||||
|
grad_contact = jnp.where(contact_mask, grad_centered, 0.0)
|
||||||
|
|
||||||
|
g_norm = jnp.sum(grad_contact * grad_contact)
|
||||||
|
|
||||||
|
search_dir = update_search_direction(
|
||||||
|
grad_contact,
|
||||||
|
direction,
|
||||||
|
contact_mask,
|
||||||
|
delta,
|
||||||
|
g_norm,
|
||||||
|
g_old,
|
||||||
|
)
|
||||||
|
|
||||||
|
displacement_dir = displacement_from_pressure(kernel_fourier, search_dir)
|
||||||
|
disp_mean = masked_mean(displacement_dir, contact_mask)
|
||||||
|
response = displacement_dir - disp_mean
|
||||||
|
|
||||||
|
tau_num = jnp.sum(jnp.where(contact_mask, grad_centered * search_dir, 0.0))
|
||||||
|
tau_den = jnp.sum(jnp.where(contact_mask, response * search_dir, 0.0))
|
||||||
|
tau = tau_num / (tau_den + 1e-12)
|
||||||
|
|
||||||
|
pressure_new = pressure - tau * search_dir
|
||||||
|
pressure_new = jnp.where(pressure_new > 0.0, pressure_new, 0.0)
|
||||||
|
|
||||||
|
inadmissible = jnp.logical_and(pressure_new == 0.0, grad_centered < 0.0)
|
||||||
|
delta_new = jnp.where(jnp.sum(inadmissible) == 0, 1.0, 0.0)
|
||||||
|
|
||||||
|
pressure_projected = project_total_load(pressure_new, W, L)
|
||||||
|
error_new = compute_error(pressure_projected, grad_centered, h_rms)
|
||||||
|
|
||||||
|
return (
|
||||||
|
pressure_projected,
|
||||||
|
search_dir,
|
||||||
|
jnp.where(g_norm > 0.0, g_norm, g_old),
|
||||||
|
delta_new,
|
||||||
|
k + 1,
|
||||||
|
error_new,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = lax.while_loop(
|
||||||
|
cond_fun,
|
||||||
|
body_fun,
|
||||||
|
(
|
||||||
|
initial_pressure,
|
||||||
|
initial_direction,
|
||||||
|
jnp.array(1.0),
|
||||||
|
jnp.array(0.0),
|
||||||
|
jnp.array(0),
|
||||||
|
jnp.array(jnp.inf),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pressure, _, _, _, iterations, error = final_state
|
||||||
|
displacement = displacement_from_pressure(kernel_fourier, pressure)
|
||||||
|
return displacement, pressure, int(iterations), float(error)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Time discretization
|
||||||
|
t0 = 0.0
|
||||||
|
t1 = 1.0
|
||||||
|
time_steps = 50
|
||||||
|
dt = (t1 - t0) / time_steps
|
||||||
|
|
||||||
|
# Total load
|
||||||
|
W = 1.0
|
||||||
|
|
||||||
|
# Geometry
|
||||||
|
L = 2.0
|
||||||
|
radius = 0.5
|
||||||
|
S = L**2
|
||||||
|
|
||||||
|
# Grid
|
||||||
|
n = 300
|
||||||
|
m = 300
|
||||||
|
x_vals = jnp.linspace(0.0, L, n, endpoint=False)
|
||||||
|
y_vals = jnp.linspace(0.0, L, m, endpoint=False)
|
||||||
|
x, y = jnp.meshgrid(x_vals, y_vals, indexing="xy")
|
||||||
|
|
||||||
|
x0 = 1.0
|
||||||
|
y0 = 1.0
|
||||||
|
|
||||||
|
E = 3.0
|
||||||
|
nu = 0.5
|
||||||
|
E_star = E / (1.0 - nu**2)
|
||||||
|
|
||||||
|
r = jnp.sqrt((x - x0) ** 2 + (y - y0) ** 2)
|
||||||
|
h_profile = -(r**2) / (2.0 * radius)
|
||||||
|
|
||||||
|
kernel_fourier = build_fourier_kernel(n, m, L, E_star)
|
||||||
|
|
||||||
|
# Maxwell model parameters
|
||||||
|
G_inf = 2.75
|
||||||
|
G_branches = jnp.array([2.75, 2.75])
|
||||||
|
tau_branches = jnp.array([0.1, 1.0])
|
||||||
|
eta_branches = G_branches * tau_branches
|
||||||
|
|
||||||
|
gamma = tau_branches / (tau_branches + dt)
|
||||||
|
G_tilde = jnp.sum(gamma * G_branches)
|
||||||
|
alpha = G_inf + G_tilde
|
||||||
|
beta = G_tilde
|
||||||
|
|
||||||
|
surface = h_profile
|
||||||
|
U = jnp.zeros((n, m))
|
||||||
|
M = jnp.zeros((G_branches.shape[0], n, m))
|
||||||
|
|
||||||
|
# Hertzian references
|
||||||
|
G_maxwell_t0 = jnp.sum(G_branches)
|
||||||
|
G_effective_t0 = G_inf + G_maxwell_t0
|
||||||
|
E_effective_t0 = 2.0 * G_effective_t0 * (1.0 + nu) / (1.0 - nu**2)
|
||||||
|
p0_t0 = (6.0 * W * (E_effective_t0**2) / (np.pi**3 * radius**2)) ** (1.0 / 3.0)
|
||||||
|
a_t0 = (3.0 * W * radius / (4.0 * E_effective_t0)) ** (1.0 / 3.0)
|
||||||
|
|
||||||
|
E_effective_inf = 2.0 * G_inf * (1.0 + nu) / (1.0 - nu**2)
|
||||||
|
p0_t_inf = (6.0 * W * (E_effective_inf**2) / (np.pi**3 * radius**2)) ** (1.0 / 3.0)
|
||||||
|
a_t_inf = (3.0 * W * radius / (4.0 * E_effective_inf)) ** (1.0 / 3.0)
|
||||||
|
|
||||||
|
pressure_distributions = []
|
||||||
|
contact_areas = []
|
||||||
|
iteration_log = []
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# I think I should avoid using for loops in JAX
|
||||||
|
for step in range(time_steps):
|
||||||
|
M_maxwell = jnp.tensordot(gamma, M, axes=1)
|
||||||
|
H_new = alpha * surface - beta * U + M_maxwell
|
||||||
|
|
||||||
|
displacement, pressure, iterations, residual = contact_solver_autodiff(
|
||||||
|
kernel_fourier,
|
||||||
|
H_new,
|
||||||
|
W,
|
||||||
|
L,
|
||||||
|
tol=1e-6,
|
||||||
|
iter_max=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
U_new = (displacement - M_maxwell + beta * U) / alpha
|
||||||
|
delta_U = U_new - U
|
||||||
|
M = gamma[:, None, None] * (M + G_branches[:, None, None] * delta_U)
|
||||||
|
|
||||||
|
area_ratio = jnp.mean(pressure > 0.0)
|
||||||
|
contact_area = float(area_ratio * S)
|
||||||
|
contact_areas.append(contact_area)
|
||||||
|
|
||||||
|
pressure_midline = np.array(jax.device_get(pressure[n // 2]))
|
||||||
|
pressure_distributions.append(pressure_midline)
|
||||||
|
|
||||||
|
iteration_log.append((iterations, residual))
|
||||||
|
U = U_new
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
print("Simulation time:", end_time - start_time, "seconds")
|
||||||
|
|
||||||
|
x_np = np.array(jax.device_get(x))
|
||||||
|
|
||||||
|
def update(frame):
|
||||||
|
ax.clear()
|
||||||
|
ax.set_xlim(0, L)
|
||||||
|
ax.set_ylim(0, 1.1 * p0_t0)
|
||||||
|
ax.grid(True)
|
||||||
|
|
||||||
|
ax.plot(
|
||||||
|
x_np[n // 2],
|
||||||
|
p0_t0 * np.sqrt(np.maximum(0.0, 1.0 - (x_np[n // 2] - x0) ** 2 / a_t0**2)),
|
||||||
|
"g--",
|
||||||
|
label="Hertz t=0",
|
||||||
|
)
|
||||||
|
ax.plot(
|
||||||
|
x_np[n // 2],
|
||||||
|
p0_t_inf * np.sqrt(np.maximum(0.0, 1.0 - (x_np[n // 2] - x0) ** 2 / a_t_inf**2)),
|
||||||
|
"b--",
|
||||||
|
label="Hertz t=inf",
|
||||||
|
)
|
||||||
|
ax.plot(x_np[n // 2], pressure_distributions[frame], "r-", label="Numerical")
|
||||||
|
ax.set_title(f"Time = {t0 + frame * dt:.2f}s")
|
||||||
|
ax.set_xlabel("x")
|
||||||
|
ax.set_ylabel("Pressure distribution")
|
||||||
|
ax.legend(loc="upper right")
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ani = FuncAnimation(fig, update, frames=len(pressure_distributions), repeat=False)
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
Ac_hertz_t0 = np.pi * a_t0**2
|
||||||
|
Ac_hertz_t_inf = np.pi * a_t_inf**2
|
||||||
|
|
||||||
|
print("Iterations and residuals per step:")
|
||||||
|
for idx, (iterations, residual) in enumerate(iteration_log):
|
||||||
|
print(f" step {idx:02d}: {iterations:3d} iterations, residual={residual:.3e}")
|
||||||
|
|
||||||
|
print("Analytical contact area radius at t0:", float(a_t0))
|
||||||
|
print("Analytical contact area radius at t_inf:", float(a_t_inf))
|
||||||
|
print("Analytical maximum pressure at t0:", float(p0_t0))
|
||||||
|
print("Analytical maximum pressure at t_inf:", float(p0_t_inf))
|
||||||
|
print("Numerical contact area at t0:", contact_areas[0])
|
||||||
|
print("Numerical contact area at t_inf:", contact_areas[-1])
|
||||||
|
print("Analytical contact area at t0:", float(Ac_hertz_t0))
|
||||||
|
print("Analytical contact area at t_inf:", float(Ac_hertz_t_inf))
|
||||||
|
|
||||||
|
time_axis = np.arange(t0, t1, dt)
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(time_axis, contact_areas)
|
||||||
|
plt.axhline(Ac_hertz_t0, color="red", linestyle="dotted")
|
||||||
|
plt.axhline(Ac_hertz_t_inf, color="blue", linestyle="dotted")
|
||||||
|
plt.xlabel("Time(s)")
|
||||||
|
plt.ylabel("Contact area($m^2$)")
|
||||||
|
plt.legend(["Numerical", "Hertz at t=0", "Hertz at t=inf"])
|
||||||
|
plt.title("Contact area vs time for multi-branch Generalized Maxwell model")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,348 @@
|
||||||
|
"""JAX generalized Maxwell contact solver without Python loops.
|
||||||
|
|
||||||
|
This variant removes the explicit Python time-stepping loop from
|
||||||
|
`JAX_GMM.py` by relying on `jax.lax.scan`, which keeps all temporally
|
||||||
|
coupled computations staged inside JAX's computation graph. The contact
|
||||||
|
solver remains identical but is compatible with scanning so the entire
|
||||||
|
transient solves as a single JIT-compiled program once the graph is
|
||||||
|
traced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.animation import FuncAnimation
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import lax
|
||||||
|
|
||||||
|
|
||||||
|
jax.config.update("jax_enable_x64", True)
|
||||||
|
|
||||||
|
|
||||||
|
def build_fourier_kernel(n: int, m: int, L: float, E_star: float) -> jnp.ndarray:
|
||||||
|
q_x = 2.0 * jnp.pi * jnp.fft.fftfreq(n, d=L / n)
|
||||||
|
q_y = 2.0 * jnp.pi * jnp.fft.fftfreq(m, d=L / m)
|
||||||
|
QX, QY = jnp.meshgrid(q_x, q_y, indexing="xy")
|
||||||
|
q_norm = jnp.sqrt(QX**2 + QY**2)
|
||||||
|
return jnp.where(q_norm > 0.0, 2.0 / (E_star * q_norm), 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def displacement_from_pressure(kernel_fourier: jnp.ndarray, pressure: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
pressure_fft = jnp.fft.fft2(pressure, norm="ortho")
|
||||||
|
displacement_fft = pressure_fft * kernel_fourier
|
||||||
|
return jnp.fft.ifft2(displacement_fft, norm="ortho").real
|
||||||
|
|
||||||
|
|
||||||
|
def elastic_energy(kernel_fourier: jnp.ndarray, h_profile: jnp.ndarray, pressure: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
displacement = displacement_from_pressure(kernel_fourier, pressure)
|
||||||
|
stored = 0.5 * jnp.sum(pressure * displacement)
|
||||||
|
work = jnp.sum(pressure * h_profile)
|
||||||
|
return stored - work
|
||||||
|
|
||||||
|
|
||||||
|
value_and_grad_energy = jax.jit(jax.value_and_grad(elastic_energy, argnums=2))
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def project_total_load(pressure: jnp.ndarray, W: float, L: float) -> jnp.ndarray:
|
||||||
|
mean_pressure = jnp.mean(pressure)
|
||||||
|
target = W / (L**2)
|
||||||
|
scale = jnp.where(mean_pressure > 0.0, target / mean_pressure, 0.0)
|
||||||
|
return jnp.where(mean_pressure > 0.0, pressure * scale, jnp.full_like(pressure, target))
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def masked_mean(values: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
count = jnp.sum(mask)
|
||||||
|
total = jnp.sum(jnp.where(mask, values, 0.0))
|
||||||
|
return jnp.where(count > 0, total / count, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def compute_error(pressure: jnp.ndarray, gradient: jnp.ndarray, h_rms: float) -> jnp.ndarray:
|
||||||
|
num = jnp.vdot(pressure.reshape(-1), gradient - jnp.min(gradient))
|
||||||
|
denom = jnp.sum(pressure) * h_rms + 1e-12
|
||||||
|
return jnp.abs(num / denom)
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def update_search_direction(
|
||||||
|
gradient: jnp.ndarray,
|
||||||
|
direction: jnp.ndarray,
|
||||||
|
contact_mask: jnp.ndarray,
|
||||||
|
delta: float,
|
||||||
|
g_norm: float,
|
||||||
|
g_old: float,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
beta_cg = jnp.where(g_old > 0.0, delta * g_norm / (g_old + 1e-12), 0.0)
|
||||||
|
updated = gradient + beta_cg * direction
|
||||||
|
return jnp.where(contact_mask, updated, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def contact_solver_autodiff(
|
||||||
|
kernel_fourier: jnp.ndarray,
|
||||||
|
h_profile: jnp.ndarray,
|
||||||
|
W: float,
|
||||||
|
L: float,
|
||||||
|
tol: float = 1e-6,
|
||||||
|
iter_max: int = 200,
|
||||||
|
):
|
||||||
|
h_rms = jnp.std(h_profile)
|
||||||
|
initial_pressure = jnp.full_like(h_profile, W / (L**2))
|
||||||
|
initial_direction = jnp.zeros_like(initial_pressure)
|
||||||
|
iter_max_jnp = jnp.array(iter_max)
|
||||||
|
|
||||||
|
def cond_fun(state):
|
||||||
|
_, _, _, _, k, error = state
|
||||||
|
return jnp.logical_and(error > tol, k < iter_max_jnp)
|
||||||
|
|
||||||
|
def body_fun(state):
|
||||||
|
pressure, direction, g_old, delta, k, _ = state
|
||||||
|
|
||||||
|
_, grad_energy = value_and_grad_energy(kernel_fourier, h_profile, pressure)
|
||||||
|
contact_mask = pressure > 0.0
|
||||||
|
|
||||||
|
grad_mean = masked_mean(grad_energy, contact_mask)
|
||||||
|
grad_centered = grad_energy - grad_mean
|
||||||
|
grad_contact = jnp.where(contact_mask, grad_centered, 0.0)
|
||||||
|
|
||||||
|
g_norm = jnp.sum(grad_contact * grad_contact)
|
||||||
|
search_dir = update_search_direction(grad_contact, direction, contact_mask, delta, g_norm, g_old)
|
||||||
|
|
||||||
|
displacement_dir = displacement_from_pressure(kernel_fourier, search_dir)
|
||||||
|
disp_mean = masked_mean(displacement_dir, contact_mask)
|
||||||
|
response = displacement_dir - disp_mean
|
||||||
|
|
||||||
|
tau_num = jnp.sum(jnp.where(contact_mask, grad_centered * search_dir, 0.0))
|
||||||
|
tau_den = jnp.sum(jnp.where(contact_mask, response * search_dir, 0.0))
|
||||||
|
tau = tau_num / (tau_den + 1e-12)
|
||||||
|
|
||||||
|
pressure_new = jnp.maximum(pressure - tau * search_dir, 0.0)
|
||||||
|
|
||||||
|
inadmissible = jnp.logical_and(pressure_new == 0.0, grad_centered < 0.0)
|
||||||
|
delta_new = jnp.where(jnp.sum(inadmissible) == 0, 1.0, 0.0)
|
||||||
|
|
||||||
|
pressure_projected = project_total_load(pressure_new, W, L)
|
||||||
|
error_new = compute_error(pressure_projected, grad_centered, h_rms)
|
||||||
|
|
||||||
|
return (
|
||||||
|
pressure_projected,
|
||||||
|
search_dir,
|
||||||
|
jnp.where(g_norm > 0.0, g_norm, g_old),
|
||||||
|
delta_new,
|
||||||
|
k + 1,
|
||||||
|
error_new,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = lax.while_loop(
|
||||||
|
cond_fun,
|
||||||
|
body_fun,
|
||||||
|
(
|
||||||
|
initial_pressure,
|
||||||
|
initial_direction,
|
||||||
|
jnp.array(1.0),
|
||||||
|
jnp.array(0.0),
|
||||||
|
jnp.array(0),
|
||||||
|
jnp.array(jnp.inf),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pressure, _, _, _, iterations, error = final_state
|
||||||
|
displacement = displacement_from_pressure(kernel_fourier, pressure)
|
||||||
|
return displacement, pressure, iterations, error
|
||||||
|
|
||||||
|
|
||||||
|
def run_simulation():
|
||||||
|
t0 = 0.0
|
||||||
|
t1 = 1.0
|
||||||
|
time_steps = 50
|
||||||
|
dt = (t1 - t0) / time_steps
|
||||||
|
|
||||||
|
W = 1.0
|
||||||
|
|
||||||
|
L = 2.0
|
||||||
|
radius = 0.5
|
||||||
|
S = L**2
|
||||||
|
|
||||||
|
n = 300
|
||||||
|
m = 300
|
||||||
|
x_vals = jnp.linspace(0.0, L, n, endpoint=False)
|
||||||
|
y_vals = jnp.linspace(0.0, L, m, endpoint=False)
|
||||||
|
x, y = jnp.meshgrid(x_vals, y_vals, indexing="xy")
|
||||||
|
|
||||||
|
x0 = 1.0
|
||||||
|
y0 = 1.0
|
||||||
|
|
||||||
|
E = 3.0
|
||||||
|
nu = 0.5
|
||||||
|
E_star = E / (1.0 - nu**2)
|
||||||
|
|
||||||
|
r = jnp.sqrt((x - x0) ** 2 + (y - y0) ** 2)
|
||||||
|
h_profile = -(r**2) / (2.0 * radius)
|
||||||
|
|
||||||
|
kernel_fourier = build_fourier_kernel(n, m, L, E_star)
|
||||||
|
|
||||||
|
G_inf = 2.75
|
||||||
|
G_branches = jnp.array([2.75, 2.75])
|
||||||
|
tau_branches = jnp.array([0.1, 1.0])
|
||||||
|
|
||||||
|
gamma = tau_branches / (tau_branches + dt)
|
||||||
|
G_tilde = jnp.sum(gamma * G_branches)
|
||||||
|
alpha = G_inf + G_tilde
|
||||||
|
beta = G_tilde
|
||||||
|
|
||||||
|
surface = h_profile
|
||||||
|
U0 = jnp.zeros((n, m))
|
||||||
|
M0 = jnp.zeros((G_branches.shape[0], n, m))
|
||||||
|
|
||||||
|
def scan_step(carry, _):
|
||||||
|
U, M = carry
|
||||||
|
|
||||||
|
M_maxwell = jnp.tensordot(gamma, M, axes=1)
|
||||||
|
H_new = alpha * surface - beta * U + M_maxwell
|
||||||
|
|
||||||
|
displacement, pressure, iterations, residual = contact_solver_autodiff(
|
||||||
|
kernel_fourier,
|
||||||
|
H_new,
|
||||||
|
W,
|
||||||
|
L,
|
||||||
|
tol=1e-6,
|
||||||
|
iter_max=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
U_new = (displacement - M_maxwell + beta * U) / alpha
|
||||||
|
delta_U = U_new - U
|
||||||
|
M_new = gamma[:, None, None] * (M + G_branches[:, None, None] * delta_U)
|
||||||
|
|
||||||
|
midline = pressure[n // 2]
|
||||||
|
contact_area = jnp.mean(pressure > 0.0) * S
|
||||||
|
|
||||||
|
return (U_new, M_new), (midline, contact_area, iterations, residual)
|
||||||
|
|
||||||
|
(_, _), outputs = lax.scan(
|
||||||
|
scan_step,
|
||||||
|
(U0, M0),
|
||||||
|
xs=None,
|
||||||
|
length=time_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
midlines, contact_areas, iterations, residuals = outputs
|
||||||
|
|
||||||
|
G_maxwell_t0 = jnp.sum(G_branches)
|
||||||
|
G_effective_t0 = G_inf + G_maxwell_t0
|
||||||
|
E_effective_t0 = 2.0 * G_effective_t0 * (1.0 + nu) / (1.0 - nu**2)
|
||||||
|
p0_t0 = (6.0 * W * (E_effective_t0**2) / (jnp.pi**3 * radius**2)) ** (1.0 / 3.0)
|
||||||
|
a_t0 = (3.0 * W * radius / (4.0 * E_effective_t0)) ** (1.0 / 3.0)
|
||||||
|
|
||||||
|
E_effective_inf = 2.0 * G_inf * (1.0 + nu) / (1.0 - nu**2)
|
||||||
|
p0_t_inf = (6.0 * W * (E_effective_inf**2) / (jnp.pi**3 * radius**2)) ** (1.0 / 3.0)
|
||||||
|
a_t_inf = (3.0 * W * radius / (4.0 * E_effective_inf)) ** (1.0 / 3.0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"x": x,
|
||||||
|
"midlines": midlines,
|
||||||
|
"contact_areas": contact_areas,
|
||||||
|
"iterations": iterations,
|
||||||
|
"residuals": residuals,
|
||||||
|
"params": {
|
||||||
|
"t0": t0,
|
||||||
|
"dt": dt,
|
||||||
|
"L": L,
|
||||||
|
"radius": radius,
|
||||||
|
"x0": x0,
|
||||||
|
"p0_t0": p0_t0,
|
||||||
|
"p0_t_inf": p0_t_inf,
|
||||||
|
"a_t0": a_t0,
|
||||||
|
"a_t_inf": a_t_inf,
|
||||||
|
"S": S,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
results = run_simulation()
|
||||||
|
total_time = time.perf_counter() - start_time
|
||||||
|
print("Simulation time:", total_time, "seconds")
|
||||||
|
|
||||||
|
x = jax.device_get(results["x"])
|
||||||
|
midlines = jax.device_get(results["midlines"])
|
||||||
|
contact_areas = jax.device_get(results["contact_areas"])
|
||||||
|
iterations = jax.device_get(results["iterations"]).astype(int)
|
||||||
|
residuals = jax.device_get(results["residuals"])
|
||||||
|
|
||||||
|
params = results["params"]
|
||||||
|
t0 = float(params["t0"])
|
||||||
|
dt = float(params["dt"])
|
||||||
|
L = float(params["L"])
|
||||||
|
x0 = float(params["x0"])
|
||||||
|
p0_t0 = float(params["p0_t0"])
|
||||||
|
p0_t_inf = float(params["p0_t_inf"])
|
||||||
|
a_t0 = float(params["a_t0"])
|
||||||
|
a_t_inf = float(params["a_t_inf"])
|
||||||
|
S = float(params["S"])
|
||||||
|
|
||||||
|
time_axis = t0 + dt * jnp.arange(midlines.shape[0])
|
||||||
|
time_axis_np = jax.device_get(time_axis)
|
||||||
|
|
||||||
|
def update(frame):
|
||||||
|
ax.clear()
|
||||||
|
ax.set_xlim(0, L)
|
||||||
|
ax.set_ylim(0, 1.1 * p0_t0)
|
||||||
|
ax.grid(True)
|
||||||
|
|
||||||
|
x_mid = x[int(x.shape[0] / 2)]
|
||||||
|
ax.plot(
|
||||||
|
x_mid,
|
||||||
|
p0_t0 * np.sqrt(np.maximum(0.0, 1.0 - (x_mid - x0) ** 2 / a_t0**2)),
|
||||||
|
"g--",
|
||||||
|
label="Hertz t=0",
|
||||||
|
)
|
||||||
|
ax.plot(
|
||||||
|
x_mid,
|
||||||
|
p0_t_inf * np.sqrt(np.maximum(0.0, 1.0 - (x_mid - x0) ** 2 / a_t_inf**2)),
|
||||||
|
"b--",
|
||||||
|
label="Hertz t=inf",
|
||||||
|
)
|
||||||
|
ax.plot(x_mid, midlines[frame], "r-", label="Numerical")
|
||||||
|
ax.set_title(f"Time = {t0 + frame * dt:.2f}s")
|
||||||
|
ax.set_xlabel("x")
|
||||||
|
ax.set_ylabel("Pressure distribution")
|
||||||
|
ax.legend(loc="upper right")
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ani = FuncAnimation(fig, update, frames=midlines.shape[0], repeat=False)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
Ac_hertz_t0 = jnp.pi * a_t0**2
|
||||||
|
Ac_hertz_t_inf = jnp.pi * a_t_inf**2
|
||||||
|
|
||||||
|
print("Iterations and residuals per step:")
|
||||||
|
for idx, (its, res) in enumerate(zip(iterations, residuals)):
|
||||||
|
print(f" step {idx:02d}: {its:3d} iterations, residual={res:.3e}")
|
||||||
|
|
||||||
|
print("Analytical contact area at t0:", float(Ac_hertz_t0))
|
||||||
|
print("Analytical contact area at t_inf:", float(Ac_hertz_t_inf))
|
||||||
|
print("Numerical contact area at t0:", float(contact_areas[0]))
|
||||||
|
print("Numerical contact area at t_inf:", float(contact_areas[-1]))
|
||||||
|
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(time_axis_np, contact_areas)
|
||||||
|
plt.axhline(Ac_hertz_t0, color="red", linestyle="dotted")
|
||||||
|
plt.axhline(Ac_hertz_t_inf, color="blue", linestyle="dotted")
|
||||||
|
plt.xlabel("Time(s)")
|
||||||
|
plt.ylabel("Contact area($m^2$)")
|
||||||
|
plt.legend(["Numerical", "Hertz at t=0", "Hertz at t=inf"])
|
||||||
|
plt.title("Contact area vs time for multi-branch Generalized Maxwell model")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,294 @@
|
||||||
|
### This script is for the Maxwell multi-branch model.
|
||||||
|
### Deduce process is in generalized_Maxwell_backward_Euler.ipynb
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.animation import FuncAnimation
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
#define input parameters
|
||||||
|
##time
|
||||||
|
t0 = 0
|
||||||
|
t1 = 1
|
||||||
|
time_steps = 50
|
||||||
|
dt = (t1 - t0)/time_steps
|
||||||
|
##load(constant)
|
||||||
|
W = 1e0 # Total load
|
||||||
|
|
||||||
|
#domain size
|
||||||
|
#R = 1 # Radius of demi-sphere
|
||||||
|
L = 2 # Domain size
|
||||||
|
Radius = 0.5
|
||||||
|
S = L**2 # Domain area
|
||||||
|
|
||||||
|
# Generate a 2D coordinate space
|
||||||
|
n = 300
|
||||||
|
m = 300
|
||||||
|
|
||||||
|
x, y = np.meshgrid(np.linspace(0, L, n, endpoint=False), np.linspace(0, L, m, endpoint=False))
|
||||||
|
|
||||||
|
x0 = 1
|
||||||
|
y0 = 1
|
||||||
|
|
||||||
|
E = 3 # Young's modulus
|
||||||
|
nu = 0.5
|
||||||
|
E_star = E / (1 - nu**2) # Plane strain modulus
|
||||||
|
|
||||||
|
##################################################################
|
||||||
|
#####First just apply for demi-sphere and compare with Hertz######
|
||||||
|
##################################################################
|
||||||
|
|
||||||
|
# We define the distance from the center of the sphere
|
||||||
|
r = np.sqrt((x-x0)**2 + (y-y0)**2)
|
||||||
|
|
||||||
|
# Define the kernel in the Fourier domain
|
||||||
|
q_x = 2 * np.pi * np.fft.fftfreq(n, d=L/n)
|
||||||
|
q_y = 2 * np.pi * np.fft.fftfreq(m, d=L/m)
|
||||||
|
QX, QY = np.meshgrid(q_x, q_y)
|
||||||
|
|
||||||
|
kernel_fourier = np.zeros_like(QX)
|
||||||
|
kernel_fourier = 2 / (E_star * np.sqrt(QX**2 + QY**2))
|
||||||
|
kernel_fourier[0, 0] = 0 # Avoid division by zero at the zero frequency
|
||||||
|
|
||||||
|
h_profile = -(r**2)/(2*Radius)
|
||||||
|
|
||||||
|
def apply_integration_operator(Origin, kernel_fourier, h_profile):
|
||||||
|
# Compute the Fourier transform of the input image
|
||||||
|
Origin2fourier = np.fft.fft2(Origin, norm='ortho')
|
||||||
|
|
||||||
|
Middle_fourier = Origin2fourier * kernel_fourier
|
||||||
|
|
||||||
|
Middle = np.fft.ifft2(Middle_fourier, norm='ortho').real
|
||||||
|
|
||||||
|
Gradient = Middle - h_profile
|
||||||
|
|
||||||
|
return Gradient, Origin2fourier#true gradient
|
||||||
|
|
||||||
|
##define our elastic solver with constrained conjuagte gradient method
|
||||||
|
def contact_solver(n, m, W, S, h_profile, tol=1e-6, iter_max=200):
|
||||||
|
|
||||||
|
|
||||||
|
# Initial pressure distribution
|
||||||
|
P = np.full((n, m), W / S) # Initial guess for the pressure
|
||||||
|
|
||||||
|
#initialize the search direction
|
||||||
|
T = np.zeros((n, m))
|
||||||
|
|
||||||
|
#set the norm of surface(to normalze the error)
|
||||||
|
h_rms = np.std(h_profile)
|
||||||
|
|
||||||
|
#initialize G_norm and G_old
|
||||||
|
G_norm = 0
|
||||||
|
G_old = 1
|
||||||
|
|
||||||
|
#initialize delta
|
||||||
|
delta = 0
|
||||||
|
|
||||||
|
# Initialize variables for the iteration
|
||||||
|
k = 0 # Iteration counter
|
||||||
|
error = np.inf # Initialize error
|
||||||
|
h_rms = np.std(h_profile)
|
||||||
|
|
||||||
|
while np.abs(error) > tol and k < iter_max:
|
||||||
|
S = P > 0
|
||||||
|
|
||||||
|
G, P_fourier = apply_integration_operator(P, kernel_fourier, h_profile)
|
||||||
|
|
||||||
|
G -= G[S].mean()
|
||||||
|
|
||||||
|
G_norm = np.linalg.norm(G[S])**2
|
||||||
|
|
||||||
|
# Calculate the search direction
|
||||||
|
T[S] = G[S] + delta * G_norm / G_old * T[S]
|
||||||
|
T[~S] = 0 ## out of contact area, dont need to update
|
||||||
|
|
||||||
|
# Update G_old
|
||||||
|
G_old = G_norm
|
||||||
|
|
||||||
|
# Set R
|
||||||
|
R, T_fourier = apply_integration_operator(T, kernel_fourier, h_profile)
|
||||||
|
R += h_profile
|
||||||
|
R -= R[S].mean()
|
||||||
|
|
||||||
|
# Calculate the step size tau
|
||||||
|
tau = np.vdot(G[S], T[S]) / np.vdot(R[S], T[S])
|
||||||
|
|
||||||
|
# Update P
|
||||||
|
P -= tau * T
|
||||||
|
P *= P > 0
|
||||||
|
|
||||||
|
# identify the inadmissible points
|
||||||
|
R = (P == 0) & (G < 0)
|
||||||
|
|
||||||
|
if R.sum() == 0:
|
||||||
|
delta = 1
|
||||||
|
else:
|
||||||
|
delta = 0#change the contact point set and need to do conjugate gradient again
|
||||||
|
|
||||||
|
# Enforce the applied force constraint
|
||||||
|
P = W * P / np.mean(P) / L**2
|
||||||
|
|
||||||
|
# Calculate the error for convergence checking
|
||||||
|
error = np.vdot(P, (G - np.min(G))) / (P.sum()*h_rms)
|
||||||
|
# print(delta, error, k, np.mean(P), np.mean(P>0), tau)
|
||||||
|
|
||||||
|
k += 1 # Increment the iteration counter
|
||||||
|
|
||||||
|
# Ensure a positive gap by updating G
|
||||||
|
G = G - np.min(G)
|
||||||
|
|
||||||
|
displacement_fourier = P_fourier * kernel_fourier
|
||||||
|
displacement = np.fft.ifft2(displacement_fourier, norm='ortho').real
|
||||||
|
|
||||||
|
return displacement, P
|
||||||
|
|
||||||
|
##################################################################
|
||||||
|
#####shear modulus for multi-branch Maxwell model###################
|
||||||
|
##################################################################
|
||||||
|
G_inf = 2.75 #elastic branch
|
||||||
|
#G = [2.75, 2, 0.25, 10, 2.5] #viscoelastic branch
|
||||||
|
G = [2.75, 2.75]
|
||||||
|
|
||||||
|
print('G_inf:', G_inf, ' G: ' + str(G))
|
||||||
|
|
||||||
|
# Define the relaxation times
|
||||||
|
#tau = [0.1, 0.5, 1, 2, 10] # relaxation times
|
||||||
|
tau = [0.1, 1]
|
||||||
|
#tau = [0, 0, 0, 0, 0]
|
||||||
|
#tau = [1e6,1e6,1e6,1e6,1e6]
|
||||||
|
eta = [g * t for g, t in zip(G, tau)]
|
||||||
|
|
||||||
|
print('tau:', tau, ' eta:', eta)
|
||||||
|
|
||||||
|
##################################################################
|
||||||
|
#####define G_tilde for one-branch Maxwell model #################
|
||||||
|
##################################################################
|
||||||
|
G_tilde = 0
|
||||||
|
for k in range(len(G)):
|
||||||
|
G_tilde += tau[k] / (tau[k] + dt) * G[k]
|
||||||
|
|
||||||
|
|
||||||
|
# Define parameters for updating the surface profile
|
||||||
|
alpha = G_inf + G_tilde
|
||||||
|
beta = G_tilde
|
||||||
|
|
||||||
|
gamma = []
|
||||||
|
for k in range(len(G)):
|
||||||
|
gamma.append(tau[k]/(tau[k] + dt))
|
||||||
|
|
||||||
|
Surface = h_profile
|
||||||
|
|
||||||
|
U = np.zeros((n, m))
|
||||||
|
M = np.zeros((len(G), n, m))
|
||||||
|
|
||||||
|
Ac=[]
|
||||||
|
M_maxwell = np.zeros_like(U)
|
||||||
|
|
||||||
|
#######################################
|
||||||
|
###Hertzian contact theory reference
|
||||||
|
#######################################
|
||||||
|
##Hertz solution at t0
|
||||||
|
G_maxwell_t0 = 0
|
||||||
|
for k in range(len(G)):
|
||||||
|
G_maxwell_t0 += G[k]
|
||||||
|
G_effective_t0 = G_inf + G_maxwell_t0
|
||||||
|
E_effective_t0 = 2*G_effective_t0*(1+nu)/(1-nu**2)
|
||||||
|
|
||||||
|
p0_t0 = (6*W*(E_effective_t0)**2/(np.pi**3*Radius**2))**(1/3)
|
||||||
|
a_t0 = (3*W*Radius/(4*(E_effective_t0)))**(1/3)
|
||||||
|
##Hertz solution at t_inf
|
||||||
|
E_effective_inf = 2*G_inf*(1+nu)/(1-nu**2)
|
||||||
|
|
||||||
|
p0_t_inf = (6*W*(E_effective_inf)**2/(np.pi**3*Radius**2))**(1/3)
|
||||||
|
a_t_inf = (3*W*Radius/(4*(E_effective_inf)))**(1/3)
|
||||||
|
|
||||||
|
|
||||||
|
# define the update function for the animation
|
||||||
|
def update(frame):
|
||||||
|
ax.clear()
|
||||||
|
ax.set_xlim(0, L)
|
||||||
|
ax.set_ylim(0, 1.1*p0_t0)
|
||||||
|
ax.grid()
|
||||||
|
|
||||||
|
# draw Hertzian contact theory reference
|
||||||
|
ax.plot(x[n//2], p0_t0*np.sqrt(1 - (x[n//2]-x0)**2 / a_t0**2), 'g--', label='Hertz at t=0')
|
||||||
|
ax.plot(x[n//2], p0_t_inf*np.sqrt(1 - (x[n//2]-x0)**2 / a_t_inf**2), 'b--', label='Hertz at t=inf')
|
||||||
|
|
||||||
|
# draw numerical solution at current time step
|
||||||
|
ax.plot(x[n//2], pressure_distributions[frame], 'r-', label='Numerical')
|
||||||
|
ax.set_title(f"Time = {t0 + frame * dt:.2f}s")
|
||||||
|
plt.xlabel("x")
|
||||||
|
plt.ylabel("Pressure distribution")
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
# collect pressure distributions at each time step
|
||||||
|
pressure_distributions = []
|
||||||
|
for t in np.arange(t0, t1, dt):
|
||||||
|
#Update the surface profile
|
||||||
|
M_maxwell[:] = 0
|
||||||
|
for k in range(len(G)):
|
||||||
|
M_maxwell += gamma[k]*M[k]
|
||||||
|
H_new = alpha*Surface - beta*U + M_maxwell
|
||||||
|
|
||||||
|
#main step1: Compute $P_{t+\Delta t}^{\prime}$
|
||||||
|
#M_new, P = contact_solver(n, m, W, S, H_new, tol=1e-6, iter_max=200)
|
||||||
|
M_new, P = contact_solver(n, m, W, S, H_new, tol=1e-6, iter_max=200)
|
||||||
|
|
||||||
|
##Sanity check??
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##main step2: Update global displacement
|
||||||
|
U_new = (1/alpha)*(M_new - M_maxwell + beta*U)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#main step3: Update the pressure
|
||||||
|
for k in range(len(G)):
|
||||||
|
M[k] = gamma[k]*(M[k] + G[k]*(U_new - U))
|
||||||
|
#only maxwell branch, see algorithm formula 1 in the notebook
|
||||||
|
|
||||||
|
|
||||||
|
Ac.append(np.mean(P > 0)*S)
|
||||||
|
|
||||||
|
#main step4: Update the total displacement field
|
||||||
|
U = U_new
|
||||||
|
|
||||||
|
pressure_distributions.append(P[n//2].copy()) # store the pressure distribution at each time step
|
||||||
|
|
||||||
|
end = time.perf_counter()
|
||||||
|
print("Simulation time:", end - start, "seconds")
|
||||||
|
|
||||||
|
# create a figure and axis
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
# create an animation
|
||||||
|
ani = FuncAnimation(fig, update, frames=len(pressure_distributions), repeat=False)
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
Ac_hertz_t0 = np.pi*a_t0**2
|
||||||
|
Ac_hertz_t_inf = np.pi*a_t_inf**2
|
||||||
|
|
||||||
|
print("Analytical contact area radius at t0:", a_t0)
|
||||||
|
print("Analytical contact area radius at t_inf:", a_t_inf)
|
||||||
|
print("Analytical maximum pressure at t0:", p0_t0)
|
||||||
|
print("Analytical maximum pressure at t_inf:", p0_t_inf)
|
||||||
|
print("Numerical contact area at t0:", Ac[0])
|
||||||
|
print("Numerical contact area at t_inf", Ac[-1])
|
||||||
|
print("Analyical contact area at t0:", Ac_hertz_t0)
|
||||||
|
print("Analyical contact area at t_inf:", Ac_hertz_t_inf)
|
||||||
|
plt.plot(np.arange(t0, t1, dt), Ac)
|
||||||
|
plt.axhline(Ac_hertz_t0, color='red', linestyle='dotted')
|
||||||
|
plt.axhline(Ac_hertz_t_inf, color='blue', linestyle='dotted')
|
||||||
|
plt.xlabel("Time(s)")
|
||||||
|
plt.ylabel("Contact area($m^2$)")
|
||||||
|
plt.legend(["Numerical", "Hertz at t=0", "Hertz at t=inf"])
|
||||||
|
#define a title that can read parameter tau_0
|
||||||
|
plt.title("Contact area vs time for multi-branch Generalized Maxwell model")
|
||||||
|
#plt.axhline(Ac_hertz_t_inf, color='blue')
|
||||||
|
plt.show()
|
||||||
|
|
@ -0,0 +1,70 @@
|
||||||
|
'''
|
||||||
|
Here we test a Hertzian contact on a generalized Maxwell material using Tamaas.
|
||||||
|
Contact with rough surfaces needs to be tested.
|
||||||
|
'''
|
||||||
|
import tamaas as tm
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# Set-up of the model
|
||||||
|
L = 2
|
||||||
|
Radius = 0.5
|
||||||
|
S = L**2
|
||||||
|
|
||||||
|
# discretization
|
||||||
|
n = m = 300
|
||||||
|
x = np.linspace(0, L, n, endpoint=False, dtype=tm.dtype)
|
||||||
|
y = np.linspace(0, L, m, endpoint=False, dtype=tm.dtype)
|
||||||
|
xx, yy = np.meshgrid(x, y, indexing="ij")
|
||||||
|
# Define the surface
|
||||||
|
surface = surface = -((xx - L / 2) ** 2 + (yy - L / 2) ** 2) / (2 * Radius)
|
||||||
|
# Create the model
|
||||||
|
model = tm.Model(tm.model_type.basic_2d, [L, L], [n, m])
|
||||||
|
|
||||||
|
# Defining the elastic branch (i.e. the behavior at t = ∞)
|
||||||
|
model.E = 3
|
||||||
|
model.nu = 0.5
|
||||||
|
|
||||||
|
# Characteristic times of the relaxation function
|
||||||
|
times = [0.1, 1]
|
||||||
|
|
||||||
|
# Shear moduli for each branch of the model
|
||||||
|
shear_moduli = [2.75, 2.75]
|
||||||
|
|
||||||
|
t0 = 0
|
||||||
|
t1 = 1
|
||||||
|
time_steps = 50
|
||||||
|
# Time step
|
||||||
|
Δt = (t1 - t0) / time_steps
|
||||||
|
|
||||||
|
# Applied load
|
||||||
|
W = 1.0
|
||||||
|
load = W / S
|
||||||
|
|
||||||
|
# Solver instanciation
|
||||||
|
solver = tm.MaxwellViscoelastic(model, surface, 1e-10,
|
||||||
|
time_step=Δt,
|
||||||
|
shear_moduli=shear_moduli,
|
||||||
|
characteristic_times=times)
|
||||||
|
|
||||||
|
# Solve one timestep with given load
|
||||||
|
start = time.perf_counter()
|
||||||
|
solver.solve(load)
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f'Simulation time for one step: {end - start} seconds')
|
||||||
|
|
||||||
|
# plot like ub Multi_branches_generalized_Maxwell.py
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
displacement = model.displacement[:]
|
||||||
|
pressure = model.traction[:]
|
||||||
|
plt.figure(figsize=(12, 5))
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.imshow(displacement, extent=(0, L, 0, L), origin='lower')
|
||||||
|
plt.title('Displacement field')
|
||||||
|
plt.colorbar()
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.imshow(pressure, extent=(0, L, 0, L), origin='lower')
|
||||||
|
plt.title('Pressure field')
|
||||||
|
plt.colorbar()
|
||||||
|
plt.show()
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
'''
|
||||||
|
This is a test file for JAX functionalities. Codes are from Jax documentation.(https://jax.readthedocs.io/en/latest/index.html)
|
||||||
|
'''
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import grad, jit
|
||||||
|
|
||||||
|
# Example: JIT compilation
|
||||||
|
def selu(x, alpha=1.67, lambda_=1.05):
|
||||||
|
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
|
||||||
|
|
||||||
|
x = jnp.arange(1000000)
|
||||||
|
|
||||||
|
import time
|
||||||
|
# measure a single execution and ensure JAX computation completes
|
||||||
|
start = time.perf_counter()
|
||||||
|
selu(x).block_until_ready()
|
||||||
|
end = time.perf_counter()
|
||||||
|
print("Elapsed:", end - start)
|
||||||
|
|
||||||
|
selu_jit = jit(selu)
|
||||||
|
# measure a single execution and ensure JAX computation completes
|
||||||
|
start = time.perf_counter()
|
||||||
|
selu_jit(x).block_until_ready()
|
||||||
|
end = time.perf_counter()
|
||||||
|
print("Elapsed with JIT:", end - start)
|
||||||
|
|
||||||
|
# Example: Automatic differentiation\
|
||||||
|
def f(x):
|
||||||
|
return jnp.sin(x) + 0.5 * x ** 2
|
||||||
|
df = grad(f)
|
||||||
|
x = 2.0
|
||||||
|
print("f'(2.0) =", df(x)) # Should print the derivative of f at x=2.0
|
||||||
Loading…
Reference in New Issue