Source code for deeptab.models._mixins.serialization

"""Save and load logic for all DeepTab estimators.

The :meth:`save` / :meth:`load` pair is the canonical persistence
mechanism.  Standard :mod:`pickle` is intentionally **not** supported:
``__getstate__`` clears ``task_model`` to avoid serialising Lightning
modules, so a pickled estimator cannot make predictions after
unpickling.  Use :meth:`save` / :meth:`load` for all persistence needs.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import lightning as pl
import torch

from deeptab.core.default_factories import DefaultDataModuleFactory, DefaultTaskModelFactory
from deeptab.core.serialization import _warn_extension, build_save_bundle, restore_base_state, restore_loaded_metadata


class _SerializationMixin:
    """Bundle-based model persistence.

    Provides :meth:`save` and the classmethod :meth:`load` as the
    sole supported persistence mechanism for fitted DeepTab estimators.
    The bundle format is defined by
    :func:`~deeptab.core.serialization.build_save_bundle` and contains
    all state needed for inference: architecture config, neural-network
    weights, fitted preprocessor, feature schema, column order, task
    metadata, and a version snapshot.

    Note
    ----
    :class:`pickle` is **not** supported.  ``__getstate__`` intentionally
    clears ``task_model`` to prevent serialising Lightning modules.  Always
    use :meth:`save` / :meth:`load` instead.
    """

    if TYPE_CHECKING:
        # _emit_event is provided at runtime by _ObservabilityMixin via the MRO.
        # The stub here lets type-checkers resolve the call sites in save/load.
        def _emit_event(self, event: str, **kwargs) -> None: ...

    def save(self, path: str | None = None) -> str:
        """Save the fitted model to *path*.

        The bundle written by this method can be restored with
        :meth:`load`.  It contains all state required for inference:
        architecture/config, neural-network weights, fitted preprocessing
        state, feature schema, column order, task metadata, classifier
        classes (when available), and package versions for debugging
        reloads across environments.

        Parameters
        ----------
        path : str or None, default=None
            Destination file path (e.g. ``"model.pt"``).  When ``None``
            and a run directory is active (i.e. ``configure_observability``
            was called with a config that creates a run dir), the model is
            saved to ``<run_dir>/artifacts/model.deeptab`` automatically.
            When no run dir is active either, raises ``ValueError``.

        Returns
        -------
        str
            The resolved path the bundle was written to.

        Raises
        ------
        ValueError
            If the model has not been fitted yet, or *path* is ``None``
            and no run directory is active.

        Examples
        --------
        >>> model = MLPClassifier()
        >>> model.fit(X_train, y_train)
        >>> saved_path = model.save("my_model.deeptab")
        >>> loaded = MLPClassifier.load(saved_path)
        >>> predictions = loaded.predict(X_test)
        """
        import os

        if path is None:
            _run_dir = getattr(self, "_run_dir", None)
            if not _run_dir:
                raise ValueError(
                    "path is required when no run directory is active. "
                    "Either pass an explicit path to save() or call "
                    "configure_observability() before fit() to enable run tracking."
                )
            path = os.path.join(_run_dir, "artifacts", "model.deeptab")
            os.makedirs(os.path.dirname(path), exist_ok=True)

        self._emit_event("save_started", path=path)
        _warn_extension(path)
        bundle = build_save_bundle(self, lss=False, family=None)
        torch.save(bundle, path)
        self._emit_event("save_completed", path=path)
        return path

    @classmethod
    def load(cls, path: str):
        """Load and return a fitted model from *path*.

        Parameters
        ----------
        path : str
            Path to a file previously written by :meth:`save`.

        Returns
        -------
        estimator
            A fully reconstructed, ready-to-predict estimator of the same
            type that was saved.

        Examples
        --------
        >>> loaded = MLPClassifier.load("my_model.deeptab")
        >>> predictions = loaded.predict(X_test)
        >>> print(loaded.task_info_["task"])
        'classification'
        >>> print(loaded.n_features_in_)
        6
        """
        _warn_extension(path)
        bundle = torch.load(path, weights_only=False)

        obj = bundle["_class"].__new__(bundle["_class"])
        restore_base_state(obj, bundle)

        # load() bypasses __init__, so factories are not yet set.
        # Initialise them to production defaults before using them.
        if not hasattr(obj, "_data_module_factory") or obj._data_module_factory is None:
            obj._data_module_factory = DefaultDataModuleFactory()
        if not hasattr(obj, "_task_model_factory") or obj._task_model_factory is None:
            obj._task_model_factory = DefaultTaskModelFactory()

        obj._data_module = obj._data_module_factory.create(
            preprocessor=bundle["preprocessor"],
            batch_size=bundle["batch_size"],
            shuffle=False,
            regression=bundle["regression"],
        )
        obj._data_module.num_feature_info = bundle["feature_info"]["num"]
        obj._data_module.cat_feature_info = bundle["feature_info"]["cat"]
        obj._data_module.embedding_feature_info = bundle["feature_info"]["emb"]
        obj._data_module.input_columns_ = bundle.get("input_columns")

        obj._task_model = obj._task_model_factory.create(
            model_class=bundle["model_class"],
            config=bundle["config"],
            feature_information=(
                bundle["feature_info"]["num"],
                bundle["feature_info"]["cat"],
                bundle["feature_info"]["emb"],
            ),
            num_classes=bundle["num_classes"],
            lss=bundle["lss"],
            family=bundle["family"],
            optimizer_type=bundle["optimizer_type"],
            optimizer_args=bundle["optimizer_kwargs"],
            lr=bundle["lr"],
            lr_patience=bundle["lr_patience"],
            lr_factor=bundle["lr_factor"],
            weight_decay=bundle["weight_decay"],
        )
        obj._task_model.load_state_dict(bundle["task_model_state_dict"])
        obj._task_model.eval()
        obj._estimator = obj._task_model.estimator

        obj._trainer = pl.Trainer(
            max_epochs=1,
            enable_progress_bar=False,
            enable_model_summary=False,
            logger=False,
        )
        restore_loaded_metadata(obj, bundle)
        obj._data_module.input_columns_ = obj.input_columns_

        obj._emit_event("load_completed", path=path)
        return obj