Fast functions for dense scalar and vector fields, implemented using just-in-time compilation.
/!\ This is (still) experimental
- GPU version of the algorithms are written in pure CUDA, and compiled just-in-time by
cupy. - CPU version of the algorithms are written in pure C++, and compiled just-in-time by
cppyy.
Installation
Dependencies
pytorch >= 1.8numpycppyycupy(if CUDA support required)
Conda
PyTorch, cppyy and cupy all heavily depend on system libraries, and easily find themselves in situation of incompatibility. The preferred installation method therefore relies on conda, which minimizes such issues.
conda install jitfields -c balbasty -c pytorch -c conda-forge
Note that in this case PyTorch without GPU support will get installed (unless PyTorch was already installed using conda, in which case the installed version will be preserved). To ensure that the GPU version of PyTorch gets installed (and ensure compatibility with cupy), you should instead do:
# for pytorch >= 1.13 conda install jitfields pytorch==$TORCH_VERSION pytorch-cuda=$CUDA_VERSION -c balbasty -c pytorch -c nvidia -c conda-forge # for pytorch < 1.13 conda install jitfields pytorch==$TORCH_VERSION cudatoolkit=$CUDA_VERSION -c balbasty -c pytorch -c conda-forge
In our experience this is enough to ensure compatibility across all dependencies.
If for some reason it is not, it may be necessary to use cupy's specific cuda-version package. See:
Pip
Installation through pip should work, as jitfields is a pure python package. As stated above,
there may be inconsistencies across pytorch, cppyy and cupy. It may therefore be preferable
to pre-install these dependencies yourself, rather than relying on pip's dependency solver.
If you intend to run code on the GPU, specify the [cuda] extra tag, which ensures that cupy gets installed.
pip install jitfields[cuda]
API
Distance transforms
Distance to binary masks
def euclidean_distance_transform(x, ndim=None, vx=1, dtype=None): ... """Compute the Euclidean distance transform of a binary image Parameters ---------- x : (..., *spatial) tensor Input tensor ndim : int, default=`x.ndim` Number of spatial dimensions vx : [sequence of] float, default=1 Voxel size Returns ------- d : (..., *spatial) tensor Distance map References ---------- ..[1] "Distance Transforms of Sampled Functions" Pedro F. Felzenszwalb & Daniel P. Huttenlocher Theory of Computing (2012) https://www.theoryofcomputing.org/articles/v008a019/v008a019.pdf """
def l1_distance_transform(x, ndim=None, vx=1, dtype=None): ... """Compute the L1 distance transform of a binary image Parameters ---------- x : (..., *spatial) tensor Input tensor dim : int, default=`x.ndim` Number of spatial dimensions vx : [sequence of] float, default=1 Voxel size dtype : torch.dtype Datatype of the distance map. By default, use x.dtype if it is a floating point type, otherwise use the default floating point type. Returns ------- d : (..., *spatial) tensor Distance map References ---------- ..[1] "Distance Transforms of Sampled Functions" Pedro F. Felzenszwalb & Daniel P. Huttenlocher Theory of Computing (2012) https://www.theoryofcomputing.org/articles/v008a019/v008a019.pdf """
def signed_distance_transform( x: tensor, ndim: Optional[int] = None, vx: OneOrSeveral[float] = 1, dtype: Optional[torch.dtype] = None, ) -> tensor: ... """Compute the signed Euclidean distance transform of a binary image Parameters ---------- x : `(..., *spatial) tensor` Input tensor, with shape `(..., *spatial)`. ndim : `int`, default=`x.ndim` Number of spatial dimensions. Default: all. vx : `[sequence of] float`, default=1 Voxel size. dtype : `torch.dtype`, optional Ouptut data type. Default is same as `x` if it has a floating point data type, else `torch.get_default_dtype()`. Returns ------- d : `(..., *spatial) tensor` Signed distance map, with shape `(..., *spatial)`. References ---------- ..[1] "Distance Transforms of Sampled Functions" Pedro F. Felzenszwalb & Daniel P. Huttenlocher Theory of Computing (2012) https://www.theoryofcomputing.org/articles/v008a019/v008a019.pdf """
Distance to 1D splines
def spline_distance_table( loc: tensor, coeff: tensor, steps: Optional[Union[int, tensor]] = None, order: OrderType = 3, bound: BoundType = 'dct2', square: bool = False, ) -> Tuple[tensor, tensor]: ... """Compute the minimum distance from a set of points to a 1D spline Parameters ---------- loc : `(..., D) tensor` Point set. coeff : `(..., N, D) tensor` Spline coefficients encoding the location of the 1D spline. steps : `int or (..., K) tensor` Number of time steps to try, or list of time steps to try. order : {1..7} Spline order. bound : `{'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}` Boundary conditions of the spline. square : bool Return the squared Euclidean distance. Returns ------- dist : `(...) tensor` Distance from each point in the set to its closest point on the spline time : `(...) tensor` Time of the closest point on the spline """
def spline_distance_brent( loc: tensor, coeff: tensor, max_iter: int = 128, tol: float = 1e-6, step_size: float = 0.01, order: OrderType = 3, bound: BoundType = 'dct2', square: bool = False, steps: Optional[Union[int, tensor]] = None, ) -> Tuple[tensor, tensor]: ... """Compute the minimum distance from a set of points to a 1D spline Parameters ---------- loc : `(..., D) tensor` Point set. coeff : `(..., N, D) tensor` Spline coefficients encoding the location of the 1D spline. max_iter : int Number of optimization steps. tol : float Tolerance for early stopping step_size : float Initial search size. order : {1..7} Spline order. bound : `{'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}` Boundary conditions of the spline. square : bool Return the squared Euclidean distance. steps : int Number of steps used in the table-based initialisation. Returns ------- dist : `(...) tensor` Distance from each point in the set to its closest point on the spline time : `(...) tensor` Time of the closest point on the spline """
def spline_distance_gaussnewton( loc: tensor, coeff: tensor, max_iter: int = 16, tol: float = 1e-6, order: OrderType = 3, bound: BoundType = 'dct2', square: bool = False, steps: Optional[Union[int, tensor]] = None, ) -> Tuple[tensor, tensor]: ... """Compute the minimum distance from a set of points to a 1D spline Parameters ---------- loc : `(..., D) tensor` Point set. coeff : `(..., N, D) tensor` Spline coefficients encoding the location of the 1D spline. max_iter : int Number of optimization steps. tol : float Tolerance for early stopping order : {1..7} Spline order. bound : `{'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}` Boundary conditions of the spline. square : bool Return the squared Euclidean distance. steps : int Number of steps used in the table-based initialisation. Returns ------- dist : `(...) tensor` Distance from each point in the set to its closest point on the spline time : `(...) tensor` Time of the closest point on the spline """
def spline_distance_brent_( dist: tensor, time: tensor, loc: tensor, coeff: tensor, max_iter: int = 128, tol: float = 1e-6, step_size: float = 0.01, order: OrderType = 3, bound: BoundType = 'dct2', square: bool = False, ) -> Tuple[tensor, tensor]: ... """Compute the minimum distance from a set of points to a 1D spline (inplace) Parameters ---------- dist : `(...) tensor` Initial distance from each point in the set to its closest point on the spline time : `(...) tensor` Initial time of the closest point on the spline loc : `(..., D) tensor` Point set. coeff : `(..., N, D) tensor` Spline coefficients encoding the location of the 1D spline. max_iter : int Number of optimization steps. tol : float Tolerance for early stopping step_size : float Initial search size. order : {1..7} Spline order. bound : `{'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}` Boundary conditions of the spline. square : bool Return the squared Euclidean distance. Returns ------- dist : `(...) tensor` Distance from each point in the set to its closest point on the spline time : `(...) tensor` Time of the closest point on the spline """
def spline_distance_gaussnewton_( dist: tensor, time: tensor, loc: tensor, coeff: tensor, max_iter: int = 16, tol: float = 1e-6, order: OrderType = 3, bound: BoundType = 'dct2', square: bool = False, ) -> Tuple[tensor, tensor]: ... """Compute the minimum distance from a set of points to a 1D spline (inplace) Parameters ---------- dist : `(...) tensor` Initial distance from each point in the set to its closest point on the spline time : `(...) tensor` Initial time of the closest point on the spline loc : `(..., D) tensor` Point set. coeff : `(..., N, D) tensor` Spline coefficients encoding the location of the 1D spline. max_iter : int Number of optimization steps. tol : float Tolerance for early stopping order : {1..7} Spline order. bound : `{'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}` Boundary conditions of the spline. square : bool Return the squared Euclidean distance. Returns ------- dist : `(...) tensor` Distance from each point in the set to its closest point on the spline time : `(...) tensor` Time of the closest point on the spline """
Distance to triangular meshes
def mesh_distance_signed( loc: tensor, vertices: tensor, faces: tensor, out: Optional[tensor] = None, ) -> tensor: ... """Compute the *signed* minimum distance from a set of points to a triangular mesh Parameters ---------- loc : `(..., D) tensor` Point set. vertices : `(N, D) tensor` Mesh vertices faces : `(M, D) tensor[integer]` Mesh faces Returns ------- dist : `(...) tensor` Signed distance from each point in the set to its closest point on the mesh (negative inside, positive outside) """
def mesh_distance( loc: tensor, vertices: tensor, faces: tensor, out: Optional[tensor] = None, ) -> tensor: ... """Compute the minimum distance from a set of points to a triangular mesh Parameters ---------- loc : `(..., D) tensor` Point set. vertices : `(N, D) tensor` Mesh vertices faces : `(M, D) tensor[integer]` Mesh faces Returns ------- dist : `(...) tensor` Signed distance from each point in the set to its closest point on the mesh (negative inside, positive outside) """
Interpolation/Resampling
def spline_coeff(inp, order, bound='dct2', dim=-1): ... """Compute the interpolating spline coefficients, along a single dimension. Parameters ---------- inp : tensor Input tensor order : {0..7}, default=2 Interpolation order. bound : {'zero', 'replicate', 'dct1', 'dct2', 'dft'}, default='dct2' Boundary conditions. dim : int, default=-1 Dimension along which to filter Returns ------- coeff : tensor Spline coefficients References ---------- ..[1] M. Unser, A. Aldroubi and M. Eden. "B-Spline Signal Processing: Part I-Theory," IEEE Transactions on Signal Processing 41(2):821-832 (1993). ..[2] M. Unser, A. Aldroubi and M. Eden. "B-Spline Signal Processing: Part II-Efficient Design and Applications," IEEE Transactions on Signal Processing 41(2):834-848 (1993). ..[3] M. Unser. "Splines: A Perfect Fit for Signal and Image Processing," IEEE Signal Processing Magazine 16(6):22-38 (1999). """
def spline_coeff_nd(inp, order, bound='dct2', ndim=None): ... """Compute the interpolating spline coefficients, along the last N dimensions. Parameters ---------- inp : (..., *spatial) tensor Input tensor order : [sequence of] {0..7}, default=2 Interpolation order. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dft'}, default='dct2' Boundary conditions. ndim : int, default=`inp.dim()` Number of spatial dimensions Returns ------- coeff : (..., *spatial) tensor Spline coefficients References ---------- ..[1] M. Unser, A. Aldroubi and M. Eden. "B-Spline Signal Processing: Part I-Theory," IEEE Transactions on Signal Processing 41(2):821-832 (1993). ..[2] M. Unser, A. Aldroubi and M. Eden. "B-Spline Signal Processing: Part II-Efficient Design and Applications," IEEE Transactions on Signal Processing 41(2):834-848 (1993). ..[3] M. Unser. "Splines: A Perfect Fit for Signal and Image Processing," IEEE Signal Processing Magazine 16(6):22-38 (1999). """
def resize(x, factor=None, shape=None, ndim=None, anchor='e', order=2, bound='dct2', prefilter=True): ... """Resize a tensor using spline interpolation Parameters ---------- x : (..., *inshape) tensor Input tensor factor : [sequence of] float, optional Factor by which to resize the tensor (> 1 == bigger) One of factor or shape must be provided. shape : [sequence of] float, optional Shape of output tensor. One of factor or shape must be provided. ndim : int, optional Number if spatial dimensions. If not provided, try to guess from factor or shape. If guess fails, assume ndim = x.dim(). anchor : {'edge', 'center'} or None What feature should be aligned across the input and output tensors. If 'edge' or 'center', the effective scaling factor may slightly differ from the requested scaling factor. If None, the center of the (0, 0) voxel is aligned, and the requested factor is exactly applied. order : [sequence of] {0..7}, default=2 Interpolation order. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dct2' How to deal with out-of-bound values. prefilter : bool, default=True Whether to first compute interpolating coefficients. Must be true for proper interpolation, otherwise this function merely performs a non-interpolating "prolongation". Returns ------- x : (..., *shape) tensor Resized tensor References ---------- ..[1] M. Unser, A. Aldroubi and M. Eden. "B-Spline Signal Processing: Part I-Theory," IEEE Transactions on Signal Processing 41(2):821-832 (1993). ..[2] M. Unser, A. Aldroubi and M. Eden. "B-Spline Signal Processing: Part II-Efficient Design and Applications," IEEE Transactions on Signal Processing 41(2):834-848 (1993). ..[3] M. Unser. "Splines: A Perfect Fit for Signal and Image Processing," IEEE Signal Processing Magazine 16(6):22-38 (1999). """
def restrict(x, factor=None, shape=None, ndim=None, anchor='e', order=2, bound='dct2', reduce_sum=False): ... """Restrict (adjoint of resize) a tensor using spline interpolation Parameters ---------- x : (..., *inshape) tensor Input tensor factor : [sequence of] float, optional Factor by which to resize the tensor (> 1 == smaller) One of factor or shape must be provided. shape : [sequence of] float, optional Shape of output tensor. One of factor or shape must be provided. ndim : int, optional Number if spatial dimensions. If not provided, try to guess from factor or shape. If guess fails, assume ndim = x.dim(). anchor : {'edge', 'center'} or None What feature should be aligned across the input and output tensors. If 'edge' or 'center', the effective scaling factor may slightly differ from the requested scaling factor. If None, the center of the (0, 0) voxel is aligned, and the requested factor is exactly applied. order : [sequence of] {0..7}, default=2 Interpolation order. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dct2' How to deal with out-of-bound values. Returns ------- x : (..., *shape) tensor restricted tensor """
def pull(inp, grid, order=2, bound='dct2', extrapolate=True, prefilter=False, out=None): ... """Sample a tensor using spline interpolation Parameters ---------- inp : (..., *inshape, channel) tensor Input tensor grid : (..., *outshape, ndim) tensor Tensor of coordinates into `inp` order : [sequence of] {0..7}, default=2 Interpolation order. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dct2' How to deal with out-of-bound values. extrapolate : bool or {'center', 'edge'} - True: use bound to extrapolate out-of-bound value - False or 'center': do not extrapolate values that fall outside of the centers of the first and last voxels. - 'edge': do not extrapolate values that fall outside of the edges of the first and last voxels. prefilter : bool, default=True Whether to first compute interpolating coefficients. Must be true for proper interpolation, otherwise this function merely performs a non-interpolating "spline sampling". Returns ------- out : (..., *outshape, channel) tensor Pulled tensor """
def push(inp, grid, shape=None, order=2, bound='dct2', extrapolate=True, prefilter=False, out=None): ... """Splat a tensor using spline interpolation Parameters ---------- inp : (..., *inshape, channel) tensor Input tensor grid : (..., *inshape, ndim) tensor Tensor of coordinates into `inp` shape : sequence[int], default=inshape Output spatial shape order : [sequence of] {0..7}, default=2 Interpolation order. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dct2' How to deal with out-of-bound values. extrapolate : bool or {'center', 'edge'} - True: use bound to extrapolate out-of-bound value - False or 'center': do not extrapolate values that fall outside of the centers of the first and last voxels. - 'edge': do not extrapolate values that fall outside of the edges of the first and last voxels. prefilter : bool, default=True Whether to compute interpolating coefficients at the end. Returns ------- out : (..., *shape, channel) tensor Pulled tensor """
def count(grid, shape=None, order=2, bound='dct2', extrapolate=True, out=None): ... """Splat ones using spline interpolation Parameters ---------- grid : (..., *inshape, ndim) tensor Tensor of coordinates shape : sequence[int], default=inshape Output spatial shape order : [sequence of] {0..7}, default=2 Interpolation order. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dct2' How to deal with out-of-bound values. extrapolate : bool or {'center', 'edge'} - True: use bound to extrapolate out-of-bound value - False or 'center': do not extrapolate values that fall outside of the centers of the first and last voxels. - 'edge': do not extrapolate values that fall outside of the edges of the first and last voxels. Returns ------- out : (..., *shape) tensor Pulled tensor """
def grad(inp, grid, order=2, bound='dct2', extrapolate=True, prefilter=False, out=None): ... """Sample the spatial gradients of a tensor using spline interpolation Parameters ---------- inp : (..., *inshape, channel) tensor Input tensor grid : (..., *outshape, ndim) tensor Tensor of coordinates into `inp` order : [sequence of] {0..7}, default=2 Interpolation order. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dct2' How to deal with out-of-bound values. extrapolate : bool or {'center', 'edge'} - True: use bound to extrapolate out-of-bound value - False or 'center': do not extrapolate values that fall outside of the centers of the first and last voxels. - 'edge': do not extrapolate values that fall outside of the edges of the first and last voxels. prefilter : bool, default=True Whether to first compute interpolating coefficients. Must be true for proper interpolation, otherwise this function merely performs a non-interpolating "spline sampling". Returns ------- out : (..., *outshape, channel, ndim) tensor Pulled gradients """
Compact symmetric (or postive-definite) matrices
def sym_matvec(mat, vec, dtype=None, out=None): ... """Matrix-vector product for compact symmetric matrices `out = mat @ vec` Parameters ---------- mat : (..., C*(C+1)//2) tensor Symmetric matrix with compact storage. The matrix should be saved as a vector containing the diagonal followed by the rows of the upper triangle. vec : (..., C) tensor Vector dtype : torch.dtype, optional Data type used to carry the computation. By default, same as input. out : (..., C) tensor, optional Output placeholder Returns ------- out : (..., C) tensor Matrix-vector product """
def sym_addmatvec(inp, mat, vec, dtype=None, out=None): ... """Add a matrix-vector product for compact symmetric matrices `out = inp + mat @ vec` Parameters ---------- inp : (..., C) tensor Vector to which the matrix-vector product is added mat : (..., C*(C+1)//2) tensor Symmetric matrix with compact storage. The matrix should be saved as a vector containing the diagonal followed by the rows of the upper triangle. vec : (..., C) tensor Vector used in the matrix-vector product dtype : torch.dtype, optional Data type used to carry the computation. By default, same as input. out : (..., C) tensor, optional Output placeholder Returns ------- out : (..., C) tensor Added matrix-vector product """
def sym_addmatvec_(inp, mat, vec, dtype=None): ... """Inplace add a matrix-vector product for compact symmetric matrices `inp += mat @ vec` Parameters ---------- inp : (..., C) tensor Vector to which the matrix-vector product is added mat : (..., C*(C+1)//2) tensor Symmetric matrix with compact storage. The matrix should be saved as a vector containing the diagonal followed by the rows of the upper triangle. vec : (..., C) tensor Vector used in the matrix-vector product dtype : torch.dtype, optional Data type used to carry the computation. By default, same as input. Returns ------- inp : (..., C) tensor Added matrix-vector product """
def sym_submatvec(inp, mat, vec, dtype=None, out=None): ... """Subtract a matrix-vector product for compact symmetric matrices `out = inp - mat @ vec` Parameters ---------- inp : (..., C) tensor Vector to which the matrix-vector product is added mat : (..., C*(C+1)//2) tensor Symmetric matrix with compact storage. The matrix should be saved as a vector containing the diagonal followed by the rows of the upper triangle. vec : (..., C) tensor Vector used in the matrix-vector product dtype : torch.dtype, optional Data type used to carry the computation. By default, same as input. out : (..., C) tensor, optional Output placeholder Returns ------- out : (..., C) tensor Subtracted matrix-vector product """
def sym_submatvec_(inp, mat, vec, dtype=None): ... """Inplace subtract a matrix-vector product for compact symmetric matrices `inp -= mat @ vec` Parameters ---------- inp : (..., C) tensor Vector to which the matrix-vector product is added mat : (..., C*(C+1)//2) tensor Symmetric matrix with compact storage. The matrix should be saved as a vector containing the diagonal followed by the rows of the upper triangle. vec : (..., C) tensor Vector used in the matrix-vector product dtype : torch.dtype, optional Data type used to carry the computation. By default, same as input. Returns ------- inp : (..., C) tensor Subtracted matrix-vector product """
def sym_solve(mat, vec, dtype=None, out=None): ... """Solve the symmetric linear system `out = mat.inverse() @ vec` !! Does not backpropagate through `mat` !! Parameters ---------- mat : (..., C*(C+1)//2) tensor Symmetric matrix with compact storage. The matrix should be saved as a vector containing the diagonal followed by the rows of the upper triangle. vec : (..., C) tensor Vector dtype : torch.dtype, optional Data type used to carry the computation. By default, same as input. out : (..., C) tensor, optional Output placeholder Returns ------- out : (..., C) tensor Solution of the linear system """
def sym_solve_(mat, vec, dtype=None): ... """Solve the symmetric linear system in-place `vec = mat.inverse() @ vec` !! Does not backpropagate through `mat` !! Parameters ---------- mat : (..., C*(C+1)//2) tensor Symmetric matrix with compact storage. The matrix should be saved as a vector containing the diagonal followed by the rows of the upper triangle. vec : (..., C) tensor Vector dtype : torch.dtype, optional Data type used to carry the computation. By default, same as input. Returns ------- vec : (..., C) tensor Solution of the linear system """
def sym_invert(mat, dtype=None, out=None): ... """Invert a compact symmetric matrix `out = mat.inverse()` !! Does not backpropagate through `mat` !! Parameters ---------- mat : (..., C*(C+1)//2) tensor Symmetric matrix with compact storage. The matrix should be saved as a vector containing the diagonal followed by the rows of the upper triangle. dtype : torch.dtype, optional Data type used to carry the computation. By default, same as input. out : (..., C*(C+1)//2) tensor, optional Output placeholder Returns ------- mat : (..., C*(C+1)//2) tensor Inverse matrix """
def sym_invert_(mat, dtype=None): ... """Invert a compact symmetric matrix in-place `mat = mat.inverse()` !! Does not backpropagate through `mat` !! Parameters ---------- mat : (..., C*(C+1)//2) tensor Symmetric matrix with compact storage. The matrix should be saved as a vector containing the diagonal followed by the rows of the upper triangle. dtype : torch.dtype, optional Data type used to carry the computation. By default, same as input. Returns ------- mat : (..., C*(C+1)//2) tensor Inverse matrix """
Regularisers for dense flow fields
def flow_matvec( vel: Tensor, weight: Optional[Tensor] = None, absolute: float = 0, membrane: float = 0, bending: float = 0, shears: float = 0, div: float = 0, bound: list[str] = 'dft', voxel_size: list[float] = 1, out: Optional[Tensor] = None) -> Tensor: ... """Apply a spatial regularization matrix. Parameters ---------- vel : (*batch, *spatial, ndim) tensor Input displacement field, in voxels. weight : (*batch, *spatial) tensor, optional Weight map, to spatially modulate the regularization. absolute : float Penalty on absolute values. membrane : float Penalty on first derivatives. bending : float Penalty on second derivatives. shears : float Penalty on local shears. div : float Penalty on local volume changes. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dft' Boundary conditions. voxel_size : [sequence of] float Voxel size. out : (*batch, *spatial, ndim) tensor, optional Output placeholder Returns ------- out : (*batch, *spatial, ndim) tensor """ # We also implement variants that adds to or subtracts from an input tensor def flow_matvec_add(inp: Tensor, ...): ... def flow_matvec_add_(inp: Tensor, ...): ... def flow_matvec_sub(inp: Tensor, ...): ... def flow_matvec_sub_(inp: Tensor, ...): ...
def flow_kernel( shape: list[int], absolute: float = 0, membrane: float = 0, bending: float = 0, shears: float = 0, div: float = 0, bound: list[str] = 'dft', voxel_size: list[float] = 1, out: Optional[Tensor] = None) -> Tensor: ... """ Return the kernel of a Toeplitz regularization matrix. Parameters ---------- shape : int or list[int] Number of spatial dimensions or shape of the tensor absolute : float Penalty on absolute values. membrane : float Penalty on first derivatives. bending : float Penalty on second derivatives. shears : float Penalty on local shears. Linear elastic energy's `mu`. div : float Penalty on local volume changes. Linear elastic energy's `lambda`. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dft' Boundary conditions. voxel_size : [sequence of] float Voxel size. out : (*shape, ndim, [ndim]) tensor, optional Output placeholder Returns ------- out : (*shape, ndim, [ndim]) tensor Convolution kernel. A matrix or kernels ([ndim, ndim]) if `shears` or `div`, else a vector of kernels ([ndim]) . """ # We also implement variants that adds to or subtracts from an input tensor def flow_kernel_add(inp: Tensor, ...): ... def flow_kernel_add_(inp: Tensor, ...): ... def flow_kernel_sub(inp: Tensor, ...): ... def flow_kernel_sub_(inp: Tensor, ...): ...
def flow_diag( shape: list[int], weight: Optional[Tensor] = None, absolute: float = 0, membrane: float = 0, bending: float = 0, shears: float = 0, div: float = 0, bound: list[str] = 'dft', voxel_size: list[float] = 1, out: Optional[Tensor] = None) -> Tensor: ... """Return the diagonal of a regularization matrix. Parameters ---------- shape : list[int] Shape of the tensor weight : (*batch, *spatial) tensor, optional Weight map, to spatially modulate the regularization. absolute : float Penalty on absolute values. membrane : float Penalty on first derivatives. bending : float Penalty on second derivatives. shears : float Penalty on local shears. div : float Penalty on local volume changes. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dft' Boundary conditions. voxel_size : [sequence of] float Voxel size. out : (*batch, *spatial, ndim) tensor, optional Output placeholder Returns ------- out : (*batch, *spatial, ndim) tensor """ # We also implement variants that adds to or subtracts from an input tensor def flow_diag_add(inp: Tensor, ...): ... def flow_diag_add_(inp: Tensor, ...): ... def flow_diag_sub(inp: Tensor, ...): ... def flow_diag_sub_(inp: Tensor, ...): ...
def flow_relax_( vel: Tensor, hes: Tensor, grd: Tensor, weight: Optional[Tensor] = None, absolute: float = 0, membrane: float = 0, bending: float = 0, shears: float = 0, div: float = 0, bound: list[str] = 'dft', voxel_size: list[float] = 1, nb_iter: int = 1, ) -> Tensor: ... """Perform relaxation iterations. Parameters ---------- vel : (*batch, *spatial, ndim) tensor Warm start. hes : (*batch, *spatial, ndim*(ndim+1)//2) tensor Input symmetric Hessian, in voxels. grd : (*batch, *spatial, ndim) tensor Input gradient, in voxels. weight : (*batch, *spatial) tensor, optional Weight map, to spatially modulate the regularization. absolute : float Penalty on absolute values. membrane : float Penalty on first derivatives. bending : float Penalty on second derivatives. shears : float Penalty on local shears. div : float Penalty on local volume changes. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dft' Boundary conditions. voxel_size : [sequence of] float Voxel size. nb_iter : int Number of iterations Returns ------- vel : (*batch, *spatial, ndim) tensor """
def flow_precond( mat: Tensor, vec: Tensor, weight : Optional[Tensor] = None, absolute: float = 0, membrane: float = 0, bending: float = 0, shears: float = 0, div: float = 0, bound: list[str] = 'dft', voxel_size: list[float] = 1, out: Optional[Tensor] = None) -> Tensor: ... """ Apply the preconditioning `(M + diag(R)) \ v` Parameters ---------- mat : (*batch, *spatial, DD) tensor DD == 1 | D | D*(D+1)//2 | D*D Preconditioning matrix `M` vec : (*batch, *spatial, D) tensor Point `v` at which to solve the system. weight : (*batch, *spatial) tensor, optional Regularization weight map. absolute : float Penalty on absolute values. membrane : float Penalty on first derivatives. bending : float Penalty on second derivatives. shears : float Penalty on local shears. div : float Penalty on local volume changes. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dft' Boundary conditions. voxel_size : [sequence of] float Voxel size. out : (*batch, *spatial, D) tensor Output placeholder. Returns ------- out : (*batch, *spatial, D) tensor Preconditioned vector. """
def flow_forward( mat: Tensor, vec: Tensor, weight : Optional[Tensor] = None, absolute: float = 0, membrane: float = 0, bending: float = 0, shears: float = 0, div: float = 0, bound: list[str] = 'dft', voxel_size: list[float] = 1, out: Optional[Tensor] = None) -> Tensor: ... """ Apply the forward matrix-vector product `(M + R) @ v` Parameters ---------- mat : (*batch, *spatial, DD) tensor DD == 1 | D | D*(D+1)//2 | D*D vec : (*batch, *spatial, D) tensor Point `v` at which to solve the system. weight : (*batch, *spatial) tensor, optional Regularization weight map. absolute : float Penalty on absolute values. membrane : float Penalty on first derivatives. bending : float Penalty on second derivatives. shears : float Penalty on local shears. div : float Penalty on local volume changes. bound : [sequence of] {'zero', 'replicate', 'dct1', 'dct2', 'dst1', 'dst2', 'dft'}, default='dft' Boundary conditions. voxel_size : [sequence of] float Voxel size. out : (*batch, *spatial, D) tensor Output placeholder. Returns ------- out : (*batch, *spatial, D) tensor Preconditioned vector. """