"""Schema definitions for tabular data structures.
Provides typed containers and metadata for tabular datasets.
New in v2.0.0.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
[docs]
@dataclass
class FeatureInfo:
"""Information about a single feature in the tabular dataset.
Parameters
----------
name : str
Feature name or identifier.
preprocessing : str
Preprocessing strategy applied to this feature.
dimension : int
Output dimension after preprocessing (e.g., embedding size).
categories : list or None
List of categories for categorical features, None for numerical.
"""
name: str
preprocessing: str
dimension: int
categories: list[Any] | None = None
@property
def is_categorical(self) -> bool:
"""Check if this feature is categorical."""
return self.categories is not None
[docs]
def to_dict(self) -> dict[str, Any]:
"""Return a serializable representation of the feature metadata."""
categories = self.categories.tolist() if hasattr(self.categories, "tolist") else self.categories # type: ignore[union-attr]
return {
"name": self.name,
"preprocessing": self.preprocessing,
"dimension": self.dimension,
"categories": categories,
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> FeatureInfo:
"""Create a FeatureInfo object from serialized metadata."""
return cls(
name=data["name"],
preprocessing=data.get("preprocessing", "unknown"),
dimension=data.get("dimension", 1),
categories=data.get("categories"),
)
[docs]
@dataclass
class FeatureSchema:
"""Schema describing the structure of tabular input features.
Tracks categorical, numerical, and embedding features with their
preprocessing metadata and dimensions.
Parameters
----------
numerical_features : dict[str, FeatureInfo]
Dictionary mapping numerical feature names to their metadata.
categorical_features : dict[str, FeatureInfo]
Dictionary mapping categorical feature names to their metadata.
embedding_features : dict[str, FeatureInfo] | None
Dictionary mapping embedding feature names to their metadata.
"""
numerical_features: dict[str, FeatureInfo]
categorical_features: dict[str, FeatureInfo]
embedding_features: dict[str, FeatureInfo] | None = None
@property
def num_numerical_features(self) -> int:
"""Total number of numerical features."""
return len(self.numerical_features)
@property
def num_categorical_features(self) -> int:
"""Total number of categorical features."""
return len(self.categorical_features)
@property
def num_embedding_features(self) -> int:
"""Total number of embedding features."""
return len(self.embedding_features) if self.embedding_features else 0
@property
def total_numerical_dim(self) -> int:
"""Total dimension across all numerical features."""
return sum(f.dimension for f in self.numerical_features.values())
@property
def total_categorical_dim(self) -> int:
"""Total dimension across all categorical features."""
return sum(f.dimension for f in self.categorical_features.values())
@property
def total_embedding_dim(self) -> int:
"""Total dimension across all embedding features."""
if not self.embedding_features:
return 0
return sum(f.dimension for f in self.embedding_features.values())
[docs]
def to_dict(self) -> dict[str, Any]:
"""Return a serializable representation of the feature schema."""
return {
"numerical_features": {name: info.to_dict() for name, info in self.numerical_features.items()},
"categorical_features": {name: info.to_dict() for name, info in self.categorical_features.items()},
"embedding_features": (
{name: info.to_dict() for name, info in self.embedding_features.items()}
if self.embedding_features
else None
),
"dimensions": {
"num_numerical_features": self.num_numerical_features,
"num_categorical_features": self.num_categorical_features,
"num_embedding_features": self.num_embedding_features,
"total_numerical_dim": self.total_numerical_dim,
"total_categorical_dim": self.total_categorical_dim,
"total_embedding_dim": self.total_embedding_dim,
},
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> FeatureSchema:
"""Create a FeatureSchema object from serialized metadata."""
embedding_features = data.get("embedding_features")
return cls(
numerical_features={
name: FeatureInfo.from_dict(info) for name, info in data.get("numerical_features", {}).items()
},
categorical_features={
name: FeatureInfo.from_dict(info) for name, info in data.get("categorical_features", {}).items()
},
embedding_features=(
{name: FeatureInfo.from_dict(info) for name, info in embedding_features.items()}
if embedding_features
else None
),
)
[docs]
@classmethod
def from_preprocessor_info(
cls,
num_feature_info: dict | None,
cat_feature_info: dict | None,
embedding_feature_info: dict | None = None,
) -> FeatureSchema:
"""Create a FeatureSchema from preprocessor feature info dictionaries.
Parameters
----------
num_feature_info : dict or None
Numerical feature information from preprocessor.
cat_feature_info : dict or None
Categorical feature information from preprocessor.
embedding_feature_info : dict or None
Embedding feature information from preprocessor.
Returns
-------
FeatureSchema
Constructed feature schema.
"""
numerical_features = {}
if num_feature_info:
for name, info in num_feature_info.items():
numerical_features[str(name)] = FeatureInfo(
name=str(name),
preprocessing=info.get("preprocessing", "unknown"),
dimension=info.get("dimension", 1),
categories=None,
)
categorical_features = {}
if cat_feature_info:
for name, info in cat_feature_info.items():
categorical_features[str(name)] = FeatureInfo(
name=str(name),
preprocessing=info.get("preprocessing", "unknown"),
dimension=info.get("dimension", 1),
categories=info.get("categories"),
)
embedding_features = None
if embedding_feature_info:
embedding_features = {}
for name, info in embedding_feature_info.items():
embedding_features[str(name)] = FeatureInfo(
name=str(name),
preprocessing=info.get("preprocessing", "unknown"),
dimension=info.get("dimension", 1),
categories=None,
)
return cls(
numerical_features=numerical_features,
categorical_features=categorical_features,
embedding_features=embedding_features,
)
[docs]
@dataclass
class TabularBatch:
"""Typed container for a batch of tabular data.
Provides a structured interface for accessing different feature types
and labels in a batch, replacing raw tuples.
Parameters
----------
numerical_features : list[torch.Tensor]
List of tensors for numerical features.
categorical_features : list[torch.Tensor]
List of tensors for categorical features.
embeddings : list[torch.Tensor] | None
List of tensors for precomputed embeddings, if any.
labels : torch.Tensor | None
Labels for supervised learning, None for prediction mode.
"""
numerical_features: list[torch.Tensor]
categorical_features: list[torch.Tensor]
embeddings: list[torch.Tensor] | None = None
labels: torch.Tensor | None = None
[docs]
def to(self, device: torch.device | str) -> TabularBatch:
"""Move all tensors in the batch to the specified device.
Parameters
----------
device : torch.device or str
Target device (e.g., 'cuda', 'cpu', 'mps').
Returns
-------
TabularBatch
A new batch with all tensors moved to the device.
"""
return TabularBatch(
numerical_features=[t.to(device) for t in self.numerical_features],
categorical_features=[t.to(device) for t in self.categorical_features],
embeddings=[t.to(device) for t in self.embeddings] if self.embeddings else None,
labels=self.labels.to(device) if self.labels is not None else None,
)
[docs]
@classmethod
def from_tuple(cls, batch_tuple: tuple) -> TabularBatch:
"""Create a TabularBatch from the legacy tuple format.
Parameters
----------
batch_tuple : tuple
Either ((num_feats, cat_feats, embeddings), labels) or
(num_feats, cat_feats, embeddings).
Returns
-------
TabularBatch
Typed batch container.
"""
if len(batch_tuple) == 2:
# Supervised mode: (features, labels)
features, labels = batch_tuple
num_feats, cat_feats, embeddings = features
return cls(
numerical_features=num_feats,
categorical_features=cat_feats,
embeddings=embeddings,
labels=labels,
)
else:
# Prediction mode: just features
num_feats, cat_feats, embeddings = batch_tuple
return cls(
numerical_features=num_feats,
categorical_features=cat_feats,
embeddings=embeddings,
labels=None,
)
[docs]
def to_tuple(self) -> tuple:
"""Convert back to legacy tuple format for backward compatibility.
Returns
-------
tuple
Either ((num_feats, cat_feats, embeddings), labels) or
(num_feats, cat_feats, embeddings).
"""
features = (self.numerical_features, self.categorical_features, self.embeddings)
if self.labels is not None:
return (features, self.labels)
return features