scikit-learn Compatible API
DeepTab estimators follow the scikit-learn pattern while training PyTorch models under the hood. You instantiate an estimator, call fit, then use predict, evaluate, score, save, and load.
What “scikit-learn compatible” means
scikit-learn defines a small set of conventions that every estimator is expected to honour. Meeting them is what lets a model drop into tools like Pipeline, GridSearchCV, and cross_val_score without special-casing. The table below lists each convention, what it requires, and whether DeepTab satisfies it.
Convention |
What it requires |
DeepTab |
|---|---|---|
Subclasses |
Inherit from sklearn’s base class for shared machinery |
✓ |
Params set in |
The constructor stores arguments verbatim and does no heavy work |
✓ |
|
Expose and update hyperparameters by name (also nested, e.g. |
✓ |
|
Training mutates the estimator in place and returns it for chaining |
✓ |
|
Produce predictions from a fitted estimator |
✓ |
|
Default metric, higher is better (R² for regression, accuracy for classification) |
✓ |
Fitted attributes end with |
Learned state is exposed as |
✓ |
|
Defines |
✓ |
Clone friendly |
|
✓ |
|
Probability estimates for classification tasks |
✓ |
Note
DeepTab implements score directly rather than inheriting ClassifierMixin / RegressorMixin, but it follows the same “higher is better” convention, so GridSearchCV and friends behave as expected.
Important
Because every constructor argument is stored untouched and all heavy lifting happens in fit, DeepTab estimators are safe to clone and reuse inside Pipeline and cross-validation. Avoid mutating private (underscore-prefixed) attributes if you rely on cloning, since those are deliberately hidden from get_params.
Basic Workflow
from deeptab.configs import MambularConfig, TrainerConfig
from deeptab.models import MambularClassifier
model = MambularClassifier(
model_config=MambularConfig(d_model=64, n_layers=4),
trainer_config=TrainerConfig(max_epochs=50, patience=10),
random_state=101,
)
model.fit(X_train, y_train)
predictions = model.predict(X_test)
metrics = model.evaluate(X_test, y_test)
Estimator Families
Most architectures expose three task variants:
Suffix |
Task |
Example |
|---|---|---|
|
Binary or multiclass classification |
|
|
Point-estimate regression |
|
|
Distributional regression |
|
Stable models are imported from deeptab.models. Experimental models are imported from deeptab.models.experimental.
Accepted Inputs
Use pandas DataFrames when possible:
import pandas as pd
X = pd.DataFrame({
"age": [25, 32, 47],
"city": pd.Series(["NYC", "Boston", "Chicago"], dtype="category"),
"income": [50000.0, 75000.0, 90000.0],
})
NumPy arrays are accepted, but they lose column names and dtype semantics:
import numpy as np
X = np.random.randn(1000, 10)
For mixed numerical/categorical data, DataFrames are strongly preferred.
Constructor Pattern
from deeptab.configs import MLPConfig, PreprocessingConfig, TrainerConfig
from deeptab.models import MLPRegressor
model = MLPRegressor(
model_config=MLPConfig(layer_sizes=[256, 128, 32], dropout=0.2),
preprocessing_config=PreprocessingConfig(numerical_preprocessing="standardization"),
trainer_config=TrainerConfig(lr=1e-3, batch_size=256, max_epochs=100),
random_state=101,
)
The split-config API is the recommended style for new code.
Fit
You can train in one of two ways. Pass X and y alone and DeepTab holds out a validation fraction internally, or pass your own X_val and y_val to control the split yourself.
# Auto split: DeepTab holds out val_size (default 0.2) for validation
model.fit(X, y)
# Explicit split: you supply the validation set, e.g. a time-based holdout
model.fit(
X_train,
y_train,
X_val=X_val,
y_val=y_val,
)
Note
X and y are required; X_val and y_val are optional. When you pass X_val you must also pass y_val, and val_size is then ignored because nothing is held out from X. There is no separate test set inside fit(): keep your test data aside and use predict() or evaluate() on it afterwards.
Early stopping, the learning-rate scheduler, and checkpointing all watch the validation metric, so a meaningful validation set, whether automatic or explicit, matters for good results.
Useful fit arguments:
Argument |
Use |
|---|---|
|
Training features and targets. |
|
Explicit validation set. If omitted, DeepTab creates one. |
|
Optional external embeddings for train/validation data. |
|
Legacy fit-time overrides; prefer |
|
Optional Lightning metrics logged during training. |
|
Additional Lightning trainer keyword arguments. |
For LSS models, family is required:
from deeptab.models import MambularLSS
model = MambularLSS()
model.fit(X_train, y_train, family="normal")
Predict
labels = classifier.predict(X_test)
values = regressor.predict(X_test)
params = lss_model.predict(X_test)
For classifiers:
probabilities = classifier.predict_proba(X_test)
For external embeddings at inference:
predictions = model.predict(X_test, embeddings=test_embeddings)
Evaluate
evaluate() returns a {metric_name: score} dictionary. With no metrics argument it uses the task defaults from the metric registry, so the keys are the metric short names:
classifier.evaluate(X_test, y_test)
# {"accuracy": ..., "auroc": ..., "log_loss": ...}
regressor.evaluate(X_test, y_test)
# {"rmse": ..., "mae": ..., "r2": ...}
For tutorials and papers, pass explicit metrics. The dictionary values are callables with the signature metric(y_true, y_pred); the built-in DeepTabMetric classes route probability-based metrics (such as LogLoss and AUROC) to predict_proba automatically:
from deeptab.metrics import Accuracy, AUROC, LogLoss
classifier.evaluate(
X_test,
y_test,
metrics={
"accuracy": Accuracy(),
"auroc": AUROC(),
"log_loss": LogLoss(),
},
)
Score
score() follows the scikit-learn convention of one default metric per estimator family (higher is better):
Estimator |
Default |
|---|---|
Classifier |
accuracy |
Regressor |
R2 |
LSS |
negative log-likelihood |
Pass a metric explicitly if you need F1, log loss, or another convention:
from sklearn.metrics import log_loss
loss = classifier.score(X_test, y_test, metric=(log_loss, True))
Learned Attributes
After fit() or build_model(), DeepTab estimators expose common sklearn-style fitted attributes:
Attribute |
Available on |
Meaning |
|---|---|---|
|
Classifier, regressor, LSS |
Number of input columns seen during fitting. |
|
Estimators fitted with string-named DataFrame columns |
Feature names and order seen during fitting. |
|
Classifiers and categorical LSS |
Class labels seen during fitting. |
Prediction inputs are checked against the fitted feature count. When the model was fitted with named DataFrame columns, prediction DataFrames must use the same feature names in the same order. This catches accidental column drops, additions, and reordering before inference.
Save and Load
DeepTab has two persistence layers:
Method |
Scope |
Use case |
|---|---|---|
|
Full fitted estimator artifact |
Reuse a trained classifier, regressor, or LSS model for inference or reproducible experiments. |
|
Raw PyTorch architecture weights only |
Low-level architecture work where you already know how to rebuild the model and preprocessing pipeline. |
For normal user workflows, prefer the estimator-level API:
model.fit(X_train, y_train)
model.save("model.deeptab")
loaded = type(model).load("model.deeptab")
predictions = loaded.predict(X_test)
The saved estimator bundle is designed as a fitted inference artifact. It includes:
Artifact field |
Why it matters |
|---|---|
Architecture metadata |
Stores the model class, module, registry status, config class, and resolved config values. |
Trained weights |
Restores the fitted |
Fitted preprocessing state |
Reuses the exact fitted preprocessing object instead of refitting on future data. |
Feature schema |
Stores column order, numerical/categorical/embedding feature groups, dimensions, and feature preprocessing metadata. |
Task metadata |
Stores the task type, regression/LSS flags, distribution family for LSS, number of output classes, and |
Runtime/debug metadata |
Stores Python, platform, DeepTab, PyTorch, Lightning, pandas, NumPy, scikit-learn, pretab, and related dependency versions. |
Using pandas DataFrames is recommended because the saved schema can preserve meaningful column names. NumPy inputs are supported, but their inferred column order is positional.
loaded = MambularClassifier.load("model.deeptab")
loaded.input_columns_
loaded.feature_schema_
loaded.task_info_
loaded.versions_
load() keeps backward compatibility with older DeepTab artifacts that do not contain the richer metadata block, but newer artifacts are easier to audit and debug across environments.
Model Inspection
DeepTab estimators expose a small inspection layer for understanding a configured or fitted model.
Method |
Returns |
When to use |
|---|---|---|
|
Dictionary with estimator, architecture, task, feature counts, config classes, and parameter counts when available |
Programmatic metadata for reports and experiment tracking |
|
Compact human-readable string |
Notebook/log output before or after training |
|
|
Auditing model size and trainable layers |
|
Dictionary with device, dtype, precision, accelerator, strategy, batch size, optimizer, and trainer state |
Checking how the model is actually running |
model.fit(X_train, y_train)
print(model.summary())
metadata = model.describe()
params = model.parameter_table()
runtime = model.runtime_info()
describe(), summary(), and runtime_info() are safe to call before fitting. parameter_table() requires a built or fitted model because the PyTorch modules do not exist until DeepTab has seen the feature schema.
model = MambularClassifier()
print(model.describe()["built"])
print(model.runtime_info()["batch_size"])
# Raises ValueError until fit() or build_model() has created the network.
model.parameter_table()
Tip
Use runtime_info() in benchmark notebooks and experiment logs. It records the resolved runtime state, which can differ from what you intended if Lightning chooses a different accelerator or if the model was loaded on CPU.
scikit-learn Integration
DeepTab implements get_params and set_params, including nested config parameters:
model.get_params()
model.set_params(
model_config__d_model=128,
trainer_config__lr=3e-4,
)
This enables GridSearchCV:
from sklearn.model_selection import GridSearchCV
from deeptab.configs import MambularConfig, PreprocessingConfig, TrainerConfig
from deeptab.models import MambularClassifier
estimator = MambularClassifier(
model_config=MambularConfig(),
preprocessing_config=PreprocessingConfig(),
trainer_config=TrainerConfig(max_epochs=30, patience=5),
)
search = GridSearchCV(
estimator=estimator,
param_grid={
"model_config__d_model": [32, 64],
"trainer_config__lr": [1e-3, 3e-4],
},
cv=3,
n_jobs=1,
)
Practical Differences From sklearn
DeepTab models train neural networks, so fit() is slower than fitting most classical sklearn estimators. Validation data, early stopping, checkpoints, GPU settings, and random seeds matter.
For reproducible research:
Use explicit train/validation/test splits.
Set
random_stateon the estimator and split functions.Save model, preprocessing, and config choices.
Report the exact DeepTab version.