Source code for deeptab.distributions.categorical
"""Categorical, Quantile, and Multinomial distributions for multi-class / distribution-free LSS models."""
import torch
import torch.distributions as dist
from .base import BaseDistribution
[docs]
class CategoricalDistribution(BaseDistribution):
"""
Represents a Categorical distribution, a discrete distribution that describes the possible results of a
random variable that can take on one of K possible categories, with the probability of each category
separately specified. This class extends BaseDistribution and includes parameter transformation and loss
computation specific to the Categorical distribution.
Parameters
----------
name (str): The name of the distribution, defaulted to "Categorical".
prob_transform (str or callable): Transformation for the probabilities to ensure
they remain valid (i.e., non-negative and sum to 1).
"""
def __init__(self, name="Categorical", prob_transform="probabilities"):
param_names = ["probs"]
super().__init__(name, param_names)
self.probs_transform = self.get_transform(prob_transform)
[docs]
def compute_loss(self, predictions, y_true):
probs = self.probs_transform(predictions)
cat_dist = dist.Categorical(probs=probs)
nll = -cat_dist.log_prob(y_true).mean()
return nll
[docs]
class Quantile(BaseDistribution):
"""
Quantile Regression Loss class.
This class computes the quantile loss (also known as pinball loss) for a set of quantiles.
It is used to handle quantile regression tasks where we aim to predict a given quantile of the target distribution.
Parameters
----------
name : str, optional
The name of the distribution, by default "Quantile".
quantiles : list of float, optional
A list of quantiles to be used for computing the loss, by default [0.25, 0.5, 0.75].
Attributes
----------
quantiles : list of float
List of quantiles for which the pinball loss is computed.
Methods
-------
compute_loss(predictions, y_true)
Computes the quantile regression loss between the predictions and true values.
"""
def __init__(self, name="Quantile", quantiles=[0.25, 0.5, 0.75]):
param_names = [f"q_{q}" for q in quantiles]
super().__init__(name, param_names)
self.quantiles = quantiles
[docs]
def compute_loss(self, predictions, y_true):
if y_true.requires_grad:
raise ValueError("y_true should not require gradients")
if predictions.size(0) != y_true.size(0):
raise ValueError("Batch size of predictions and y_true must match")
losses = []
for i, q in enumerate(self.quantiles):
errors = y_true - predictions[:, i]
quantile_loss = torch.max((q - 1) * errors, q * errors)
losses.append(quantile_loss)
loss = torch.mean(torch.stack(losses, dim=1).sum(dim=1))
return loss
[docs]
class MultinomialDistribution(BaseDistribution):
"""
Represents a Multinomial distribution for modelling count vectors that sum to a
known total (e.g. word counts per document, allele frequencies, multi-label counts
where total responses per sample is fixed).
The neural network outputs ``num_classes`` logits which are converted to probabilities
via softmax. ``total_count`` is a fixed constructor argument, not a predicted
parameter.
Parameters
----------
name (str): Defaults to ``"Multinomial"``.
num_classes (int): Number of categories K. Sets ``param_count = K``.
Defaults to ``2``.
total_count (int): Total number of trials n (e.g. 1 makes this equivalent
to Categorical). Defaults to ``1``.
prob_transform (str or callable): Transform for the class logits.
Defaults to ``"probabilities"`` (softmax).
"""
def __init__(
self,
name="Multinomial",
num_classes=2,
total_count=1,
prob_transform="probabilities",
):
param_names = [f"p_{k}" for k in range(num_classes)]
super().__init__(name, param_names)
self.total_count = total_count
self.probs_transform = self.get_transform(prob_transform)
[docs]
def compute_loss(self, predictions, y_true):
probs = self.probs_transform(predictions)
multinomial_dist = dist.Multinomial(total_count=self.total_count, probs=probs)
nll = -multinomial_dist.log_prob(y_true).mean()
return nll