Source code for deeptab.configs.models.tabr_config

from collections.abc import Callable
from dataclasses import dataclass, field

import torch.nn as nn

from ..core import BaseModelConfig


[docs] @dataclass class TabRConfig(BaseModelConfig): """Architecture-only configuration for TabR models (DeepTab 2.0 API). Training fields (``lr``, ``weight_decay``, ``lr_factor``) are configured via :class:`~deeptab.configs.trainer_config.TrainerConfig`. Parameters ---------- embedding_type : str, default='plr' Type of feature embedding to use (e.g., 'plr', 'ple'). plr_lite : bool, default=True Whether to use the lightweight PLR embedding variant. n_frequencies : int, default=75 Number of random Fourier feature frequencies. frequencies_init_scale : float, default=0.045 Scale for initializing Fourier feature frequencies. d_main : int, default=256 Main hidden dimensionality of the predictor network. context_dropout : float, default=0.38920071545944357 Dropout applied to context (candidate) representations. d_multiplier : int, default=2 Multiplier for intermediate dimensions inside the predictor. encoder_n_blocks : int, default=0 Number of residual blocks in the feature encoder. predictor_n_blocks : int, default=1 Number of residual blocks in the predictor network. mixer_normalization : str, default='auto' Normalization strategy for the mixer (``'auto'`` selects adaptively). dropout0 : float, default=0.38852797479169876 Dropout rate on the first linear projection. dropout1 : float, default=0.0 Dropout rate on the second linear projection. normalization : str, default='LayerNorm' Type of normalization layer to use. memory_efficient : bool, default=False Whether to trade compute for lower memory in candidate lookups. candidate_encoding_batch_size : int, default=0 Batch size for encoding candidates (0 = full batch). context_size : int, default=96 Number of nearest-neighbour candidates to retrieve per sample. """ # Override embedding defaults specific to TabR embedding_type: str = "plr" plr_lite: bool = True n_frequencies: int = 75 frequencies_init_scale: float = 0.045 # Architecture d_main: int = 256 context_dropout: float = 0.38920071545944357 d_multiplier: int = 2 encoder_n_blocks: int = 0 predictor_n_blocks: int = 1 mixer_normalization: str = "auto" dropout0: float = 0.38852797479169876 dropout1: float = 0.0 normalization: str = "LayerNorm" memory_efficient: bool = False candidate_encoding_batch_size: int = 0 context_size: int = 96