Source code for deeptab.models.base

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, ClassVar

import lightning as pl
import numpy as np
from pretab.preprocessor import Preprocessor
from sklearn.base import BaseEstimator

from deeptab.configs.core import BaseModelConfig, PreprocessingConfig, TrainerConfig
from deeptab.core.default_factories import DefaultDataModuleFactory, DefaultTaskModelFactory
from deeptab.core.exceptions import (
    DataError,
    target_nan_error,
    target_range_error,
    warn_config,
    warn_data,
    xy_length_mismatch_error,
)
from deeptab.core.inspection import InspectionMixin
from deeptab.core.interfaces import IDataModule, IDataModuleFactory, ITaskModel, ITaskModelFactory
from deeptab.models._mixins import (
    _FitMixin,
    _HyperparameterMixin,
    _ObservabilityMixin,
    _PredictMixin,
    _SerializationMixin,
)

if TYPE_CHECKING:
    from deeptab.core.observability import ObservabilityConfig


def _warn_on_misplaced_configs(model_config, preprocessing_config, trainer_config) -> None:
    """Warn when a config object is passed to the wrong constructor slot.

    The constructor duck-types each config, so a misplaced config (for example a
    ``TrainerConfig`` handed to ``model_config``) would otherwise be accepted
    silently and then quietly ignored. This emits an advisory ``ConfigWarning``
    when a recognisably wrong config type is detected, without raising, so that
    legitimate duck-typed test doubles keep working.
    """
    slots = (
        ("model_config", model_config, BaseModelConfig),
        ("preprocessing_config", preprocessing_config, PreprocessingConfig),
        ("trainer_config", trainer_config, TrainerConfig),
    )
    known_config_types = (BaseModelConfig, PreprocessingConfig, TrainerConfig)
    for slot_name, value, expected_cls in slots:
        if value is None:
            continue
        # Only warn when the value is clearly *another* known config type;
        # unknown duck-typed objects (mocks, test doubles) are left alone.
        if isinstance(value, known_config_types) and not isinstance(value, expected_cls):
            warn_config(
                f"{type(value).__name__} was passed as '{slot_name}', but '{slot_name}' "
                f"expects a {expected_cls.__name__}. Configs are not reordered for you, "
                f"so this one will be misused or silently ignored. Pass it as its matching "
                f"argument instead.",
                stacklevel=4,
            )


def _validate_fit_inputs(
    X,
    y,
    regression: bool,
    family: str | None = None,
) -> None:
    """Validate X and y before any preprocessing or model building.

    Raises
    ------
    EmptyDataError
        If X is empty (caught later by ensure_dataframe).
    DataError
        If len(X) != len(y), y contains NaN, or y violates the distribution
        family's range constraint.
    """
    n_X = len(X)
    n_y = len(y)
    if n_X != n_y:
        raise xy_length_mismatch_error(n_X, n_y)

    if hasattr(X, "ndim") and X.ndim == 1:
        raise ValueError(
            "Expected a 2D array for X, got a 1D array instead. "
            "Reshape your data using X.reshape(-1, 1) for a single feature."
        )

    y_arr = np.asarray(y)
    if y_arr.ndim <= 2 and np.issubdtype(y_arr.dtype, np.floating) and np.isnan(y_arr).any():
        raise target_nan_error()

    # Distribution family range constraints
    if family is not None:
        family_lower = family.lower()
        if family_lower in {"poisson", "negativebinom"} and (y_arr < 0).any():
            raise target_range_error(family, "non-negative")
        if family_lower in {"gamma", "inversegaussian"} and (y_arr <= 0).any():
            raise target_range_error(family, "strictly positive")
        if family_lower == "binomial" and not np.all((y_arr == 0) | (y_arr == 1)):
            raise target_range_error(family, "binary (0 or 1)")

    # Warn about high-NaN columns
    if hasattr(X, "isna"):
        nan_rate = X.isna().mean()
        high_nan = nan_rate[nan_rate > 0.5].index.tolist()
        if high_nan:
            warn_data(
                f"Columns with >50% missing values: {[str(c) for c in high_nan]}. "
                "Consider dropping or imputing them before calling fit().",
                stacklevel=5,
            )


class SklearnBase(
    _ObservabilityMixin,
    _FitMixin,
    _PredictMixin,
    _SerializationMixin,
    _HyperparameterMixin,
    InspectionMixin,
    BaseEstimator,
):
    """Thin coordinator — all behaviour lives in the mixins.

    MRO:
        _ObservabilityMixin  →  _FitMixin  →  _PredictMixin
        → _SerializationMixin  →  _HyperparameterMixin
        → InspectionMixin  →  BaseEstimator

    Concrete estimators declare the architecture and its default config class
    via the ``_model_cls`` and ``_config_cls`` class attributes instead of
    passing them through ``__init__``. This keeps the constructor signature
    limited to the public, sklearn-introspectable parameters.
    """

    # Set by concrete estimator subclasses (e.g. ``_model_cls = MLP``).
    _model_cls: ClassVar[type | None] = None
    _config_cls: ClassVar[type | None] = None

    def __init__(
        self,
        model_config=None,
        preprocessing_config=None,
        trainer_config=None,
        observability_config: ObservabilityConfig | None = None,
        random_state=None,
    ):
        model_cls = type(self)._model_cls
        config_cls = type(self)._config_cls
        if model_cls is None or config_cls is None:
            raise TypeError(
                f"{type(self).__name__} must define the '_model_cls' and "
                "'_config_cls' class attributes (the architecture class and its "
                "default config class)."
            )
        self.random_state = random_state
        self._preprocessor_arg_names = [
            "n_bins",
            "feature_preprocessing",
            "numerical_preprocessing",
            "categorical_preprocessing",
            "use_decision_tree_bins",
            "binning_strategy",
            "task",
            "cat_cutoff",
            "treat_all_integers_as_numerical",
            "degree",
            "scaling_strategy",
            "n_knots",
            "use_decision_tree_knots",
            "knots_strategy",
            "spline_implementation",
        ]

        if model_config is not None or preprocessing_config is not None or trainer_config is not None:
            # ---- New split-config path ----
            _warn_on_misplaced_configs(model_config, preprocessing_config, trainer_config)
            self.model_config = model_config
            self.preprocessing_config = (
                preprocessing_config if preprocessing_config is not None else PreprocessingConfig()
            )
            self.trainer_config = trainer_config if trainer_config is not None else TrainerConfig()

            if model_config is not None and hasattr(model_config, "get_params"):
                self._config_kwargs = model_config.get_params(deep=False)
                self.config = model_config
            else:
                self._config_kwargs = {}
                self.config = config_cls()

            if hasattr(self.preprocessing_config, "to_preprocessor_kwargs"):
                self._preprocessor_kwargs = self.preprocessing_config.to_preprocessor_kwargs()
            else:
                self._preprocessor_kwargs = {}
            self._preprocessor = Preprocessor(**self._preprocessor_kwargs)

            self._optimizer_type = getattr(self.trainer_config, "optimizer_type", "Adam")
            self._optimizer_kwargs = {}
        else:
            # ---- No configs provided: fall back to defaults ----
            self.model_config = None
            self.preprocessing_config = None
            self.trainer_config = None

            self._config_kwargs = {}
            self.config = config_cls()

            self._preprocessor_kwargs = {}
            self._preprocessor = Preprocessor()

            self._optimizer_type = "Adam"
            self._optimizer_kwargs = {}

        self._estimator = model_cls
        self._task_model = None
        self._built = False
        # Fitted attributes (_data_module, _trainer, _best_model_path) are
        # initialised here so fit() never *adds* new public attributes.
        # input_columns_ is a proper fitted attribute (trailing _) set only
        # in fit() via set_input_feature_attributes(); not initialised here.
        self._data_module: IDataModule | None = None
        self._trainer: pl.Trainer | None = None
        self._best_model_path: str | None = None
        # Dependency-inversion factories (underscore-prefixed: ignored by
        # sklearn's get_params/set_params; clones always get fresh defaults).
        # Set via direct attribute assignment to inject test doubles:
        #   estimator._data_module_factory = MyFactory()
        self._data_module_factory: IDataModuleFactory = DefaultDataModuleFactory()
        self._task_model_factory: ITaskModelFactory = DefaultTaskModelFactory()
        # Observability — wire up backends if a config was provided.
        # Underscore-prefix: hidden from sklearn get_params/set_params/clone.
        # Only wire up for a genuine ObservabilityConfig; like the model and
        # preprocessing configs above, an unexpected value is stored as-is and
        # validation is deferred rather than raising inside __init__.
        self._observability_config: ObservabilityConfig | None = observability_config
        if observability_config is not None and hasattr(observability_config, "structured_logging"):
            self.configure_observability(observability_config)

    @property
    def config(self):
        """The instantiated model config object backing this estimator.

        Stored on the private ``_config`` attribute so it stays out of
        sklearn's ``get_params``/``__init__`` introspection (it is derived
        from ``model_config``/``_model_cls`` rather than a constructor
        parameter), while remaining readable and settable as ``estimator.config``.
        """
        return self._config

    @config.setter
    def config(self, value):
        self._config = value

    def get_params(self, deep=True):
        """Get parameters for this estimator."""
        if self.model_config is not None or self.preprocessing_config is not None or self.trainer_config is not None:
            # New split-config style
            params = {
                "model_config": self.model_config,
                "preprocessing_config": self.preprocessing_config,
                "trainer_config": self.trainer_config,
                "random_state": self.random_state,
            }
            if deep:
                if self.model_config is not None and hasattr(self.model_config, "get_params"):
                    for k, v in self.model_config.get_params(deep=False).items():
                        params[f"model_config__{k}"] = v
                if self.preprocessing_config is not None and hasattr(self.preprocessing_config, "get_params"):
                    for k, v in self.preprocessing_config.get_params(deep=False).items():
                        params[f"preprocessing_config__{k}"] = v
                if self.trainer_config is not None and hasattr(self.trainer_config, "get_params"):
                    for k, v in self.trainer_config.get_params(deep=False).items():
                        params[f"trainer_config__{k}"] = v
            return params

        # Legacy flat-kwargs style
        params = {}
        params.update(self._config_kwargs)
        params.update(self._preprocessor_kwargs)
        if deep:
            get_params_fn = getattr(self._preprocessor, "get_params", None)
            if get_params_fn is not None:
                preprocessor_params = {
                    key: value for key, value in get_params_fn().items() if key in self._preprocessor_arg_names
                }
                params.update(preprocessor_params)
        return params

    def set_params(self, **parameters):
        """Set the parameters of this estimator."""
        if self.model_config is not None or self.preprocessing_config is not None or self.trainer_config is not None:
            # New split-config style
            direct_params = {}
            model_config_params = {}
            preprocessing_config_params = {}
            trainer_config_params = {}

            for k, v in parameters.items():
                if k.startswith("model_config__"):
                    model_config_params[k[len("model_config__") :]] = v
                elif k.startswith("preprocessing_config__"):
                    preprocessing_config_params[k[len("preprocessing_config__") :]] = v
                elif k.startswith("trainer_config__"):
                    trainer_config_params[k[len("trainer_config__") :]] = v
                else:
                    direct_params[k] = v

            for k, v in direct_params.items():
                if k == "model_config":
                    self.model_config = v
                    if v is not None and hasattr(v, "get_params"):
                        self.config = v
                        self._config_kwargs = v.get_params(deep=False)
                elif k == "preprocessing_config":
                    self.preprocessing_config = v
                    if v is not None and hasattr(v, "to_preprocessor_kwargs"):
                        self._preprocessor_kwargs = v.to_preprocessor_kwargs()
                        self._preprocessor = Preprocessor(**self._preprocessor_kwargs)
                elif k == "trainer_config":
                    self.trainer_config = v
                    if v is not None and hasattr(v, "optimizer_type"):
                        self._optimizer_type = v.optimizer_type
                elif k == "random_state":
                    self.random_state = v

            if model_config_params and self.model_config is not None and hasattr(self.model_config, "set_params"):
                self.model_config.set_params(**model_config_params)
                self._config_kwargs = self.model_config.get_params(deep=False)
            if (
                preprocessing_config_params
                and self.preprocessing_config is not None
                and hasattr(self.preprocessing_config, "set_params")
            ):
                self.preprocessing_config.set_params(**preprocessing_config_params)
                self._preprocessor_kwargs = self.preprocessing_config.to_preprocessor_kwargs()
                self._preprocessor = Preprocessor(**self._preprocessor_kwargs)
            if trainer_config_params and self.trainer_config is not None and hasattr(self.trainer_config, "set_params"):
                self.trainer_config.set_params(**trainer_config_params)
                self._optimizer_type = self.trainer_config.optimizer_type

            return self

        # Legacy flat-kwargs style
        config_params = {k: v for k, v in parameters.items() if k not in self._preprocessor_arg_names}
        preprocessor_params = {k: v for k, v in parameters.items() if k in self._preprocessor_arg_names}

        if config_params:
            self._config_kwargs.update(config_params)

        if preprocessor_params:
            self._preprocessor_kwargs.update(preprocessor_params)
            self._preprocessor.set_params(**self._preprocessor_kwargs)  # type: ignore[attr-defined]

        return self

    def __sklearn_is_fitted__(self) -> bool:
        """sklearn hook: return True only after fit() has completed.

        Declaring this method prevents sklearn's ``check_is_fitted`` from
        inspecting attributes ending with ``_`` (e.g. ``input_columns_``,
        ``n_features_in_``) which exist even on unfitted estimators.
        """
        return bool(getattr(self, "is_fitted_", False))

    def __getstate__(self):
        state = self.__dict__.copy()
        state["task_model"] = None  # Avoid serializing the task model
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        self._task_model = None  # Reinitialize task model