Source code for deeptab.data.dataset
from torch.utils.data import Dataset
from deeptab.data.schema import TabularBatch
[docs]
class TabularDataset(Dataset):
"""Custom dataset for handling structured tabular data with separate categorical
and numerical features.
This dataset is task-agnostic and simply stores and retrieves features and labels
without any task-specific preprocessing. Label dtype conversion should be handled
externally by the DataModule or training logic.
Parameters
----------
cat_features_list : list of Tensors
A list of tensors representing the categorical features.
num_features_list : list of Tensors
A list of tensors representing the numerical features.
embeddings_list : list of Tensors, optional
A list of tensors representing the embeddings.
labels : Tensor, optional
A tensor of labels. If None, the dataset is used for prediction.
return_batch_object : bool, default=False
If True, returns a TabularBatch object instead of a tuple. For backward
compatibility, defaults to False.
"""
def __init__(
self,
cat_features_list,
num_features_list,
embeddings_list=None,
labels=None,
return_batch_object=False,
):
assert cat_features_list or num_features_list # noqa: S101
self.cat_features_list = cat_features_list # Categorical features tensors
self.num_features_list = num_features_list # Numerical features tensors
self.embeddings_list = embeddings_list # Embeddings tensors (optional)
self.labels = labels # Labels (optional, None in prediction mode)
self.return_batch_object = return_batch_object
def __len__(self):
_feats = self.num_features_list if self.num_features_list else self.cat_features_list
return len(_feats[0])
def __getitem__(self, idx):
"""Retrieves the features and label for a given index.
Parameters
----------
idx : int
The index of the data point.
Returns
-------
tuple or TabularBatch
If return_batch_object is False (default), returns a tuple containing
lists of tensors for numerical features, categorical features, embeddings
(if available), and a label (if available).
If return_batch_object is True, returns a TabularBatch object.
"""
cat_features = [feature_tensor[idx] for feature_tensor in self.cat_features_list]
num_features = [feature_tensor[idx] for feature_tensor in self.num_features_list]
if self.embeddings_list is not None:
embeddings = [embed_tensor[idx] for embed_tensor in self.embeddings_list]
else:
embeddings = None
label = self.labels[idx] if self.labels is not None else None
if self.return_batch_object:
return TabularBatch(
numerical_features=num_features,
categorical_features=cat_features,
embeddings=embeddings,
labels=label,
)
else:
# Legacy tuple format
if label is not None:
return (num_features, cat_features, embeddings), label
else:
return (num_features, cat_features, embeddings)