Source code for qualia_codegen_plugin_snn.graph.layers.TObjectDetectionPostProcessLayer


from __future__ import annotations

import sys
from dataclasses import dataclass

from qualia_codegen_core.graph.layers.TBaseLayer import TBaseLayer
from qualia_codegen_core.typing import TYPE_CHECKING, NDArrayFloatOrInt

if TYPE_CHECKING:
    from collections import OrderedDict  # noqa: TCH003

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

[docs] @dataclass class TObjectDetectionPostProcessLayer(TBaseLayer): num_classes: int num_fms: int image_shape: tuple[int, ...] score_threshold: float nms_threshold: float topk_candidates: int detections_per_image: int box_coder_weights: list[float] anchors: NDArrayFloatOrInt @property @override def weights(self) -> OrderedDict[str, NDArrayFloatOrInt]: w = super().weights w['anchors'] = self.anchors return w