Source code for deeptab.models._mixins.fit

"""Model construction and training-loop logic for all DeepTab estimators."""

from __future__ import annotations

import os
import re
import time
import uuid
from collections.abc import Callable
from dataclasses import fields as dataclass_fields
from dataclasses import is_dataclass
from typing import TYPE_CHECKING, Any

import lightning as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
from pretab.preprocessor import Preprocessor

from deeptab.core.sklearn_compat import ensure_dataframe, set_input_feature_attributes
from deeptab.training import pretrain_embeddings

if TYPE_CHECKING:
    from deeptab.configs import PreprocessingConfig, TrainerConfig
    from deeptab.core.default_factories import DefaultDataModuleFactory, DefaultTaskModelFactory
    from deeptab.core.observability import ObservabilityConfig
    from deeptab.models._mixins.observability import _SupportsInfo


def _build_trainer_loggers(
    obs_config: ObservabilityConfig | None,
    run_dir_name: str | None = None,
) -> bool | list[Any]:
    """Return Lightning loggers derived from *obs_config*.

    Returns ``False`` (no logger) when no experiment trackers are configured
    so that Lightning never writes a spurious ``lightning_logs/`` directory.
    Returns a list of loggers when trackers are active.
    """
    if obs_config is None or not obs_config.experiment_trackers:
        return False  # suppress Lightning's default CSVLogger
    from deeptab.core.observability import build_lightning_loggers

    loggers = build_lightning_loggers(obs_config, run_dir_name=run_dir_name)
    return loggers if loggers else False


class _FitMixin:
    # ---------------------------------------------------------------------------
    # Attributes provided by SklearnBase when this mixin is composed.
    # Declared here for static type-checkers only; never initialised in this class.
    # ---------------------------------------------------------------------------
    if TYPE_CHECKING:
        random_state: int | None
        trainer_config: TrainerConfig | None
        preprocessing_config: PreprocessingConfig | None
        config: Any
        input_columns_: list[str] | None
        _data_module_factory: DefaultDataModuleFactory
        _task_model_factory: DefaultTaskModelFactory
        _optimizer_type: str | None
        _optimizer_kwargs: dict | None
        _event_logger: _SupportsInfo | None

        def _emit_event(self, event: str, **kwargs: Any) -> None: ...

    """Model construction and training loop.

    Responsibilities
    ----------------
    * ``_build_model`` — creates and configures the ``IDataModule`` and
      ``ITaskModel`` collaborators using the injected factories.
    * ``fit`` — orchestrates data validation, model construction, Lightning
      Trainer setup, weight checkpointing, and best-weight restoration.
    * ``get_number_of_params`` — counts trainable / total parameters after a
      model has been built.
    * ``_pretrain`` — contrastive pre-training pass (optional, used for
      embedding warm-start).
    """

    # ------------------------------------------------------------------
    # Model construction
    # ------------------------------------------------------------------

    def _build_model(
        self,
        X,
        y,
        regression: bool,
        val_size: float = 0.2,
        X_val=None,
        y_val=None,
        embeddings=None,
        embeddings_val=None,
        num_classes: int | None = None,
        random_state: int = 101,
        batch_size: int = 128,
        shuffle: bool = True,
        stratify: bool = True,
        lr: float | None = None,
        lr_patience: int | None = None,
        lr_factor: float | None = None,
        weight_decay: float | None = None,
        train_metrics: dict[str, Callable] | None = None,
        val_metrics: dict[str, Callable] | None = None,
        dataloader_kwargs={},
        loss_fct: Callable | None = None,
        sampler=None,
    ):
        """Builds the model using the provided training data."""
        # When trainer_config is active, use its values for lr / weight_decay / scheduler
        if self.trainer_config is not None:
            tc = self.trainer_config
            if lr is None:
                lr = tc.lr
            if lr_patience is None:
                lr_patience = tc.lr_patience
            if lr_factor is None:
                lr_factor = tc.lr_factor
            if weight_decay is None:
                weight_decay = tc.weight_decay

        # Collect new scheduler/optimizer fields from TrainerConfig
        _tc = self.trainer_config
        _scheduler_type = (
            getattr(_tc, "scheduler_type", "ReduceLROnPlateau") if _tc is not None else "ReduceLROnPlateau"
        )
        _scheduler_kwargs = getattr(_tc, "scheduler_kwargs", None) if _tc is not None else None
        _scheduler_monitor = getattr(_tc, "scheduler_monitor", None) if _tc is not None else None
        _scheduler_interval = getattr(_tc, "scheduler_interval", "epoch") if _tc is not None else "epoch"
        _scheduler_frequency = getattr(_tc, "scheduler_frequency", 1) if _tc is not None else 1
        _no_wd_bn = getattr(_tc, "no_weight_decay_for_bias_and_norm", False) if _tc is not None else False
        _optimizer_kwargs = getattr(_tc, "optimizer_kwargs", None) if _tc is not None else None

        # Re-sync preprocessor from current preprocessing_config state so that
        # direct mutations (e.g. clf.preprocessing_config.n_bins = 8) are
        # honoured on the next fit(), consistent with set_params() behaviour.
        if self.preprocessing_config is not None:
            self._preprocessor_kwargs = self.preprocessing_config.to_preprocessor_kwargs()
            self._preprocessor = Preprocessor(**self._preprocessor_kwargs)

        X = ensure_dataframe(X)
        set_input_feature_attributes(self, X)
        if hasattr(y, "values"):
            y = y.values
        if X_val is not None:
            X_val = ensure_dataframe(X_val)
            if y_val is not None and hasattr(y_val, "values"):
                y_val = y_val.values

        self._data_module = self._data_module_factory.create(
            preprocessor=self._preprocessor,
            batch_size=batch_size,
            shuffle=shuffle,
            X_val=X_val,
            y_val=y_val,
            val_size=val_size,
            random_state=random_state,
            regression=regression,
            stratify=stratify,
            sampler=sampler,
            **dataloader_kwargs,
        )
        # Insert timer start for data module before preprocess_data call
        _t_data = time.monotonic()
        self._data_module.input_columns_ = self.input_columns_

        self._data_module.preprocess_data(
            X,
            y,
            X_val=X_val,
            y_val=y_val,
            embeddings_train=embeddings,
            embeddings_val=embeddings_val,
            val_size=val_size,
            random_state=random_state,
        )
        _dm = self._data_module
        _n_train = len(_dm.y_train) if getattr(_dm, "y_train", None) is not None else None  # type: ignore[union-attr]
        _n_val = len(_dm.y_val) if getattr(_dm, "y_val", None) is not None else None  # type: ignore[union-attr]
        _n_num = len(_dm.num_feature_info) if getattr(_dm, "num_feature_info", None) is not None else None  # type: ignore[union-attr]
        _n_cat = len(_dm.cat_feature_info) if getattr(_dm, "cat_feature_info", None) is not None else None  # type: ignore[union-attr]
        self._emit_event(
            "data.created",
            n_train=_n_train,
            n_val=_n_val,
            n_num_features=_n_num,
            n_cat_features=_n_cat,
            val_size=val_size,
            duration_min=round((time.monotonic() - _t_data) / 60, 4),
        )

        _t_model = time.monotonic()
        # After the first build, self._estimator holds the model *instance*
        # (assigned below). Resolve back to the class so repeated builds
        # (e.g. HPO trials or a refit) construct a fresh model correctly.
        _model_class = self._estimator if isinstance(self._estimator, type) else type(self._estimator)
        self._task_model = self._task_model_factory.create(
            model_class=_model_class,  # type: ignore
            config=self.config,
            feature_information=(
                self._data_module.num_feature_info,  # type: ignore[arg-type]
                self._data_module.cat_feature_info,  # type: ignore[arg-type]
                self._data_module.embedding_feature_info,  # type: ignore[arg-type]
            ),
            lr=lr if lr is not None else getattr(self.config, "lr", None),
            lr_patience=(lr_patience if lr_patience is not None else getattr(self.config, "lr_patience", None)),
            lr_factor=lr_factor if lr_factor is not None else getattr(self.config, "lr_factor", None),
            weight_decay=(weight_decay if weight_decay is not None else getattr(self.config, "weight_decay", None)),
            num_classes=num_classes,  # type: ignore[arg-type]
            train_metrics=train_metrics,
            val_metrics=val_metrics,
            optimizer_type=(
                self.trainer_config.optimizer_type if self.trainer_config is not None else self._optimizer_type
            ),
            optimizer_args=_optimizer_kwargs if _optimizer_kwargs is not None else self._optimizer_kwargs,
            scheduler_type=_scheduler_type,
            scheduler_kwargs=_scheduler_kwargs,
            monitor=_scheduler_monitor
            if _scheduler_monitor is not None
            else (
                getattr(self.trainer_config, "monitor", "val_loss") if self.trainer_config is not None else "val_loss"
            ),
            mode=getattr(self.trainer_config, "mode", "min") if self.trainer_config is not None else "min",
            scheduler_interval=_scheduler_interval,
            scheduler_frequency=_scheduler_frequency,
            no_weight_decay_for_bias_and_norm=_no_wd_bn,
            loss_fct=loss_fct,
        )

        self._built = True
        self._estimator = self._task_model.estimator
        _n_params_build = sum(p.numel() for p in self._task_model.parameters() if p.requires_grad)
        self._emit_event(
            "model.created",
            backbone=type(self._estimator).__name__,
            n_params=_n_params_build,
            n_num_features=_n_num,
            n_cat_features=_n_cat,
            duration_min=round((time.monotonic() - _t_model) / 60, 4),
        )

        return self

    def get_number_of_params(self, requires_grad=True):
        """Calculate the number of parameters in the model.

        Parameters
        ----------
        requires_grad : bool, optional
            If True, only count the parameters that require gradients (trainable parameters).
            If False, count all parameters. Default is True.

        Returns
        -------
        int
            The total number of parameters in the model.

        Raises
        ------
        ValueError
            If the model has not been built prior to calling this method.
        """
        if not self._built:
            raise ValueError("The model must be built before the number of parameters can be estimated")
        if requires_grad:
            return sum(p.numel() for p in self._task_model.parameters() if p.requires_grad)  # type: ignore
        return sum(p.numel() for p in self._task_model.parameters())  # type: ignore

    # ------------------------------------------------------------------
    # Training loop
    # ------------------------------------------------------------------

    def fit(
        self,
        X,
        y,
        regression: bool,
        val_size: float = 0.2,
        X_val=None,
        y_val=None,
        embeddings=None,
        embeddings_val=None,
        num_classes: int | None = None,
        max_epochs: int = 100,
        random_state: int = 101,
        batch_size: int = 128,
        shuffle: bool = True,
        stratify: bool = True,
        patience: int = 15,
        monitor: str = "val_loss",
        mode: str = "min",
        lr: float | None = None,
        lr_patience: int | None = None,
        lr_factor: float | None = None,
        weight_decay: float | None = None,
        checkpoint_path="model_checkpoints",
        dataloader_kwargs={},
        train_metrics: dict[str, Callable] | None = None,
        val_metrics: dict[str, Callable] | None = None,
        rebuild=True,
        loss_fct: Callable | None = None,
        sampler=None,
        **trainer_kwargs,
    ):
        """Trains the model using the provided training data.

        Parameters
        ----------
        X : DataFrame or array-like, shape (n_samples, n_features)
            The training input samples.
        y : array-like, shape (n_samples,) or (n_samples, n_targets)
            The target values.
        regression : bool
            Whether this is a regression task.
        val_size : float, default=0.2
            Proportion of the dataset for the validation split when ``X_val``
            is ``None``.
        X_val : DataFrame or array-like, optional
            Explicit validation features.
        y_val : array-like, optional
            Explicit validation targets.
        embeddings : array-like, optional
            Pre-computed embeddings for training samples.
        embeddings_val : array-like, optional
            Pre-computed embeddings for validation samples.
        num_classes : int or None, optional
            Number of target classes (classification only).
        max_epochs : int, default=100
            Maximum number of training epochs.
        random_state : int, default=101
            RNG seed for reproducibility.
        batch_size : int, default=128
            Mini-batch size.
        shuffle : bool, default=True
            Whether to shuffle training data each epoch.
        stratify : bool, default=True
            Whether to stratify the validation split on ``y`` for classification
            tasks so the split keeps the same class proportions. Ignored for
            regression. When a ``TrainerConfig`` is set, its ``stratify`` value
            takes precedence.
        patience : int, default=15
            Early-stopping patience (epochs without validation improvement).
        monitor : str, default="val_loss"
            Metric to monitor for early stopping.
        mode : str, default="min"
            Whether the monitored metric should be minimised (``"min"``) or
            maximised (``"max"``).
        lr : float or None, optional
            Learning rate override.
        lr_patience : int or None, optional
            LR scheduler patience override.
        lr_factor : float or None, optional
            LR scheduler reduction factor override.
        weight_decay : float or None, optional
            Weight-decay (L2 penalty) override.
        checkpoint_path : str, default="model_checkpoints"
            Directory for Lightning checkpoints.
        dataloader_kwargs : dict, default={}
            Extra kwargs forwarded to the PyTorch DataLoader.
        train_metrics : dict or None, optional
            TorchMetrics to log during training.
        val_metrics : dict or None, optional
            TorchMetrics to log during validation.
        rebuild : bool, default=True
            Whether to rebuild the model when already built.
        loss_fct : Callable or None, optional
            Custom loss function override.
        sampler : {"balanced", True}, array-like, or None, optional
            Weighted-sampling specification.
        **trainer_kwargs
            Additional keyword arguments forwarded to ``pl.Trainer``.

        Returns
        -------
        self
        """
        # When trainer_config is active, override all training-loop params from it
        if self.trainer_config is not None:
            tc = self.trainer_config
            max_epochs = tc.max_epochs
            batch_size = tc.batch_size
            val_size = tc.val_size
            shuffle = tc.shuffle
            stratify = tc.stratify
            patience = tc.patience
            monitor = tc.monitor
            mode = tc.mode
            checkpoint_path = tc.checkpoint_path

        # Validate inputs before any preprocessing or model construction
        from deeptab.models.base import _validate_fit_inputs

        _validate_fit_inputs(X, y, regression=regression)

        # When random_state was fixed at construction time, honour it
        if self.random_state is not None:
            random_state = self.random_state

        # Seed all RNGs so that weight init, dropout masks, and DataLoader
        # shuffling are all deterministic when a random_state is provided.
        if random_state is not None:
            from deeptab.core.reproducibility import set_seed

            set_seed(random_state)

        # Generate a short unique run id for this fit() call so that
        # concurrent/repeated runs are distinguishable in the event log.
        self._run_id = uuid.uuid4().hex[:8]
        self._fit_start_ms = time.monotonic()

        # ---------------------------------------------------------------
        # Per-run output directory
        # Create a run directory whenever an ObservabilityConfig is present
        # so that ModelCheckpoint always writes into <run_dir>/checkpoints/
        # instead of the fallback global 'model_checkpoints/' directory.
        # ---------------------------------------------------------------
        _obs_config = getattr(self, "_observability_config", None)
        _run_dir_name: str | None = None
        self._run_dir = None
        if _obs_config is not None:
            from deeptab.core.observability import create_run_dir, write_run_config

            self._run_dir, _run_dir_name = create_run_dir(_obs_config, self._run_id)
            # Write config.yaml to the run directory.
            try:
                write_run_config(self._run_dir, self.get_params())  # type: ignore[attr-defined]
            except Exception:  # noqa: S110
                pass
            # (Re-)build the per-run structured logger so lifecycle.jsonl
            # lands inside this run's directory.
            if _obs_config.structured_logging:
                from deeptab.core.observability import build_structlog_logger

                self._event_logger = build_structlog_logger(_obs_config, run_dir=self._run_dir)

        self._emit_event(
            "fit.started",
            model_class=type(self).__name__,
            n_samples=len(X),
            n_features=X.shape[1] if hasattr(X, "shape") else len(X.columns),
            random_state=getattr(self, "random_state", None),
        )

        if rebuild:
            self._build_model(
                X=X,
                y=y,
                regression=regression,
                val_size=val_size,
                X_val=X_val,
                y_val=y_val,
                embeddings=embeddings,
                embeddings_val=embeddings_val,
                num_classes=num_classes,
                random_state=random_state,  # type: ignore[arg-type]
                batch_size=batch_size,
                shuffle=shuffle,
                stratify=stratify,
                lr=lr,
                lr_patience=lr_patience,
                lr_factor=lr_factor,
                weight_decay=weight_decay,
                dataloader_kwargs=dataloader_kwargs,
                train_metrics=train_metrics,
                val_metrics=val_metrics,
                loss_fct=loss_fct,
                sampler=sampler,
            )
        else:
            if not self._built:
                raise ValueError(
                    "The model must be built before calling the fit method. "
                    "Either call .build_model() or set rebuild=True"
                )

        # n_params computed in _build_model and emitted via model.created;
        # recalculate here for _log_run_metadata_to_mlflow and fit.completed.
        _n_params = sum(p.numel() for p in self._task_model.parameters() if p.requires_grad)  # type: ignore[union-attr]

        early_stop_callback = EarlyStopping(
            monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode
        )

        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss",
            mode="min",
            save_top_k=1,
            # Use the per-run checkpoints/ sub-directory when a run dir exists.
            # When no run dir is active (no observability config), use a temp
            # directory so no model_checkpoints/ folder is left behind.
            dirpath=os.path.join(self._run_dir, "checkpoints") if self._run_dir else None,
            filename="best_model",
        )

        self._trainer = pl.Trainer(
            max_epochs=max_epochs,
            callbacks=[
                early_stop_callback,
                checkpoint_callback,
                ModelSummary(max_depth=2),
            ],
            # Let an explicit `logger=` in trainer_kwargs override our default.
            logger=trainer_kwargs.pop(
                "logger",
                _build_trainer_loggers(getattr(self, "_observability_config", None), _run_dir_name),
            ),
            **trainer_kwargs,
        )
        self._task_model.train()  # type: ignore[union-attr]
        self._task_model.estimator.train()  # type: ignore[union-attr]

        _t_train = time.monotonic()
        self._emit_event(
            "train.started",
            max_epochs=max_epochs,
            batch_size=batch_size,
            lr=lr,
            optimizer=getattr(self.trainer_config, "optimizer_type", None) if self.trainer_config is not None else None,
            patience=patience,
            val_size=val_size,
        )
        self._trainer.fit(self._task_model, self._data_module)  # type: ignore

        self._best_model_path = checkpoint_callback.best_model_path
        if self._best_model_path:
            torch.serialization.add_safe_globals([type(self.config)])
            checkpoint = torch.load(self._best_model_path, weights_only=False)
            self._task_model.load_state_dict(checkpoint["state_dict"])  # type: ignore

        # Parse best epoch from checkpoint filename (epoch=N pattern).
        _best_epoch: int | None = None
        if self._best_model_path:
            _m = re.search(r"epoch=(\d+)", self._best_model_path)
            if _m:
                _best_epoch = int(_m.group(1))
        _best_val_loss = (
            checkpoint_callback.best_model_score.item() if checkpoint_callback.best_model_score is not None else None
        )
        _n_params = sum(p.numel() for p in self._task_model.parameters() if p.requires_grad)  # type: ignore[union-attr]
        self._emit_event(
            "train.completed",
            best_epoch=_best_epoch,
            best_val_loss=_best_val_loss,
            n_epochs_run=getattr(self._trainer, "current_epoch", None),
            duration_min=round((time.monotonic() - _t_train) / 60, 4),
        )

        _total_duration_min = round((time.monotonic() - self._fit_start_ms) / 60, 4)

        # Write per-run summary.json BEFORE MLflow artifact logging so it
        # can be uploaded alongside config.yaml and lifecycle.jsonl.
        if self._run_dir is not None:
            from deeptab.core.observability import write_run_summary

            write_run_summary(
                self._run_dir,
                {
                    "run_id": self._run_id,
                    "model_class": type(self).__name__,
                    "n_params": _n_params,
                    "n_samples": len(X) if hasattr(X, "__len__") else None,
                    "best_val_loss": _best_val_loss,
                    "best_epoch": _best_epoch,
                    "n_epochs_run": getattr(self._trainer, "current_epoch", None),
                    "duration_min": _total_duration_min,
                },
            )

        self.is_fitted_ = True
        self._log_run_metadata_to_mlflow(
            n_samples=len(X) if hasattr(X, "__len__") else None,
            n_features=getattr(self, "n_features_in_", None),
            n_train=getattr(getattr(self, "_data_module", None), "y_train", None),
            n_val=getattr(getattr(self, "_data_module", None), "y_val", None),
            n_params=_n_params,
            best_val_loss=_best_val_loss,
            best_epoch=_best_epoch,
        )
        self._emit_event(
            "fit.completed",
            status="success",
            model_class=type(self).__name__,
            n_params=_n_params,
            best_val_loss=_best_val_loss,
            duration_min=_total_duration_min,
        )
        return self

    def _log_run_metadata_to_mlflow(
        self,
        n_samples: int | None,
        n_features: int | None,
        n_train: Any,
        n_val: Any,
        n_params: int,
        best_val_loss: float | None,
        best_epoch: int | None,
    ) -> None:
        """Log hyperparameters, dataset stats, tags, and run summary to MLflow.

        Called once at the end of ``fit()``.  Does nothing when MLflow is not
        in the active experiment trackers.
        """
        obs = getattr(self, "_observability_config", None)
        if obs is None or "mlflow" not in obs.experiment_trackers:
            return

        try:
            from lightning.pytorch.loggers import MLFlowLogger
        except ImportError:
            return

        # Find the MLFlowLogger that was active during this training run.
        mlflow_logger: Any = next(
            (lg for lg in (getattr(self._trainer, "loggers", None) or []) if isinstance(lg, MLFlowLogger)),
            None,
        )
        if mlflow_logger is None or mlflow_logger.run_id is None:
            return

        run_id: str = mlflow_logger.run_id
        client = mlflow_logger.experiment  # MlflowClient

        # ------------------------------------------------------------------
        # 1. Hyperparameters — model config + trainer config (flat, prefixed)
        # ------------------------------------------------------------------
        params: dict[str, str] = {}

        if is_dataclass(self.config):
            for f in dataclass_fields(self.config):
                v = getattr(self.config, f.name)
                if v is not None:
                    params[f"model/{f.name}"] = str(v)

        tc = getattr(self, "trainer_config", None)
        if tc is not None and is_dataclass(tc):
            for f in dataclass_fields(tc):
                v = getattr(tc, f.name)
                if v is not None:
                    params[f"trainer/{f.name}"] = str(v)

        # ------------------------------------------------------------------
        # 2. Dataset stats
        # ------------------------------------------------------------------
        dm = getattr(self, "_data_module", None)
        _n_train = len(n_train) if n_train is not None else None
        _n_val = len(n_val) if n_val is not None else None
        for k, v in {
            "data/n_samples": n_samples,
            "data/n_features": n_features,
            "data/n_train": _n_train,
            "data/n_val": _n_val,
            "data/n_num_features": len(dm.num_feature_info)
            if dm is not None and getattr(dm, "num_feature_info", None) is not None
            else None,  # type: ignore[union-attr]
            "data/n_cat_features": len(dm.cat_feature_info)
            if dm is not None and getattr(dm, "cat_feature_info", None) is not None
            else None,  # type: ignore[union-attr]
        }.items():
            if v is not None:
                params[k] = str(v)

        # ------------------------------------------------------------------
        # 3. Training summary
        # ------------------------------------------------------------------
        for k, v in {
            "train/n_params": n_params,
            "train/best_epoch": best_epoch,
            "train/best_val_loss": f"{best_val_loss:.6f}" if best_val_loss is not None else None,
        }.items():
            if v is not None:
                params[k] = str(v)

        # Log params in batches of 100 (MLflow API limit per call).
        import mlflow.entities  # type: ignore[import-untyped]

        items = list(params.items())
        for i in range(0, len(items), 100):
            batch = [mlflow.entities.Param(k, v) for k, v in items[i : i + 100]]
            client.log_batch(run_id, params=batch)

        # ------------------------------------------------------------------
        # 4. Tags — model class, deeptab version, task type
        # ------------------------------------------------------------------
        try:
            from deeptab._version import __version__ as _dtv
        except ImportError:
            _dtv = "unknown"

        for tag_key, tag_val in {
            "deeptab.model_class": type(self).__name__,
            "deeptab.version": _dtv,
            "deeptab.random_state": str(getattr(self, "random_state", None)),
        }.items():
            client.set_tag(run_id, tag_key, tag_val)

        # ------------------------------------------------------------------
        # 5. Run artifacts — config.yaml, lifecycle.jsonl, summary.json,
        #    and checkpoints from the per-run directory (when present).
        # ------------------------------------------------------------------
        import os

        _run_dir = getattr(self, "_run_dir", None)
        if _run_dir is not None:
            for fname in ("config.yaml", "config.json", "lifecycle.jsonl", "summary.json"):
                fpath = os.path.join(_run_dir, fname)
                if os.path.exists(fpath):
                    try:
                        client.log_artifact(run_id, fpath)
                    except Exception:  # noqa: S110
                        pass
            ckpt_dir = os.path.join(_run_dir, "checkpoints")
            if os.path.isdir(ckpt_dir):
                for ckpt in os.listdir(ckpt_dir):
                    try:
                        client.log_artifact(run_id, os.path.join(ckpt_dir, ckpt), artifact_path="checkpoints")
                    except Exception:  # noqa: S110
                        pass

    # ------------------------------------------------------------------
    # Pre-training
    # ------------------------------------------------------------------

    def _pretrain(
        self,
        base_model,
        train_dataloader,
        pretrain_epochs=5,
        k_neighbors=5,
        temperature=0.1,
        save_path="pretrained_embeddings.pth",
        regression=True,
        lr=1e-3,
        use_positive=True,
        use_negative=True,
        pool_sequence=True,
    ):
        """Run a contrastive pre-training pass to warm-start embeddings.

        Delegates to :func:`~deeptab.training.pretrain_embeddings`. Call
        this before :meth:`fit` when you want to initialise the backbone
        with representation learning before fine-tuning on the target task.

        Parameters
        ----------
        base_model :
            The backbone model to pre-train.
        train_dataloader : DataLoader
            DataLoader that yields batches of tabular features.
        pretrain_epochs : int, default=5
            Number of contrastive pre-training epochs.
        k_neighbors : int, default=5
            Number of nearest neighbours used to construct positive pairs.
        temperature : float, default=0.1
            Softmax temperature for the contrastive loss.
        save_path : str, default="pretrained_embeddings.pth"
            Path to save the pre-trained weights.
        regression : bool, default=True
            Whether the downstream task is regression.
        lr : float, default=1e-3
            Learning rate for the pre-training optimiser.
        use_positive : bool, default=True
            Whether to include positive-pair terms in the loss.
        use_negative : bool, default=True
            Whether to include negative-pair terms in the loss.
        pool_sequence : bool, default=True
            Whether to pool sequence-dimension embeddings before computing
            the contrastive loss.
        """
        pretrain_embeddings(
            base_model=base_model,
            train_dataloader=train_dataloader,
            pretrain_epochs=pretrain_epochs,
            k_neighbors=k_neighbors,
            temperature=temperature,
            save_path=save_path,
            regression=regression,
            lr=lr,
            use_positive=use_positive,
            use_negative=use_negative,
            pool_sequence=pool_sequence,
        )