File size: 9,657 Bytes
56b6e17 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 | from transformers.modeling_outputs import ModelOutput
from transformers import PreTrainedModel
from transformers.utils import logging
from typing import Optional, Tuple, Union, List, Dict
from dataclasses import dataclass
import torch.nn as nn
import torch
from copy import deepcopy
from cerberusdet.models.cerberus import CerberusDet
from cerberusdet.utils.general import (
nms_between_tasks,
non_max_suppression,
scale_boxes,
)
from .configuration_cerberus import CerberusDetConfig
logger = logging.get_logger(__name__)
@dataclass
class CerberusOutput(ModelOutput):
nms_results_per_task_batch: Dict[str, List[torch.Tensor]] = None
boxes: List[torch.IntTensor] = None
scores: List[torch.FloatTensor] = None
labels: List[torch.IntTensor] = None
tasks_ids: List[torch.IntTensor] = None
class CerberusDetPreTrainedModel(PreTrainedModel):
config_class = CerberusDetConfig
base_model_prefix = 'model'
_no_split_modules = ['model']
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initializes the weights of the model layers."""
if module is nn.Conv2d:
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif module is nn.BatchNorm2d:
module.eps = 1e-3
module.momentum = 0.03
elif module in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
module.inplace = True
class CerberusDetForObjectDetection(CerberusDetPreTrainedModel):
def __init__(self, config: CerberusDetConfig):
"""
Initializes the CerberusDet object detection model based on the provided configuration.
Args:
config (CerberusDetConfig): The configuration object containing model parameters.
"""
super().__init__(config)
self.config = config
# initialize a model
self.model = CerberusDet(
task_ids=config.task_ids,
nc=config.tasks_nc,
cfg=config.cfg,
ch=3,
verbose=config.verbose,
stride=config.strides,
backbone_verbose=False,
)
self.model.names = config.names
self.cerber_schedule = config.cfg["cerber"]
self.model.sequential_split(deepcopy(self.cerber_schedule), device=None)
self.half = config.torch_dtype is torch.float16
# Initialize weights and apply final processing
self.post_init()
def _combine_output(self, output_per_task: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Combines task results for a SINGLE image.
"""
output = torch.zeros((0, 6))
for task, bboxes in output_per_task.items():
bboxes = bboxes.cpu()
# Mapping local class IDs to global ones
if bboxes.shape[0] > 0:
bboxes[:, 5].apply_(lambda cat: self.config.categories_inds_map[task][int(cat)])
output = torch.cat((output, bboxes), 0)
return output
def forward(
self,
pixel_values: torch.Tensor,
original_shapes: Optional[Union[List[Tuple[int, int]], torch.Tensor]] = None,
max_det: int = 300,
agnostic_nms: bool = False,
conf_thres: float = None,
iou_thres: float = None,
iou_thres_between_tasks: float = None,
return_dict: Optional[bool] = None,
postprocess: bool = True,
) -> Union[Tuple, CerberusOutput]:
"""
Performs the forward pass of the model and the full post-processing pipeline, including
multi-task inference, Non-Maximum Suppression (NMS), and coordinate rescaling.
Args:
pixel_values (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
original_shapes (`List[Tuple[int, int]]` or `torch.Tensor`, *optional*):
The original (height, width) of each image in the batch. If provided, bounding box
coordinates will be rescaled from the model input size back to these original dimensions.
max_det (`int`, *optional*, defaults to 300):
The maximum number of detections for each task to return per image.
agnostic_nms (`bool`, *optional*, defaults to `False`):
If `True`, performs class-agnostic NMS (overlapping boxes of different classes can be suppressed).
conf_thres (`float`, *optional*):
The confidence threshold for filtering detections. If `None`, defaults to the value in the config.
iou_thres (`float`, *optional*):
The IoU threshold for NMS within a single task (head). If `None`, defaults to the value in the config.
iou_thres_between_tasks (`float`, *optional*):
The IoU threshold for NMS applied to detections overlapping between different tasks.
If `None`, defaults to the value in the config.
return_dict (`bool`, *optional*):
Whether to return a [`~utils.ModelOutput`] object instead of a plain tuple.
Returns:
`Union[Tuple, CerberusOutput]`:
An instance of [`CerberusOutput`] containing detected boxes, scores, labels, and task IDs,
or a tuple of these tensors if `return_dict=False`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Use defaults from self/config if not provided
agnostic_nms = agnostic_nms if agnostic_nms is not None else self.config.agnostic_nms
conf_thres = conf_thres if conf_thres is not None else self.config.conf_thres
iou_thres = iou_thres if iou_thres is not None else self.config.iou_thres
iou_thres_between_tasks = iou_thres_between_tasks if iou_thres_between_tasks is not None else self.config.iou_thres_between_tasks
# 1. Forward pass
# model() returns dict: {task_name: (predictions, ...)}
all_out = self.model(pixel_values)
batch_size = pixel_values.shape[0]
# Dictionary to store "raw" NMS results for each task across the entire batch
# task_name -> List[Tensor] (list length = batch_size)
nms_results_per_task_batch: Dict[str, List[torch.Tensor]] = {}
for task in all_out.keys():
task_pred, _ = all_out[task]
# non_max_suppression returns a list of tensors of length batch_size
preds_batch = non_max_suppression(
task_pred,
conf_thres,
iou_thres,
agnostic=agnostic_nms,
max_det=max_det
)
nms_results_per_task_batch[task] = preds_batch
if not postprocess:
if not return_dict:
return (nms_results_per_task_batch, None, None, None, None)
else:
return CerberusOutput(nms_results_per_task_batch=nms_results_per_task_batch)
# Lists to collect results per image for batch padding later
batch_boxes = []
batch_scores = []
batch_labels = []
batch_tasks = []
# 2. Process each image in the batch individually
for i in range(batch_size):
# Collect detections from all tasks for the current image i
# task_name -> Tensor (detections for img[i])
current_img_task_results = {
task: preds_list[i].detach()
for task, preds_list in nms_results_per_task_batch.items()
}
# 3. Combine & Cross-task NMS
# det shape: [N, 6] -> (x1, y1, x2, y2, conf, global_cls)
det = self._combine_output(current_img_task_results)
det = nms_between_tasks(det, self.config.categories_inds_map, iou_thres=iou_thres_between_tasks)
# 4. Scale coords to original image
if len(det) > 0 and original_shapes is not None:
if isinstance(original_shapes, torch.Tensor):
curr_shape = original_shapes[i].tolist()
else:
curr_shape = original_shapes[i]
# Scale: tensor.shape[2:] is (H_net, W_net) of the current tensor
det[:, :4] = scale_boxes(pixel_values.shape[2:], det[:, :4], curr_shape).round()
# 5. Collect tensors for this image
if len(det) > 0:
# Append tensors (move to CPU/Int as required by output spec)
batch_boxes.append(det[:, :4].int())
batch_scores.append(det[:, 4])
batch_labels.append(det[:, 5].int())
# Map global class ID back to Task ID
current_task_ids = [self.config.global_cls_to_task_id_map.get(int(c.item()), -1) for c in det[:, 5]]
batch_tasks.append(torch.tensor(current_task_ids, dtype=torch.int32))
else:
# Empty tensors if no detection
batch_boxes.append(torch.zeros((0, 4), dtype=torch.int32))
batch_scores.append(torch.zeros((0,), dtype=torch.float32))
batch_labels.append(torch.zeros((0,), dtype=torch.int32))
batch_tasks.append(torch.zeros((0,), dtype=torch.int32))
if not return_dict:
return (nms_results_per_task_batch, batch_boxes, batch_scores, batch_labels, batch_tasks)
return CerberusOutput(
nms_results_per_task_batch=nms_results_per_task_batch,
boxes=batch_boxes,
scores=batch_scores,
labels=batch_labels,
tasks_ids=batch_tasks
)
|