TabTransformer
Overview
TabTransformer uses self-attention to contextualize categorical feature embeddings. DeepTab’s implementation follows that core idea: categorical and external embedding features pass through a Transformer encoder, while numerical features are normalized and concatenated afterward before the prediction head.
Use it when categorical interactions are central to the task. If the dataset has no categorical features, use FTTransformer, MLP, ResNet, or TabM instead.
Architectural Details
DeepTab’s TabTransformer pipeline is:
Validate that categorical feature information is present.
Embed categorical and external embedding features with
EmbeddingLayer.Apply a Transformer encoder to the categorical token sequence.
Pool the contextualized categorical tokens.
Concatenate the pooled categorical representation with layer-normalized numerical features.
Predict with
MLPhead.
categorical tokens -> TransformerEncoder -> pooling
numerical features -> LayerNorm
[pooled categorical, normalized numerical] -> MLPhead
Main Building Blocks
Component |
DeepTab implementation |
Role |
|---|---|---|
Categorical tokenizer |
|
Embeds categorical columns only. |
Transformer |
|
Contextualizes categorical tokens. |
Numerical path |
|
Normalizes raw numerical vector. |
Pooling |
|
Reduces categorical tokens. |
Head |
|
Combines categorical and numerical representations. |
Implementation Notes
DeepTab raises a ValueError if no categorical features are available. This is intentional for this implementation, because the Transformer body is applied only to categorical tokens.
The default config uses d_model=128, n_layers=4, n_heads=8, transformer_activation=ReGLU(), and transformer_dim_feedforward=512.
Practical Config
from deeptab.configs import PreprocessingConfig, TabTransformerConfig, TrainerConfig
from deeptab.models import TabTransformerClassifier
model = TabTransformerClassifier(
model_config=TabTransformerConfig(
d_model=128,
n_layers=4,
n_heads=8,
attn_dropout=0.2,
ff_dropout=0.1,
pooling_method="avg",
),
preprocessing_config=PreprocessingConfig(
numerical_preprocessing="standard",
categorical_preprocessing="int",
),
trainer_config=TrainerConfig(lr=3e-4, batch_size=128, max_epochs=100),
random_state=101,
)
Key settings:
Setting |
Typical range |
Effect |
|---|---|---|
|
|
Categorical token width. |
|
|
Contextualization depth. |
|
|
Attention heads. |
|
|
How categorical tokens are reduced. |
|
|
Extra capacity after concatenation. |
When To Use
Use TabTransformer for categorical-heavy datasets where context-dependent categorical embeddings are likely to matter. Prefer FTTransformer for numerical-heavy datasets.
References
Huang et al., TabTransformer: Tabular Data Modeling Using Contextual Embeddings.
Vaswani et al., Attention Is All You Need.