Source code for deeptab.distributions.beta

"""Beta and Dirichlet distributions for bounded / compositional LSS models."""

import torch
import torch.distributions as dist

from .base import BaseDistribution


[docs] class BetaDistribution(BaseDistribution): """ Represents a Beta distribution, a continuous distribution defined on the interval [0, 1], commonly used in Bayesian statistics for modeling probabilities. This class extends BaseDistribution and includes parameter transformation and loss computation specific to the Beta distribution. Parameters ---------- name (str): The name of the distribution, defaulted to "Beta". shape_transform (str or callable): Transformation for the alpha (shape) parameter to ensure it remains positive. scale_transform (str or callable): Transformation for the beta (scale) parameter to ensure it remains positive. """ def __init__( self, name="Beta", shape_transform="positive", scale_transform="positive", ): param_names = [ "alpha", "beta", ] super().__init__(name, param_names) self.alpha_transform = self.get_transform(shape_transform) self.beta_transform = self.get_transform(scale_transform)
[docs] def compute_loss(self, predictions, y_true): alpha = self.alpha_transform(predictions[:, self.param_names.index("alpha")]) beta = self.beta_transform(predictions[:, self.param_names.index("beta")]) beta_dist = dist.Beta(alpha, beta) nll = -beta_dist.log_prob(y_true).mean() return nll
[docs] class DirichletDistribution(BaseDistribution): """ Represents a Dirichlet distribution, a multivariate generalization of the Beta distribution. It is commonly used in Bayesian statistics for modeling multinomial distribution probabilities. This class extends BaseDistribution and includes parameter transformation and loss computation specific to the Dirichlet distribution. Parameters ---------- name (str): The name of the distribution, defaulted to "Dirichlet". concentration_transform (str or callable): Transformation to apply to concentration parameters to ensure they remain positive. """ def __init__(self, name="Dirichlet", concentration_transform="positive"): param_names = ["concentration"] super().__init__(name, param_names) self.concentration_transform = self.get_transform(concentration_transform)
[docs] def compute_loss(self, predictions, y_true): concentration = self.concentration_transform(predictions) dirichlet_dist = dist.Dirichlet(concentration) nll = -dirichlet_dist.log_prob(y_true).mean() return nll