Source code for deeptab.distributions.student_t

"""Student-t and Johnson SU distributions for heavy-tailed / skewed LSS models."""

import numpy as np
import torch
import torch.distributions as dist

from .base import BaseDistribution


[docs] class StudentTDistribution(BaseDistribution): """ Represents a Student's t-distribution, a family of continuous probability distributions that arise when estimating the mean of a normally distributed population in situations where the sample size is small. This class extends BaseDistribution and includes parameter transformation and loss computation specific to the Student's t-distribution. Parameters ---------- name (str): The name of the distribution, defaulted to "StudentT". df_transform (str or callable): Transformation for the degrees of freedom parameter to ensure it remains positive. loc_transform (str or callable): Transformation for the location parameter. scale_transform (str or callable): Transformation for the scale parameter to ensure it remains positive. """ def __init__( self, name="StudentT", df_transform="positive", loc_transform="none", scale_transform="positive", ): param_names = ["df", "loc", "scale"] super().__init__(name, param_names) self.df_transform = self.get_transform(df_transform) self.loc_transform = self.get_transform(loc_transform) self.scale_transform = self.get_transform(scale_transform)
[docs] def compute_loss(self, predictions, y_true): df = self.df_transform(predictions[:, self.param_names.index("df")]) loc = self.loc_transform(predictions[:, self.param_names.index("loc")]) scale = self.scale_transform(predictions[:, self.param_names.index("scale")]) student_t_dist = dist.StudentT(df, loc, scale) # type: ignore nll = -student_t_dist.log_prob(y_true).mean() return nll
[docs] def evaluate_nll(self, y_true, y_pred): metrics = super().evaluate_nll(y_true, y_pred) y_true_tensor = torch.tensor(y_true, dtype=torch.float32) y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) mse_loss = torch.nn.functional.mse_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]) rmse = np.sqrt(mse_loss.detach().numpy()) mae = ( torch.nn.functional.l1_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]).detach().numpy() ) metrics["mse"] = mse_loss.detach().numpy() metrics["mae"] = mae metrics["rmse"] = rmse return metrics
[docs] class JohnsonSuDistribution(BaseDistribution): """ Represents a Johnson's SU distribution with parameters for skewness, shape, location, and scale. Parameters ---------- name (str): The name of the distribution. Defaults to "JohnsonSu". skew_transform (str or callable): The transformation for the skewness parameter. Defaults to "none". shape_transform (str or callable): The transformation for the shape parameter. Defaults to "positive". loc_transform (str or callable): The transformation for the location parameter. Defaults to "none". scale_transform (str or callable): The transformation for the scale parameter. Defaults to "positive". """ def __init__( self, name="JohnsonSu", skew_transform="none", shape_transform="positive", loc_transform="none", scale_transform="positive", ): param_names = ["skew", "shape", "location", "scale"] super().__init__(name, param_names) self.skew_transform = self.get_transform(skew_transform) self.shape_transform = self.get_transform(shape_transform) self.loc_transform = self.get_transform(loc_transform) self.scale_transform = self.get_transform(scale_transform)
[docs] def log_prob(self, x, skew, shape, loc, scale): """Compute the log probability density of the Johnson's SU distribution.""" z = skew + shape * torch.asinh((x - loc) / scale) log_pdf = ( torch.log(shape / (scale * np.sqrt(2 * np.pi))) - 0.5 * z**2 - 0.5 * torch.log(1 + ((x - loc) / scale) ** 2) ) return log_pdf
[docs] def compute_loss(self, predictions, y_true): skew = self.skew_transform(predictions[:, self.param_names.index("skew")]) shape = self.shape_transform(predictions[:, self.param_names.index("shape")]) loc = self.loc_transform(predictions[:, self.param_names.index("location")]) scale = self.scale_transform(predictions[:, self.param_names.index("scale")]) log_probs = self.log_prob(y_true, skew, shape, loc, scale) nll = -log_probs.mean() return nll
[docs] def evaluate_nll(self, y_true, y_pred): metrics = super().evaluate_nll(y_true, y_pred) y_true_tensor = torch.tensor(y_true, dtype=torch.float32) y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) mse_loss = torch.nn.functional.mse_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("location")]) rmse = np.sqrt(mse_loss.detach().numpy()) mae = ( torch.nn.functional.l1_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("location")]) .detach() .numpy() ) metrics.update({"mse": mse_loss.detach().numpy(), "mae": mae, "rmse": rmse}) return metrics