Source code for deeptab.metrics.base
"""Base class for DeepTab evaluation metrics."""
from __future__ import annotations
from abc import ABC, abstractmethod
import numpy as np
[docs]
class DeepTabMetric(ABC):
"""Abstract base class for all DeepTab evaluation metrics.
Every metric in ``deeptab.metrics`` subclasses this ABC and exposes three
class-level attributes that the training loop and registry read
automatically — you never need to set them yourself when *using* a metric,
only when *writing* a custom one.
Attributes
----------
name : str
A short, machine-readable identifier for the metric. It is used as:
* the key in the dict returned by ``model.evaluate()``
* the suffix in training-log entries (e.g. ``val_rmse``)
* the registry lookup key in :data:`~deeptab.metrics.METRIC_REGISTRY`
Examples: ``"rmse"``, ``"crps"``, ``"auroc"``.
higher_is_better : bool
Tells the framework whether a *larger* or *smaller* value is
preferable. This matters in two places:
* **HPO** — hyperparameter search uses it to set the optimisation
direction (maximise vs. minimise) when a metric is chosen as the
objective.
* **Early stopping / model selection** — callbacks can use it to
decide whether a new checkpoint is an improvement.
``False`` (default) means *lower is better* — appropriate for loss
functions and error metrics (MSE, MAE, NLL, deviances).
``True`` means *higher is better* — appropriate for scores like R²,
accuracy, AUROC, and CRPS variants where a higher value is desirable.
needs_raw : bool
Controls *which* form of ``y_pred`` the training loop passes to this
metric.
* ``False`` (default) — the metric receives **already-transformed**
distribution parameters, i.e. the output of
``model.predict(X, raw=False)``. For example, a Normal distribution
model returns ``[mean, std]`` where ``std > 0`` is guaranteed. This
is the right choice for almost every metric.
* ``True`` — the metric receives **raw model logits** before the
distribution's parameter transforms are applied.
:class:`~deeptab.metrics.NegativeLogLikelihood` sets this to
``True`` because it calls ``distribution.compute_loss()`` which
applies the transforms itself; passing already-transformed values
would double-transform and produce wrong results.
Examples
--------
Using a built-in metric directly:
>>> from deeptab.metrics import RootMeanSquaredError
>>> import numpy as np
>>> metric = RootMeanSquaredError()
>>> metric.name
'rmse'
>>> metric.higher_is_better
False
>>> metric(np.array([1.0, 2.0, 3.0]), np.array([1.1, 2.0, 2.9]))
0.08164965809277261
Passing metrics to ``model.fit()`` for live training logging:
>>> from deeptab.metrics import CRPS, MeanAbsoluteError
>>> model.fit(X_train, y_train,
... val_metrics={"crps": CRPS(family="normal"),
... "mae": MeanAbsoluteError()})
# Logs val_crps and val_mae each epoch.
Writing a custom metric:
>>> from deeptab.metrics import DeepTabMetric
>>> import numpy as np
>>> class MedianAbsoluteError(DeepTabMetric):
... name = "mdae"
... higher_is_better = False # lower error = better
... needs_raw = False # use transformed predictions
...
... def __call__(self, y_true, y_pred):
... y_pred = np.asarray(y_pred)
... mean_pred = y_pred[:, 0] if y_pred.ndim == 2 else y_pred.ravel()
... return float(np.median(np.abs(np.asarray(y_true).ravel() - mean_pred)))
"""
name: str
higher_is_better: bool = False
needs_raw: bool = False
@abstractmethod
def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""Compute the metric value.
Parameters
----------
y_true : np.ndarray, shape (n,) or (n, d)
Ground-truth target values.
y_pred : np.ndarray, shape (n,) or (n, p)
Model predictions.
* When ``needs_raw=False`` (default): already-transformed
distribution parameters from ``model.predict(X, raw=False)``.
For a Normal distribution this is ``[[mean_0, std_0], ...]``.
* When ``needs_raw=True``: raw logits from the model's final
linear layer, before any parameter transform (e.g. softplus)
is applied.
Returns
-------
float
Scalar metric value.
"""
...
def __repr__(self) -> str:
return f"{type(self).__name__}()"