Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typehinting #180

Merged
merged 12 commits into from
Jun 15, 2024
6 changes: 4 additions & 2 deletions examples/bayesopt_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings

warnings.filterwarnings("ignore")
@@ -22,7 +24,7 @@
from torch import nn, optim
from torch.nn import functional as F

from laplace import Laplace
from laplace import BaseLaplace, Laplace


class LaplaceBNN(Model):
@@ -35,7 +37,7 @@ def __init__(
self,
train_X: torch.Tensor,
train_Y: torch.Tensor,
bnn: Laplace = None,
bnn: BaseLaplace | None = None,
likelihood: str = "regression",
batch_size: int = 1024,
):
22 changes: 18 additions & 4 deletions laplace/__init__.py
Original file line number Diff line number Diff line change
@@ -6,9 +6,6 @@
.. include:: ../examples/reward_modeling_example.md
"""

REGRESSION = "regression"
CLASSIFICATION = "classification"

from laplace.baselaplace import (
BaseLaplace,
DiagLaplace,
@@ -21,6 +18,15 @@
from laplace.lllaplace import DiagLLLaplace, FullLLLaplace, KronLLLaplace, LLLaplace
from laplace.marglik_training import marglik_training
from laplace.subnetlaplace import DiagSubnetLaplace, FullSubnetLaplace, SubnetLaplace
from laplace.utils.enums import (
HessianStructure,
Likelihood,
LinkApprox,
PredType,
PriorStructure,
SubsetOfWeights,
TuningMethod,
)

__all__ = [
"Laplace", # direct access to all Laplace classes via unified interface
@@ -38,4 +44,12 @@
"FullSubnetLaplace",
"DiagSubnetLaplace", # subnetwork
"marglik_training",
] # methods
# Enums
"SubsetOfWeights",
"HessianStructure",
"Likelihood",
"PredType",
"LinkApprox",
"TuningMethod",
"PriorStructure",
]
Loading
Loading