Source code for deeptab.models.utils.sklearn_parent

import warnings
from collections.abc import Callable

import lightning as pl
import pandas as pd
import torch
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
from pretab.preprocessor import Preprocessor
from sklearn.base import BaseEstimator
from skopt import gp_minimize
from torch.utils.data import DataLoader
from tqdm import tqdm

from ...base_models.utils.lightning_wrapper import TaskModel
from ...base_models.utils.pretraining import pretrain_embeddings
from ...data_utils.datamodule import MambularDataModule
from ...utils.config_mapper import activation_mapper, get_search_space, round_to_nearest_16


class SklearnBase(BaseEstimator):
    def __init__(self, model, config, **kwargs):
        self.preprocessor_arg_names = [
            "n_bins",
            "feature_preprocessing",
            "numerical_preprocessing",
            "categorical_preprocessing",
            "use_decision_tree_bins",
            "binning_strategy",
            "task",
            "cat_cutoff",
            "treat_all_integers_as_numerical",
            "degree",
            "scaling_strategy",
            "n_knots",
            "use_decision_tree_knots",
            "knots_strategy",
            "spline_implementation",
        ]

        self.config_kwargs = {
            k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer")
        }
        self.config = config(**self.config_kwargs)

        self.preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names}

        self.preprocessor = Preprocessor(**self.preprocessor_kwargs)
        self.estimator = model
        self.task_model = None
        self.built = False

        self.optimizer_type = kwargs.get("optimizer_type", "Adam")

        self.optimizer_kwargs = {
            k: v
            for k, v in kwargs.items()
            if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"]
            and k.startswith("optimizer_")
        }

    def get_params(self, deep=True):
        """Get parameters for this estimator."""
        params = {}
        params.update(self.config_kwargs)
        params.update(self.preprocessor_kwargs)
        if deep:
            get_params_fn = getattr(self.preprocessor, "get_params", None)
            if get_params_fn is not None:
                preprocessor_params = {
                    key: value for key, value in get_params_fn().items() if key in self.preprocessor_arg_names
                }
                params.update(preprocessor_params)
        return params

    def set_params(self, **parameters):
        """Set the parameters of this estimator."""
        config_params = {k: v for k, v in parameters.items() if k not in self.preprocessor_arg_names}
        preprocessor_params = {k: v for k, v in parameters.items() if k in self.preprocessor_arg_names}

        # Update config and preprocessor parameters correctly
        if config_params:
            self.config_kwargs.update(config_params)

        if preprocessor_params:
            self.preprocessor_kwargs.update(preprocessor_params)
            self.preprocessor.set_params(**self.preprocessor_kwargs)  # type: ignore[attr-defined]

        return self

    def __getstate__(self):
        state = self.__dict__.copy()
        state["task_model"] = None  # Avoid serializing the task model
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.task_model = None  # Reinitialize task model

    def _build_model(
        self,
        X,
        y,
        regression: bool,
        val_size: float = 0.2,
        X_val=None,
        y_val=None,
        embeddings=None,
        embeddings_val=None,
        num_classes: int | None = None,
        random_state: int = 101,
        batch_size: int = 128,
        shuffle: bool = True,
        lr: float | None = None,
        lr_patience: int | None = None,
        lr_factor: float | None = None,
        weight_decay: float | None = None,
        train_metrics: dict[str, Callable] | None = None,
        val_metrics: dict[str, Callable] | None = None,
        dataloader_kwargs={},
    ):
        """Builds the model using the provided training data.

        Parameters
        ----------
        X : DataFrame or array-like, shape (n_samples, n_features)
            The training input samples.
        y : array-like, shape (n_samples,) or (n_samples, n_targets)
            The target values (real numbers).
        val_size : float, default=0.2
            The proportion of the dataset to include in the validation split if `X_val` is None.
            Ignored if `X_val` is provided.
        X_val : DataFrame or array-like, shape (n_samples, n_features), optional
            The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
        y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
            The validation target values. Required if `X_val` is provided.
        random_state : int, default=101
            Controls the shuffling applied to the data before applying the split.
        batch_size : int, default=64
            Number of samples per gradient update.
        shuffle : bool, default=True
            Whether to shuffle the training data before each epoch.
        lr : float, default=1e-3
            Learning rate for the optimizer.
        lr_patience : int, default=10
            Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
        factor : float, default=0.1
            Factor by which the learning rate will be reduced.
        weight_decay : float, default=0.025
            Weight decay (L2 penalty) coefficient.
        train_metrics : dict, default=None
            torch.metrics dict to be logged during training.
        val_metrics : dict, default=None
            torch.metrics dict to be logged during validation.
        dataloader_kwargs: dict, default={}
            The kwargs for the pytorch dataloader class.



        Returns
        -------
        self : object
            The built regressor.
        """
        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(X)
        if isinstance(y, pd.Series):
            y = y.values
        if X_val is not None:
            if not isinstance(X_val, pd.DataFrame):
                X_val = pd.DataFrame(X_val)
            if isinstance(y_val, pd.Series):
                y_val = y_val.values

        self.data_module = MambularDataModule(
            preprocessor=self.preprocessor,
            batch_size=batch_size,
            shuffle=shuffle,
            X_val=X_val,
            y_val=y_val,
            val_size=val_size,
            random_state=random_state,
            regression=regression,
            **dataloader_kwargs,
        )

        self.data_module.preprocess_data(
            X,
            y,
            X_val=X_val,
            y_val=y_val,
            embeddings_train=embeddings,
            embeddings_val=embeddings_val,
            val_size=val_size,
            random_state=random_state,
        )

        self.task_model = TaskModel(
            model_class=self.estimator,  # type: ignore
            config=self.config,
            feature_information=(
                self.data_module.num_feature_info,
                self.data_module.cat_feature_info,
                self.data_module.embedding_feature_info,
            ),
            lr=lr if lr is not None else self.config.lr,
            lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience),
            lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
            weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay),
            num_classes=num_classes,  # type: ignore[arg-type]
            train_metrics=train_metrics,
            val_metrics=val_metrics,
            optimizer_type=self.optimizer_type,
            optimizer_args=self.optimizer_kwargs,
        )

        self.built = True
        self.estimator = self.task_model.estimator

        return self

    def get_number_of_params(self, requires_grad=True):
        """Calculate the number of parameters in the model.

        Parameters
        ----------
        requires_grad : bool, optional
            If True, only count the parameters that require gradients (trainable parameters).
            If False, count all parameters. Default is True.

        Returns
        -------
        int
            The total number of parameters in the model.

        Raises
        ------
        ValueError
            If the model has not been built prior to calling this method.
        """
        if not self.built:
            raise ValueError("The model must be built before the number of parameters can be estimated")
        else:
            if requires_grad:
                return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad)  # type: ignore
            else:
                return sum(p.numel() for p in self.task_model.parameters())  # type: ignore

    def fit(
        self,
        X,
        y,
        regression: bool,
        val_size: float = 0.2,
        X_val=None,
        y_val=None,
        embeddings=None,
        embeddings_val=None,
        num_classes: int | None = None,
        max_epochs: int = 100,
        random_state: int = 101,
        batch_size: int = 128,
        shuffle: bool = True,
        patience: int = 15,
        monitor: str = "val_loss",
        mode: str = "min",
        lr: float | None = None,
        lr_patience: int | None = None,
        lr_factor: float | None = None,
        weight_decay: float | None = None,
        checkpoint_path="model_checkpoints",
        dataloader_kwargs={},
        train_metrics: dict[str, Callable] | None = None,
        val_metrics: dict[str, Callable] | None = None,
        rebuild=True,
        **trainer_kwargs,
    ):
        """Trains the regression model using the provided training data. Optionally, a separate validation set can be
        used.

        Parameters
        ----------
        X : DataFrame or array-like, shape (n_samples, n_features)
            The training input samples.
        y : array-like, shape (n_samples,) or (n_samples, n_targets)
            The target values (real numbers).
        val_size : float, default=0.2
            The proportion of the dataset to include in the validation split if `X_val` is None.
            Ignored if `X_val` is provided.
        X_val : DataFrame or array-like, shape (n_samples, n_features), optional
            The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
        y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
            The validation target values. Required if `X_val` is provided.
        max_epochs : int, default=100
            Maximum number of epochs for training.
        random_state : int, default=101
            Controls the shuffling applied to the data before applying the split.
        batch_size : int, default=64
            Number of samples per gradient update.
        shuffle : bool, default=True
            Whether to shuffle the training data before each epoch.
        patience : int, default=10
            Number of epochs with no improvement on the validation loss to wait before early stopping.
        monitor : str, default="val_loss"
            The metric to monitor for early stopping.
        mode : str, default="min"
            Whether the monitored metric should be minimized (`min`) or maximized (`max`).
        lr : float, default=1e-3
            Learning rate for the optimizer.
        lr_patience : int, default=10
            Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
        factor : float, default=0.1
            Factor by which the learning rate will be reduced.
        weight_decay : float, default=0.025
            Weight decay (L2 penalty) coefficient.
        checkpoint_path : str, default="model_checkpoints"
            Path where the checkpoints are being saved.
        dataloader_kwargs: dict, default={}
            The kwargs for the pytorch dataloader class.
        train_metrics : dict, default=None
            torch.metrics dict to be logged during training.
        val_metrics : dict, default=None
            torch.metrics dict to be logged during validation.
        rebuild: bool, default=True
            Whether to rebuild the model when it already was built.
        **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.


        Returns
        -------
        self : object
            The fitted regressor.
        """
        if rebuild and not self.built:
            self._build_model(
                X=X,
                y=y,
                regression=regression,
                val_size=val_size,
                X_val=X_val,
                y_val=y_val,
                embeddings=embeddings,
                embeddings_val=embeddings_val,
                num_classes=num_classes,
                random_state=random_state,
                batch_size=batch_size,
                shuffle=shuffle,
                lr=lr,
                lr_patience=lr_patience,
                lr_factor=lr_factor,
                weight_decay=weight_decay,
                dataloader_kwargs=dataloader_kwargs,
                train_metrics=train_metrics,
                val_metrics=val_metrics,
            )

        else:
            if not self.built:
                raise ValueError(
                    "The model must be built before calling the fit method. \
                                 Either call .build_model() or set rebuild=True"
                )

        early_stop_callback = EarlyStopping(
            monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode
        )

        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss",  # Adjust according to your validation metric
            mode="min",
            save_top_k=1,
            dirpath=checkpoint_path,  # Specify the directory to save checkpoints
            filename="best_model",
        )

        # Initialize the trainer and train the model
        self.trainer = pl.Trainer(
            max_epochs=max_epochs,
            callbacks=[
                early_stop_callback,
                checkpoint_callback,
                ModelSummary(max_depth=2),
            ],
            **trainer_kwargs,
        )
        self.task_model.train()  # type: ignore[union-attr]
        self.task_model.estimator.train()  # type: ignore[union-attr]
        self.trainer.fit(self.task_model, self.data_module)  # type: ignore

        self.best_model_path = checkpoint_callback.best_model_path
        if self.best_model_path:
            torch.serialization.add_safe_globals([type(self.config)])
            checkpoint = torch.load(self.best_model_path, weights_only=False)
            self.task_model.load_state_dict(checkpoint["state_dict"])  # type: ignore

        self.is_fitted_ = True
        return self

    def _score(self, X, y, embeddings, metric):
        # Explicitly load the best model state if needed
        if hasattr(self, "trainer") and self.best_model_path:
            torch.serialization.add_safe_globals([type(self.config)])
            checkpoint = torch.load(self.best_model_path, weights_only=False)
            self.task_model.load_state_dict(checkpoint["state_dict"])  # type: ignore

        predictions = self.predict(X, embeddings)

        return metric(y, predictions)

    def predict(self, X, embeddings=None, device=None):
        raise NotImplementedError("The 'predict' method is not implemented in the Parent class.")

    def encode(self, X, embeddings=None, batch_size=64):
        """
        Encodes input data using the trained model's embedding layer.

        Parameters
        ----------
        X : array-like or DataFrame
            Input data to be encoded.
        batch_size : int, optional, default=64
            Batch size for encoding.

        Returns
        -------
        torch.Tensor
            Encoded representations of the input data.

        Raises
        ------
        ValueError
            If the model or data module is not fitted.
        """
        # Ensure model and data module are initialized
        if self.task_model is None or self.data_module is None:
            raise ValueError("The model or data module has not been fitted yet.")
        encoded_dataset = self.data_module.preprocess_new_data(X, embeddings)

        data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False)

        # Process data in batches
        encoded_outputs = []
        for batch in tqdm(data_loader):
            embeddings = self.task_model.estimator.encode(
                batch
            )  # Call your encode function  # type: ignore[union-attr]
            encoded_outputs.append(embeddings)

        # Concatenate all encoded outputs
        encoded_outputs = torch.cat(encoded_outputs, dim=0)

        return encoded_outputs

    def _pretrain(
        self,
        base_model,
        train_dataloader,
        pretrain_epochs=5,
        k_neighbors=5,
        temperature=0.1,
        save_path="pretrained_embeddings.pth",
        regression=True,
        lr=1e-3,
        use_positive=True,
        use_negative=True,
        pool_sequence=True,
    ):
        pretrain_embeddings(
            base_model=base_model,
            train_dataloader=train_dataloader,
            pretrain_epochs=pretrain_epochs,
            k_neighbors=k_neighbors,
            temperature=temperature,
            save_path=save_path,
            regression=regression,
            lr=lr,
            use_positive=use_positive,
            use_negative=use_negative,
            pool_sequence=pool_sequence,
        )

    # ------------------------------------------------------------------
    # Persistence
    # ------------------------------------------------------------------

    def save(self, path: str) -> None:
        """Save the fitted model to *path*.

        The bundle written by this method can be restored with
        :meth:`load`.  It contains all state required for inference:
        the config, the fitted preprocessor, feature metadata, and
        the neural-network weights.

        Parameters
        ----------
        path : str
            Destination file path (e.g. ``"model.pt"``).

        Raises
        ------
        ValueError
            If the model has not been fitted yet.
        """
        if not getattr(self, "is_fitted_", False):
            raise ValueError("Model must be fitted before saving.")
        if self.task_model is None:
            raise RuntimeError("task_model is unexpectedly None after fitting.")
        bundle = {
            "_class": type(self),
            "config": self.config,
            "config_kwargs": self.config_kwargs,
            "preprocessor_kwargs": getattr(self, "preprocessor_kwargs", {}),
            "preprocessor": self.preprocessor,
            "feature_info": {
                "num": self.data_module.num_feature_info,
                "cat": self.data_module.cat_feature_info,
                "emb": self.data_module.embedding_feature_info,
            },
            "batch_size": self.data_module.batch_size,
            "regression": self.data_module.regression,
            "model_class": type(self.estimator),
            "num_classes": self.task_model.num_classes,
            "lss": False,
            "family": None,
            "optimizer_type": self.optimizer_type,
            "optimizer_kwargs": self.optimizer_kwargs,
            "lr": self.task_model.lr,
            "lr_patience": self.task_model.lr_patience,
            "lr_factor": self.task_model.lr_factor,
            "weight_decay": self.task_model.weight_decay,
            "task_model_state_dict": self.task_model.state_dict(),
        }
        torch.save(bundle, path)

    @classmethod
    def load(cls, path: str):
        """Load and return a fitted model from *path*.

        Parameters
        ----------
        path : str
            Path to a file previously written by :meth:`save`.

        Returns
        -------
        estimator
            A fully reconstructed, ready-to-predict estimator of the
            same type that was saved.
        """
        bundle = torch.load(path, weights_only=False)

        obj = bundle["_class"].__new__(bundle["_class"])
        obj.config = bundle["config"]
        obj.config_kwargs = bundle["config_kwargs"]
        obj.preprocessor_kwargs = bundle.get("preprocessor_kwargs", {})
        obj.preprocessor = bundle["preprocessor"]
        obj.optimizer_type = bundle["optimizer_type"]
        obj.optimizer_kwargs = bundle["optimizer_kwargs"]
        obj.built = True
        obj.is_fitted_ = True
        obj.preprocessor_arg_names = [
            "n_bins",
            "feature_preprocessing",
            "numerical_preprocessing",
            "categorical_preprocessing",
            "use_decision_tree_bins",
            "binning_strategy",
            "task",
            "cat_cutoff",
            "treat_all_integers_as_numerical",
            "degree",
            "scaling_strategy",
            "n_knots",
            "use_decision_tree_knots",
            "knots_strategy",
            "spline_implementation",
        ]

        obj.data_module = MambularDataModule(
            preprocessor=bundle["preprocessor"],
            batch_size=bundle["batch_size"],
            shuffle=False,
            regression=bundle["regression"],
        )
        obj.data_module.num_feature_info = bundle["feature_info"]["num"]
        obj.data_module.cat_feature_info = bundle["feature_info"]["cat"]
        obj.data_module.embedding_feature_info = bundle["feature_info"]["emb"]

        obj.task_model = TaskModel(
            model_class=bundle["model_class"],
            config=bundle["config"],
            feature_information=(
                bundle["feature_info"]["num"],
                bundle["feature_info"]["cat"],
                bundle["feature_info"]["emb"],
            ),
            num_classes=bundle["num_classes"],
            lss=bundle["lss"],
            family=bundle["family"],
            optimizer_type=bundle["optimizer_type"],
            optimizer_args=bundle["optimizer_kwargs"],
            lr=bundle["lr"],
            lr_patience=bundle["lr_patience"],
            lr_factor=bundle["lr_factor"],
            weight_decay=bundle["weight_decay"],
        )
        obj.task_model.load_state_dict(bundle["task_model_state_dict"])
        obj.task_model.eval()
        obj.estimator = obj.task_model.estimator

        obj.trainer = pl.Trainer(
            max_epochs=1,
            enable_progress_bar=False,
            enable_model_summary=False,
            logger=False,
        )

        return obj

    def optimize_hparams(
        self,
        X,
        y,
        regression,
        X_val=None,
        y_val=None,
        embeddings=None,
        embeddings_val=None,
        time=100,
        max_epochs=200,
        prune_by_epoch=True,
        prune_epoch=5,
        fixed_params={
            "pooling_method": "avg",
            "head_skip_layers": False,
            "head_layer_size_length": 0,
            "cat_encoding": "int",
            "head_skip_layer": False,
            "use_cls": False,
        },
        custom_search_space=None,
        **optimize_kwargs,
    ):
        """Optimizes hyperparameters using Bayesian optimization with optional pruning.

        Parameters
        ----------
        X : array-like
            Training data.
        y : array-like
            Training labels.
        X_val, y_val : array-like, optional
            Validation data and labels.
        time : int
            The number of optimization trials to run.
        max_epochs : int
            Maximum number of epochs for training.
        prune_by_epoch : bool
            Whether to prune based on a specific epoch (True) or the best validation loss (False).
        prune_epoch : int
            The specific epoch to prune by when prune_by_epoch is True.
        **optimize_kwargs : dict
            Additional keyword arguments passed to the fit method.

        Returns
        -------
        best_hparams : list
            Best hyperparameters found during optimization.
        """

        # Define the hyperparameter search space from the model config
        param_names, param_space = get_search_space(
            self.config,
            fixed_params=fixed_params,
            custom_search_space=custom_search_space,
        )

        # Initial model fitting to get the baseline validation loss
        self.fit(
            X,
            y,
            regression=regression,
            X_val=X_val,
            y_val=y_val,
            embeddings=embeddings,
            embeddings_val=embeddings_val,
            max_epochs=max_epochs,
        )
        best_val_loss = float("inf")

        if hasattr(self, "score") and callable(self.score):  # type: ignore[attr-defined]
            if X_val is not None and y_val is not None:
                val_loss = self.score(X_val, y_val)  # type: ignore[attr-defined]
            else:
                val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"]
        else:
            raise NotImplementedError("The 'score' method is not implemented in the child class.")

        best_val_loss = val_loss
        best_epoch_val_loss = self.task_model.epoch_val_loss_at(  # type: ignore
            prune_epoch
        )

        def _objective(hyperparams):
            nonlocal best_val_loss, best_epoch_val_loss  # Access across trials

            head_layer_sizes = []
            head_layer_size_length = None

            for key, param_value in zip(param_names, hyperparams, strict=False):
                if key == "head_layer_size_length":
                    head_layer_size_length = param_value
                elif key.startswith("head_layer_size_"):
                    head_layer_sizes.append(round_to_nearest_16(param_value))
                else:
                    field_type = self.config.__dataclass_fields__[key].type

                    # Check if the field is a callable (e.g., activation function)
                    if field_type == callable and isinstance(param_value, str):
                        if param_value in activation_mapper:
                            setattr(self.config, key, activation_mapper[param_value])
                        else:
                            raise ValueError(f"Unknown activation function: {param_value}")
                    else:
                        setattr(self.config, key, param_value)

            # Truncate or use part of head_layer_sizes based on the optimized length
            if head_layer_size_length is not None:
                self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length]

            # Build the model with updated hyperparameters
            self._build_model(
                X,
                y,
                regression=regression,
                X_val=X_val,
                y_val=y_val,
                embeddings=embeddings,
                embeddings_val=embeddings_val,
                lr=self.config.lr,
                **optimize_kwargs,
            )

            # Dynamically set the early pruning threshold
            if prune_by_epoch:
                early_pruning_threshold = best_epoch_val_loss * 1.5  # Prune based on specific epoch loss
            else:
                # Prune based on the best overall validation loss
                early_pruning_threshold = best_val_loss * 1.5  # type: ignore[operator]

            # Initialize the model with pruning
            self.task_model.early_pruning_threshold = early_pruning_threshold  # type: ignore
            self.task_model.pruning_epoch = prune_epoch  # type: ignore

            try:
                # Wrap the risky operation (model fitting) in a try-except block
                self.fit(
                    X,
                    y,
                    regression=regression,
                    X_val=X_val,
                    y_val=y_val,
                    max_epochs=max_epochs,
                    rebuild=False,
                )

                # Evaluate validation loss
                if hasattr(self, "score") and callable(self._score):
                    if X_val is not None and y_val is not None:
                        val_loss = self._score(X_val, y_val)  # type: ignore[call-arg]
                    else:
                        val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"]
                else:
                    raise NotImplementedError("The 'score' method is not implemented in the child class.")

                # Pruning based on validation loss at specific epoch
                epoch_val_loss = self.task_model.epoch_val_loss_at(  # type: ignore
                    prune_epoch
                )

                if prune_by_epoch and epoch_val_loss < best_epoch_val_loss:
                    best_epoch_val_loss = epoch_val_loss

                if val_loss < best_val_loss:  # type: ignore[operator]
                    best_val_loss = val_loss

                return val_loss

            except Exception as e:
                # Penalize the hyperparameter configuration with a large value
                print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}")
                return best_val_loss * 100  # Large value to discourage this configuration  # type: ignore[operator]

        # Perform Bayesian optimization using scikit-optimize
        result = gp_minimize(_objective, param_space, n_calls=time, random_state=42)

        # Update the model with the best-found hyperparameters
        best_hparams = result.x  # type: ignore
        head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None
        layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None

        # Iterate over the best hyperparameters found by optimization
        for key, param_value in zip(param_names, best_hparams, strict=False):
            if key.startswith("head_layer_size_") and head_layer_sizes is not None:
                # These are the individual head layer sizes
                head_layer_sizes.append(round_to_nearest_16(param_value))
            elif key.startswith("layer_size_") and layer_sizes is not None:
                # These are the individual layer sizes
                layer_sizes.append(round_to_nearest_16(param_value))
            else:
                # For all other config values, update normally
                field_type = self.config.__dataclass_fields__[key].type
                if field_type == callable and isinstance(param_value, str):
                    setattr(self.config, key, activation_mapper[param_value])
                else:
                    setattr(self.config, key, param_value)

        # After the loop, set head_layer_sizes or layer_sizes in the config
        if head_layer_sizes is not None and head_layer_sizes:
            self.config.head_layer_sizes = head_layer_sizes
        if layer_sizes is not None and layer_sizes:
            self.config.layer_sizes = layer_sizes

        print("Best hyperparameters found:", best_hparams)

        return best_hparams