from __future__ import annotations
import sys
from dataclasses import dataclass
from qualia_codegen_core.typing import TYPE_CHECKING, NDArrayFloatOrInt
from .TBaseLayer import TBaseLayer
if TYPE_CHECKING:
from collections import OrderedDict # noqa: TC003
from .TActivationLayer import TActivation # noqa: TC001
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
[docs]
@dataclass
class TConvLayer(TBaseLayer):
activation: TActivation
kernel: NDArrayFloatOrInt
kernel_size: tuple[int, ...]
strides: tuple[int, ...]
filters: int
use_bias: bool
bias: NDArrayFloatOrInt
groups: int
@property
@override
def weights(self) -> OrderedDict[str, NDArrayFloatOrInt]:
w = super().weights
w['kernel'] = self.kernel
if self.use_bias:
w['bias'] = self.bias
return w