from __future__ import annotations
import warnings
from collections.abc import Callable
import numpy as np
import torch
from sklearn.metrics import accuracy_score, log_loss
from deeptab.core.exceptions import NotFittedError, not_fitted_error
from deeptab.metrics import get_default_metrics_dict
from deeptab.models.base import SklearnBase
from deeptab.training.losses import build_classification_loss, compute_class_weights
def _resolve_loss_and_sampler(loss_fct, class_weight, balanced_sampler, sample_weight, y, classes, num_classes):
"""Translate the imbalance-handling arguments into a ``(loss_fct, sampler)`` pair.
* ``loss_fct`` — an ``nn.Module``, a registered loss name (e.g. ``"focal"``),
or ``None``. Combined with ``class_weight`` via
:func:`deeptab.training.losses.build_classification_loss`.
* ``sampler`` — ``sample_weight`` (explicit per-row weights) takes precedence,
otherwise ``"balanced"`` when ``balanced_sampler`` is set, otherwise ``None``.
"""
class_weights = None
if class_weight is not None:
class_weights = compute_class_weights(class_weight, y, classes=classes)
resolved_loss = build_classification_loss(loss_fct, num_classes=num_classes, class_weights=class_weights)
if sample_weight is not None:
sampler = sample_weight
elif balanced_sampler:
sampler = "balanced"
else:
sampler = None
return resolved_loss, sampler
class SklearnBaseClassifier(SklearnBase):
def __init__(
self,
model_config=None,
preprocessing_config=None,
trainer_config=None,
observability_config=None,
random_state=None,
):
super().__init__(
model_config=model_config,
preprocessing_config=preprocessing_config,
trainer_config=trainer_config,
observability_config=observability_config,
random_state=random_state,
)
def build_model(
self,
X,
y,
val_size: float = 0.2,
X_val=None,
y_val=None,
embeddings=None,
embeddings_val=None,
random_state: int = 101,
batch_size: int = 128,
shuffle: bool = True,
stratify: bool = True,
lr: float | None = None,
lr_patience: int | None = None,
lr_factor: float | None = None,
weight_decay: float | None = None,
train_metrics: dict[str, Callable] | None = None,
val_metrics: dict[str, Callable] | None = None,
dataloader_kwargs={},
class_weight: str | dict | list | np.ndarray | None = None,
loss_fct=None,
balanced_sampler: bool = False,
sample_weight=None,
):
"""Builds the model using the provided training data.
Parameters
----------
X : DataFrame or array-like, shape (n_samples, n_features)
The training input samples.
y : array-like, shape (n_samples,) or (n_samples, n_targets)
The target values (real numbers).
val_size : float, default=0.2
The proportion of the dataset to include in the validation split if `X_val` is None.
Ignored if `X_val` is provided.
X_val : DataFrame or array-like, shape (n_samples, n_features), optional
The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
The validation target values. Required if `X_val` is provided.
random_state : int, default=101
Controls the shuffling applied to the data before applying the split.
batch_size : int, default=128
Number of samples per gradient update.
shuffle : bool, default=True
Whether to shuffle the training data before each epoch.
stratify : bool, default=True
Whether to stratify the validation split on `y` so the split keeps
the same class proportions. Set to False for a purely random split.
lr : float, default=1e-3
Learning rate for the optimizer.
lr_patience : int, default=10
Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
lr_factor : float, default=0.1
Factor by which the learning rate will be reduced.
train_metrics : dict, default=None
torch.metrics dict to be logged during training.
val_metrics : dict, default=None
torch.metrics dict to be logged during validation.
weight_decay : float, default=0.025
Weight decay (L2 penalty) coefficient.
dataloader_kwargs: dict, default={}
The kwargs for the pytorch dataloader class.
class_weight : {"balanced"}, dict, array-like, or None, default=None
Weights associated with classes for imbalanced data. ``"balanced"``
mirrors scikit-learn and uses ``n_samples / (n_classes * bincount(y))``.
A mapping ``{class_label: weight}`` or an array (ordered like
``np.unique(y)``) sets weights explicitly. Ignored when ``loss_fct``
is an ``nn.Module``.
loss_fct : nn.Module, str, or None, default=None
Custom loss. An ``nn.Module`` is used as-is; a registered loss name
(e.g. ``"focal"``, ``"bce"``, ``"cross_entropy"``) is built and
combined with ``class_weight``. ``None`` falls back to the default
(weighted) task loss.
balanced_sampler : bool, default=False
If ``True``, draw class-balanced mini-batches with a
``WeightedRandomSampler`` (oversamples minority classes).
sample_weight : array-like, optional
Explicit per-row sampling weights (length matches ``X``). Takes
precedence over ``balanced_sampler`` and drives the
``WeightedRandomSampler``.
Returns
-------
self : object
The built classifier.
"""
self.classes_ = np.unique(y)
num_classes = len(self.classes_)
loss_fct, sampler = _resolve_loss_and_sampler(
loss_fct, class_weight, balanced_sampler, sample_weight, y, self.classes_, num_classes
)
return super()._build_model(
X,
y,
regression=False,
val_size=val_size,
X_val=X_val,
y_val=y_val,
embeddings=embeddings,
embeddings_val=embeddings_val,
num_classes=num_classes,
random_state=random_state,
batch_size=batch_size,
shuffle=shuffle,
stratify=stratify,
lr=lr,
lr_patience=lr_patience,
lr_factor=lr_factor,
weight_decay=weight_decay,
train_metrics=train_metrics,
val_metrics=val_metrics,
dataloader_kwargs=dataloader_kwargs,
loss_fct=loss_fct,
sampler=sampler,
)
def fit(
self,
X,
y,
val_size: float = 0.2,
X_val=None,
y_val=None,
embeddings=None,
embeddings_val=None,
max_epochs: int = 100,
random_state: int = 101,
batch_size: int = 128,
shuffle: bool = True,
stratify: bool = True,
patience: int = 15,
monitor: str = "val_loss",
mode: str = "min",
lr: float | None = None,
lr_patience: int | None = None,
lr_factor: float | None = None,
weight_decay: float | None = None,
checkpoint_path="model_checkpoints",
train_metrics: dict[str, Callable] | None = None,
val_metrics: dict[str, Callable] | None = None,
dataloader_kwargs={},
rebuild=True,
class_weight: str | dict | list | np.ndarray | None = None,
loss_fct=None,
balanced_sampler: bool = False,
sample_weight=None,
**trainer_kwargs,
):
"""Trains the classification model using the provided training data. Optionally, a separate validation set can
be used.
Parameters
----------
X : DataFrame or array-like, shape (n_samples, n_features)
The training input samples.
y : array-like, shape (n_samples,) or (n_samples, n_targets)
The target values (real numbers).
val_size : float, default=0.2
The proportion of the dataset to include in the validation split if `X_val` is None.
Ignored if `X_val` is provided.
X_val : DataFrame or array-like, shape (n_samples, n_features), optional
The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
The validation target values. Required if `X_val` is provided.
max_epochs : int, default=100
Maximum number of epochs for training.
random_state : int, default=101
Controls the shuffling applied to the data before applying the split.
batch_size : int, default=64
Number of samples per gradient update.
shuffle : bool, default=True
Whether to shuffle the training data before each epoch.
stratify : bool, default=True
Whether to stratify the validation split on `y` so the split keeps
the same class proportions. Set to False for a purely random split.
When a `TrainerConfig` is set, its `stratify` value takes precedence.
patience : int, default=10
Number of epochs with no improvement on the validation loss to wait before early stopping.
monitor : str, default="val_loss"
The metric to monitor for early stopping.
mode : str, default="min"
Whether the monitored metric should be minimized (`min`) or maximized (`max`).
lr : float, default=1e-3
Learning rate for the optimizer.
lr_patience : int, default=10
Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
factor : float, default=0.1
Factor by which the learning rate will be reduced.
weight_decay : float, default=0.025
Weight decay (L2 penalty) coefficient.
checkpoint_path : str, default="model_checkpoints"
Path where the checkpoints are being saved.
train_metrics : dict, default=None
torch.metrics dict to be logged during training.
val_metrics : dict, default=None
torch.metrics dict to be logged during validation.
dataloader_kwargs: dict, default={}
The kwargs for the pytorch dataloader class.
rebuild: bool, default=True
Whether to rebuild the model when it already was built.
class_weight : {"balanced"}, dict, array-like, or None, default=None
Weights associated with classes for imbalanced data. ``"balanced"``
mirrors scikit-learn and uses ``n_samples / (n_classes * bincount(y))``
so under-represented classes contribute more to the loss. A mapping
``{class_label: weight}`` or an array (ordered like ``np.unique(y)``)
sets weights explicitly. For binary targets the weights are converted
to a ``pos_weight`` for ``BCEWithLogitsLoss``; for multiclass they
become the ``weight`` of ``CrossEntropyLoss``. Ignored when
``loss_fct`` is an ``nn.Module``.
loss_fct : nn.Module, str, or None, default=None
Custom loss. An ``nn.Module`` is used as-is; a registered loss name
(e.g. ``"focal"``, ``"bce"``, ``"cross_entropy"``) is built and
combined with ``class_weight`` (see
:func:`deeptab.training.losses.build_classification_loss`). ``None``
falls back to the default (weighted) task loss.
balanced_sampler : bool, default=False
If ``True``, draw class-balanced mini-batches with a
``WeightedRandomSampler`` (oversamples minority classes). This
rebalances the data instead of (or in addition to) reweighting the
loss.
sample_weight : array-like, optional
Explicit per-row sampling weights (length matches ``X``). Takes
precedence over ``balanced_sampler``; rows are drawn into batches in
proportion to their weight.
**trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
Returns
-------
self : object
The fitted classifier.
"""
self.classes_ = np.unique(y)
num_classes = len(self.classes_)
loss_fct, sampler = _resolve_loss_and_sampler(
loss_fct, class_weight, balanced_sampler, sample_weight, y, self.classes_, num_classes
)
return super().fit(
X=X,
y=y,
regression=False,
val_size=val_size,
X_val=X_val,
y_val=y_val,
embeddings=embeddings,
embeddings_val=embeddings_val,
max_epochs=max_epochs,
random_state=random_state,
batch_size=batch_size,
shuffle=shuffle,
stratify=stratify,
patience=patience,
monitor=monitor,
mode=mode,
lr=lr,
lr_patience=lr_patience,
lr_factor=lr_factor,
weight_decay=weight_decay,
checkpoint_path=checkpoint_path,
dataloader_kwargs=dataloader_kwargs,
train_metrics=train_metrics,
val_metrics=val_metrics,
rebuild=rebuild,
num_classes=num_classes,
loss_fct=loss_fct,
sampler=sampler,
**trainer_kwargs,
)
def predict(self, X, embeddings=None, device=None):
"""Predicts target labels for the given input samples.
Parameters
----------
X : DataFrame or array-like, shape (n_samples, n_features)
The input samples for which to predict target values.
Returns
-------
predictions : ndarray, shape (n_samples,)
The predicted class labels.
"""
X = self._validate_predict_input(X)
if self._task_model is None:
raise not_fitted_error(type(self).__name__, "predict")
self._emit_event("predict_started", n_samples=len(X))
# Preprocess the data using the data module
if self._data_module is None:
raise not_fitted_error(type(self).__name__, "predict")
self._data_module.assign_predict_dataset(X, embeddings)
# Set model to evaluation mode
self._task_model.eval()
# Perform inference using PyTorch Lightning's predict function
if self._trainer is None:
raise not_fitted_error(type(self).__name__, "predict")
logits_list = self._trainer.predict(self._task_model, self._data_module) # type: ignore[arg-type]
# Concatenate predictions from all batches
logits = torch.cat(logits_list, dim=0) # type: ignore
# Check if ensemble is used
if getattr(self._estimator, "returns_ensemble", False): # If using ensemble
logits = logits.mean(dim=1) # Average over ensemble dimension
if logits.dim() == 1: # Ensure correct shape
logits = logits.unsqueeze(1)
# Check the shape of the logits to determine binary or multi-class classification
if logits.shape[1] == 1:
# Binary classification
probabilities = torch.sigmoid(logits)
predictions = (probabilities > 0.5).long().view(-1)
else:
# Multi-class classification
probabilities = torch.softmax(logits, dim=1)
predictions = torch.argmax(probabilities, dim=1)
# Convert predictions to NumPy array and return
predicted_indices = predictions.cpu().numpy()
classes = getattr(self, "classes_", None)
if classes is not None and len(classes) > 0:
result = classes[predicted_indices]
else:
result = predicted_indices
self._emit_event("predict_completed")
return result
def predict_proba(self, X, embeddings=None, device=None):
"""Predicts class probabilities for the given input samples.
Parameters
----------
X : DataFrame or array-like, shape (n_samples, n_features)
The input samples for which to predict class probabilities.
Returns
-------
probabilities : ndarray, shape (n_samples, n_classes)
The predicted class probabilities.
"""
X = self._validate_predict_input(X)
if self._task_model is None:
raise not_fitted_error(type(self).__name__, "predict_proba")
# Preprocess the data using the data module
if self._data_module is None:
raise not_fitted_error(type(self).__name__, "predict_proba")
self._data_module.assign_predict_dataset(X, embeddings)
# Set model to evaluation mode
self._task_model.eval()
# Perform inference using PyTorch Lightning's predict function
if self._trainer is None:
raise not_fitted_error(type(self).__name__, "predict_proba")
logits_list = self._trainer.predict(self._task_model, self._data_module) # type: ignore[arg-type]
# Concatenate predictions from all batches
logits = torch.cat(logits_list, dim=0) # type: ignore[arg-type]
# Check if ensemble is used
if getattr(self._estimator, "returns_ensemble", False): # If using ensemble
logits = logits.mean(dim=1) # Average over ensemble dimension
if logits.dim() == 1: # Ensure correct shape
logits = logits.unsqueeze(1)
# Compute probabilities
if logits.shape[1] > 1:
probabilities = torch.softmax(logits, dim=1) # Multi-class classification
else:
positive = torch.sigmoid(logits).view(-1, 1)
probabilities = torch.cat([1.0 - positive, positive], dim=1)
# Convert probabilities to NumPy array and return
return probabilities.cpu().numpy()
def evaluate(self, X, y_true, embeddings=None, metrics=None):
"""Evaluate the model on the given data using specified metrics.
Parameters
----------
X : array-like or pd.DataFrame of shape (n_samples, n_features)
The input samples to predict.
y_true : array-like of shape (n_samples,)
The true class labels.
embeddings : array-like or list, optional
Embeddings for unstructured data inputs.
metrics : dict, optional
A ``{name: callable}`` dictionary where each callable has the
signature ``metric(y_true, y_pred) -> float``. Each callable may
be a :class:`~deeptab.metrics.DeepTabMetric` instance or any plain
callable. Metrics that need probability scores (e.g. AUROC, LogLoss)
should accept the 2-D ``predict_proba`` output as ``y_pred``;
metrics that need class labels (e.g. Accuracy, F1) should accept
the 1-D ``predict`` output.
For :class:`~deeptab.metrics.DeepTabMetric` instances, the method
inspects the ``name`` attribute to decide which prediction format
to supply: probability-based metrics (``auroc``, ``auprc``,
``log_loss``, ``brier``, ``ece``) receive ``predict_proba`` output;
all others receive ``predict`` output.
If ``None``, defaults to the registry defaults for
``"classification"`` (Accuracy, AUROC, LogLoss).
Returns
-------
scores : dict
``{metric_name: score}`` dictionary.
"""
if metrics is None:
metrics = get_default_metrics_dict("classification")
# Metric names that work on probability scores
_PROBA_NAMES = {"auroc", "auprc", "log_loss", "brier", "ece"}
# Determine which prediction types are actually needed
needs_proba = any((getattr(fn, "name", None) in _PROBA_NAMES) for fn in metrics.values())
needs_labels = any((getattr(fn, "name", None) not in _PROBA_NAMES) for fn in metrics.values())
probabilities = self.predict_proba(X, embeddings) if needs_proba else None
predictions = self.predict(X, embeddings) if needs_labels else None
scores = {}
for metric_name, metric_func in metrics.items():
use_proba = getattr(metric_func, "name", None) in _PROBA_NAMES
preds = probabilities if use_proba else predictions
if preds is None:
scores[metric_name] = float("nan")
continue
try:
scores[metric_name] = metric_func(y_true, preds)
except Exception as exc:
warnings.warn(f"Metric '{metric_name}' failed: {exc}", RuntimeWarning, stacklevel=2)
scores[metric_name] = float("nan")
return scores
def score(self, X, y, embeddings=None, metric=None):
"""Calculate the score of the model using the specified metric.
Parameters
----------
X : array-like or pd.DataFrame of shape (n_samples, n_features)
The input samples to predict.
y : array-like of shape (n_samples,)
The true class labels against which to evaluate the predictions.
metric : tuple or callable, optional
A tuple containing the metric function and a boolean indicating whether
the metric requires probability scores (True) or class labels (False).
If omitted, accuracy is used to match scikit-learn classifier behavior.
Returns
-------
score : float
The score calculated using the specified metric.
"""
if metric is None:
return accuracy_score(y, self.predict(X, embeddings))
if isinstance(metric, tuple):
metric_func, use_proba = metric
else:
metric_func, use_proba = metric, False
if use_proba:
probabilities = self.predict_proba(X, embeddings)
return metric_func(y, probabilities)
else:
predictions = self.predict(X, embeddings)
return metric_func(y, predictions)
def pretrain(
self,
pretrain_epochs=15,
k_neighbors=10,
temperature=0.1,
save_path="pretrained_embeddings.pth",
lr=1e-3,
use_positive=True,
use_negative=False,
pool_sequence=True,
):
"""
Pretrains the embedding layer of the model using a contrastive learning approach.
This method performs pretraining by optimizing the embeddings with respect to
neighborhood structure in the feature space. The embeddings are saved after training.
Parameters
----------
pretrain_epochs : int, default=15
Number of epochs to run pretraining.
k_neighbors : int, default=10
Number of neighbors used in the contrastive loss computation.
temperature : float, default=0.1
Temperature parameter for contrastive loss scaling.
save_path : str, default="pretrained_embeddings.pth"
Path to save the pretrained embeddings.
lr : float, default=1e-3
Learning rate for the pretraining optimizer.
use_positive : bool, default=True
Whether to include positive pairs in contrastive learning.
use_negative : bool, default=False
Whether to include negative pairs in contrastive learning.
pool_sequence : bool, default=True
Whether to apply sequence pooling before computing contrastive loss.
Raises
------
ValueError
If the model has not been built before calling this method.
ValueError
If the model does not contain an embedding layer.
Notes
-----
- This function requires that `self.build_model()` has been called beforehand.
- The pretraining method uses `self.task_model.estimator.embedding_layer`.
- The method invokes `super()._pretrain()` with regression mode enabled.
"""
if not self._built:
raise ValueError("The model has not been built yet. Call model.build_model(**args) first.")
if not hasattr(self._task_model.estimator, "embedding_layer"): # type: ignore[union-attr]
raise ValueError("The model does not have an embedding layer")
if self._data_module is None:
raise not_fitted_error(type(self).__name__, "_pretrain")
self._data_module.setup("fit")
super()._pretrain(
self._task_model.estimator, # type: ignore[union-attr]
self._data_module,
pretrain_epochs=pretrain_epochs,
k_neighbors=k_neighbors,
temperature=temperature,
save_path=save_path,
regression=False,
lr=lr,
use_positive=use_positive,
use_negative=use_negative,
pool_sequence=pool_sequence,
)
def optimize_hparams(
self,
X,
y,
X_val=None,
y_val=None,
embeddings=None,
embeddings_val=None,
time=100,
max_epochs=200,
prune_by_epoch=True,
prune_epoch=5,
fixed_params={
"pooling_method": "avg",
"head_skip_layers": False,
"head_layer_size_length": 0,
"cat_encoding": "int",
"head_skip_layer": False,
"use_cls": False,
},
custom_search_space=None,
**optimize_kwargs,
):
"""Optimizes hyperparameters using Bayesian optimization with optional pruning.
Parameters
----------
X : array-like
Training data.
y : array-like
Training labels.
X_val, y_val : array-like, optional
Validation data and labels.
time : int
The number of optimization trials to run.
max_epochs : int
Maximum number of epochs for training.
prune_by_epoch : bool
Whether to prune based on a specific epoch (True) or the best validation loss (False).
prune_epoch : int
The specific epoch to prune by when prune_by_epoch is True.
**optimize_kwargs : dict
Additional keyword arguments passed to the fit method.
Returns
-------
best_hparams : list
Best hyperparameters found during optimization.
"""
return super().optimize_hparams(
X,
y,
regression=False,
X_val=X_val,
y_val=y_val,
embeddings=embeddings,
embeddings_val=embeddings_val,
time=time,
max_epochs=max_epochs,
prune_by_epoch=prune_by_epoch,
prune_epoch=prune_epoch,
fixed_params=fixed_params,
custom_search_space=custom_search_space,
**optimize_kwargs,
)