Source code for deeptab.configs.models.tabularnn_config
from collections.abc import Callable
from dataclasses import dataclass, field
import torch.nn as nn
from ..core import BaseModelConfig
[docs]
@dataclass
class TabulaRNNConfig(BaseModelConfig):
"""Architecture-only configuration for TabulaRNN models (DeepTab 2.0 API).
Parameters
----------
d_model : int, default=128
Dimensionality of embeddings or model representations.
activation : Callable, default=nn.SELU()
Activation function for the RNN layers.
model_type : str, default='RNN'
Type of model, one of "RNN", "LSTM", "GRU", "mLSTM", "sLSTM".
n_layers : int, default=4
Number of layers in the RNN.
rnn_dropout : float, default=0.2
Dropout rate for the RNN layers.
norm : str, default='RMSNorm'
Normalization method to be used.
residuals : bool, default=False
Whether to include residual connections in the RNN.
norm_first : bool, default=False
Whether to apply normalization before other operations in each block.
bias : bool, default=True
Whether to use bias in the linear layers.
rnn_activation : str, default='relu'
Activation function for the RNN layers.
dim_feedforward : int, default=256
Size of the feedforward network.
d_conv : int, default=4
Size of the convolutional layer for embedding features.
dilation : int, default=1
Dilation factor for the convolution.
conv_bias : bool, default=True
Whether to use bias in the convolutional layers.
head_layer_sizes : list, default=field(default_factory=list
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
Dropout rate for the head layers.
head_skip_layers : bool, default=False
Whether to skip layers in the head.
head_activation : Callable, default=nn.SELU()
Activation function for the head layers.
head_use_batch_norm : bool, default=False
Whether to use batch normalization in the head layers.
pooling_method : str, default='avg'
Pooling method to be used ('avg', 'cls', etc.).
"""
# Override parent defaults
d_model: int = 128
activation: Callable = nn.SELU() # noqa: RUF009
# RNN-specific architecture
model_type: str = "RNN"
n_layers: int = 4
rnn_dropout: float = 0.2
norm: str = "RMSNorm"
residuals: bool = False
norm_first: bool = False
bias: bool = True
rnn_activation: str = "relu"
dim_feedforward: int = 256
d_conv: int = 4
dilation: int = 1
conv_bias: bool = True
# Head
head_layer_sizes: list = field(default_factory=list)
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: Callable = nn.SELU() # noqa: RUF009
head_use_batch_norm: bool = False
# Pooling
pooling_method: str = "avg"