Source code for deeptab.models.lss_base

import warnings
from collections.abc import Callable

import lightning as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
from pretab.preprocessor import Preprocessor
from torch.utils.data import DataLoader
from tqdm import tqdm

from deeptab.core.exceptions import not_fitted_error
from deeptab.core.serialization import _warn_extension, build_save_bundle, restore_base_state, restore_loaded_metadata
from deeptab.core.sklearn_compat import ensure_dataframe, set_input_feature_attributes, validate_input_features
from deeptab.data.datamodule import TabularDataModule
from deeptab.distributions import get_distribution
from deeptab.metrics import get_default_metrics_dict
from deeptab.models.base import SklearnBase, _validate_fit_inputs
from deeptab.training import TaskModel


class SklearnBaseLSS(SklearnBase):
    """Distributional regression base class (LSS variant of SklearnBase).

    Inherits all sklearn compatibility, parameter management, serialization,
    HPO, and observability from ``SklearnBase``. Overrides ``build_model``,
    ``fit``, ``predict``, ``save``, and ``load`` to add LSS-specific concerns:
    distribution family selection, ``lss=True`` flag to ``TaskModel``, and
    distribution-transform post-processing in ``predict``.
    """

    def build_model(
        self,
        X,
        y,
        val_size: float = 0.2,
        X_val=None,
        y_val=None,
        random_state: int = 101,
        batch_size: int = 128,
        shuffle: bool = True,
        lr: float | None = None,
        lr_patience: int | None = None,
        lr_factor: float | None = None,
        weight_decay: float | None = None,
        train_metrics: dict[str, Callable] | None = None,
        val_metrics: dict[str, Callable] | None = None,
        dataloader_kwargs={},
    ):
        """Builds the model using the provided training data.

        Parameters
        ----------
        X : DataFrame or array-like, shape (n_samples, n_features)
            The training input samples.
        y : array-like, shape (n_samples,) or (n_samples, n_targets)
            The target values (real numbers).
        val_size : float, default=0.2
            The proportion of the dataset to include in the validation split if `X_val` is None.
            Ignored if `X_val` is provided.
        X_val : DataFrame or array-like, shape (n_samples, n_features), optional
            The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
        y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
            The validation target values. Required if `X_val` is provided.
        random_state : int, default=101
            Controls the shuffling applied to the data before applying the split.
        batch_size : int, default=64
            Number of samples per gradient update.
        shuffle : bool, default=True
            Whether to shuffle the training data before each epoch.
        lr : float, default=1e-3
            Learning rate for the optimizer.
        lr_patience : int, default=10
            Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
        lr_factor : float, default=0.1
            Factor by which the learning rate will be reduced.
        train_metrics : dict, default=None
            torch.metrics dict to be logged during training.
        val_metrics : dict, default=None
            torch.metrics dict to be logged during validation.
        weight_decay : float, default=0.025
            Weight decay (L2 penalty) coefficient.
        dataloader_kwargs: dict, default={}
            The kwargs for the pytorch dataloader class.

        Returns
        -------
        self : object
            The built distributional regressor.
        """
        # When trainer_config is active, resolve lr / scheduler params from it
        if self.trainer_config is not None:
            tc = self.trainer_config
            if lr is None:
                lr = tc.lr
            if lr_patience is None:
                lr_patience = tc.lr_patience
            if lr_factor is None:
                lr_factor = tc.lr_factor
            if weight_decay is None:
                weight_decay = tc.weight_decay

        # Re-sync preprocessor from current preprocessing_config state so that
        # direct mutations (e.g. clf.preprocessing_config.n_bins = 8) are
        # honoured on the next fit(), consistent with set_params() behaviour.
        if self.preprocessing_config is not None:
            self._preprocessor_kwargs = self.preprocessing_config.to_preprocessor_kwargs()
            self._preprocessor = Preprocessor(**self._preprocessor_kwargs)

        X = ensure_dataframe(X)
        set_input_feature_attributes(self, X)
        self.classes_ = np.unique(y) if getattr(self, "family_name", None) == "categorical" else None
        if hasattr(y, "values"):
            y = y.values
        if X_val is not None:
            X_val = ensure_dataframe(X_val)
            if y_val is not None and hasattr(y_val, "values"):
                y_val = y_val.values

        self._data_module = TabularDataModule(
            preprocessor=self._preprocessor,
            batch_size=batch_size,
            shuffle=shuffle,
            X_val=X_val,
            y_val=y_val,
            val_size=val_size,
            random_state=random_state,
            regression=getattr(self, "family_name", None) != "categorical",
            **dataloader_kwargs,
        )
        self._data_module.input_columns_ = self.input_columns_

        self._data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state)

        # After the first build, self._estimator holds the model *instance*
        # (assigned below). Resolve back to the class so repeated builds
        # (e.g. HPO trials or a refit) construct a fresh model correctly.
        _model_class = self._estimator if isinstance(self._estimator, type) else type(self._estimator)
        self._task_model = TaskModel(
            model_class=_model_class,  # type: ignore
            num_classes=self.family.param_count,
            family=self.family,
            config=self.config,
            feature_information=(
                self._data_module.num_feature_info,
                self._data_module.cat_feature_info,
                self._data_module.embedding_feature_info,
            ),
            lr=lr if lr is not None else getattr(self.config, "lr", None),
            lr_patience=(lr_patience if lr_patience is not None else getattr(self.config, "lr_patience", None)),
            lr_factor=lr_factor if lr_factor is not None else getattr(self.config, "lr_factor", None),
            weight_decay=(weight_decay if weight_decay is not None else getattr(self.config, "weight_decay", None)),
            lss=True,
            train_metrics=train_metrics,
            val_metrics=val_metrics,
            optimizer_type=(  # type: ignore[arg-type]
                self.trainer_config.optimizer_type if self.trainer_config is not None else self._optimizer_type
            ),
            optimizer_args=(
                getattr(self.trainer_config, "optimizer_kwargs", None) or self._optimizer_kwargs
                if self.trainer_config is not None
                else self._optimizer_kwargs
            ),
        )

        self._built = True
        self._estimator = self._task_model.estimator

        return self

    def fit(
        self,
        X,
        y,
        family,
        val_size: float = 0.2,
        X_val=None,
        y_val=None,
        max_epochs: int = 100,
        random_state: int = 101,
        batch_size: int = 128,
        shuffle: bool = True,
        patience: int = 15,
        monitor: str = "val_loss",
        mode: str = "min",
        lr: float | None = None,
        lr_patience: int | None = None,
        lr_factor: float | None = None,
        weight_decay: float | None = None,
        checkpoint_path="model_checkpoints",
        distributional_kwargs=None,
        train_metrics: dict[str, Callable] | None = None,
        val_metrics: dict[str, Callable] | None = None,
        dataloader_kwargs={},
        rebuild=True,
        **trainer_kwargs,
    ):
        """Trains the regression model using the provided training data. Optionally, a separate validation set can be
        used.

        Parameters
        ----------
        X : DataFrame or array-like, shape (n_samples, n_features)
            The training input samples.
        y : array-like, shape (n_samples,) or (n_samples, n_targets)
            The target values (real numbers).
        family : str
            The name of the distribution family to use for the loss function. Examples include 'normal'
            for regression tasks.
        val_size : float, default=0.2
            The proportion of the dataset to include in the validation split if `X_val` is None.
            Ignored if `X_val` is provided.
        X_val : DataFrame or array-like, shape (n_samples, n_features), optional
            The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
        y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
            The validation target values. Required if `X_val` is provided.
        max_epochs : int, default=100
            Maximum number of epochs for training.
        random_state : int, default=101
            Controls the shuffling applied to the data before applying the split.
        batch_size : int, default=64
            Number of samples per gradient update.
        shuffle : bool, default=True
            Whether to shuffle the training data before each epoch.
        patience : int, default=10
            Number of epochs with no improvement on the validation loss to wait before early stopping.
        monitor : str, default="val_loss"
            The metric to monitor for early stopping.
        mode : str, default="min"
            Whether the monitored metric should be minimized (`min`) or maximized (`max`).
        lr : float, default=1e-3
            Learning rate for the optimizer.
        lr_patience : int, default=10
            Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
        factor : float, default=0.1
            Factor by which the learning rate will be reduced.
        weight_decay : float, default=0.025
            Weight decay (L2 penalty) coefficient.
        distributional_kwargs : dict, default=None
            any arguments taht are specific for a certain distribution.
        train_metrics : dict, default=None
            torch.metrics dict to be logged during training.
        val_metrics : dict, default=None
            torch.metrics dict to be logged during validation.
        checkpoint_path : str, default="model_checkpoints"
            Path where the checkpoints are being saved.
        dataloader_kwargs: dict, default={}
            The kwargs for the pytorch dataloader class.
        **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.


        Returns
        -------
        self : object
            The fitted regressor.
        """
        # When trainer_config is active, override all training-loop params from it
        if self.trainer_config is not None:
            tc = self.trainer_config
            max_epochs = tc.max_epochs
            batch_size = tc.batch_size
            val_size = tc.val_size
            shuffle = tc.shuffle
            patience = tc.patience
            monitor = tc.monitor
            mode = tc.mode
            checkpoint_path = tc.checkpoint_path

        # Validate inputs before any preprocessing or model construction
        _validate_fit_inputs(X, y, regression=True, family=family)

        # When random_state was fixed at construction time, honour it
        if self.random_state is not None:
            random_state = self.random_state

        if distributional_kwargs is None:
            distributional_kwargs = {}

        self.family = get_distribution(family, **distributional_kwargs)
        self.family_name = family

        if rebuild:
            self.build_model(
                X=X,
                y=y,
                val_size=val_size,
                X_val=X_val,
                y_val=y_val,
                random_state=random_state,
                batch_size=batch_size,
                shuffle=shuffle,
                lr=lr,
                lr_patience=lr_patience,
                lr_factor=lr_factor,
                train_metrics=train_metrics,
                val_metrics=val_metrics,
                weight_decay=weight_decay,
                dataloader_kwargs=dataloader_kwargs,
            )

        else:
            if not self._built:
                raise ValueError(
                    "The model must be built before calling the fit method. \
                                 Either call .build_model() or set rebuild=True"
                )

        early_stop_callback = EarlyStopping(
            monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode
        )

        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss",  # Adjust according to your validation metric
            mode="min",
            save_top_k=1,
            dirpath=checkpoint_path,  # Specify the directory to save checkpoints
            filename="best_model",
        )

        # Initialize the trainer and train the model
        self._trainer = pl.Trainer(
            max_epochs=max_epochs,
            callbacks=[
                early_stop_callback,
                checkpoint_callback,
                ModelSummary(max_depth=2),
            ],
            **trainer_kwargs,
        )
        self._trainer.fit(self._task_model, self._data_module)  # type: ignore

        self._best_model_path = checkpoint_callback.best_model_path
        if self._best_model_path:
            torch.serialization.add_safe_globals([type(self.config)])
            checkpoint = torch.load(self._best_model_path, weights_only=False)
            self._task_model.load_state_dict(checkpoint["state_dict"])  # type: ignore

        self.is_fitted_ = True
        return self

    def predict(self, X, raw=False, device=None):
        """Predicts target values for the given input samples.

        Parameters
        ----------
        X : DataFrame or array-like, shape (n_samples, n_features)
            The input samples for which to predict target values.


        Returns
        -------
        predictions : ndarray, shape (n_samples,) or (n_samples, n_outputs)
            The predicted target values.
        """
        X = self._validate_predict_input(X)
        if self._task_model is None:
            raise not_fitted_error(type(self).__name__, "predict")

        self._emit_event("predict_started", n_samples=len(X))

        # Preprocess the data using the data module
        self._data_module.assign_predict_dataset(X)  # type: ignore[union-attr]

        # Set model to evaluation mode
        self._task_model.eval()

        # Perform inference using PyTorch Lightning's predict function
        predictions_list = self._trainer.predict(self._task_model, self._data_module)  # type: ignore[union-attr, arg-type]

        # Concatenate predictions from all batches
        predictions = torch.cat(predictions_list, dim=0)  # type: ignore[arg-type]

        # Check if ensemble is used
        if getattr(self._estimator, "returns_ensemble", False):  # If using ensemble
            predictions = predictions.mean(dim=1)  # Average over ensemble dimension

        if not raw:
            result = self._task_model.family(predictions).cpu().numpy()  # type: ignore
        else:
            result = predictions.cpu().numpy()
        self._emit_event("predict_completed")
        return result

    def evaluate(self, X, y_true, metrics=None, distribution_family=None):
        """Evaluate the model on the given data using specified metrics.

        Parameters
        ----------
        X : array-like or pd.DataFrame of shape (n_samples, n_features)
            The input samples to predict.
        y_true : array-like of shape (n_samples,)
            The true target values.
        metrics : dict, optional
            A ``{name: callable}`` dictionary of metric functions with signature
            ``metric(y_true, y_pred) -> float``.  Each callable may be a
            :class:`~deeptab.metrics.DeepTabMetric` instance or any plain
            callable.  When a metric has ``needs_raw=True``, raw model logits
            are passed instead of transformed distribution parameters.
            If ``None``, the default metrics for the distribution family are
            used (see :func:`deeptab.metrics.get_default_metrics`).
        distribution_family : str, optional
            Distribution family key (e.g. ``"normal"``, ``"gamma"``).  Inferred
            from the fitted model when ``None``.

        Returns
        -------
        scores : dict
            ``{metric_name: score}`` dictionary.
        """
        # Infer distribution family from model settings if not provided
        if distribution_family is None:
            distribution_family = getattr(self._task_model, "distribution_family", "normal")

        # Setup default metrics if none are provided
        if metrics is None:
            metrics = self.get_default_metrics(distribution_family)

        # Obtain both transformed and raw predictions up-front only when needed
        needs_any_raw = any(getattr(fn, "needs_raw", False) for fn in metrics.values())
        predictions_transformed = self.predict(X, raw=False)
        predictions_raw = self.predict(X, raw=True) if needs_any_raw else None

        y_true = np.asarray(y_true)
        scores = {}
        for metric_name, metric_func in metrics.items():
            _needs_raw = getattr(metric_func, "needs_raw", False)
            preds = predictions_raw if (_needs_raw and predictions_raw is not None) else predictions_transformed
            try:
                scores[metric_name] = metric_func(y_true, preds)
            except Exception as exc:
                warnings.warn(f"Metric '{metric_name}' failed: {exc}", RuntimeWarning, stacklevel=2)
                scores[metric_name] = float("nan")

        return scores

    def get_default_metrics(self, distribution_family):
        """Return default evaluation metrics for the given distribution family.

        Delegates to :func:`deeptab.metrics.get_default_metrics_dict`, which
        returns a ``{name: DeepTabMetric}`` dictionary covering all supported
        distribution families.

        Parameters
        ----------
        distribution_family : str
            Distribution family key, e.g. ``"normal"``, ``"gamma"``.

        Returns
        -------
        dict
            ``{metric_name: callable}`` dictionary of metric functions.
        """
        return get_default_metrics_dict("lss", family=distribution_family)

    def score(self, X, y, metric="NLL"):
        """Calculate the score of the model using the specified metric.

        Parameters
        ----------
        X : array-like or pd.DataFrame of shape (n_samples, n_features)
            The input samples to predict.
        y : array-like of shape (n_samples,) or (n_samples, n_outputs)
            The true target values against which to evaluate the predictions.
        metric : str, default="NLL"
            So far, only negative log-likelihood is supported

        Returns
        -------
        score : float
            The score calculated using the specified metric.
        """
        predictions = self.predict(X)
        score = self._task_model.family.evaluate_nll(y, predictions)  # type: ignore
        return score

    def encode(self, X, batch_size=64):
        """
        Encodes input data using the trained model's embedding layer.

        Parameters
        ----------
        X : array-like or DataFrame
            Input data to be encoded.
        batch_size : int, optional, default=64
            Batch size for encoding.

        Returns
        -------
        torch.Tensor
            Encoded representations of the input data.

        Raises
        ------
        ValueError
            If the model or data module is not fitted.
        """
        # Ensure model and data module are initialized
        if self._task_model is None or self._data_module is None:
            raise ValueError("The model or data module has not been fitted yet.")
        if not hasattr(self._task_model.estimator, "embedding_layer"):  # type: ignore[union-attr]
            raise AttributeError(
                f"{type(self._task_model.estimator).__name__} does not have an embedding_layer."  # type: ignore[union-attr]
            )
        encoded_dataset = self._data_module.preprocess_new_data(X)

        data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False)

        # Process data in batches
        encoded_outputs = []
        for num_features, cat_features in tqdm(data_loader):
            embeddings = self._task_model.estimator.encode(num_features, cat_features)  # type: ignore[union-attr]  # Call your encode function
            encoded_outputs.append(embeddings)

        # Concatenate all encoded outputs
        encoded_outputs = torch.cat(encoded_outputs, dim=0)

        return encoded_outputs

    # ------------------------------------------------------------------
    # Persistence
    # ------------------------------------------------------------------

    def save(self, path: str) -> None:
        """Save the fitted model to *path*.

        The bundle written by this method can be restored with
        :meth:`load`.  It contains all state required for inference:
        the architecture/config, neural-network weights, fitted
        preprocessing state, feature schema and column order, task
        metadata, distribution family, classifier classes for
        categorical LSS models, and package versions for debugging
        reloads across environments.

        The bundle is built by :func:`~deeptab.core.serialization.build_save_bundle`,
        which is the single source of truth for artifact structure across all
        model variants.

        Parameters
        ----------
        path : str
            Destination file path (e.g. ``"model.pt"``).

        Raises
        ------
        ValueError
            If the model has not been fitted yet.

        Examples
        --------
        >>> model = MLPLSS()
        >>> model.fit(X_train, y_train, family="normal")
        >>> model.save("my_lss_model.deeptab")
        >>> loaded = MLPLSS.load("my_lss_model.deeptab")
        >>> predictions = loaded.predict(X_test)
        """
        _warn_extension(path)
        bundle = build_save_bundle(self, lss=True, family=self.family_name)
        torch.save(bundle, path)

    @classmethod
    def load(cls, path: str):
        """Load and return a fitted model from *path*.

        Parameters
        ----------
        path : str
            Path to a file previously written by :meth:`save`.

        Returns
        -------
        estimator
            A fully reconstructed, ready-to-predict estimator. Exposes
            ``artifact_metadata_``, ``architecture_metadata_``,
            ``feature_schema_``, ``input_columns_``, ``task_info_``,
            ``classes_``, and ``versions_`` attributes after loading.

        Examples
        --------
        >>> loaded = MLPLSS.load("my_lss_model.deeptab")
        >>> predictions = loaded.predict(X_test)
        >>> print(loaded.task_info_[\"family\"])
        'normal'
        """
        _warn_extension(path)
        bundle = torch.load(path, weights_only=False)

        obj = bundle["_class"].__new__(bundle["_class"])
        restore_base_state(obj, bundle)
        obj.family = get_distribution(bundle["family"])
        obj.family_name = bundle["family"]

        obj._data_module = TabularDataModule(
            preprocessor=bundle["preprocessor"],
            batch_size=bundle["batch_size"],
            shuffle=False,
            regression=bundle["regression"],
        )
        obj._data_module.num_feature_info = bundle["feature_info"]["num"]
        obj._data_module.cat_feature_info = bundle["feature_info"]["cat"]
        obj._data_module.embedding_feature_info = bundle["feature_info"]["emb"]
        obj._data_module.input_columns_ = bundle.get("input_columns")

        obj._task_model = TaskModel(
            model_class=bundle["model_class"],
            config=bundle["config"],
            feature_information=(
                bundle["feature_info"]["num"],
                bundle["feature_info"]["cat"],
                bundle["feature_info"]["emb"],
            ),
            num_classes=bundle["num_classes"],
            lss=bundle["lss"],
            family=obj.family,
            optimizer_type=bundle["optimizer_type"],
            optimizer_args=bundle["optimizer_kwargs"],
            lr=bundle["lr"],
            lr_patience=bundle["lr_patience"],
            lr_factor=bundle["lr_factor"],
            weight_decay=bundle["weight_decay"],
        )
        obj._task_model.load_state_dict(bundle["task_model_state_dict"])
        obj._task_model.eval()
        obj._estimator = obj._task_model.estimator

        obj._trainer = pl.Trainer(
            max_epochs=1,
            enable_progress_bar=False,
            enable_model_summary=False,
            logger=False,
        )
        restore_loaded_metadata(obj, bundle)
        obj._data_module.input_columns_ = obj.input_columns_

        return obj

    def optimize_hparams(
        self,
        X,
        y,
        X_val=None,
        y_val=None,
        time=100,
        max_epochs=200,
        prune_by_epoch=True,
        prune_epoch=5,
        fixed_params={
            "pooling_method": "avg",
            "head_skip_layers": False,
            "head_layer_size_length": 0,
            "cat_encoding": "int",
            "head_skip_layer": False,
            "use_cls": False,
        },
        custom_search_space=None,
        **optimize_kwargs,
    ):
        """Optimizes hyperparameters using Bayesian optimization with optional pruning.

        Parameters
        ----------
        X : array-like
            Training data.
        y : array-like
            Training labels.
        X_val, y_val : array-like, optional
            Validation data and labels.
        time : int
            The number of optimization trials to run.
        max_epochs : int
            Maximum number of epochs for training.
        prune_by_epoch : bool
            Whether to prune based on a specific epoch (True) or the best validation loss (False).
        prune_epoch : int
            The specific epoch to prune by when prune_by_epoch is True.
        **optimize_kwargs : dict
            Additional keyword arguments passed to the fit method.

        Returns
        -------
        best_hparams : list
            Best hyperparameters found during optimization.
        """

        return super().optimize_hparams(
            X,
            y,
            regression=False,
            X_val=X_val,
            y_val=y_val,
            time=time,
            max_epochs=max_epochs,
            prune_by_epoch=prune_by_epoch,
            prune_epoch=prune_epoch,
            fixed_params=fixed_params,
            custom_search_space=custom_search_space,
            **optimize_kwargs,
        )