Legendre curve checkpointing demo#
Checkpointing allows to reduce memory spend when training with high degree polynomials at the expense of some speed. This notebook demonstrates how you can use checkpointing, and shows that the results with and without checkpointing are identical.
[1]:
import torch
import torchcurves.functional as tcf
[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float64
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed_all(0)
batch_size = 512
num_curves = 4
dim = 3
degree = 255
n_coeffs = degree + 1
x_base = torch.linspace(-1, 1, batch_size, device=device, dtype=dtype)
x_base = x_base.unsqueeze(1).repeat(1, num_curves)
coeffs_base = torch.randn(n_coeffs, num_curves, dim, device=device, dtype=dtype)
/home/alex/git/torchcurves/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:174: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
return torch._C._cuda_getDeviceCount() > 0
[3]:
def run(checkpoint_segments):
x = x_base.clone().requires_grad_(True)
coeffs = coeffs_base.clone().requires_grad_(True)
y = tcf.legendre_curves(x, coeffs, checkpoint_segments=checkpoint_segments)
loss = (y ** 2).sum()
loss.backward()
return y, x.grad, coeffs.grad
y_no, gx_no, gc_no = run(None)
y_ckpt, gx_ckpt, gc_ckpt = run(4)
torch.testing.assert_close(y_no, y_ckpt)
torch.testing.assert_close(gx_no, gx_ckpt)
torch.testing.assert_close(gc_no, gc_ckpt)
print("Outputs and gradients match.")
Outputs and gradients match.
[4]:
def peak_memory(checkpoint_segments):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
run(checkpoint_segments)
torch.cuda.synchronize(device)
return torch.cuda.max_memory_allocated(device)
if device.type == "cuda":
peak_no = peak_memory(None)
peak_ckpt = peak_memory(4)
print(f"Peak without checkpointing: {peak_no / 1024 ** 2:.2f} MiB")
print(f"Peak with checkpointing: {peak_ckpt / 1024 ** 2:.2f} MiB")
else:
print("CUDA not available; skipping peak memory comparison.")
CUDA not available; skipping peak memory comparison.
[ ]: