Source code for deeptab.distributions.registry
"""Distribution registry: maps family name strings to distribution classes."""
from __future__ import annotations
from deeptab.core.exceptions import InvalidParamError, invalid_param_error
from .base import BaseDistribution
from .beta import BetaDistribution, DirichletDistribution
from .categorical import CategoricalDistribution, MultinomialDistribution, Quantile
from .gamma import GammaDistribution, InverseGammaDistribution
from .mixture import MixtureOfGaussiansDistribution
from .negative_binomial import NegativeBinomialDistribution
from .normal import LogNormalDistribution, NormalDistribution
from .poisson import PoissonDistribution, ZeroInflatedPoissonDistribution
from .student_t import JohnsonSuDistribution, StudentTDistribution
from .tweedie import TweedieDistribution
DISTRIBUTION_REGISTRY: dict[str, type[BaseDistribution]] = {
"normal": NormalDistribution,
"lognormal": LogNormalDistribution,
"poisson": PoissonDistribution,
"zip": ZeroInflatedPoissonDistribution,
"gamma": GammaDistribution,
"inversegamma": InverseGammaDistribution,
"beta": BetaDistribution,
"dirichlet": DirichletDistribution,
"studentt": StudentTDistribution,
"johnsonsu": JohnsonSuDistribution,
"negativebinom": NegativeBinomialDistribution,
"categorical": CategoricalDistribution,
"multinomial": MultinomialDistribution,
"quantile": Quantile,
"tweedie": TweedieDistribution,
"mog": MixtureOfGaussiansDistribution,
}
[docs]
def get_distribution(family: str, **kwargs: object) -> BaseDistribution:
"""Instantiate a distribution by its registry name.
Parameters
----------
family : str
The distribution family key (e.g. ``"normal"``, ``"gamma"``).
**kwargs
Extra keyword arguments forwarded to the distribution constructor
(e.g. ``quantiles=[0.1, 0.5, 0.9]`` for ``"quantile"``).
Returns
-------
BaseDistribution
A ready-to-use distribution instance.
Raises
------
InvalidParamError
If *family* is not a registered key.
"""
if family not in DISTRIBUTION_REGISTRY:
available = sorted(DISTRIBUTION_REGISTRY)
raise invalid_param_error(
"MambularLSS / LSS model",
"family",
family,
"must be a registered distribution family name",
available,
)
return DISTRIBUTION_REGISTRY[family](**kwargs) # type: ignore[call-arg]