LegendreCurve#

class torchcurves.LegendreCurve(num_curves, dim, degree, input_map='real.rational', checkpoint_segments=None)#

PyTorch module for a batch of parametrized curves using Legendre polynomial basis.

The learnable parameters are the control points (coefficients) of the Legendre series for each curve. All curves share the same degree. The input of this layer is mapped to \([-1, 1]\). Each curve is:

\[\mathbf{C}_m(u) = \sum_{k=0}^{\mathrm{degree}} \mathbf{C}_{m,k} \cdot P_k(u),\]

where \(P_k\) is the \(k\)-th Legendre polynomial.

Parameters:
  • num_curves (int) – Number of Legendre curves to define (\(M\)).

  • dim (int) – Dimension of each curve’s output points (\(D\)).

  • degree (int) – Degree of the Legendre polynomial basis (shared by all curves). The number of coefficients per curve will be degree + 1.

  • input_map (Union[str, InputMap]) – Map from raw inputs to \([-1, 1]\). Can be a dotted preset string like “real.rational”, a map object from torchcurves.maps, or a callable with signature f(x, out_min, out_max).

  • checkpoint_segments (Optional[int]) – Optional number of segments for gradient checkpointing. Larger values save memory but increase compute. Only used when gradients are enabled.

forward(u)#

Evaluate the batch of Legendre curves.

Parameters:

u (Tensor) – Parameter values of size \((B, C)\), where \(B\) is the mini-batch size, and C is the number of curves, and must be equal to self.num_curves.

Return type:

Tensor

Returns:

Points on the Legendre curves of shape \((B, C, D)\).