File size: 9,657 Bytes
56b6e17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
from transformers.modeling_outputs import ModelOutput
from transformers import PreTrainedModel
from transformers.utils import logging

from typing import Optional, Tuple, Union, List, Dict
from dataclasses import dataclass

import torch.nn as nn
import torch
from copy import deepcopy

from cerberusdet.models.cerberus import CerberusDet
from cerberusdet.utils.general import (
    nms_between_tasks,
    non_max_suppression,
    scale_boxes,
)

from .configuration_cerberus import CerberusDetConfig


logger = logging.get_logger(__name__)


@dataclass
class CerberusOutput(ModelOutput):
    nms_results_per_task_batch: Dict[str, List[torch.Tensor]] = None
    boxes: List[torch.IntTensor] = None
    scores: List[torch.FloatTensor] = None
    labels: List[torch.IntTensor] = None
    tasks_ids: List[torch.IntTensor] = None


class CerberusDetPreTrainedModel(PreTrainedModel):
    config_class = CerberusDetConfig
    base_model_prefix = 'model'
    _no_split_modules = ['model']

    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
        """Initializes the weights of the model layers."""

        if module is nn.Conv2d:
            pass  # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif module is nn.BatchNorm2d:
            module.eps = 1e-3
            module.momentum = 0.03
        elif module in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
            module.inplace = True


class CerberusDetForObjectDetection(CerberusDetPreTrainedModel):
    def __init__(self, config: CerberusDetConfig):
        """
        Initializes the CerberusDet object detection model based on the provided configuration.

        Args:
            config (CerberusDetConfig): The configuration object containing model parameters.
        """
        super().__init__(config)
        self.config = config

        # initialize a model
        self.model = CerberusDet(
            task_ids=config.task_ids,
            nc=config.tasks_nc,
            cfg=config.cfg,
            ch=3,
            verbose=config.verbose,
            stride=config.strides,
            backbone_verbose=False,
        )
        self.model.names = config.names
        self.cerber_schedule = config.cfg["cerber"]
        self.model.sequential_split(deepcopy(self.cerber_schedule), device=None)

        self.half = config.torch_dtype is torch.float16

        # Initialize weights and apply final processing
        self.post_init()

    def _combine_output(self, output_per_task: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Combines task results for a SINGLE image.
        """
        output = torch.zeros((0, 6))
        for task, bboxes in output_per_task.items():
            bboxes = bboxes.cpu()
            # Mapping local class IDs to global ones
            if bboxes.shape[0] > 0:
                bboxes[:, 5].apply_(lambda cat: self.config.categories_inds_map[task][int(cat)])
                output = torch.cat((output, bboxes), 0)
        return output

    def forward(
        self,
        pixel_values: torch.Tensor,
        original_shapes: Optional[Union[List[Tuple[int, int]], torch.Tensor]] = None,
        max_det: int = 300,
        agnostic_nms: bool = False,
        conf_thres: float = None,
        iou_thres: float = None,
        iou_thres_between_tasks: float = None,
        return_dict: Optional[bool] = None,
        postprocess: bool = True,
    ) -> Union[Tuple, CerberusOutput]:
        """
        Performs the forward pass of the model and the full post-processing pipeline, including
        multi-task inference, Non-Maximum Suppression (NMS), and coordinate rescaling.

        Args:
            pixel_values (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
                Pixel values.
            original_shapes (`List[Tuple[int, int]]` or `torch.Tensor`, *optional*):
                The original (height, width) of each image in the batch. If provided, bounding box
                coordinates will be rescaled from the model input size back to these original dimensions.
            max_det (`int`, *optional*, defaults to 300):
                The maximum number of detections for each task to return per image.
            agnostic_nms (`bool`, *optional*, defaults to `False`):
                If `True`, performs class-agnostic NMS (overlapping boxes of different classes can be suppressed).
            conf_thres (`float`, *optional*):
                The confidence threshold for filtering detections. If `None`, defaults to the value in the config.
            iou_thres (`float`, *optional*):
                The IoU threshold for NMS within a single task (head). If `None`, defaults to the value in the config.
            iou_thres_between_tasks (`float`, *optional*):
                The IoU threshold for NMS applied to detections overlapping between different tasks.
                If `None`, defaults to the value in the config.
            return_dict (`bool`, *optional*):
                Whether to return a [`~utils.ModelOutput`] object instead of a plain tuple.

        Returns:
            `Union[Tuple, CerberusOutput]`:
                An instance of [`CerberusOutput`] containing detected boxes, scores, labels, and task IDs,
                or a tuple of these tensors if `return_dict=False`.
        """

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Use defaults from self/config if not provided
        agnostic_nms = agnostic_nms if agnostic_nms is not None else self.config.agnostic_nms
        conf_thres = conf_thres if conf_thres is not None else self.config.conf_thres
        iou_thres = iou_thres if iou_thres is not None else self.config.iou_thres
        iou_thres_between_tasks = iou_thres_between_tasks if iou_thres_between_tasks is not None else self.config.iou_thres_between_tasks

        # 1. Forward pass
        # model() returns dict: {task_name: (predictions, ...)}
        all_out = self.model(pixel_values)

        batch_size = pixel_values.shape[0]

        # Dictionary to store "raw" NMS results for each task across the entire batch
        # task_name -> List[Tensor] (list length = batch_size)
        nms_results_per_task_batch: Dict[str, List[torch.Tensor]] = {}

        for task in all_out.keys():
            task_pred, _ = all_out[task]
            # non_max_suppression returns a list of tensors of length batch_size
            preds_batch = non_max_suppression(
                task_pred,
                conf_thres,
                iou_thres,
                agnostic=agnostic_nms,
                max_det=max_det
            )
            nms_results_per_task_batch[task] = preds_batch

        if not postprocess:
            if not return_dict:
                return (nms_results_per_task_batch, None, None, None, None)
            else:
                return CerberusOutput(nms_results_per_task_batch=nms_results_per_task_batch)

        # Lists to collect results per image for batch padding later
        batch_boxes = []
        batch_scores = []
        batch_labels = []
        batch_tasks = []

        # 2. Process each image in the batch individually
        for i in range(batch_size):
            # Collect detections from all tasks for the current image i
            # task_name -> Tensor (detections for img[i])
            current_img_task_results = {
                task: preds_list[i].detach()
                for task, preds_list in nms_results_per_task_batch.items()
            }

            # 3. Combine & Cross-task NMS
            # det shape: [N, 6] -> (x1, y1, x2, y2, conf, global_cls)
            det = self._combine_output(current_img_task_results)
            det = nms_between_tasks(det, self.config.categories_inds_map, iou_thres=iou_thres_between_tasks)

            # 4. Scale coords to original image
            if len(det) > 0 and original_shapes is not None:
                if isinstance(original_shapes, torch.Tensor):
                    curr_shape = original_shapes[i].tolist()
                else:
                    curr_shape = original_shapes[i]

                # Scale: tensor.shape[2:] is (H_net, W_net) of the current tensor
                det[:, :4] = scale_boxes(pixel_values.shape[2:], det[:, :4], curr_shape).round()

            # 5. Collect tensors for this image
            if len(det) > 0:
                # Append tensors (move to CPU/Int as required by output spec)
                batch_boxes.append(det[:, :4].int())
                batch_scores.append(det[:, 4])
                batch_labels.append(det[:, 5].int())

                # Map global class ID back to Task ID
                current_task_ids = [self.config.global_cls_to_task_id_map.get(int(c.item()), -1) for c in det[:, 5]]
                batch_tasks.append(torch.tensor(current_task_ids, dtype=torch.int32))
            else:
                # Empty tensors if no detection
                batch_boxes.append(torch.zeros((0, 4), dtype=torch.int32))
                batch_scores.append(torch.zeros((0,), dtype=torch.float32))
                batch_labels.append(torch.zeros((0,), dtype=torch.int32))
                batch_tasks.append(torch.zeros((0,), dtype=torch.int32))

        if not return_dict:
            return (nms_results_per_task_batch, batch_boxes, batch_scores, batch_labels, batch_tasks)

        return CerberusOutput(
            nms_results_per_task_batch=nms_results_per_task_batch,
            boxes=batch_boxes,
            scores=batch_scores,
            labels=batch_labels,
            tasks_ids=batch_tasks
        )