Skip to content

abhijeetgangan/torch_matfunc

Repository files navigation

torch_matfunc

A collection of PyTorch matrix functions.

CI codecov License: MIT Python 3.10+

Implemented functions

  • expm_frechet: Matrix exponential and its Fréchet derivative.
  • matrix_log_33: Analytical matrix logarithm for a 3x3 matrix.

Example Usage

Matrix exponential and its Fréchet derivative
import torch
from torch_matfunc.matrix.expm_frechet import expm_frechet
from scipy.linalg import expm_frechet as scipy_expm_frechet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float64

A = torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device)
E = torch.tensor([[3, 4], [7, 8]], dtype=dtype, device=device)

A_numpy = A.cpu().numpy()
E_numpy = E.cpu().numpy()

# Compute the matrix exponential and its Fréchet derivative
expm, expm_frechet = expm_frechet(A, E, method="SPS", compute_expm=True)
expm_scipy, expm_frechet_scipy = scipy_expm_frechet(A_numpy, E_numpy, method="SPS", compute_expm=True)

# Compare with scipy
assert torch.allclose(expm.cpu(), torch.tensor(expm_scipy))
assert torch.allclose(expm_frechet.cpu(), torch.tensor(expm_frechet_scipy))

Matrix exponential and its Fréchet derivative (autograd)

import torch
from torch_matfunc.matrix.expm_frechet import expm
from scipy.linalg import expm_frechet as scipy_expm_frechet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float64

A = torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device, requires_grad=True)
E = torch.tensor([[3, 4], [7, 8]], dtype=dtype, device=device)

A_numpy = A.cpu().detach().numpy()
E_numpy = E.cpu().numpy()

# Compute the matrix exponential
expm = expm.apply(A)

# Compute the gradient of the matrix exponential
expm_frechet = torch.autograd.grad(expm, A, E)[0]
expm_scipy, expm_frechet_scipy = scipy_expm_frechet(A_numpy, E_numpy, method="SPS", compute_expm=True)

# Compare with scipy
assert torch.allclose(expm.cpu(), torch.tensor(expm_scipy))
assert torch.allclose(expm_frechet.cpu(), torch.tensor(expm_frechet_scipy))

About

A collection of PyTorch matrix functions.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages