Configurations
DeepTab uses a split-config API: model hyperparameters are divided across three separate dataclasses so that architecture choices, data preprocessing, and training settings can be managed, versioned, and shared independently.
Config class |
Controls |
Typical fields |
|---|---|---|
|
Neural architecture |
|
Feature engineering |
|
|
Training loop |
|
Quick-start by task
All three model variants (Classifier, Regressor, and LSS) accept the same config objects. The only difference is the class you import.
Classification
from deeptab.configs import MLPConfig, PreprocessingConfig, TrainerConfig
from deeptab.models import MLPClassifier
model = MLPClassifier(
model_config=MLPConfig(d_model=128, dropout=0.1),
preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
trainer_config=TrainerConfig(max_epochs=50, lr=1e-3),
)
model.fit(X_train, y_train)
preds = model.predict(X_test) # class labels
proba = model.predict_proba(X_test) # class probabilities
Regression
from deeptab.configs import ResNetConfig, TrainerConfig
from deeptab.models import ResNetRegressor
model = ResNetRegressor(
model_config=ResNetConfig(d_model=256, n_layers=4),
trainer_config=TrainerConfig(max_epochs=100, lr=5e-4, patience=10),
)
model.fit(X_train, y_train)
preds = model.predict(X_test) # continuous values
Distributional regression (LSS)
LSS models predict the full distribution of the target, not just a point estimate.
Pass family to fit to select the output distribution.
from deeptab.configs import MambularConfig, TrainerConfig
from deeptab.models import MambularLSS
model = MambularLSS(
model_config=MambularConfig(d_model=64, n_layers=6),
trainer_config=TrainerConfig(max_epochs=100, lr=1e-3),
)
model.fit(X_train, y_train, family="normal") # learns μ and σ per row
dist_params = model.predict(X_test) # shape (N, n_params)
Common families: "normal", "poisson", "gamma", "beta", "dirichlet".
Scikit-learn compatibility
Every config dataclass extends sklearn.base.BaseEstimator, so the full
scikit-learn parameter protocol is available.
get_params
Returns a flat dictionary of all hyperparameters, identical to the behaviour of any scikit-learn estimator:
from deeptab.configs import MLPConfig, TrainerConfig
cfg = MLPConfig(d_model=128, dropout=0.2)
print(cfg.get_params())
# {'d_model': 128, 'dropout': 0.2, 'layer_sizes': [256, 128, 32], ...}
trainer = TrainerConfig(max_epochs=50)
print(trainer.get_params())
# {'max_epochs': 50, 'lr': 0.0001, 'batch_size': 128, ...}
set_params
Updates parameters in-place and returns self, enabling scikit-learn pipeline
and grid-search integration:
cfg = MLPConfig()
cfg.set_params(d_model=256, dropout=0.3)
trainer = TrainerConfig()
trainer.set_params(max_epochs=200, lr=5e-4)
Hyperparameter search with GridSearchCV
Because the estimator itself also follows get_params / set_params, you can
tune any config field via GridSearchCV using the <config_attr>__<field>
double-underscore notation:
from sklearn.model_selection import GridSearchCV
from deeptab.configs import MLPConfig, TrainerConfig
from deeptab.models import MLPClassifier
model = MLPClassifier(
model_config=MLPConfig(),
trainer_config=TrainerConfig(max_epochs=20),
)
param_grid = {
"model_config__d_model": [64, 128, 256],
"model_config__dropout": [0.1, 0.3],
"trainer_config__lr": [1e-3, 5e-4],
}
search = GridSearchCV(model, param_grid, cv=3, scoring="accuracy")
search.fit(X_train, y_train)
print(search.best_params_)
sklearn clone
Configs can be deep-copied with sklearn.base.clone:
from sklearn.base import clone
original = MLPConfig(d_model=128)
copy = clone(original) # fully independent copy
Available model configs
Config class |
Model family |
|---|---|
AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks |
|
ENODE: Extended Neural Oblivious Decision Ensembles |
|
FT-Transformer: Feature Tokenizer Transformer |
|
MambaTab: Mamba-based tabular model |
|
MambAttention: Mamba + self-attention hybrid |
|
Mambular: general-purpose Mamba backbone |
|
MLP: multilayer perceptron baseline |
|
ModernNCA: Modern Neural Context-Aware model (experimental) |
|
NDTF: Neural Decision Tree Forest |
|
NODE: Neural Oblivious Decision Ensembles |
|
ResNet: residual network for tabular data |
|
SAINT: Self-Attention and Intersample Attention Transformer |
|
TabM: Batch-Ensembling MLP |
|
TabR: Retrieval-Augmented Tabular model |
|
TabTransformer: transformer with categorical embeddings |
|
TabulaRNN: LSTM / GRU recurrent baseline |
|
Tangos: Targeted Regularisation (experimental) |
|
Trompt: tree-inspired tabular model (experimental) |