Model Operations
This page covers what you can do with a fitted DeepTab model beyond training: how to save and reload artifacts, and how to inspect any model’s architecture, parameters, device, and runtime characteristics.
Serialisation
DeepTab models save the complete artifact needed for inference: weights, fitted preprocessor, feature schema, model config, task metadata, and package versions.
Saving and loading
The recommended extension is .deeptab. DeepTab emits a UserWarning when a different extension is used (e.g. .pt), but any path is accepted.
# Save
model.save("my_model.deeptab")
# Load (returns a fully ready estimator, no re-fitting needed)
from deeptab.models import MLPClassifier
loaded = MLPClassifier.load("my_model.deeptab")
predictions = loaded.predict(X_test)
Tip
Use the class that matches the saved model type. Using the wrong class will raise an error with a clear message pointing to the mismatch.
What is inside the artifact
The bundle saved to disk is a PyTorch-serialised dictionary containing:
Key |
Contents |
|---|---|
|
Neural network weights (Lightning module state dict) |
|
Fitted |
|
Numerical, categorical, and embedding feature metadata |
|
Model config dataclass used during training |
|
Architecture, schema, preprocessing, task, and version sub-blocks |
|
Ordered list of column names, for feature-name validation at predict time |
|
Class labels for classifiers |
|
Python, PyTorch, Lightning, NumPy, pandas, scikit-learn versions |
Why everything lives in one bundle
A trained model is more than its weights. To turn raw input into a prediction you also need the fitted preprocessor that scaled and encoded the features, the feature schema that says which columns belong where, the architecture and its config to rebuild the network, and the task metadata that decides whether an output is a class label, a point estimate, or distribution parameters. If any of these travel separately, a reload can silently go wrong: a column in the wrong position or a re-fitted scaler will produce confident but incorrect predictions.
DeepTab keeps all of it together so that one file is enough to reproduce the exact model you trained. The saved package versions make that promise auditable, so when a colleague loads your artifact a year later they can see the environment it was built in.
Note
The metadata is tiny next to the weights. Schema, config, task info, and version stamps add a few kilobytes and grow with the number of features, not the number of training rows. A model trained on ten rows and one trained on ten million carry the same metadata footprint.
How .deeptab compares to raw formats
.deeptab is not a new on-disk format. It is a PyTorch-serialised dictionary with a clear name, and the value it adds over saving raw weights is everything wrapped around those weights.
Capability |
|
|
|
|
|---|---|---|---|---|
Model weights |
✓ |
✓ |
✓ |
✓ |
Rebuilds the correct architecture automatically |
✗ |
depends on class |
✗ |
✓ |
Fitted preprocessor (scalers, encoders) |
✗ |
sometimes |
✗ |
✓ |
Feature schema for predict-time validation |
✗ |
✗ |
✗ |
✓ |
Task metadata (regression, LSS family, |
✗ |
sometimes |
✗ |
✓ |
Environment version stamps |
✗ |
✗ |
✗ |
✓ |
Self-contained: predict with no extra glue code |
✗ |
✗ |
✗ |
✓ |
With a bare .pt file you have to recreate the architecture by hand and re-attach a preprocessor before the weights mean anything. A pickled estimator can capture more, but it stores a live Python object graph that breaks the moment a class is renamed or a dependency shifts, and unpickling it runs arbitrary code. .deeptab sidesteps both problems by storing structured metadata alongside the weights and reconstructing the model through DeepTab’s own loader.
Important
The self-contained reload is a feature of the DeepTab package, not of the file on its own. Loading a .deeptab artifact needs deeptab installed, ideally at a compatible version, which is exactly why the version snapshot is saved. The file is not a framework-independent interchange format. If you need a model that runs in a non-Python or non-DeepTab runtime, export to ONNX or TorchScript instead.
Warning
Because the artifact is pickle-backed under the hood, only load .deeptab files from sources you trust, the same caution that applies to any torch.load or pickle file.
Verifying a round-trip
model.save("my_model.deeptab")
loaded = MLPClassifier.load("my_model.deeptab")
# Hard predictions must be bit-identical
assert (model.predict(X_test) == loaded.predict(X_test)).all()
# Probabilities within floating-point tolerance
import numpy as np
np.testing.assert_allclose(
model.predict_proba(X_test),
loaded.predict_proba(X_test),
atol=1e-5,
)
print("Round-trip verified ✓")
Metadata attributes after loading
After load() the estimator exposes several read-only metadata attributes:
loaded.artifact_metadata_ # full metadata dict
loaded.architecture_metadata_ # architecture sub-block
loaded.feature_schema_ # feature schema sub-block
loaded.task_info_ # {"task": "classification", "num_classes": 2, ...}
loaded.classes_ # class labels
loaded.versions_ # package version snapshot
loaded.n_features_in_ # number of input features
loaded.input_columns_ # ordered feature names
The feature schema
feature_schema_ is the model’s data contract. When the preprocessor fits, DeepTab records the exact column order, which features are numerical, categorical, or embedding, and the per-feature information the architecture needs to size its layers.
loaded.feature_schema_
# {
# "column_order": ["age", "income", "city"],
# "feature_groups": {
# "numerical": ["age", "income"],
# "categorical": ["city"],
# "embedding": [],
# },
# "feature_info": { # per-feature details used to build the network
# "num": {"age": {...}, "income": {...}},
# "cat": {"city": {...}},
# "emb": {},
# },
# "schema": {...}, # full preprocessing-derived schema snapshot
# }
This single description does several jobs. The architecture reads it to size its input and embedding layers, so you never wire feature counts by hand. It records which columns the model expects, in what order and of what type, which is what lets InferenceModel.validate_input() reject a mismatched request at serving time. Because it is saved inside the artifact, a reloaded model knows its feature layout without re-fitting.
Note
The schema grows with the number of features, not the number of rows. It is the piece that lets a saved model carry “how to feed me” alongside its weights, so think of it as the bridge between preprocessing, the network, and deployment.
Model Inspection
All DeepTab estimators inherit InspectionMixin, which provides four read-only methods and one dry-run profiler. They are safe to call before or after fitting.
describe(): structured dict
Returns a structured snapshot of the estimator and its fitted state:
info = model.describe()
# {
# "estimator": "MLPClassifier",
# "architecture": "MLP",
# "task": "classification",
# "built": True,
# "fitted": True,
# "model_config": "MLPConfig",
# "feature_counts": {"numerical": 8, "categorical": 2, "embedding": 0, "total": 10},
# "num_classes": 2,
# "parameters": {"total": 45312, "trainable": 45312, "non_trainable": 0},
# }
Safe to call before fitting: parameter and feature metadata are omitted when the model is not yet built.
summary(): human-readable string
Compact text report combining describe() and runtime_info():
print(model.summary())
# MLPClassifier summary
# Architecture: MLP
# Task: classification
# Built: True
# Fitted: True
# Model config: MLPConfig
# Features: 10 total (8 numerical, 2 categorical, 0 embedding)
# Parameters: 45,312 total, 45,312 trainable, 0 non-trainable
# Device: cpu
# Precision: None
# Accelerator: None
parameter_table(): per-parameter DataFrame
Returns one row per parameter:
df = model.parameter_table()
df.head()
# name module shape num_params trainable dtype device
# estimator.embedding.weight estimator.embedding (50, 32) 1600 True float32 cpu
# ...
# Trainable only
df_train = model.parameter_table(trainable_only=True)
runtime_info(): device and training setup
info = model.runtime_info()
# {
# "built": True,
# "fitted": True,
# "device": "cpu",
# "dtype": "float32",
# "precision": None,
# "accelerator": None,
# "max_epochs": 100,
# "current_epoch": 87,
# "batch_size": 64,
# "lr": 0.0001,
# "weight_decay": 1e-06,
# ...
# }
profile(): pre-training dry run
profile() builds the model on a small sample, runs a forward pass, and returns a complete picture of what training will look like, without any gradient updates.
result = model.profile(X, y) # dry_run=True by default
# {
# "builds": True,
# "error": None,
# "device": "cpu",
# "dtype": "float32",
# "total_params": 45312,
# "trainable_params": 45312,
# "memory_mb": 0.173,
# "batch_shape": {"num_features": [[64, 20], ...], "cat_features": [], "labels": [64, 1]},
# "output_shape": [64, 1],
# "loss_fct": "BCEWithLogitsLoss",
# "forward_ms_median": 1.4,
# "forward_ms_min": 1.1,
# "describe": {...},
# "runtime": {...},
# }
Key parameters:
Parameter |
Default |
Effect |
|---|---|---|
|
|
Discard temporary build after profiling; leaves estimator unfitted |
|
|
Number of passes used to estimate timing; median is reported |
|
|
Override batch size for timing (defaults to |
|
|
Seed for the dry-run build |
When dry_run=False, the estimator is left built after the call and can proceed directly to fit().
If the build fails for any reason, result["builds"] is False and result["error"] contains the exception message, while all other keys are still present.