"""Metric registry: maps (task, family) keys to default metric lists."""
from __future__ import annotations
from .base import DeepTabMetric
from .classification import AUROC, Accuracy, LogLoss
from .distributional import (
CRPS,
BetaBrierScore,
DirichletError,
GammaDeviance,
InverseGammaDeviance,
LogNormalNLL,
NegativeBinomialDeviance,
PoissonDeviance,
StudentTLoss,
TweedieDeviance,
)
from .regression import MeanAbsoluteError, PinballLoss, R2Score, RootMeanSquaredError
# ---------------------------------------------------------------------------
# Registry definition
# ---------------------------------------------------------------------------
# Keys follow the pattern "<task>" or "<task>:<family>".
# The first entry in each list is treated as the *primary* metric.
# All metrics here receive already-transformed distribution parameters
# (raw=False predictions). NegativeLogLikelihood is intentionally excluded
# from this registry because it requires raw logits; use model.score() for NLL.
METRIC_REGISTRY: dict[str, list[DeepTabMetric]] = {
# ---- Point-estimate tasks ----
"regression": [RootMeanSquaredError(), MeanAbsoluteError(), R2Score()],
"classification": [Accuracy(), AUROC(), LogLoss()],
# ---- LSS families ----
"lss:normal": [CRPS(family="normal"), RootMeanSquaredError(), MeanAbsoluteError()],
"lss:lognormal": [LogNormalNLL(), CRPS(family="lognormal"), RootMeanSquaredError()],
"lss:studentt": [StudentTLoss(), CRPS(family="studentt")],
"lss:gamma": [GammaDeviance(), RootMeanSquaredError()],
"lss:inversegamma": [InverseGammaDeviance(), GammaDeviance()],
"lss:tweedie": [TweedieDeviance(), RootMeanSquaredError()],
"lss:beta": [BetaBrierScore(), RootMeanSquaredError()],
"lss:poisson": [PoissonDeviance(), RootMeanSquaredError()],
"lss:zip": [PoissonDeviance(), RootMeanSquaredError()],
"lss:negativebinom": [NegativeBinomialDeviance(), RootMeanSquaredError()],
"lss:categorical": [Accuracy(), LogLoss()],
"lss:dirichlet": [DirichletError()],
"lss:multinomial": [LogLoss()],
"lss:johnsonsu": [CRPS(family="johnsonsu"), RootMeanSquaredError()],
"lss:mog": [CRPS(family="normal"), RootMeanSquaredError()],
"lss:quantile": [PinballLoss(quantile=0.5)],
}
[docs]
def get_default_metrics(task: str, family: str | None = None) -> list[DeepTabMetric]:
"""Return the default list of metrics for a given task and distribution family.
Parameters
----------
task : str
One of ``"regression"``, ``"classification"``, or ``"lss"``.
family : str, optional
Distribution family key used for LSS tasks, e.g. ``"normal"``,
``"gamma"``, ``"poisson"``. Ignored for non-LSS tasks.
Returns
-------
list[DeepTabMetric]
Ordered list of metric instances. The first entry is the primary
metric. Returns an empty list when the combination is unknown.
"""
if family is not None:
key = f"{task}:{family}"
if key in METRIC_REGISTRY:
return METRIC_REGISTRY[key]
return METRIC_REGISTRY.get(task, [])
[docs]
def get_default_metrics_dict(task: str, family: str | None = None) -> dict[str, DeepTabMetric]:
"""Like :func:`get_default_metrics` but returns a ``{name: metric}`` dict.
Convenience wrapper for code paths that store metrics as dicts.
"""
return {m.name: m for m in get_default_metrics(task, family)}