Configurations

DeepTab uses a split-config API: model hyperparameters are divided across three separate dataclasses so that architecture choices, data preprocessing, and training settings can be managed, versioned, and shared independently.

Config class

Controls

Typical fields

ModelConfig
(e.g. MLPConfig)

Neural architecture

d_model, n_layers, dropout, activation, …

PreprocessingConfig

Feature engineering

numerical_preprocessing, n_bins, scaling_strategy, …

TrainerConfig

Training loop

max_epochs, lr, batch_size, patience, …


Quick-start by task

All three model variants (Classifier, Regressor, and LSS) accept the same config objects. The only difference is the class you import.

Classification

from deeptab.configs import MLPConfig, PreprocessingConfig, TrainerConfig
from deeptab.models import MLPClassifier

model = MLPClassifier(
    model_config=MLPConfig(d_model=128, dropout=0.1),
    preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
    trainer_config=TrainerConfig(max_epochs=50, lr=1e-3),
)
model.fit(X_train, y_train)
preds = model.predict(X_test)          # class labels
proba = model.predict_proba(X_test)    # class probabilities

Regression

from deeptab.configs import ResNetConfig, TrainerConfig
from deeptab.models import ResNetRegressor

model = ResNetRegressor(
    model_config=ResNetConfig(d_model=256, n_layers=4),
    trainer_config=TrainerConfig(max_epochs=100, lr=5e-4, patience=10),
)
model.fit(X_train, y_train)
preds = model.predict(X_test)          # continuous values

Distributional regression (LSS)

LSS models predict the full distribution of the target, not just a point estimate. Pass family to fit to select the output distribution.

from deeptab.configs import MambularConfig, TrainerConfig
from deeptab.models import MambularLSS

model = MambularLSS(
    model_config=MambularConfig(d_model=64, n_layers=6),
    trainer_config=TrainerConfig(max_epochs=100, lr=1e-3),
)
model.fit(X_train, y_train, family="normal")   # learns μ and σ per row
dist_params = model.predict(X_test)            # shape (N, n_params)

Common families: "normal", "poisson", "gamma", "beta", "dirichlet".


Scikit-learn compatibility

Every config dataclass extends sklearn.base.BaseEstimator, so the full scikit-learn parameter protocol is available.

get_params

Returns a flat dictionary of all hyperparameters, identical to the behaviour of any scikit-learn estimator:

from deeptab.configs import MLPConfig, TrainerConfig

cfg = MLPConfig(d_model=128, dropout=0.2)
print(cfg.get_params())
# {'d_model': 128, 'dropout': 0.2, 'layer_sizes': [256, 128, 32], ...}

trainer = TrainerConfig(max_epochs=50)
print(trainer.get_params())
# {'max_epochs': 50, 'lr': 0.0001, 'batch_size': 128, ...}

set_params

Updates parameters in-place and returns self, enabling scikit-learn pipeline and grid-search integration:

cfg = MLPConfig()
cfg.set_params(d_model=256, dropout=0.3)

trainer = TrainerConfig()
trainer.set_params(max_epochs=200, lr=5e-4)

Hyperparameter search with GridSearchCV

Because the estimator itself also follows get_params / set_params, you can tune any config field via GridSearchCV using the <config_attr>__<field> double-underscore notation:

from sklearn.model_selection import GridSearchCV
from deeptab.configs import MLPConfig, TrainerConfig
from deeptab.models import MLPClassifier

model = MLPClassifier(
    model_config=MLPConfig(),
    trainer_config=TrainerConfig(max_epochs=20),
)

param_grid = {
    "model_config__d_model": [64, 128, 256],
    "model_config__dropout": [0.1, 0.3],
    "trainer_config__lr": [1e-3, 5e-4],
}

search = GridSearchCV(model, param_grid, cv=3, scoring="accuracy")
search.fit(X_train, y_train)
print(search.best_params_)

sklearn clone

Configs can be deep-copied with sklearn.base.clone:

from sklearn.base import clone

original = MLPConfig(d_model=128)
copy = clone(original)   # fully independent copy

Sharing and versioning configs

Because configs are plain dataclasses they serialise trivially:

import dataclasses, json

cfg = MLPConfig(d_model=128, dropout=0.1)
# serialise
blob = json.dumps(dataclasses.asdict(cfg))
# restore
cfg2 = MLPConfig(**json.loads(blob))

Available model configs

Config class

Model family

AutoIntConfig

AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks

ENODEConfig

ENODE: Extended Neural Oblivious Decision Ensembles

FTTransformerConfig

FT-Transformer: Feature Tokenizer Transformer

MambaTabConfig

MambaTab: Mamba-based tabular model

MambAttentionConfig

MambAttention: Mamba + self-attention hybrid

MambularConfig

Mambular: general-purpose Mamba backbone

MLPConfig

MLP: multilayer perceptron baseline

ModernNCAConfig

ModernNCA: Modern Neural Context-Aware model (experimental)

NDTFConfig

NDTF: Neural Decision Tree Forest

NODEConfig

NODE: Neural Oblivious Decision Ensembles

ResNetConfig

ResNet: residual network for tabular data

SAINTConfig

SAINT: Self-Attention and Intersample Attention Transformer

TabMConfig

TabM: Batch-Ensembling MLP

TabRConfig

TabR: Retrieval-Augmented Tabular model

TabTransformerConfig

TabTransformer: transformer with categorical embeddings

TabulaRNNConfig

TabulaRNN: LSTM / GRU recurrent baseline

TangosConfig

Tangos: Targeted Regularisation (experimental)

TromptConfig

Trompt: tree-inspired tabular model (experimental)