MambAttention
Hybrid Mamba + Attention architecture. MambAttention interleaves Mamba SSM layers with multi-head self-attention layers, allowing the model to capture both local sequential patterns (via Mamba’s linear-time recurrence) and global dependencies across all features simultaneously (via attention).
When to Use
When you need the memory efficiency of Mamba for local patterns and the expressiveness of attention for global feature interactions. A natural upgrade from either Mambular or FTTransformer when neither alone is sufficient.
Limitations
More hyperparameters than either Mambular or FTTransformer alone.
Higher compute and memory cost than a pure Mamba or pure attention model.
Fewer community benchmarks available; expect more tuning effort.
API Reference
- class deeptab.models.MambAttentionRegressor(*args: Any, **kwargs: Any)[source]
MambAttention regressor. This class extends the SklearnBaseRegressor class and uses the MambAttention model with the default MambAttention configuration.
Notes
The parameters for this class include the attributes from the config dataclass as well as preprocessing arguments handled by the base class.
Configuration class for the Default Mambular Attention model with predefined hyperparameters.
- Parameters:
d_model (int, default=64) – Dimensionality of the model.
n_layers (int, default=4) – Number of layers in the model.
expand_factor (int, default=2) – Expansion factor for the feed-forward layers.
n_heads (int, default=8) – Number of attention heads in the model.
last_layer (str, default="attn") – Type of the last layer (e.g., ‘attn’).
n_mamba_per_attention (int, default=1) – Number of Mamba blocks per attention layer.
bias (bool, default=False) – Whether to use bias in the linear layers.
d_conv (int, default=4) – Dimensionality of the convolutional layers.
conv_bias (bool, default=True) – Whether to use bias in the convolutional layers.
dropout (float, default=0.0) – Dropout rate for regularization.
attn_dropout (float, default=0.2) – Dropout rate for the attention mechanism.
dt_rank (str, default="auto") – Rank of the decision tree.
d_state (int, default=128) – Dimensionality of the state in recurrent layers.
dt_scale (float, default=1.0) – Scaling factor for the decision tree.
dt_init (str, default="random") – Initialization method for the decision tree.
dt_max (float, default=0.1) – Maximum value for decision tree initialization.
dt_min (float, default=1e-04) – Minimum value for decision tree initialization.
dt_init_floor (float, default=1e-04) – Floor value for decision tree initialization.
norm (str, default="LayerNorm") – Type of normalization used in the model.
activation (callable, default=nn.SiLU()) – Activation function for the model.
head_layer_sizes (list, default=()) – Sizes of the fully connected layers in the model’s head.
head_dropout (float, default=0.5) – Dropout rate for the head layers.
head_skip_layers (bool, default=False) – Whether to use skip connections in the head layers.
head_activation (callable, default=nn.SELU()) – Activation function for the head layers.
head_use_batch_norm (bool, default=False) – Whether to use batch normalization in the head layers.
pooling_method (str, default="avg") – Pooling method to be used (‘avg’, ‘max’, etc.).
bidirectional (bool, default=False) – Whether to process input sequences bidirectionally.
use_learnable_interaction (bool, default=False) – Whether to use learnable feature interactions before passing through Mamba blocks.
use_cls (bool, default=False) – Whether to append a CLS token for sequence pooling.
shuffle_embeddings (bool, default=False) – Whether to shuffle embeddings before passing to Mamba layers.
cat_encoding (str, default="int") – Encoding method for categorical features (‘int’, ‘one-hot’, etc.).
AD_weight_decay (bool, default=True) – Whether weight decay is applied to A-D matrices.
BC_layer_norm (bool, default=False) – Whether to apply layer normalization to B-C matrices.
use_pscan (bool, default=False) – Whether to use PSCAN for the state-space model.
n_attention_layers (int, default=1) – Number of attention layers in the model.
feature_preprocessing (dict, optional) – Dictionary mapping feature names to specific preprocessing methods. Overrides global defaults.
n_bins (int, default=64) – Number of bins used for binning-based preprocessing (e.g., for discretizers or PLE).
numerical_preprocessing (str, default="ple") – Preprocessing method for numerical features (e.g., “standardization”, “minmax”, “ple”, “rbf”, etc.).
categorical_preprocessing (str, default="int") – Preprocessing method for categorical features (e.g., “int”, “ordinal”, “onehot”).
use_decision_tree_bins (bool, default=False) – Whether to use decision tree binning for numerical discretization.
binning_strategy (str, default="uniform") – Strategy for bin placement when not using tree-based methods. Options: “uniform”, “quantile”.
task (str, default="regression") – Problem type used to guide preprocessing (e.g., “regression” or “classification”).
cat_cutoff (float or int, default=0.03) – Threshold to determine whether integer-valued features are treated as categorical.
treat_all_integers_as_numerical (bool, default=False) – If True, treat all integer-typed columns as numerical regardless of cardinality.
degree (int, default=3) – Degree of polynomial or spline basis functions where applicable.
scaling_strategy (str, default="minmax") – Strategy for feature scaling (e.g., “standardization”, “minmax”, etc.).
n_knots (int, default=64) – Number of knots used in spline-based feature expansions.
use_decision_tree_knots (bool, default=True) – Whether to use decision tree-based knot placement for spline transformations.
knots_strategy (str, default="uniform") – Strategy for placing knots for splines (“uniform” or “quantile”).
spline_implementation (str, default="sklearn") – Which spline backend implementation to use (e.g., “sklearn”, “custom”).
min_unique_vals (int, default=5) – Minimum number of unique values required for a feature to be treated as numerical.
Examples
>>> from deeptab.models import MambAttentionRegressor >>> model = MambAttentionRegressor(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) >>> model.evaluate(X_test, y_test)- build_model(X, y, val_size=0.2, X_val=None, y_val=None, embeddings=None, embeddings_val=None, random_state=101, batch_size=128, shuffle=True, lr=None, lr_patience=None, lr_factor=None, weight_decay=None, train_metrics=None, val_metrics=None, dataloader_kwargs={})
Builds the model using the provided training data.
- Parameters:
X (DataFrame or array-like, shape (n_samples, n_features)) – The training input samples.
y (array-like, shape (n_samples,) or (n_samples, n_targets)) – The target values (real numbers).
val_size (
float) – The proportion of the dataset to include in the validation split ifX_valis None. Ignored ifX_valis provided.X_val (DataFrame or array-like, shape (n_samples, n_features), optional) – The validation input samples. If provided,
Xandyare not split and this data is used for validation.y_val (array-like, shape (n_samples,) or (n_samples, n_targets), optional) – The validation target values. Required if
X_valis provided.random_state (
int) – Controls the shuffling applied to the data before applying the split.batch_size (
int) – Number of samples per gradient update.shuffle (
bool) – Whether to shuffle the training data before each epoch.lr (
float|None) – Learning rate for the optimizer.lr_patience (
int|None) – Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.factor (float, default=0.1) – Factor by which the learning rate will be reduced.
weight_decay (
float|None) – Weight decay (L2 penalty) coefficient.train_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during training.val_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during validation.dataloader_kwargs (dict, default={}) – The kwargs for the pytorch dataloader class.
- Returns:
self – The built regressor.
- Return type:
object
- encode(X, embeddings=None, batch_size=64)
Encodes input data using the trained model’s embedding layer.
- Parameters:
X (array-like or DataFrame) – Input data to be encoded.
batch_size (int, optional, default=64) – Batch size for encoding.
- Returns:
Encoded representations of the input data.
- Return type:
torch.Tensor
- Raises:
ValueError – If the model or data module is not fitted.
- evaluate(X, y_true, embeddings=None, metrics=None)
Evaluate the model on the given data using specified metrics.
- Parameters:
X (array-like or pd.DataFrame of shape (n_samples, n_features)) – The input samples to predict.
y_true (array-like of shape (n_samples,) or (n_samples, n_outputs)) – The true target values against which to evaluate the predictions.
metrics (dict) – A dictionary where keys are metric names and values are the metric functions.
Notes
This method uses the
predictmethod to generate predictions and computes each metric.- Returns:
scores – A dictionary with metric names as keys and their corresponding scores as values.
- Return type:
dict
- fit(X, y, val_size=0.2, X_val=None, y_val=None, embeddings=None, embeddings_val=None, max_epochs=100, random_state=101, batch_size=128, shuffle=True, patience=15, monitor='val_loss', mode='min', lr=None, lr_patience=None, lr_factor=None, weight_decay=None, checkpoint_path='model_checkpoints', dataloader_kwargs={}, train_metrics=None, val_metrics=None, rebuild=True, **trainer_kwargs)
Trains the regression model using the provided training data. Optionally, a separate validation set can be used.
- Parameters:
X (DataFrame or array-like, shape (n_samples, n_features)) – The training input samples.
y (array-like, shape (n_samples,) or (n_samples, n_targets)) – The target values (real numbers).
val_size (
float) – The proportion of the dataset to include in the validation split ifX_valis None. Ignored ifX_valis provided.X_val (DataFrame or array-like, shape (n_samples, n_features), optional) – The validation input samples. If provided,
Xandyare not split and this data is used for validation.y_val (array-like, shape (n_samples,) or (n_samples, n_targets), optional) – The validation target values. Required if
X_valis provided.max_epochs (
int) – Maximum number of epochs for training.random_state (
int) – Controls the shuffling applied to the data before applying the split.batch_size (
int) – Number of samples per gradient update.shuffle (
bool) – Whether to shuffle the training data before each epoch.patience (
int) – Number of epochs with no improvement on the validation loss to wait before early stopping.monitor (
str) – The metric to monitor for early stopping.mode (
str) – Whether the monitored metric should be minimized (min) or maximized (max).lr (
float|None) – Learning rate for the optimizer.lr_patience (
int|None) – Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.factor (float, default=0.1) – Factor by which the learning rate will be reduced.
weight_decay (
float|None) – Weight decay (L2 penalty) coefficient.checkpoint_path (str, default="model_checkpoints") – Path where the checkpoints are being saved.
dataloader_kwargs (dict, default={}) – The kwargs for the pytorch dataloader class.
train_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during training.val_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during validation.rebuild (bool, default=True) – Whether to rebuild the model when it already was built.
**trainer_kwargs (Additional keyword arguments for PyTorch Lightning's Trainer class.)
- Returns:
self – The fitted regressor.
- Return type:
object
- get_number_of_params(requires_grad=True)
Calculate the number of parameters in the model.
- Parameters:
requires_grad (bool, optional) – If True, only count the parameters that require gradients (trainable parameters). If False, count all parameters. Default is True.
- Returns:
The total number of parameters in the model.
- Return type:
int
- Raises:
ValueError – If the model has not been built prior to calling this method.
- get_params(deep=True)
Get parameters for this estimator.
- classmethod load(path)
Load and return a fitted model from path.
- Parameters:
path (
str) – Path to a file previously written bysave().- Returns:
A fully reconstructed, ready-to-predict estimator of the same type that was saved.
- Return type:
estimator
- optimize_hparams(X, y, X_val=None, y_val=None, embeddings=None, embeddings_val=None, time=100, max_epochs=200, prune_by_epoch=True, prune_epoch=5, fixed_params={'cat_encoding': 'int', 'head_layer_size_length': 0, 'head_skip_layer': False, 'head_skip_layers': False, 'pooling_method': 'avg', 'use_cls': False}, custom_search_space=None, **optimize_kwargs)
Optimizes hyperparameters using Bayesian optimization with optional pruning.
- Parameters:
X (array-like) – Training data.
y (array-like) – Training labels.
X_val (array-like, optional) – Validation data and labels.
y_val (array-like, optional) – Validation data and labels.
time (int) – The number of optimization trials to run.
max_epochs (int) – Maximum number of epochs for training.
prune_by_epoch (bool) – Whether to prune based on a specific epoch (True) or the best validation loss (False).
prune_epoch (int) – The specific epoch to prune by when prune_by_epoch is True.
**optimize_kwargs (dict) – Additional keyword arguments passed to the fit method.
- Returns:
best_hparams – Best hyperparameters found during optimization.
- Return type:
list
- predict(X, embeddings=None, device=None)
Predicts target values for the given input samples.
- Parameters:
X (DataFrame or array-like, shape (n_samples, n_features)) – The input samples for which to predict target values.
- Returns:
predictions – The predicted target values.
- Return type:
ndarray, shape (n_samples,) or (n_samples, n_outputs)
- pretrain(pretrain_epochs=15, k_neighbors=10, temperature=0.1, save_path='pretrained_embeddings.pth', lr=0.001, use_positive=True, use_negative=False, pool_sequence=True)
Pretrains the embedding layer of the model using a contrastive learning approach.
This method performs pretraining by optimizing the embeddings with respect to neighborhood structure in the feature space. The embeddings are saved after training.
- Parameters:
pretrain_epochs (int, default=15) – Number of epochs to run pretraining.
k_neighbors (int, default=10) – Number of neighbors used in the contrastive loss computation.
temperature (float, default=0.1) – Temperature parameter for contrastive loss scaling.
save_path (str, default="pretrained_embeddings.pth") – Path to save the pretrained embeddings.
lr (float, default=1e-3) – Learning rate for the pretraining optimizer.
use_positive (bool, default=True) – Whether to include positive pairs in contrastive learning.
use_negative (bool, default=False) – Whether to include negative pairs in contrastive learning.
pool_sequence (bool, default=True) – Whether to apply sequence pooling before computing contrastive loss.
- Raises:
ValueError – If the model has not been built before calling this method.
ValueError – If the model does not contain an embedding layer.
Notes
This function requires that
self.build_model()has been called beforehand.The pretraining method uses
self.task_model.estimator.embedding_layer.The method invokes
super()._pretrain()with regression mode enabled.
- save(path)
Save the fitted model to path.
The bundle written by this method can be restored with
load(). It contains all state required for inference: the config, the fitted preprocessor, feature metadata, and the neural-network weights.- Parameters:
path (
str) – Destination file path (e.g."model.pt").- Raises:
ValueError – If the model has not been fitted yet.
- Return type:
None
- score(X, y, embeddings=None, metric=sklearn.metrics.mean_squared_error)
Calculate the score of the model using the specified metric.
- Parameters:
X (array-like or pd.DataFrame of shape (n_samples, n_features)) – The input samples to predict.
y (array-like of shape (n_samples,) or (n_samples, n_outputs)) – The true target values against which to evaluate the predictions.
metric (callable, default=mean_squared_error) – The metric function to use for evaluation. Must be a callable with the signature
metric(y_true, y_pred).
- Returns:
score – The score calculated using the specified metric.
- Return type:
float
- set_params(**parameters)
Set the parameters of this estimator.
- class deeptab.models.MambAttentionClassifier(*args: Any, **kwargs: Any)[source]
MambAttention classifier. This class extends the SklearnBaseClassifier class and uses the MambAttention model with the default MambAttention configuration.
Notes
The parameters for this class include the attributes from the config dataclass as well as preprocessing arguments handled by the base class.
Configuration class for the Default Mambular Attention model with predefined hyperparameters.
- Parameters:
d_model (int, default=64) – Dimensionality of the model.
n_layers (int, default=4) – Number of layers in the model.
expand_factor (int, default=2) – Expansion factor for the feed-forward layers.
n_heads (int, default=8) – Number of attention heads in the model.
last_layer (str, default="attn") – Type of the last layer (e.g., ‘attn’).
n_mamba_per_attention (int, default=1) – Number of Mamba blocks per attention layer.
bias (bool, default=False) – Whether to use bias in the linear layers.
d_conv (int, default=4) – Dimensionality of the convolutional layers.
conv_bias (bool, default=True) – Whether to use bias in the convolutional layers.
dropout (float, default=0.0) – Dropout rate for regularization.
attn_dropout (float, default=0.2) – Dropout rate for the attention mechanism.
dt_rank (str, default="auto") – Rank of the decision tree.
d_state (int, default=128) – Dimensionality of the state in recurrent layers.
dt_scale (float, default=1.0) – Scaling factor for the decision tree.
dt_init (str, default="random") – Initialization method for the decision tree.
dt_max (float, default=0.1) – Maximum value for decision tree initialization.
dt_min (float, default=1e-04) – Minimum value for decision tree initialization.
dt_init_floor (float, default=1e-04) – Floor value for decision tree initialization.
norm (str, default="LayerNorm") – Type of normalization used in the model.
activation (callable, default=nn.SiLU()) – Activation function for the model.
head_layer_sizes (list, default=()) – Sizes of the fully connected layers in the model’s head.
head_dropout (float, default=0.5) – Dropout rate for the head layers.
head_skip_layers (bool, default=False) – Whether to use skip connections in the head layers.
head_activation (callable, default=nn.SELU()) – Activation function for the head layers.
head_use_batch_norm (bool, default=False) – Whether to use batch normalization in the head layers.
pooling_method (str, default="avg") – Pooling method to be used (‘avg’, ‘max’, etc.).
bidirectional (bool, default=False) – Whether to process input sequences bidirectionally.
use_learnable_interaction (bool, default=False) – Whether to use learnable feature interactions before passing through Mamba blocks.
use_cls (bool, default=False) – Whether to append a CLS token for sequence pooling.
shuffle_embeddings (bool, default=False) – Whether to shuffle embeddings before passing to Mamba layers.
cat_encoding (str, default="int") – Encoding method for categorical features (‘int’, ‘one-hot’, etc.).
AD_weight_decay (bool, default=True) – Whether weight decay is applied to A-D matrices.
BC_layer_norm (bool, default=False) – Whether to apply layer normalization to B-C matrices.
use_pscan (bool, default=False) – Whether to use PSCAN for the state-space model.
n_attention_layers (int, default=1) – Number of attention layers in the model.
feature_preprocessing (dict, optional) – Dictionary mapping feature names to specific preprocessing methods. Overrides global defaults.
n_bins (int, default=64) – Number of bins used for binning-based preprocessing (e.g., for discretizers or PLE).
numerical_preprocessing (str, default="ple") – Preprocessing method for numerical features (e.g., “standardization”, “minmax”, “ple”, “rbf”, etc.).
categorical_preprocessing (str, default="int") – Preprocessing method for categorical features (e.g., “int”, “ordinal”, “onehot”).
use_decision_tree_bins (bool, default=False) – Whether to use decision tree binning for numerical discretization.
binning_strategy (str, default="uniform") – Strategy for bin placement when not using tree-based methods. Options: “uniform”, “quantile”.
task (str, default="regression") – Problem type used to guide preprocessing (e.g., “regression” or “classification”).
cat_cutoff (float or int, default=0.03) – Threshold to determine whether integer-valued features are treated as categorical.
treat_all_integers_as_numerical (bool, default=False) – If True, treat all integer-typed columns as numerical regardless of cardinality.
degree (int, default=3) – Degree of polynomial or spline basis functions where applicable.
scaling_strategy (str, default="minmax") – Strategy for feature scaling (e.g., “standardization”, “minmax”, etc.).
n_knots (int, default=64) – Number of knots used in spline-based feature expansions.
use_decision_tree_knots (bool, default=True) – Whether to use decision tree-based knot placement for spline transformations.
knots_strategy (str, default="uniform") – Strategy for placing knots for splines (“uniform” or “quantile”).
spline_implementation (str, default="sklearn") – Which spline backend implementation to use (e.g., “sklearn”, “custom”).
min_unique_vals (int, default=5) – Minimum number of unique values required for a feature to be treated as numerical.
Examples
>>> from MambAttention.models import MambAttentionClassifier >>> model = MambAttentionClassifier(d_model=64, n_layers=8) >>> model.fit(X_train, y_train) >>> preds = model.predict(X_test) >>> model.evaluate(X_test, y_test)- build_model(X, y, val_size=0.2, X_val=None, y_val=None, embeddings=None, embeddings_val=None, random_state=101, batch_size=128, shuffle=True, lr=None, lr_patience=None, lr_factor=None, weight_decay=None, train_metrics=None, val_metrics=None, dataloader_kwargs={})
Builds the model using the provided training data.
- Parameters:
X (DataFrame or array-like, shape (n_samples, n_features)) – The training input samples.
y (array-like, shape (n_samples,) or (n_samples, n_targets)) – The target values (real numbers).
val_size (
float) – The proportion of the dataset to include in the validation split ifX_valis None. Ignored ifX_valis provided.X_val (DataFrame or array-like, shape (n_samples, n_features), optional) – The validation input samples. If provided,
Xandyare not split and this data is used for validation.y_val (array-like, shape (n_samples,) or (n_samples, n_targets), optional) – The validation target values. Required if
X_valis provided.random_state (
int) – Controls the shuffling applied to the data before applying the split.batch_size (
int) – Number of samples per gradient update.shuffle (
bool) – Whether to shuffle the training data before each epoch.lr (
float|None) – Learning rate for the optimizer.lr_patience (
int|None) – Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.lr_factor (
float|None) – Factor by which the learning rate will be reduced.train_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during training.val_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during validation.weight_decay (
float|None) – Weight decay (L2 penalty) coefficient.dataloader_kwargs (dict, default={}) – The kwargs for the pytorch dataloader class.
- Returns:
self – The built classifier.
- Return type:
object
- encode(X, embeddings=None, batch_size=64)
Encodes input data using the trained model’s embedding layer.
- Parameters:
X (array-like or DataFrame) – Input data to be encoded.
batch_size (int, optional, default=64) – Batch size for encoding.
- Returns:
Encoded representations of the input data.
- Return type:
torch.Tensor
- Raises:
ValueError – If the model or data module is not fitted.
- evaluate(X, y_true, embeddings=None, metrics=None)
Evaluate the model on the given data using specified metrics.
- Parameters:
X (array-like or pd.DataFrame of shape (n_samples, n_features)) – The input samples to predict.
y_true (array-like of shape (n_samples,)) – The true class labels against which to evaluate the predictions.
embneddings (array-like or list of shape(n_samples, dimension)) – List or array with embeddings for unstructured data inputs
metrics (dict) – A dictionary where keys are metric names and values are tuples containing the metric function and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
- Returns:
scores – A dictionary with metric names as keys and their corresponding scores as values.
- Return type:
dict
Notes
This method uses either the
predictorpredict_probamethod depending on the metric requirements.
- fit(X, y, val_size=0.2, X_val=None, y_val=None, embeddings=None, embeddings_val=None, max_epochs=100, random_state=101, batch_size=128, shuffle=True, patience=15, monitor='val_loss', mode='min', lr=None, lr_patience=None, lr_factor=None, weight_decay=None, checkpoint_path='model_checkpoints', train_metrics=None, val_metrics=None, dataloader_kwargs={}, rebuild=True, **trainer_kwargs)
Trains the classification model using the provided training data. Optionally, a separate validation set can be used.
- Parameters:
X (DataFrame or array-like, shape (n_samples, n_features)) – The training input samples.
y (array-like, shape (n_samples,) or (n_samples, n_targets)) – The target values (real numbers).
val_size (
float) – The proportion of the dataset to include in the validation split ifX_valis None. Ignored ifX_valis provided.X_val (DataFrame or array-like, shape (n_samples, n_features), optional) – The validation input samples. If provided,
Xandyare not split and this data is used for validation.y_val (array-like, shape (n_samples,) or (n_samples, n_targets), optional) – The validation target values. Required if
X_valis provided.max_epochs (
int) – Maximum number of epochs for training.random_state (
int) – Controls the shuffling applied to the data before applying the split.batch_size (
int) – Number of samples per gradient update.shuffle (
bool) – Whether to shuffle the training data before each epoch.patience (
int) – Number of epochs with no improvement on the validation loss to wait before early stopping.monitor (
str) – The metric to monitor for early stopping.mode (
str) – Whether the monitored metric should be minimized (min) or maximized (max).lr (
float|None) – Learning rate for the optimizer.lr_patience (
int|None) – Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.factor (float, default=0.1) – Factor by which the learning rate will be reduced.
weight_decay (
float|None) – Weight decay (L2 penalty) coefficient.checkpoint_path (str, default="model_checkpoints") – Path where the checkpoints are being saved.
train_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during training.val_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during validation.dataloader_kwargs (dict, default={}) – The kwargs for the pytorch dataloader class.
rebuild (bool, default=True) – Whether to rebuild the model when it already was built.
**trainer_kwargs (Additional keyword arguments for PyTorch Lightning's Trainer class.)
- Returns:
self – The fitted classifier.
- Return type:
object
- get_number_of_params(requires_grad=True)
Calculate the number of parameters in the model.
- Parameters:
requires_grad (bool, optional) – If True, only count the parameters that require gradients (trainable parameters). If False, count all parameters. Default is True.
- Returns:
The total number of parameters in the model.
- Return type:
int
- Raises:
ValueError – If the model has not been built prior to calling this method.
- get_params(deep=True)
Get parameters for this estimator.
- classmethod load(path)
Load and return a fitted model from path.
- Parameters:
path (
str) – Path to a file previously written bysave().- Returns:
A fully reconstructed, ready-to-predict estimator of the same type that was saved.
- Return type:
estimator
- optimize_hparams(X, y, X_val=None, y_val=None, embeddings=None, embeddings_val=None, time=100, max_epochs=200, prune_by_epoch=True, prune_epoch=5, fixed_params={'cat_encoding': 'int', 'head_layer_size_length': 0, 'head_skip_layer': False, 'head_skip_layers': False, 'pooling_method': 'avg', 'use_cls': False}, custom_search_space=None, **optimize_kwargs)
Optimizes hyperparameters using Bayesian optimization with optional pruning.
- Parameters:
X (array-like) – Training data.
y (array-like) – Training labels.
X_val (array-like, optional) – Validation data and labels.
y_val (array-like, optional) – Validation data and labels.
time (int) – The number of optimization trials to run.
max_epochs (int) – Maximum number of epochs for training.
prune_by_epoch (bool) – Whether to prune based on a specific epoch (True) or the best validation loss (False).
prune_epoch (int) – The specific epoch to prune by when prune_by_epoch is True.
**optimize_kwargs (dict) – Additional keyword arguments passed to the fit method.
- Returns:
best_hparams – Best hyperparameters found during optimization.
- Return type:
list
- predict(X, embeddings=None, device=None)
Predicts target labels for the given input samples.
- Parameters:
X (DataFrame or array-like, shape (n_samples, n_features)) – The input samples for which to predict target values.
- Returns:
predictions – The predicted class labels.
- Return type:
ndarray, shape (n_samples,)
- predict_proba(X, embeddings=None, device=None)
Predicts class probabilities for the given input samples.
- Parameters:
X (DataFrame or array-like, shape (n_samples, n_features)) – The input samples for which to predict class probabilities.
- Returns:
probabilities – The predicted class probabilities.
- Return type:
ndarray, shape (n_samples, n_classes)
- pretrain(pretrain_epochs=15, k_neighbors=10, temperature=0.1, save_path='pretrained_embeddings.pth', lr=0.001, use_positive=True, use_negative=False, pool_sequence=True)
Pretrains the embedding layer of the model using a contrastive learning approach.
This method performs pretraining by optimizing the embeddings with respect to neighborhood structure in the feature space. The embeddings are saved after training.
- Parameters:
pretrain_epochs (int, default=15) – Number of epochs to run pretraining.
k_neighbors (int, default=10) – Number of neighbors used in the contrastive loss computation.
temperature (float, default=0.1) – Temperature parameter for contrastive loss scaling.
save_path (str, default="pretrained_embeddings.pth") – Path to save the pretrained embeddings.
lr (float, default=1e-3) – Learning rate for the pretraining optimizer.
use_positive (bool, default=True) – Whether to include positive pairs in contrastive learning.
use_negative (bool, default=False) – Whether to include negative pairs in contrastive learning.
pool_sequence (bool, default=True) – Whether to apply sequence pooling before computing contrastive loss.
- Raises:
ValueError – If the model has not been built before calling this method.
ValueError – If the model does not contain an embedding layer.
Notes
This function requires that
self.build_model()has been called beforehand.The pretraining method uses
self.task_model.estimator.embedding_layer.The method invokes
super()._pretrain()with regression mode enabled.
- save(path)
Save the fitted model to path.
The bundle written by this method can be restored with
load(). It contains all state required for inference: the config, the fitted preprocessor, feature metadata, and the neural-network weights.- Parameters:
path (
str) – Destination file path (e.g."model.pt").- Raises:
ValueError – If the model has not been fitted yet.
- Return type:
None
- score(X, y, embeddings=None, metric=(sklearn.metrics.log_loss, True))
Calculate the score of the model using the specified metric.
- Parameters:
X (array-like or pd.DataFrame of shape (n_samples, n_features)) – The input samples to predict.
y (array-like of shape (n_samples,)) – The true class labels against which to evaluate the predictions.
metric (tuple, default=(log_loss, True)) – A tuple containing the metric function and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
- Returns:
score – The score calculated using the specified metric.
- Return type:
float
- set_params(**parameters)
Set the parameters of this estimator.
- class deeptab.models.MambAttentionLSS(*args: Any, **kwargs: Any)[source]
MambAttention LSS for distributional regression. This class extends the SklearnBaseLSS class and uses the MambAttention model with the default MambAttention configuration.
Notes
The parameters for this class include the attributes from the config dataclass as well as preprocessing arguments handled by the base class.
Configuration class for the Default Mambular Attention model with predefined hyperparameters.
- Parameters:
d_model (int, default=64) – Dimensionality of the model.
n_layers (int, default=4) – Number of layers in the model.
expand_factor (int, default=2) – Expansion factor for the feed-forward layers.
n_heads (int, default=8) – Number of attention heads in the model.
last_layer (str, default="attn") – Type of the last layer (e.g., ‘attn’).
n_mamba_per_attention (int, default=1) – Number of Mamba blocks per attention layer.
bias (bool, default=False) – Whether to use bias in the linear layers.
d_conv (int, default=4) – Dimensionality of the convolutional layers.
conv_bias (bool, default=True) – Whether to use bias in the convolutional layers.
dropout (float, default=0.0) – Dropout rate for regularization.
attn_dropout (float, default=0.2) – Dropout rate for the attention mechanism.
dt_rank (str, default="auto") – Rank of the decision tree.
d_state (int, default=128) – Dimensionality of the state in recurrent layers.
dt_scale (float, default=1.0) – Scaling factor for the decision tree.
dt_init (str, default="random") – Initialization method for the decision tree.
dt_max (float, default=0.1) – Maximum value for decision tree initialization.
dt_min (float, default=1e-04) – Minimum value for decision tree initialization.
dt_init_floor (float, default=1e-04) – Floor value for decision tree initialization.
norm (str, default="LayerNorm") – Type of normalization used in the model.
activation (callable, default=nn.SiLU()) – Activation function for the model.
head_layer_sizes (list, default=()) – Sizes of the fully connected layers in the model’s head.
head_dropout (float, default=0.5) – Dropout rate for the head layers.
head_skip_layers (bool, default=False) – Whether to use skip connections in the head layers.
head_activation (callable, default=nn.SELU()) – Activation function for the head layers.
head_use_batch_norm (bool, default=False) – Whether to use batch normalization in the head layers.
pooling_method (str, default="avg") – Pooling method to be used (‘avg’, ‘max’, etc.).
bidirectional (bool, default=False) – Whether to process input sequences bidirectionally.
use_learnable_interaction (bool, default=False) – Whether to use learnable feature interactions before passing through Mamba blocks.
use_cls (bool, default=False) – Whether to append a CLS token for sequence pooling.
shuffle_embeddings (bool, default=False) – Whether to shuffle embeddings before passing to Mamba layers.
cat_encoding (str, default="int") – Encoding method for categorical features (‘int’, ‘one-hot’, etc.).
AD_weight_decay (bool, default=True) – Whether weight decay is applied to A-D matrices.
BC_layer_norm (bool, default=False) – Whether to apply layer normalization to B-C matrices.
use_pscan (bool, default=False) – Whether to use PSCAN for the state-space model.
n_attention_layers (int, default=1) – Number of attention layers in the model.
feature_preprocessing (dict, optional) – Dictionary mapping feature names to specific preprocessing methods. Overrides global defaults.
n_bins (int, default=64) – Number of bins used for binning-based preprocessing (e.g., for discretizers or PLE).
numerical_preprocessing (str, default="ple") – Preprocessing method for numerical features (e.g., “standardization”, “minmax”, “ple”, “rbf”, etc.).
categorical_preprocessing (str, default="int") – Preprocessing method for categorical features (e.g., “int”, “ordinal”, “onehot”).
use_decision_tree_bins (bool, default=False) – Whether to use decision tree binning for numerical discretization.
binning_strategy (str, default="uniform") – Strategy for bin placement when not using tree-based methods. Options: “uniform”, “quantile”.
task (str, default="regression") – Problem type used to guide preprocessing (e.g., “regression” or “classification”).
cat_cutoff (float or int, default=0.03) – Threshold to determine whether integer-valued features are treated as categorical.
treat_all_integers_as_numerical (bool, default=False) – If True, treat all integer-typed columns as numerical regardless of cardinality.
degree (int, default=3) – Degree of polynomial or spline basis functions where applicable.
scaling_strategy (str, default="minmax") – Strategy for feature scaling (e.g., “standardization”, “minmax”, etc.).
n_knots (int, default=64) – Number of knots used in spline-based feature expansions.
use_decision_tree_knots (bool, default=True) – Whether to use decision tree-based knot placement for spline transformations.
knots_strategy (str, default="uniform") – Strategy for placing knots for splines (“uniform” or “quantile”).
spline_implementation (str, default="sklearn") – Which spline backend implementation to use (e.g., “sklearn”, “custom”).
min_unique_vals (int, default=5) – Minimum number of unique values required for a feature to be treated as numerical.
Examples
>>> from MambAttention.models import MambAttentionLSS >>> model = MambAttentionLSS(d_model=64, n_layers=8) >>> model.fit(X_train, y_train, family='normal') >>> preds = model.predict(X_test) >>> model.evaluate(X_test, y_test)- build_model(X, y, val_size=0.2, X_val=None, y_val=None, random_state=101, batch_size=128, shuffle=True, lr=None, lr_patience=None, lr_factor=None, weight_decay=None, train_metrics=None, val_metrics=None, dataloader_kwargs={})
Builds the model using the provided training data.
- Parameters:
X (DataFrame or array-like, shape (n_samples, n_features)) – The training input samples.
y (array-like, shape (n_samples,) or (n_samples, n_targets)) – The target values (real numbers).
val_size (
float) – The proportion of the dataset to include in the validation split ifX_valis None. Ignored ifX_valis provided.X_val (DataFrame or array-like, shape (n_samples, n_features), optional) – The validation input samples. If provided,
Xandyare not split and this data is used for validation.y_val (array-like, shape (n_samples,) or (n_samples, n_targets), optional) – The validation target values. Required if
X_valis provided.random_state (
int) – Controls the shuffling applied to the data before applying the split.batch_size (
int) – Number of samples per gradient update.shuffle (
bool) – Whether to shuffle the training data before each epoch.lr (
float|None) – Learning rate for the optimizer.lr_patience (
int|None) – Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.lr_factor (
float|None) – Factor by which the learning rate will be reduced.train_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during training.val_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during validation.weight_decay (
float|None) – Weight decay (L2 penalty) coefficient.dataloader_kwargs (dict, default={}) – The kwargs for the pytorch dataloader class.
- Returns:
self – The built distributional regressor.
- Return type:
object
- encode(X, batch_size=64)
Encodes input data using the trained model’s embedding layer.
- Parameters:
X (array-like or DataFrame) – Input data to be encoded.
batch_size (int, optional, default=64) – Batch size for encoding.
- Returns:
Encoded representations of the input data.
- Return type:
torch.Tensor
- Raises:
ValueError – If the model or data module is not fitted.
- evaluate(X, y_true, metrics=None, distribution_family=None)
Evaluate the model on the given data using specified metrics.
- Parameters:
X (array-like or pd.DataFrame of shape (n_samples, n_features)) – The input samples to predict.
y_true (array-like of shape (n_samples,)) – The true class labels against which to evaluate the predictions.
metrics (dict) – A dictionary where keys are metric names and values are tuples containing the metric function and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
distribution_family (str, optional) – Specifies the distribution family the model is predicting for. If None, it will attempt to infer based on the model’s settings.
- Returns:
scores – A dictionary with metric names as keys and their corresponding scores as values.
- Return type:
dict
Notes
This method uses either the
predictorpredict_probamethod depending on the metric requirements.
- fit(X, y, family, val_size=0.2, X_val=None, y_val=None, max_epochs=100, random_state=101, batch_size=128, shuffle=True, patience=15, monitor='val_loss', mode='min', lr=None, lr_patience=None, lr_factor=None, weight_decay=None, checkpoint_path='model_checkpoints', distributional_kwargs=None, train_metrics=None, val_metrics=None, dataloader_kwargs={}, rebuild=True, **trainer_kwargs)
Trains the regression model using the provided training data. Optionally, a separate validation set can be used.
- Parameters:
X (DataFrame or array-like, shape (n_samples, n_features)) – The training input samples.
y (array-like, shape (n_samples,) or (n_samples, n_targets)) – The target values (real numbers).
family (str) – The name of the distribution family to use for the loss function. Examples include ‘normal’ for regression tasks.
val_size (
float) – The proportion of the dataset to include in the validation split ifX_valis None. Ignored ifX_valis provided.X_val (DataFrame or array-like, shape (n_samples, n_features), optional) – The validation input samples. If provided,
Xandyare not split and this data is used for validation.y_val (array-like, shape (n_samples,) or (n_samples, n_targets), optional) – The validation target values. Required if
X_valis provided.max_epochs (
int) – Maximum number of epochs for training.random_state (
int) – Controls the shuffling applied to the data before applying the split.batch_size (
int) – Number of samples per gradient update.shuffle (
bool) – Whether to shuffle the training data before each epoch.patience (
int) – Number of epochs with no improvement on the validation loss to wait before early stopping.monitor (
str) – The metric to monitor for early stopping.mode (
str) – Whether the monitored metric should be minimized (min) or maximized (max).lr (
float|None) – Learning rate for the optimizer.lr_patience (
int|None) – Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.factor (float, default=0.1) – Factor by which the learning rate will be reduced.
weight_decay (
float|None) – Weight decay (L2 penalty) coefficient.distributional_kwargs (dict, default=None) – any arguments taht are specific for a certain distribution.
train_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during training.val_metrics (
dict[str,Callable] |None) – torch.metrics dict to be logged during validation.checkpoint_path (str, default="model_checkpoints") – Path where the checkpoints are being saved.
dataloader_kwargs (dict, default={}) – The kwargs for the pytorch dataloader class.
**trainer_kwargs (Additional keyword arguments for PyTorch Lightning's Trainer class.)
- Returns:
self – The fitted regressor.
- Return type:
object
- get_default_metrics(distribution_family)
Provides default metrics based on the distribution family.
- Parameters:
distribution_family (str) – The distribution family for which to provide default metrics.
- Returns:
metrics – A dictionary of default metric functions.
- Return type:
dict
- get_number_of_params(requires_grad=True)
Calculate the number of parameters in the model.
- Parameters:
requires_grad (bool, optional) – If True, only count the parameters that require gradients (trainable parameters). If False, count all parameters. Default is True.
- Returns:
The total number of parameters in the model.
- Return type:
int
- Raises:
ValueError – If the model has not been built prior to calling this method.
- get_params(deep=True)
Get parameters for this estimator.
- Parameters:
deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns:
params – Parameter names mapped to their values.
- Return type:
dict
- classmethod load(path)
Load and return a fitted model from path.
- Parameters:
path (
str) – Path to a file previously written bysave().- Returns:
A fully reconstructed, ready-to-predict estimator.
- Return type:
estimator
- optimize_hparams(X, y, X_val=None, y_val=None, time=100, max_epochs=200, prune_by_epoch=True, prune_epoch=5, fixed_params={'cat_encoding': 'int', 'head_layer_size_length': 0, 'head_skip_layer': False, 'head_skip_layers': False, 'pooling_method': 'avg', 'use_cls': False}, custom_search_space=None, **optimize_kwargs)
Optimizes hyperparameters using Bayesian optimization with optional pruning.
- Parameters:
X (array-like) – Training data.
y (array-like) – Training labels.
X_val (array-like, optional) – Validation data and labels.
y_val (array-like, optional) – Validation data and labels.
time (int) – The number of optimization trials to run.
max_epochs (int) – Maximum number of epochs for training.
prune_by_epoch (bool) – Whether to prune based on a specific epoch (True) or the best validation loss (False).
prune_epoch (int) – The specific epoch to prune by when prune_by_epoch is True.
**optimize_kwargs (dict) – Additional keyword arguments passed to the fit method.
- Returns:
best_hparams – Best hyperparameters found during optimization.
- Return type:
list
- predict(X, raw=False, device=None)
Predicts target values for the given input samples.
- Parameters:
X (DataFrame or array-like, shape (n_samples, n_features)) – The input samples for which to predict target values.
- Returns:
predictions – The predicted target values.
- Return type:
ndarray, shape (n_samples,) or (n_samples, n_outputs)
- save(path)
Save the fitted model to path.
- Parameters:
path (
str) – Destination file path (e.g."model.pt").- Raises:
ValueError – If the model has not been fitted yet.
- Return type:
None
- score(X, y, metric='NLL')
Calculate the score of the model using the specified metric.
- Parameters:
X (array-like or pd.DataFrame of shape (n_samples, n_features)) – The input samples to predict.
y (array-like of shape (n_samples,) or (n_samples, n_outputs)) – The true target values against which to evaluate the predictions.
metric (str, default="NLL") – So far, only negative log-likelihood is supported
- Returns:
score – The score calculated using the specified metric.
- Return type:
float
- set_params(**parameters)
Set the parameters of this estimator.
- Parameters:
**parameters (dict) – Estimator parameters.
- Returns:
self – Estimator instance.
- Return type:
object