Source code for deeptab.distributions.negative_binomial

"""Negative Binomial distribution for overdispersed count LSS models."""

import torch
import torch.distributions as dist

from .base import BaseDistribution


[docs] class NegativeBinomialDistribution(BaseDistribution): """ Represents a Negative Binomial distribution, often used for count data and modeling the number of failures before a specified number of successes occurs in a series of Bernoulli trials. This class extends BaseDistribution and includes parameter transformation and loss computation specific to the Negative Binomial distribution. Parameters ---------- name (str): The name of the distribution, defaulted to "NegativeBinomial". mean_transform (str or callable): Transformation for the mean parameter to ensure it remains positive. dispersion_transform (str or callable): Transformation for the dispersion parameter to ensure it remains positive. """ def __init__( self, name="NegativeBinomial", mean_transform="positive", dispersion_transform="positive", ): param_names = ["mean", "dispersion"] super().__init__(name, param_names) self.mean_transform = self.get_transform(mean_transform) self.dispersion_transform = self.get_transform(dispersion_transform)
[docs] def compute_loss(self, predictions, y_true): mean = self.mean_transform(predictions[:, self.param_names.index("mean")]) dispersion = self.dispersion_transform(predictions[:, self.param_names.index("dispersion")]) # variance = mean + mean^2 / dispersion r = torch.tensor(1.0) / dispersion # type: ignore[operator] p = r / (r + mean) negative_binomial_dist = dist.NegativeBinomial(total_count=r, probs=p) nll = -negative_binomial_dist.log_prob(y_true).mean() return nll