Source code for deeptab.models._mixins.predict
"""Inference, encoding, and scoring logic for all DeepTab estimators."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import torch
from sklearn.utils.validation import check_is_fitted
from torch.utils.data import DataLoader
from tqdm import tqdm
from deeptab.core.sklearn_compat import validate_input_features
if TYPE_CHECKING:
from deeptab.core.interfaces import IDataModule, ITaskModel
class _PredictMixin:
"""Inference, encoding, and internal scoring.
Responsibilities
----------------
* ``predict`` — abstract; overridden by each concrete estimator to
return predictions in the expected sklearn shape.
* ``_validate_predict_input`` — checks the model is fitted and that
the input columns match those seen during ``fit``.
* ``encode`` — returns dense embedding vectors from the model backbone
for a given input DataFrame.
* ``_score`` — internal helper used by ``optimize_hparams`` to evaluate
validation loss with the best checkpoint loaded.
"""
if TYPE_CHECKING:
# Attributes provided by SklearnBase when this mixin is composed.
# Declared here for static type-checkers only; never initialised in this class.
config: Any
_best_model_path: str | None
_task_model: ITaskModel | None
_data_module: IDataModule | None
def predict(self, X, embeddings=None, device=None):
"""Return predictions for input *X*.
Parameters
----------
X : array-like or DataFrame of shape (n_samples, n_features)
Input features.
embeddings : array-like or None, optional
Pre-computed external embeddings aligned with the rows of *X*.
device : str or torch.device or None, optional
Device override for inference (e.g. ``"cpu"`` to force CPU).
When ``None`` the model's current device is used.
Returns
-------
numpy.ndarray
1-D array of shape ``(n_samples,)`` for classification and
regression tasks.
Raises
------
NotImplementedError
Always — this method must be overridden by each concrete subclass.
"""
raise NotImplementedError("The 'predict' method is not implemented in the Parent class.")
def _validate_predict_input(self, X):
"""Check the model is fitted and validate the input feature columns.
Parameters
----------
X : array-like or DataFrame
Raw input to be passed to ``predict``.
Returns
-------
pandas.DataFrame
The validated and coerced input, with columns verified against
those seen during ``fit``.
Raises
------
sklearn.exceptions.NotFittedError
If ``fit`` has not been called yet.
deeptab.core.exceptions.ColumnCountError
If the number of columns differs from ``n_features_in_``.
"""
check_is_fitted(self) # raises sklearn's NotFittedError before any other check
return validate_input_features(self, X)
def _score(self, X, y, embeddings, metric):
"""Evaluate *metric* on *X* / *y* using the best-checkpoint weights.
Reloads the best model checkpoint before running ``predict`` so that
the score reflects the best validation state rather than the last
epoch's weights.
Parameters
----------
X : array-like or DataFrame
Input features.
y : array-like
True target values.
embeddings : array-like or None
Pre-computed external embeddings aligned with *X*.
metric : Callable[[array-like, array-like], float]
A scoring callable that accepts ``(y_true, y_pred)`` and
returns a scalar (lower = better for losses, higher = better
for accuracy-style metrics).
Returns
-------
float
The metric value computed on the predictions.
"""
# Explicitly load the best model state if needed
if hasattr(self, "_trainer") and self._best_model_path:
torch.serialization.add_safe_globals([type(self.config)])
checkpoint = torch.load(self._best_model_path, weights_only=False)
self._task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore
predictions = self.predict(X, embeddings)
return metric(y, predictions)
def encode(self, X, embeddings=None, batch_size=64):
"""Return dense embedding vectors from the model backbone.
Runs the fitted model's ``encode`` method on batches of *X* and
concatenates the results into a single tensor.
Parameters
----------
X : array-like or DataFrame of shape (n_samples, n_features)
Input features to encode.
embeddings : array-like or None, optional
Pre-computed external embeddings aligned with the rows of *X*.
batch_size : int, default=64
Number of samples processed in each forward pass.
Returns
-------
torch.Tensor of shape (n_samples, embedding_dim)
Encoded representations of the input data.
Raises
------
ValueError
If the model has not been fitted yet.
Examples
--------
>>> clf = MLPClassifier()
>>> clf.fit(X_train, y_train)
>>> embeddings = clf.encode(X_test) # (n_samples, embedding_dim)
>>> embeddings.shape
torch.Size([100, 64])
"""
if self._task_model is None or self._data_module is None:
raise ValueError("The model or data module has not been fitted yet.")
encoded_dataset = self._data_module.preprocess_new_data(X, embeddings)
data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False)
encoded_outputs = []
for batch in tqdm(data_loader):
emb = self._task_model.estimator.encode(batch) # type: ignore[union-attr]
encoded_outputs.append(emb)
return torch.cat(encoded_outputs, dim=0)