Source code for deeptab.configs.experimental.modernnca_config
from collections.abc import Callable
from dataclasses import dataclass, field
import torch.nn as nn
from ..core import BaseModelConfig
[docs]
@dataclass
class ModernNCAConfig(BaseModelConfig):
"""Architecture-only configuration for ModernNCA models (DeepTab 2.0 API).
Parameters
----------
embedding_type : str, default='plr'
Type of feature embedding to use (e.g., 'plr', 'ple').
plr_lite : bool, default=True
Whether to use the lightweight PLR embedding variant.
n_frequencies : int, default=75
Number of random Fourier feature frequencies.
frequencies_init_scale : float, default=0.045
Scale for initializing Fourier feature frequencies.
dim : int, default=128
Embedding dimensionality per feature.
d_block : int, default=512
Hidden size of each residual block.
n_blocks : int, default=4
Number of residual blocks.
dropout : float, default=0.1
Dropout rate applied inside each block.
temperature : float, default=0.75
Temperature scaling for NCA softmax similarity.
sample_rate : float, default=0.5
Fraction of training candidates used per forward pass.
num_embeddings : dict | None, default=None
Optional dict mapping feature indices to embedding sizes.
head_layer_sizes : list, default=field(default_factory=list
Sizes of the fully connected layers in the prediction head.
head_dropout : float, default=0.5
Dropout rate for the head layers.
head_skip_layers : bool, default=False
Whether to use skip connections in the head layers.
head_activation : Callable, default=nn.SELU()
Activation function for the head layers.
head_use_batch_norm : bool, default=False
Whether to use batch normalization in the head layers.
"""
# Override parent defaults
embedding_type: str = "plr"
plr_lite: bool = True
n_frequencies: int = 75
frequencies_init_scale: float = 0.045
# ModernNCA-specific architecture
dim: int = 128
d_block: int = 512
n_blocks: int = 4
dropout: float = 0.1
temperature: float = 0.75
sample_rate: float = 0.5
num_embeddings: dict | None = None
# Head
head_layer_sizes: list = field(default_factory=list)
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: Callable = nn.SELU() # noqa: RUF009
head_use_batch_norm: bool = False