Source code for deeptab.distributions.normal

"""Normal (Gaussian) and Log-Normal distributions for LSS models."""

from collections.abc import Callable

import numpy as np
import torch
import torch.distributions as dist

from .base import BaseDistribution


[docs] class NormalDistribution(BaseDistribution): """ Represents a Normal (Gaussian) distribution with parameters for mean and variance, including functionality for transforming these parameters and computing the loss. Inherits from BaseDistribution. Parameters ---------- name (str): The name of the distribution. Defaults to "Normal". mean_transform (str or callable): The transformation for the mean parameter. Defaults to "none". var_transform (str or callable): The transformation for the variance parameter. Defaults to "positive". """ def __init__(self, name="Normal", mean_transform="none", var_transform="positive"): param_names = [ "mean", "variance", ] super().__init__(name, param_names) self.mean_transform = self.get_transform(mean_transform) self.variance_transform = self.get_transform(var_transform)
[docs] def compute_loss(self, predictions, y_true): mean = self.mean_transform(predictions[:, self.param_names.index("mean")]) variance = self.variance_transform(predictions[:, self.param_names.index("variance")]) normal_dist = dist.Normal(mean, variance) nll = -normal_dist.log_prob(y_true).mean() return nll
[docs] def evaluate_nll(self, y_true, y_pred): metrics = super().evaluate_nll(y_true, y_pred) y_true_tensor = torch.tensor(y_true, dtype=torch.float32) y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) mse_loss = torch.nn.functional.mse_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]) rmse = np.sqrt(mse_loss.detach().numpy()) mae = ( torch.nn.functional.l1_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]) .detach() .numpy() ) metrics["mse"] = mse_loss.detach().numpy() metrics["mae"] = mae metrics["rmse"] = rmse return metrics
[docs] class LogNormalDistribution(BaseDistribution): """ Represents a Log-Normal distribution for right-skewed positive continuous targets such as wages, prices, latencies, and insurance claim amounts. The neural network predicts the mean (``loc``) and standard deviation (``scale``) of the underlying normal distribution in log-space. The median of the outcome is ``exp(loc)`` and the mean is ``exp(loc + scale²/2)``. Parameters ---------- name (str): The name of the distribution. Defaults to ``"LogNormal"``. loc_transform (str or callable): Transform for the log-space mean. Defaults to ``"none"`` (identity — mean in log-space can be any real number). scale_transform (str or callable): Transform for the log-space standard deviation. Defaults to ``"positive"`` (softplus, ensures sigma > 0). """ def __init__(self, name="LogNormal", loc_transform="none", scale_transform="positive"): param_names = ["loc", "scale"] super().__init__(name, param_names) self.loc_transform = self.get_transform(loc_transform) self.scale_transform = self.get_transform(scale_transform)
[docs] def compute_loss(self, predictions, y_true): loc = self.loc_transform(predictions[:, self.param_names.index("loc")]) scale = self.scale_transform(predictions[:, self.param_names.index("scale")]) lognormal_dist = dist.LogNormal(loc, scale) nll = -lognormal_dist.log_prob(y_true).mean() return nll
[docs] def evaluate_nll(self, y_true, y_pred): metrics = super().evaluate_nll(y_true, y_pred) y_true_tensor = torch.tensor(y_true, dtype=torch.float32) y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) # Median prediction = exp(loc) — a natural point estimate for log-normal loc = self.loc_transform(y_pred_tensor[:, self.param_names.index("loc")]) median_pred = torch.exp(loc) mse_loss = torch.nn.functional.mse_loss(y_true_tensor, median_pred) rmse = np.sqrt(mse_loss.detach().numpy()) mae = torch.nn.functional.l1_loss(y_true_tensor, median_pred).detach().numpy() metrics["mse"] = mse_loss.detach().numpy() metrics["mae"] = mae metrics["rmse"] = rmse return metrics