Source code for deeptab.configs.core

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field

import torch.nn as nn
from sklearn.base import BaseEstimator

from deeptab.core.exceptions import (
    ConfigWarning,
    IncompatibleParamsError,
    InvalidParamError,
    incompatible_params_error,
    invalid_param_error,
    warn_config,
)

# Valid choices for PreprocessingConfig fields (mirrors pretab.Preprocessor)
_VALID_NUMERICAL_PREPROCESSING: frozenset[str | None] = frozenset(
    {
        "ple",
        "quantile",
        "splines",
        "standardization",
        "minmax",
        "robust",
        "box-cox",
        "yeo-johnson",
        None,
    }
)
_VALID_SCALING_STRATEGY: frozenset[str | None] = frozenset({"minmax", "standardization", "robust", None})
_VALID_BINNING_STRATEGY: frozenset[str | None] = frozenset({"uniform", "quantile", "kmeans", None})
_VALID_CAT_ENCODING: frozenset[str] = frozenset({"int", "one-hot", "linear"})
_VALID_MONITOR_MODE: frozenset[str] = frozenset({"min", "max"})

__all__ = [
    "BaseModelConfig",
    "PreprocessingConfig",
    "TrainerConfig",
]


[docs] @dataclass class BaseModelConfig(BaseEstimator): """Shared architecture hyperparameters for all DeepTab models. This class contains only architectural / structural configuration. Training-related parameters (``lr``, ``weight_decay``, ``max_epochs``, …) belong in :class:`~deeptab.configs.trainer_config.TrainerConfig`. Preprocessing parameters belong in :class:`~deeptab.configs.preprocessing_config.PreprocessingConfig`. Parameters ---------- use_embeddings : bool, default=False Whether to use embedding layers for numerical/categorical features. embedding_activation : Callable, default=nn.Identity() Activation function applied to embeddings. embedding_type : str, default="linear" Type of embedding (``"linear"``, ``"plr"``, etc.). embedding_bias : bool, default=False Whether to add a bias term to embedding layers. layer_norm_after_embedding : bool, default=False Whether to apply layer normalisation after the embedding layer. d_model : int, default=32 Embedding / model dimensionality. plr_lite : bool, default=False Whether to use the lightweight PLR embedding variant. n_frequencies : int, default=48 Number of frequency components for PLR embeddings. frequencies_init_scale : float, default=0.01 Initial scale for PLR frequency components. embedding_projection : bool, default=True Whether to apply a linear projection after embeddings. batch_norm : bool, default=False Whether to use batch normalisation in the model body. layer_norm : bool, default=False Whether to use layer normalisation in the model body. layer_norm_eps : float, default=1e-5 Epsilon for layer normalisation numerical stability. activation : Callable, default=nn.ReLU() Activation function used throughout the model body. cat_encoding : str, default="int" How categorical features are encoded at the model input (``"int"``, ``"one-hot"``, ``"linear"``). """ # Embedding parameters use_embeddings: bool = False embedding_activation: Callable = nn.Identity() # noqa: RUF009 embedding_type: str = "linear" embedding_bias: bool = False layer_norm_after_embedding: bool = False d_model: int = 32 plr_lite: bool = False n_frequencies: int = 48 frequencies_init_scale: float = 0.01 embedding_projection: bool = True # Architecture parameters batch_norm: bool = False layer_norm: bool = False layer_norm_eps: float = 1e-05 activation: Callable = nn.ReLU() # noqa: RUF009 cat_encoding: str = "int" def __post_init__(self) -> None: # type: ignore[override] if self.d_model < 1: raise invalid_param_error(type(self).__name__, "d_model", self.d_model, "must be >= 1") if self.cat_encoding not in _VALID_CAT_ENCODING: raise invalid_param_error( type(self).__name__, "cat_encoding", self.cat_encoding, "must be one of the known encoding strategies", sorted(_VALID_CAT_ENCODING), ) # --- Common optional fields present on many model configs --- cls_name = type(self).__name__ n_layers = getattr(self, "n_layers", None) if n_layers is not None and n_layers < 1: raise invalid_param_error(cls_name, "n_layers", n_layers, "must be >= 1") n_heads = getattr(self, "n_heads", None) if n_heads is not None: if n_heads < 1: raise invalid_param_error(cls_name, "n_heads", n_heads, "must be >= 1") if self.d_model % n_heads != 0: raise incompatible_params_error( cls_name, f"d_model ({self.d_model}) must be divisible by n_heads ({n_heads}).", ) for dropout_field in ("dropout", "attn_dropout", "ff_dropout", "head_dropout", "rnn_dropout"): val = getattr(self, dropout_field, None) if val is not None and not (0.0 <= val < 1.0): raise invalid_param_error( cls_name, dropout_field, val, "must be in [0, 1)", ) # --- Embedding / frequency fields on BaseModelConfig itself --- if self.n_frequencies < 1: raise invalid_param_error(cls_name, "n_frequencies", self.n_frequencies, "must be >= 1") if self.frequencies_init_scale <= 0: raise invalid_param_error(cls_name, "frequencies_init_scale", self.frequencies_init_scale, "must be > 0") if self.layer_norm_eps <= 0: raise invalid_param_error(cls_name, "layer_norm_eps", self.layer_norm_eps, "must be > 0") # --- Cross-field: conflicting normalisation --- if self.batch_norm and self.layer_norm: warn_config( f"{cls_name}: both batch_norm=True and layer_norm=True are set. " "Using both simultaneously is unusual and may produce unexpected results. " "Consider enabling only one.", stacklevel=3, ) # --- Mamba / RNN / Transformer optional integer fields --- for int_field in ("expand_factor", "d_conv", "d_state", "dim_feedforward", "transformer_dim_feedforward"): val = getattr(self, int_field, None) if val is not None and val < 1: raise invalid_param_error(cls_name, int_field, val, "must be >= 1")
[docs] @dataclass class PreprocessingConfig(BaseEstimator): """Configuration for input feature preprocessing. All fields map directly to arguments accepted by ``pretab.preprocessor.Preprocessor``. Using ``None`` for any field leaves the preprocessor default in effect. Parameters ---------- numerical_preprocessing : str or None, default=None Strategy for transforming numerical features (e.g. ``"ple"``, ``"quantile"``, ``"standard"``). ``None`` uses the preprocessor's built-in default. categorical_preprocessing : str or None, default=None Strategy for transforming categorical features (e.g. ``"int"``, ``"one-hot"``). ``None`` uses the preprocessor's built-in default. n_bins : int or None, default=None Number of bins for numerical binning. ``None`` uses the preprocessor default. feature_preprocessing : str or None, default=None General feature-level preprocessing override. use_decision_tree_bins : bool or None, default=None Whether to use decision-tree-derived bin edges. binning_strategy : str or None, default=None Strategy for choosing bin edges (e.g. ``"uniform"``, ``"quantile"``). task : str or None, default=None Task type passed to the preprocessor for task-aware transformations (e.g. ``"regression"``, ``"classification"``). cat_cutoff : float or None, default=None Threshold for treating integer columns as categorical. treat_all_integers_as_numerical : bool or None, default=None When ``True``, integer columns are never converted to categorical. degree : int or None, default=None Polynomial / spline degree for numerical feature expansion. scaling_strategy : str or None, default=None Scaling method applied to numerical features (e.g. ``"standard"``, ``"minmax"``, ``"robust"``). n_knots : int or None, default=None Number of knots for spline preprocessing. use_decision_tree_knots : bool or None, default=None Whether to use decision-tree-derived knot positions. knots_strategy : str or None, default=None Strategy for knot placement. spline_implementation : str or None, default=None Backend used for spline transformations. """ numerical_preprocessing: str | None = None categorical_preprocessing: str | None = None n_bins: int | None = None feature_preprocessing: str | None = None use_decision_tree_bins: bool | None = None binning_strategy: str | None = None task: str | None = None cat_cutoff: float | None = None treat_all_integers_as_numerical: bool | None = None degree: int | None = None scaling_strategy: str | None = None n_knots: int | None = None use_decision_tree_knots: bool | None = None knots_strategy: str | None = None spline_implementation: str | None = None def __post_init__(self) -> None: # type: ignore[override] if self.numerical_preprocessing not in _VALID_NUMERICAL_PREPROCESSING: raise invalid_param_error( "PreprocessingConfig", "numerical_preprocessing", self.numerical_preprocessing, "must be one of the known preprocessing methods", sorted(x for x in _VALID_NUMERICAL_PREPROCESSING if x is not None), ) if self.n_bins is not None and self.n_bins < 2: raise invalid_param_error("PreprocessingConfig", "n_bins", self.n_bins, "must be >= 2") if self.n_knots is not None and self.n_knots < 2: raise invalid_param_error("PreprocessingConfig", "n_knots", self.n_knots, "must be >= 2") if self.scaling_strategy not in _VALID_SCALING_STRATEGY: raise invalid_param_error( "PreprocessingConfig", "scaling_strategy", self.scaling_strategy, "must be one of the known scaling strategies", sorted(x for x in _VALID_SCALING_STRATEGY if x is not None), ) if self.binning_strategy not in _VALID_BINNING_STRATEGY: raise invalid_param_error( "PreprocessingConfig", "binning_strategy", self.binning_strategy, "must be one of the known binning strategies", sorted(x for x in _VALID_BINNING_STRATEGY if x is not None), ) if self.cat_cutoff is not None and not (0.0 < self.cat_cutoff < 1.0): raise invalid_param_error( "PreprocessingConfig", "cat_cutoff", self.cat_cutoff, "must be in the open interval (0, 1)", ) if self.degree is not None and self.degree < 1: raise invalid_param_error("PreprocessingConfig", "degree", self.degree, "must be >= 1")
[docs] def to_preprocessor_kwargs(self) -> dict: """Return a dict of non-None fields suitable for passing to ``Preprocessor(**...)``. Returns ------- dict Mapping of field name → value for every field that is not ``None``. """ return {k: v for k, v in self.get_params(deep=False).items() if v is not None}
[docs] @dataclass class TrainerConfig(BaseEstimator): """Configuration for training loop, optimizer, and runtime execution. These settings are entirely separate from model architecture. They control *how* a model is trained and executed, not *what* the model is. Parameters ---------- max_epochs : int, default=100 Maximum number of training epochs. batch_size : int, default=128 Number of samples per gradient update. val_size : float, default=0.2 Fraction of the training data held out for validation when no explicit validation set is provided. shuffle : bool, default=True Whether to shuffle training data before each epoch. stratify : bool, default=True Whether to stratify the validation split on ``y`` for classification tasks, so the train and validation sets keep the same class proportions. Has no effect on regression, where a continuous target cannot be stratified. Set to ``False`` to draw a purely random split. patience : int, default=15 Number of epochs with no improvement on ``monitor`` before early stopping is triggered. monitor : str, default="val_loss" Metric name to monitor for early stopping and checkpoint selection. mode : str, default="min" Whether the monitored metric should be minimised (``"min"``) or maximised (``"max"``). lr : float, default=1e-4 Learning rate for the optimizer. lr_patience : int, default=10 Number of epochs with no improvement before the learning rate is reduced by ``lr_factor``. lr_factor : float, default=0.1 Multiplicative factor applied to the learning rate when patience is exceeded. weight_decay : float, default=1e-6 L2 regularisation coefficient (weight decay) for the optimizer. optimizer_type : str, default="Adam" Optimizer class name. Must be a valid ``torch.optim`` class name or a name registered in the project's optimizer registry. optimizer_kwargs : dict or None, default=None Extra keyword arguments forwarded to the optimizer constructor. scheduler_type : str or None, default="ReduceLROnPlateau" LR-scheduler class name (case-insensitive), or ``None`` / ``"none"`` to disable the scheduler entirely. scheduler_kwargs : dict or None, default=None Extra keyword arguments forwarded to the scheduler constructor. ``factor`` and ``patience`` are synthesised from ``lr_factor`` and ``lr_patience`` for ``ReduceLROnPlateau`` when absent here. scheduler_monitor : str or None, default=None Metric name for the scheduler to monitor. Falls back to the value of ``monitor`` when ``None``. scheduler_interval : str, default="epoch" Lightning scheduling granularity: ``"epoch"`` or ``"step"``. scheduler_frequency : int, default=1 How often the scheduler steps at the given interval. no_weight_decay_for_bias_and_norm : bool, default=False When ``True``, bias vectors and normalisation-layer scale/shift parameters receive zero weight decay. Recommended for transformer- style models with ``LayerNorm``. checkpoint_path : str, default="model_checkpoints" Directory where PyTorch Lightning model checkpoints are saved. """ max_epochs: int = 100 batch_size: int = 128 val_size: float = 0.2 shuffle: bool = True stratify: bool = True patience: int = 15 monitor: str = "val_loss" mode: str = "min" lr: float = 1e-4 lr_patience: int = 10 lr_factor: float = 0.1 weight_decay: float = 1e-6 optimizer_type: str = "Adam" optimizer_kwargs: dict | None = None scheduler_type: str | None = "ReduceLROnPlateau" scheduler_kwargs: dict | None = None scheduler_monitor: str | None = None scheduler_interval: str = "epoch" scheduler_frequency: int = 1 no_weight_decay_for_bias_and_norm: bool = False checkpoint_path: str = "model_checkpoints" def __post_init__(self) -> None: # type: ignore[override] if self.max_epochs < 1: raise invalid_param_error("TrainerConfig", "max_epochs", self.max_epochs, "must be >= 1") if self.batch_size < 1: raise invalid_param_error("TrainerConfig", "batch_size", self.batch_size, "must be >= 1") if self.lr <= 0: raise invalid_param_error("TrainerConfig", "lr", self.lr, "must be > 0") if self.weight_decay < 0: raise invalid_param_error("TrainerConfig", "weight_decay", self.weight_decay, "must be >= 0") if not (0.0 < self.val_size < 1.0): raise invalid_param_error( "TrainerConfig", "val_size", self.val_size, "must be in the open interval (0, 1)", ) if self.mode not in _VALID_MONITOR_MODE: raise invalid_param_error( "TrainerConfig", "mode", self.mode, "must be 'min' or 'max'", ["min", "max"], ) if self.lr_patience < 1: raise invalid_param_error("TrainerConfig", "lr_patience", self.lr_patience, "must be >= 1") if not (0.0 < self.lr_factor < 1.0): raise invalid_param_error( "TrainerConfig", "lr_factor", self.lr_factor, "must be in the open interval (0, 1)", ) if self.patience >= self.max_epochs: warn_config( f"TrainerConfig: patience={self.patience} >= " f"max_epochs={self.max_epochs}. " "Early stopping will never trigger before training ends. " "Consider reducing patience or increasing max_epochs.", stacklevel=3, ) if self.lr_patience >= self.max_epochs: warn_config( f"TrainerConfig: lr_patience={self.lr_patience} >= " f"max_epochs={self.max_epochs}. " "The learning rate scheduler will never reduce the LR before training ends. " "Consider reducing lr_patience or increasing max_epochs.", stacklevel=3, ) if self.scheduler_interval not in {"epoch", "step"}: raise invalid_param_error( "TrainerConfig", "scheduler_interval", self.scheduler_interval, "must be 'epoch' or 'step'", ["epoch", "step"], ) if self.scheduler_frequency < 1: raise invalid_param_error( "TrainerConfig", "scheduler_frequency", self.scheduler_frequency, "must be >= 1", )