Source code for deeptab.distributions.poisson

"""Poisson and Zero-Inflated Poisson distributions for count data LSS models."""

import numpy as np
import torch
import torch.distributions as dist

from .base import BaseDistribution


[docs] class PoissonDistribution(BaseDistribution): """ Represents a Poisson distribution, typically used for modeling count data or the number of events occurring within a fixed interval of time or space. This class extends the BaseDistribution and includes parameter transformation and loss computation specific to the Poisson distribution. Parameters ---------- name (str): The name of the distribution, defaulted to "Poisson". rate_transform (str or callable): Transformation to apply to the rate parameter to ensure it remains positive. """ def __init__(self, name="Poisson", rate_transform="positive"): param_names = ["rate"] super().__init__(name, param_names) self.rate_transform = self.get_transform(rate_transform)
[docs] def compute_loss(self, predictions, y_true): rate = self.rate_transform(predictions[:, self.param_names.index("rate")]) poisson_dist = dist.Poisson(rate) nll = -poisson_dist.log_prob(y_true).mean() return nll
[docs] def evaluate_nll(self, y_true, y_pred): metrics = super().evaluate_nll(y_true, y_pred) y_true_tensor = torch.tensor(y_true, dtype=torch.float32) y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) rate = self.rate_transform(y_pred_tensor[:, self.param_names.index("rate")]) mse_loss = torch.nn.functional.mse_loss(y_true_tensor, rate) # type: ignore rmse = np.sqrt(mse_loss.detach().numpy()) mae = ( torch.nn.functional.l1_loss(y_true_tensor, rate) # type: ignore .detach() .numpy() # type: ignore ) # type: ignore poisson_deviance = 2 * torch.sum(y_true_tensor * torch.log(y_true_tensor / rate) - (y_true_tensor - rate)) # type: ignore[operator] metrics["mse"] = mse_loss.detach().numpy() metrics["mae"] = mae metrics["rmse"] = rmse metrics["poisson_deviance"] = poisson_deviance.detach().numpy() return metrics
[docs] class ZeroInflatedPoissonDistribution(BaseDistribution): """ Represents a Zero-Inflated Poisson (ZIP) distribution for count data with excess zeros (e.g. number of insurance claims, rare-event counts). The model outputs two parameters: * **pi** — zero-inflation probability π ∈ (0, 1). Extra zeros arise with probability pi; with probability (1 - pi) the count follows Poisson(rate). * **rate** — Poisson rate λ > 0. The mixture probability mass function is: .. math:: P(Y = 0) &= \\pi + (1 - \\pi)\\,e^{-\\lambda} \\\\ P(Y = k>0) &= (1 - \\pi)\\,\\text{Poisson}(k;\\,\\lambda) Parameters ---------- name (str): Defaults to ``"ZeroInflatedPoisson"``. pi_transform (str or callable): Transform for the inflation probability. Defaults to ``"sigmoid"`` to map logits → (0, 1). rate_transform (str or callable): Transform for the Poisson rate. Defaults to ``"positive"`` (softplus). """ def __init__( self, name="ZeroInflatedPoisson", pi_transform="sigmoid", rate_transform="positive", ): param_names = ["pi", "rate"] super().__init__(name, param_names) self.pi_transform = self.get_transform(pi_transform) self.rate_transform = self.get_transform(rate_transform)
[docs] def compute_loss(self, predictions, y_true): pi = self.pi_transform(predictions[:, self.param_names.index("pi")]) rate = self.rate_transform(predictions[:, self.param_names.index("rate")]) # log P(Y=0) = log(pi + (1-pi)*exp(-rate)) log_zero = torch.log(pi + (1.0 - pi) * torch.exp(-rate) + 1e-8) # log P(Y=k>0) = log(1-pi) + Poisson log-prob log_nonzero = torch.log(1.0 - pi + 1e-8) + dist.Poisson(rate).log_prob(y_true) log_prob = torch.where(y_true == 0, log_zero, log_nonzero) nll = -log_prob.mean() return nll
[docs] def evaluate_nll(self, y_true, y_pred): metrics = super().evaluate_nll(y_true, y_pred) y_true_tensor = torch.tensor(y_true, dtype=torch.float32) y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) pi = self.pi_transform(y_pred_tensor[:, self.param_names.index("pi")]) rate = self.rate_transform(y_pred_tensor[:, self.param_names.index("rate")]) # E[Y] = (1 - pi) * rate mean_pred = (1.0 - pi) * rate mse_loss = torch.nn.functional.mse_loss(y_true_tensor, mean_pred) rmse = np.sqrt(mse_loss.detach().numpy()) mae = torch.nn.functional.l1_loss(y_true_tensor, mean_pred).detach().numpy() metrics["mse"] = mse_loss.detach().numpy() metrics["mae"] = mae metrics["rmse"] = rmse return metrics