Advanced Training and Production Inference

This tutorial covers the parts of DeepTab you reach for once the basics feel comfortable: tuning the optimizer, controlling the learning-rate schedule, plugging in your own optimizer or scheduler, and deploying a trained model with InferenceModel. Each part builds on the one before it, but the sections are self-contained, so feel free to jump straight to the topic you need.

Note

The notebook linked above mirrors this tutorial. Use the markdown page for reading; use the notebook when you want to execute cells directly.

What You Will Learn

  • How to discover available optimizers and schedulers at runtime.

  • How to pass optimizer_type, optimizer_kwargs, and scheduler fields through TrainerConfig.

  • What no_weight_decay_for_bias_and_norm does and when to use it.

  • How to register a custom optimizer or scheduler so it works with the same config interface.

  • How to use InferenceModel for schema-validated, deployment-friendly inference.

  • How validate_input, predict_proba, and predict_params behave in production.

Setup

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split

from deeptab import InferenceModel
from deeptab.configs import MambularConfig, PreprocessingConfig, TrainerConfig
from deeptab.models import MambularClassifier
from deeptab.training import (
    available_optimizers,
    available_schedulers,
    register_optimizer,
    register_scheduler,
    unregister_optimizer,
    unregister_scheduler,
)

Note

For a quick demonstration these tutorials train with very low max_epochs and patience (5 and 2). Treat these as placeholders and choose values that match your own compute budget and problem. As a starting point, at least max_epochs=100 and patience=10 are recommended for meaningful results.

Data

All examples in this tutorial share a single binary classification dataset.

RANDOM_STATE = 42

X_num, y = make_classification(
    n_samples=1500,
    n_features=12,
    n_informative=8,
    n_redundant=2,
    random_state=RANDOM_STATE,
)

X = pd.DataFrame(X_num, columns=[f"feat_{i}" for i in range(X_num.shape[1])])

X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.3, stratify=y, random_state=RANDOM_STATE
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=RANDOM_STATE
)

Part 1: Optimizers

The optimizer decides how each gradient update turns into a change in the model’s weights. DeepTab defaults to Adam, a dependable starting point for most tabular problems. When you want more control, you can select any optimizer in the registry and forward custom arguments to it through TrainerConfig.

Discovering available optimizers

available_optimizers() returns a sorted list of all names registered in the optimizer registry. All standard torch.optim classes are pre-registered at import time.

opts = available_optimizers()
print(opts)
# ['adadelta', 'adagrad', 'adam', 'adamax', 'adamw', 'asgd', 'lbfgs',
#  'nadam', 'radam', 'rmsprop', 'rprop', 'sgd', 'sparseadam']

Note

Registry names are stored in lowercase, so available_optimizers() always returns lowercase strings. Lookups are case insensitive, so optimizer_type="AdamW" and optimizer_type="adamw" resolve to the same class.

Using AdamW instead of the default Adam

Pass optimizer_type to TrainerConfig. Any additional optimizer constructor arguments go in optimizer_kwargs:

trainer = TrainerConfig(
    max_epochs=5,
    batch_size=128,
    lr=3e-4,
    patience=2,
    optimizer_type="AdamW",
    optimizer_kwargs={
        "betas": (0.9, 0.98),   # custom momentum coefficients
        "eps": 1e-8,            # numerical stability term
    },
    weight_decay=1e-2,          # passed as a top-level TrainerConfig field
)

clf = MambularClassifier(
    model_config=MambularConfig(d_model=64, n_layers=3),
    preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
    trainer_config=trainer,
    random_state=RANDOM_STATE,
)
clf.fit(X_train, y_train, X_val=X_val, y_val=y_val)
print("AdamW AUROC:", roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]))

Note

lr and weight_decay are top-level TrainerConfig fields because they are also used by the early-stopping monitor and parameter-group logic. All other optimizer-specific arguments go in optimizer_kwargs.

Weight-decay exemption for bias and normalisation layers

Setting no_weight_decay_for_bias_and_norm=True splits model parameters into two groups: one with weight_decay as configured and one (biases and normalisation weights) with weight_decay=0. This is the recommended practice for transformer-style architectures.

trainer_wd = TrainerConfig(
    max_epochs=5,
    batch_size=128,
    lr=3e-4,
    patience=2,
    optimizer_type="AdamW", # Case-insensitive, should work the same as "adamw"
    weight_decay=1e-2,
    no_weight_decay_for_bias_and_norm=True,   # enable the weight-decay split
)

clf_wd = MambularClassifier(
    model_config=MambularConfig(d_model=64, n_layers=3),
    preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
    trainer_config=trainer_wd,
    random_state=RANDOM_STATE,
)
clf_wd.fit(X_train, y_train, X_val=X_val, y_val=y_val)

Using SGD with momentum

SGD with momentum takes more tuning than Adam, but paired with a good learning-rate schedule it can settle into flatter minima that generalise well. Nesterov momentum usually adds a small further improvement at no extra cost.

clf_sgd = MambularClassifier(
    model_config=MambularConfig(d_model=64, n_layers=3),
    preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
    trainer_config=TrainerConfig(
        max_epochs=5,
        batch_size=128,
        lr=5e-3,
        patience=2,
        optimizer_type="SGD",
        optimizer_kwargs={"momentum": 0.9, "nesterov": True},
        weight_decay=1e-4,
    ),
    random_state=RANDOM_STATE,
)
clf_sgd.fit(X_train, y_train, X_val=X_val, y_val=y_val)

Tip

Unsure which optimizer to pick? Start with AdamW at the default learning rate. It converges quickly and is forgiving of hyperparameter choices. Reach for SGD with momentum only when you have the budget to tune the learning-rate schedule carefully.


Part 2: Schedulers

A scheduler adjusts the learning rate as training progresses, and a good schedule often matters as much as the optimizer itself. A higher rate early on lets the model make rapid progress, while a lower rate later helps it settle into a good solution instead of bouncing around it.

Discovering available schedulers

scheds = available_schedulers()
print(scheds)
# ['constantlr', 'cosineannealinglr', 'cosineannealingwarmrestarts', 'cycliclr',
#  'exponentiallr', 'linearlr', 'multisteplr', 'onecyclelr', 'reducelronplateau',
#  'sequentiallr', 'steplr']

CosineAnnealingLR

Cosine annealing lowers the learning rate from its starting value toward eta_min along a cosine curve spread over T_max epochs. It needs very little tuning and is a strong default when you train for a fixed number of epochs.

clf_cos = MambularClassifier(
    model_config=MambularConfig(d_model=64, n_layers=3),
    preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
    trainer_config=TrainerConfig(
        max_epochs=5,
        batch_size=128,
        lr=3e-4,
        patience=2,
        optimizer_type="AdamW",
        weight_decay=1e-2,
        scheduler_type="CosineAnnealingLR",
        scheduler_kwargs={"T_max": 5, "eta_min": 1e-6},
        scheduler_interval="epoch",
    ),
    random_state=RANDOM_STATE,
)
clf_cos.fit(X_train, y_train, X_val=X_val, y_val=y_val)

ReduceLROnPlateau (default scheduler)

ReduceLROnPlateau is the default scheduler. It watches a metric and reduces the learning rate when that metric stops improving. The TrainerConfig.mode field tells it which direction counts as improvement: mode="min" (the default) for losses, mode="max" for metrics where higher is better such as accuracy or AUROC.

clf_plateau = MambularClassifier(
    model_config=MambularConfig(d_model=64, n_layers=3),
    preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
    trainer_config=TrainerConfig(
        max_epochs=5,
        batch_size=128,
        lr=3e-4,
        patience=2,
        optimizer_type="AdamW",
        weight_decay=1e-2,
        scheduler_type="ReduceLROnPlateau",
        scheduler_monitor="val_loss",    # metric the scheduler watches
        scheduler_kwargs={
            "factor": 0.5,
            "patience": 5,
            "min_lr": 1e-6,
        },
    ),
    random_state=RANDOM_STATE,
)
clf_plateau.fit(X_train, y_train, X_val=X_val, y_val=y_val)

Important

scheduler_monitor defaults to None. When it is None, DeepTab falls back to TrainerConfig.monitor (which is "val_loss" by default). The reduction direction is not inferred from the monitor name: it is taken from TrainerConfig.mode. If you monitor a higher-is-better metric such as accuracy or AUROC, set mode="max" on the TrainerConfig so the scheduler reduces the learning rate at the right moment.

Disabling the scheduler

Set scheduler_type=None to use a constant learning rate:

clf_const_lr = MambularClassifier(
    model_config=MambularConfig(d_model=64, n_layers=3),
    preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
    trainer_config=TrainerConfig(
        max_epochs=5,
        batch_size=128,
        lr=3e-4,
        patience=2,
        scheduler_type=None,
    ),
    random_state=RANDOM_STATE,
)
clf_const_lr.fit(X_train, y_train, X_val=X_val, y_val=y_val)

Step-level scheduler (OneCycleLR)

Some schedulers need to step every batch, not every epoch. Set scheduler_interval="step":

steps_per_epoch = int(np.ceil(len(X_train) / 128))

clf_onecycle = MambularClassifier(
    model_config=MambularConfig(d_model=64, n_layers=3),
    preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
    trainer_config=TrainerConfig(
        max_epochs=5,
        batch_size=128,
        lr=1e-3,
        patience=2,
        optimizer_type="AdamW",
        weight_decay=1e-2,
        scheduler_type="OneCycleLR",
        scheduler_kwargs={
            "max_lr": 1e-3,
            "total_steps": 40 * steps_per_epoch,
            "anneal_strategy": "cos",
        },
        scheduler_interval="step",
    ),
    random_state=RANDOM_STATE,
)

Note

Some schedulers such as OneCycleLR define their own learning-rate curve and work best with scheduler_interval="step". Pass every required scheduler argument (for example total_steps) through scheduler_kwargs.

Warning

OneCycleLR raises an error if training runs for more steps than total_steps. Set total_steps to at least max_epochs * steps_per_epoch, or pass epochs and steps_per_epoch instead, so the schedule covers the whole run.


Part 3: Custom Optimizer and Scheduler Registration

Sometimes the built-in choices are not enough, whether you are reproducing a paper or experimenting with an idea of your own. The registry pattern lets you plug in any optimizer or scheduler that follows the standard torch.optim.Optimizer or torch.optim.lr_scheduler.LRScheduler interface. Once registered, it works through the same TrainerConfig fields as the built-in classes.

How the registry works

DeepTab keeps a process-global mapping of name -> class for optimizers and another for schedulers. When you pass optimizer_type="adamw" to TrainerConfig, DeepTab simply looks that name up in the registry. Three functions act on each registry:

  • register_optimizer(name, cls) / register_scheduler(name, cls): add a new entry.

  • available_optimizers() / available_schedulers(): list what is registered.

  • unregister_optimizer(name) / unregister_scheduler(name): remove an entry you added.

Two rules keep this safe to use:

  • Names are unique. Registering a name that already exists raises a ValueError:

    ValueError: Optimizer 'scaledadam' is already registered. Pass override=True to replace it.
    

    Pass override=True to intentionally replace the entry. This is what you want when you iterate on an implementation and re-run a cell, or when you swap a built-in for your own variant.

  • Built-ins are protected. You can override a built-in like adam, but you cannot unregister it; removing it would break every estimator in the process. Only names you registered yourself can be removed.

Registering a custom optimizer

override=True makes registration idempotent, so re-running the snippet does not raise the “already registered” error above.

class ScaledAdam(torch.optim.Adam):
    """Adam with gradient pre-scaling (toy example)."""

    def __init__(self, params, lr=1e-3, scale=1.0, **kwargs):
        super().__init__(params, lr=lr * scale, **kwargs)


register_optimizer("scaledadam", ScaledAdam, override=True)

# Verify registration (names are stored lowercase)
print("scaledadam" in available_optimizers())   # True

# Use it via TrainerConfig
clf_custom_opt = MambularClassifier(
    model_config=MambularConfig(d_model=64, n_layers=3),
    preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
    trainer_config=TrainerConfig(
        max_epochs=5,
        batch_size=128,
        lr=3e-4,
        patience=2,
        optimizer_type="scaledadam",
        optimizer_kwargs={"scale": 0.8},
    ),
    random_state=RANDOM_STATE,
)
clf_custom_opt.fit(X_train, y_train, X_val=X_val, y_val=y_val)

Registering a custom scheduler

Schedulers follow exactly the same rules: override=True for idempotent re-registration, and the same protection for built-ins.

class WarmupConstant(torch.optim.lr_scheduler.LambdaLR):
    """Linear warmup for `warmup_steps`, then constant LR."""

    def __init__(self, optimizer, warmup_steps: int = 100):
        def _lambda(step: int) -> float:
            if step < warmup_steps:
                return float(step) / max(1, warmup_steps)
            return 1.0

        super().__init__(optimizer, lr_lambda=_lambda)


register_scheduler("warmupconstant", WarmupConstant, override=True)

print("warmupconstant" in available_schedulers())   # True

clf_warmup = MambularClassifier(
    model_config=MambularConfig(d_model=64, n_layers=3),
    preprocessing_config=PreprocessingConfig(numerical_preprocessing="quantile"),
    trainer_config=TrainerConfig(
        max_epochs=5,
        batch_size=128,
        lr=3e-4,
        patience=2,
        scheduler_type="warmupconstant",
        scheduler_kwargs={"warmup_steps": 200},
        scheduler_interval="step",
    ),
    random_state=RANDOM_STATE,
)
clf_warmup.fit(X_train, y_train, X_val=X_val, y_val=y_val)

Cleaning up: unregistering your entries

If you no longer need a custom optimizer or scheduler (for example to free up a name or reset state between experiments), remove it with unregister_optimizer / unregister_scheduler. Use missing_ok=True for idempotent teardown that will not raise if the entry is already gone. Built-in DeepTab names are protected and cannot be removed.

# Remove the custom entries we added above.
unregister_optimizer("scaledadam")
unregister_scheduler("warmupconstant")
print("scaledadam" in available_optimizers())   # False

# Safe to call again: missing_ok avoids an error if it is already gone.
unregister_optimizer("scaledadam", missing_ok=True)

# Built-ins are protected: this raises, by design.
try:
    unregister_optimizer("adam")
except ValueError as err:
    print("Refused to remove built-in:", err)

Part 4: Production Inference with InferenceModel

InferenceModel wraps a fitted estimator and exposes only the prediction surface. Training methods (fit, optimize_hparams, etc.) are absent, which prevents accidental retraining in service code.

Save a model to disk

clf_wd.save("advanced_clf.deeptab")

Load via from_path

model = InferenceModel.from_path("advanced_clf.deeptab")
print(model)
# InferenceModel(task='classification', estimator='MambularClassifier',
#                n_features=12, features=['feat_0', 'feat_1', 'feat_2', ...], n_classes=2)

Wrap an already-fitted estimator

If the estimator is already in memory, skip the save/load round-trip:

model_live = InferenceModel.from_estimator(clf_wd)
print(model_live.task)          # classification
print(model_live.n_features)    # 12

Introspection

info = model.describe()
print(list(info))
# ['estimator', 'architecture', 'task', 'built', 'fitted', 'model_config',
#  'preprocessing_config', 'trainer_config', 'feature_counts', 'num_classes',
#  'family', 'returns_ensemble', 'parameters', 'inference_task']

rt = model.runtime_info()
print(list(rt))
# ['built', 'fitted', 'device', 'dtype', 'precision', 'accelerator', 'strategy',
#  'num_devices', 'root_device', 'max_epochs', 'current_epoch', 'global_step',
#  'batch_size', 'optimizer_type', 'lr', 'weight_decay', 'logger', 'deterministic']

params_df = model.parameter_table()
print(params_df.head())

Schema validation

validate_input checks that the incoming DataFrame matches the feature schema seen during training. Call it before every forward pass in production.

# Happy path
X_clean = model.validate_input(X_test)

# Missing column
X_bad = X_test.drop(columns=["feat_0"])
try:
    model.validate_input(X_bad)
except ValueError as exc:
    print(exc)
# ValueError: Input is missing 1 column(s) that were present during training:
# ['feat_0'].

# Extra columns are dropped with a warning in lenient mode
X_wide = X_test.copy()
X_wide["audit_id"] = range(len(X_test))
X_clean = model.validate_input(X_wide, allow_extra_columns=True)
# UserWarning: Input has 1 column(s) not seen during training (['audit_id']);
# they will be dropped.

Prediction

# Hard class labels
labels = model.predict(X_clean)
print("Accuracy:", accuracy_score(y_test, labels))

# Class probabilities (classification only)
proba = model.predict_proba(X_clean)
print("AUROC:", roc_auc_score(y_test, proba[:, 1]))

predict_proba raises TypeError for non-classification tasks:

# model.predict_proba(X_clean)
# TypeError: predict_proba() is only available for classification models,
# but this model's task is 'regression'.

Production service pattern

A minimal FastAPI-style handler using InferenceModel:

# Module-level: load once at startup
_MODEL = InferenceModel.from_path("advanced_clf.deeptab")


def score(payload: dict) -> dict:
    X = pd.DataFrame([payload])
    X_clean = _MODEL.validate_input(X, allow_extra_columns=True)
    proba   = _MODEL.predict_proba(X_clean)
    label   = _MODEL.predict(X_clean)
    return {
        "probability_positive": float(proba[0, 1]),
        "label": int(label[0]),
    }

Configuration Reference

TrainerConfig field

Default

Effect

optimizer_type

"Adam"

Optimizer class name from the registry

optimizer_kwargs

None

Extra constructor kwargs (beyond lr, weight_decay)

weight_decay

1e-6

Passed to optimizer; exempt layers use 0.0

no_weight_decay_for_bias_and_norm

False

Split params into WD/no-WD groups

scheduler_type

"ReduceLROnPlateau"

Scheduler class name, or None

scheduler_kwargs

None

Scheduler constructor kwargs

scheduler_monitor

None

Metric watched by plateau schedulers; falls back to monitor

scheduler_interval

"epoch"

"epoch" or "step"

scheduler_frequency

1

Step frequency multiplier

Next Steps