Source code for deeptab.configs.models.resnet_config

from collections.abc import Callable
from dataclasses import dataclass, field

import torch.nn as nn

from ..core import BaseModelConfig


[docs] @dataclass class ResNetConfig(BaseModelConfig): """Architecture-only configuration for ResNet models (DeepTab 2.0 API). Parameters ---------- activation : Callable, default=nn.SELU() Activation function for the ResNet layers. layer_sizes : list, default=[256, 128, 32] Sizes of the layers in the ResNet. dropout : float, default=0.5 Dropout rate for regularization. norm : bool, default=False Whether to use normalization in the ResNet. num_blocks : int, default=3 Number of residual blocks in the ResNet. """ # Override parent defaults activation: Callable = nn.SELU() # noqa: RUF009 # ResNet-specific architecture layer_sizes: list = field(default_factory=lambda: [256, 128, 32]) dropout: float = 0.5 norm: bool = False num_blocks: int = 3