RGA 1D (b‑grid) — GPU‑only README
Greedy solver for the penalized 1D Poisson energy on $[0,\pi]$ with a selectable dictionary
g_{w,b}(x) = \phi(wx+b)
and a fixed $b$–grid argmax per step (no Newton). The code is JAX‑JITed and intended for NVIDIA GPUs (CUDA 12). It produces a 3×2 summary figure and can run headless.
Top‑left panel now shows only the $L^1$ error $|u_k-u^*|_{L^1}$.
Requirements (GPU)
- NVIDIA GPU with CUDA 12.x runtime
- Python 3.10–3.12 (tested on 3.12)
- Internet access to fetch JAX CUDA wheels
JAX’s CUDA wheels are hosted outside PyPI. Use the special index URL.
requirements.txt
(GPU)
jax[cuda12_pip]==0.4.35
numpy==1.26.4
matplotlib==3.8.4
Create venv & install (Linux)
python3 -m venv .venv
source .venv/bin/activate
pip install -U pip wheel
# IMPORTANT: use JAX CUDA index URL
pip install -r requirements.txt \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Recommended for servers/headless:
export MPLBACKEND=Agg
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
Verify GPU:
python - <<'PY'
import jax; print(jax.devices()) # expect [cuda(id=0), ...]
PY
Usage
The entry point is main.py
.
Choose the activation / dictionary
--activation {sin, relu, relu2, relu3, tanh}
selects $\phi$ in
g_{w,b}(x)=\phi(wx+b)
.
sin
$\phi(z)=\sin z$tanh
$\phi(z)=\tanh z$relu
$\phi(z)=\max{z,0}$relu2
$\phi(z)=\max{z,0}^2$relu3
$\phi(z)=\max{z,0}^3$
Normalization & line search
--norm_atoms
normalizes atoms in the problem’s $H_\delta$ norm.--line_search
uses the exact 1‑D minimizer along the chosen atom; otherwise the relaxed RGA update with budget--M
is used.
With
--norm_atoms
, the variation‑norm proxyk1
is comparable to the number of selected atoms. With--line_search
, step sizes are chosen analytically and the budgetM
only serves as a reference in the plot.
Typical runs
# 1) Sine dictionary, budgeted RGA, no normalization
python -u main.py \
--activation sin --iters 1000000 --log_every 10000 \
--save plots/sine_rga_summary
# 2) ReLU^2, normalized atoms + line search
python -u main.py \
--activation relu2 --norm_atoms --line_search \
--iters 800000 --log_every 10000 \
--save plots/relu2_ls
# 3) tanh, normalized atoms, budget M=5
python -u main.py \
--activation tanh --norm_atoms --M 5 \
--iters 400000 --save plots/tanh_norm_M5
CLI reference (subset)
--iters
total iterations (default500000
)--N
spatial grid size (default2001
)--Nb
$b$‑grid resolution on $[-\pi,\pi]$ (default1200
)--M
budget for relaxed RGA (default10.0
)--delta
boundary penalty (default1e-2
)--b_chunk
static chunk size for argmax over $b$ (default256
)--log_every
logging stride (default1000
)--norm_atoms
(flag) normalize each atom in $H_\delta$--line_search
(flag) use exact 1‑D line search--activation
dictionary choice (see above)--save
output path prefix (no extension)
Outputs
If --save outputs/run_1/summary
is given, the script writes
outputs/run_1/summary.png
outputs/run_1/summary.svg
The 3×2 figure contains:
- Top‑left: $L^1$ error $|u_k-u^*|_{L^1}$ (log–log)
- Variation‑norm proxy
k1
vs. budgetM
- $L^2$ error with slope guides ($n^{-1}$, $n^{-1/2}$, $n^{-1/4}$)
- $u_k$ vs. exact $u^*(x)=\tfrac12 x(\pi-x)$
- Argmax trace of $w$
- Argmax trace of $b$ (with $\pm\pi$ lines)
Math (very short)
We minimize the penalized Poisson energy
\mathcal R_\delta(u) = \tfrac12!\int_0^\pi |u'|^2 - \int_0^\pi fu + \tfrac{1}{2\delta}(u(0)^2+u(\pi)^2),\quad f\equiv1,
whose minimizer is $u^*(x)=\tfrac12x(\pi-x)$. Each RGA step selects $g_{w,b}$ by scanning a fixed $b$‑grid and updates either
- Relaxed RGA:
u_k=(1-\alpha_k)u_{k-1}-\alpha_k M\,\mathrm{sgn}\,\langle g,\nabla\mathcal R\rangle\, g
with\alpha_k=\min(1,2/k)
, or - Line search:
u_k=u_{k-1}-\lambda_k g
with\lambda_k=\langle g,\nabla\mathcal R\rangle/\|g\|_{H_\delta}^2
.
Under the standard assumptions (convex, $K$‑smooth; symmetric dictionary bounded in $H_\delta$), the optimization error toward the best element in the balanced hull decays like $\mathcal O(1/n)$; for $L^2$ one typically gets $|u_k-u^*|_{L^2}=\mathcal O(n^{-1/2})$ provided the discretization error is below the optimization regime.
Performance tips
- Keep
b_chunk
around256
–512
for good tiling. - Increase
N
andNb
gradually; both raise per‑step cost. - Use
--log_every
to reduce host traffic. - Double precision is enabled via
jax_enable_x64=True
in the script.
Troubleshooting
NVCC / cuda_nvcc
error
Do not install nvidia-cuda-nvcc-cu12
. If present, remove it and pin JAX wheels exactly:
pip uninstall -y nvidia-cuda-nvcc-cu12
pip install --force-reinstall --no-deps \
"jax==0.4.35" "jaxlib==0.4.35" \
"jax-cuda12-plugin==0.4.35" "jax-cuda12-pjrt==0.4.35"
OOM or crash on start
Lower memory preallocation (env vars above) and/or reduce N
, Nb
, b_chunk
.
No figures appear
Always pass --save <prefix>
on servers (headless). The code writes both PNG + SVG.
Copy results from server
# from your laptop
scp -r USER@SERVER:/path/to/project/plots ./plots_local
Acknowledgement
Implements a standard Relaxed Greedy Algorithm with a symmetric dictionary and optional line search, JAX‑JITed for NVIDIA GPUs.