No description
Find a file
Your Name 40b714677c 10e5
2025-09-10 22:00:42 +00:00
runs Update: RGA main.py, Shapes, Line-Search usw. 2025-09-06 11:09:11 +00:00
.DS_Store Update README, code & plots 2025-08-31 17:48:54 +02:00
.gitignore chore: ignore .venv 2025-08-30 19:36:43 +00:00
main.py budget clipping 2025-09-10 22:06:05 +02:00
r elu2_line_search_no_10e4.png 10e5 2025-09-10 22:00:42 +00:00
r elu2_line_search_no_10e4.svg 10e5 2025-09-10 22:00:42 +00:00
readme.md Update README, code & plots 2025-08-31 17:48:54 +02:00
relu2_line_search_10e3.png lineaerch 2025-09-10 21:14:06 +00:00
relu2_line_search_10e3.svg lineaerch 2025-09-10 21:14:06 +00:00
relu2_line_search_budgetclipping.png linesearch plots budget clippinmg 2025-09-10 20:32:13 +00:00
relu2_line_search_budgetclipping.svg linesearch plots budget clippinmg 2025-09-10 20:32:13 +00:00
relu2_line_search_budgetclipping_500 --M 500 --delta 1e-3.png 1000 2025-09-10 20:56:51 +00:00
relu2_line_search_budgetclipping_500 --M 500 --delta 1e-3.svg 1000 2025-09-10 20:56:51 +00:00
relu2_line_search_budgetclipping_500_10e3.png lineaerch 2025-09-10 21:14:06 +00:00
relu2_line_search_budgetclipping_500_10e3.svg lineaerch 2025-09-10 21:14:06 +00:00
relu2_line_search_budgetclipping_1000 --M 1000 --delta 1e-3.png 1000 2025-09-10 20:56:51 +00:00
relu2_line_search_budgetclipping_1000 --M 1000 --delta 1e-3.svg 1000 2025-09-10 20:56:51 +00:00
relu2_line_search_budgetclipping_1000_10e3.png lineaerch 2025-09-10 21:14:06 +00:00
relu2_line_search_budgetclipping_1000_10e3.svg lineaerch 2025-09-10 21:14:06 +00:00
relu2_line_search_budgetclipping_M100.png linesearch plots budget clippinmg 2025-09-10 20:32:13 +00:00
relu2_line_search_budgetclipping_M100.svg linesearch plots budget clippinmg 2025-09-10 20:32:13 +00:00
relu2_line_search_budgetclipping_M500.png linesearch plots budget clippinmg 2025-09-10 20:32:13 +00:00
relu2_line_search_budgetclipping_M500.svg linesearch plots budget clippinmg 2025-09-10 20:32:13 +00:00
relu2_line_search_no_10e4.png 10e5 2025-09-10 22:00:42 +00:00
relu2_line_search_no_10e4.svg 10e5 2025-09-10 22:00:42 +00:00
relu2_line_search_no_10e5.png small delta line search 2025-09-10 21:50:11 +00:00
relu2_line_search_no_10e5.svg small delta line search 2025-09-10 21:50:11 +00:00
relu2_line_search_no_wpos.png lien search like before ? 2025-09-10 21:39:04 +00:00
relu2_line_search_no_wpos.svg lien search like before ? 2025-09-10 21:39:04 +00:00
relu2_line_search_nobudgetclipping_10e3_w6.png morew 2025-09-10 21:30:35 +00:00
relu2_line_search_nobudgetclipping_10e3_w6.svg morew 2025-09-10 21:30:35 +00:00
requirements.txt für gpu 2025-08-30 19:34:07 +02:00

RGA 1D (bgrid) — GPUonly README

Greedy solver for the penalized 1D Poisson energy on $[0,\pi]$ with a selectable dictionary g_{w,b}(x) = \phi(wx+b) and a fixed $b$grid argmax per step (no Newton). The code is JAXJITed and intended for NVIDIA GPUs (CUDA 12). It produces a 3×2 summary figure and can run headless.

Topleft panel now shows only the $L^1$ error $|u_k-u^*|_{L^1}$.


Requirements (GPU)

  • NVIDIA GPU with CUDA 12.x runtime
  • Python 3.103.12 (tested on 3.12)
  • Internet access to fetch JAX CUDA wheels

JAXs CUDA wheels are hosted outside PyPI. Use the special index URL.

requirements.txt (GPU)

jax[cuda12_pip]==0.4.35
numpy==1.26.4
matplotlib==3.8.4

Create venv & install (Linux)

python3 -m venv .venv
source .venv/bin/activate
pip install -U pip wheel
# IMPORTANT: use JAX CUDA index URL
pip install -r requirements.txt \
  -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Recommended for servers/headless:

export MPLBACKEND=Agg
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85

Verify GPU:

python - <<'PY'
import jax; print(jax.devices())  # expect [cuda(id=0), ...]
PY

Usage

The entry point is main.py.

Choose the activation / dictionary

--activation {sin, relu, relu2, relu3, tanh} selects $\phi$ in g_{w,b}(x)=\phi(wx+b).

  • sin    $\phi(z)=\sin z$
  • tanh  $\phi(z)=\tanh z$
  • relu  $\phi(z)=\max{z,0}$
  • relu2 $\phi(z)=\max{z,0}^2$
  • relu3 $\phi(z)=\max{z,0}^3$
  • --norm_atoms normalizes atoms in the problems $H_\delta$ norm.
  • --line_search uses the exact 1D minimizer along the chosen atom; otherwise the relaxed RGA update with budget --M is used.

With --norm_atoms, the variationnorm proxy k1 is comparable to the number of selected atoms. With --line_search, step sizes are chosen analytically and the budget M only serves as a reference in the plot.

Typical runs

# 1) Sine dictionary, budgeted RGA, no normalization
python -u main.py \
  --activation sin --iters 1000000 --log_every 10000 \
  --save plots/sine_rga_summary

# 2) ReLU^2, normalized atoms + line search
python -u main.py \
  --activation relu2 --norm_atoms --line_search \
  --iters 800000 --log_every 10000 \
  --save plots/relu2_ls

# 3) tanh, normalized atoms, budget M=5
python -u main.py \
  --activation tanh --norm_atoms --M 5 \
  --iters 400000 --save plots/tanh_norm_M5

CLI reference (subset)

  • --iters total iterations (default 500000)
  • --N spatial grid size (default 2001)
  • --Nb $b$grid resolution on $[-\pi,\pi]$ (default 1200)
  • --M budget for relaxed RGA (default 10.0)
  • --delta boundary penalty (default 1e-2)
  • --b_chunk static chunk size for argmax over $b$ (default 256)
  • --log_every logging stride (default 1000)
  • --norm_atoms (flag) normalize each atom in $H_\delta$
  • --line_search (flag) use exact 1D line search
  • --activation dictionary choice (see above)
  • --save output path prefix (no extension)

Outputs

If --save outputs/run_1/summary is given, the script writes

outputs/run_1/summary.png
outputs/run_1/summary.svg

The 3×2 figure contains:

  1. Topleft: $L^1$ error $|u_k-u^*|_{L^1}$ (loglog)
  2. Variationnorm proxy k1 vs. budget M
  3. $L^2$ error with slope guides ($n^{-1}$, $n^{-1/2}$, $n^{-1/4}$)
  4. $u_k$ vs. exact $u^*(x)=\tfrac12 x(\pi-x)$
  5. Argmax trace of $w$
  6. Argmax trace of $b$ (with $\pm\pi$ lines)

Math (very short)

We minimize the penalized Poisson energy

\mathcal R_\delta(u) = \tfrac12!\int_0^\pi |u'|^2 - \int_0^\pi fu + \tfrac{1}{2\delta}(u(0)^2+u(\pi)^2),\quad f\equiv1,

whose minimizer is $u^*(x)=\tfrac12x(\pi-x)$. Each RGA step selects $g_{w,b}$ by scanning a fixed $b$grid and updates either

  • Relaxed RGA: u_k=(1-\alpha_k)u_{k-1}-\alpha_k M\,\mathrm{sgn}\,\langle g,\nabla\mathcal R\rangle\, g with \alpha_k=\min(1,2/k), or
  • Line search: u_k=u_{k-1}-\lambda_k g with \lambda_k=\langle g,\nabla\mathcal R\rangle/\|g\|_{H_\delta}^2.

Under the standard assumptions (convex, $K$smooth; symmetric dictionary bounded in $H_\delta$), the optimization error toward the best element in the balanced hull decays like $\mathcal O(1/n)$; for $L^2$ one typically gets $|u_k-u^*|_{L^2}=\mathcal O(n^{-1/2})$ provided the discretization error is below the optimization regime.


Performance tips

  • Keep b_chunk around 256512 for good tiling.
  • Increase N and Nb gradually; both raise perstep cost.
  • Use --log_every to reduce host traffic.
  • Double precision is enabled via jax_enable_x64=True in the script.

Troubleshooting

NVCC / cuda_nvcc error Do not install nvidia-cuda-nvcc-cu12. If present, remove it and pin JAX wheels exactly:

pip uninstall -y nvidia-cuda-nvcc-cu12
pip install --force-reinstall --no-deps \
  "jax==0.4.35" "jaxlib==0.4.35" \
  "jax-cuda12-plugin==0.4.35" "jax-cuda12-pjrt==0.4.35"

OOM or crash on start Lower memory preallocation (env vars above) and/or reduce N, Nb, b_chunk.

No figures appear Always pass --save <prefix> on servers (headless). The code writes both PNG + SVG.


Copy results from server

# from your laptop
scp -r USER@SERVER:/path/to/project/plots ./plots_local

Acknowledgement

Implements a standard Relaxed Greedy Algorithm with a symmetric dictionary and optional line search, JAXJITed for NVIDIA GPUs.