from transformers.modeling_outputs import ModelOutput from transformers import PreTrainedModel from transformers.utils import logging from typing import Optional, Tuple, Union, List, Dict from dataclasses import dataclass import torch.nn as nn import torch from copy import deepcopy from cerberusdet.models.cerberus import CerberusDet from cerberusdet.utils.general import ( nms_between_tasks, non_max_suppression, scale_boxes, ) from .configuration_cerberus import CerberusDetConfig logger = logging.get_logger(__name__) @dataclass class CerberusOutput(ModelOutput): nms_results_per_task_batch: Dict[str, List[torch.Tensor]] = None boxes: List[torch.IntTensor] = None scores: List[torch.FloatTensor] = None labels: List[torch.IntTensor] = None tasks_ids: List[torch.IntTensor] = None class CerberusDetPreTrainedModel(PreTrainedModel): config_class = CerberusDetConfig base_model_prefix = 'model' _no_split_modules = ['model'] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initializes the weights of the model layers.""" if module is nn.Conv2d: pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif module is nn.BatchNorm2d: module.eps = 1e-3 module.momentum = 0.03 elif module in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: module.inplace = True class CerberusDetForObjectDetection(CerberusDetPreTrainedModel): def __init__(self, config: CerberusDetConfig): """ Initializes the CerberusDet object detection model based on the provided configuration. Args: config (CerberusDetConfig): The configuration object containing model parameters. """ super().__init__(config) self.config = config # initialize a model self.model = CerberusDet( task_ids=config.task_ids, nc=config.tasks_nc, cfg=config.cfg, ch=3, verbose=config.verbose, stride=config.strides, backbone_verbose=False, ) self.model.names = config.names self.cerber_schedule = config.cfg["cerber"] self.model.sequential_split(deepcopy(self.cerber_schedule), device=None) self.half = config.torch_dtype is torch.float16 # Initialize weights and apply final processing self.post_init() def _combine_output(self, output_per_task: Dict[str, torch.Tensor]) -> torch.Tensor: """ Combines task results for a SINGLE image. """ output = torch.zeros((0, 6)) for task, bboxes in output_per_task.items(): bboxes = bboxes.cpu() # Mapping local class IDs to global ones if bboxes.shape[0] > 0: bboxes[:, 5].apply_(lambda cat: self.config.categories_inds_map[task][int(cat)]) output = torch.cat((output, bboxes), 0) return output def forward( self, pixel_values: torch.Tensor, original_shapes: Optional[Union[List[Tuple[int, int]], torch.Tensor]] = None, max_det: int = 300, agnostic_nms: bool = False, conf_thres: float = None, iou_thres: float = None, iou_thres_between_tasks: float = None, return_dict: Optional[bool] = None, postprocess: bool = True, ) -> Union[Tuple, CerberusOutput]: """ Performs the forward pass of the model and the full post-processing pipeline, including multi-task inference, Non-Maximum Suppression (NMS), and coordinate rescaling. Args: pixel_values (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. original_shapes (`List[Tuple[int, int]]` or `torch.Tensor`, *optional*): The original (height, width) of each image in the batch. If provided, bounding box coordinates will be rescaled from the model input size back to these original dimensions. max_det (`int`, *optional*, defaults to 300): The maximum number of detections for each task to return per image. agnostic_nms (`bool`, *optional*, defaults to `False`): If `True`, performs class-agnostic NMS (overlapping boxes of different classes can be suppressed). conf_thres (`float`, *optional*): The confidence threshold for filtering detections. If `None`, defaults to the value in the config. iou_thres (`float`, *optional*): The IoU threshold for NMS within a single task (head). If `None`, defaults to the value in the config. iou_thres_between_tasks (`float`, *optional*): The IoU threshold for NMS applied to detections overlapping between different tasks. If `None`, defaults to the value in the config. return_dict (`bool`, *optional*): Whether to return a [`~utils.ModelOutput`] object instead of a plain tuple. Returns: `Union[Tuple, CerberusOutput]`: An instance of [`CerberusOutput`] containing detected boxes, scores, labels, and task IDs, or a tuple of these tensors if `return_dict=False`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Use defaults from self/config if not provided agnostic_nms = agnostic_nms if agnostic_nms is not None else self.config.agnostic_nms conf_thres = conf_thres if conf_thres is not None else self.config.conf_thres iou_thres = iou_thres if iou_thres is not None else self.config.iou_thres iou_thres_between_tasks = iou_thres_between_tasks if iou_thres_between_tasks is not None else self.config.iou_thres_between_tasks # 1. Forward pass # model() returns dict: {task_name: (predictions, ...)} all_out = self.model(pixel_values) batch_size = pixel_values.shape[0] # Dictionary to store "raw" NMS results for each task across the entire batch # task_name -> List[Tensor] (list length = batch_size) nms_results_per_task_batch: Dict[str, List[torch.Tensor]] = {} for task in all_out.keys(): task_pred, _ = all_out[task] # non_max_suppression returns a list of tensors of length batch_size preds_batch = non_max_suppression( task_pred, conf_thres, iou_thres, agnostic=agnostic_nms, max_det=max_det ) nms_results_per_task_batch[task] = preds_batch if not postprocess: if not return_dict: return (nms_results_per_task_batch, None, None, None, None) else: return CerberusOutput(nms_results_per_task_batch=nms_results_per_task_batch) # Lists to collect results per image for batch padding later batch_boxes = [] batch_scores = [] batch_labels = [] batch_tasks = [] # 2. Process each image in the batch individually for i in range(batch_size): # Collect detections from all tasks for the current image i # task_name -> Tensor (detections for img[i]) current_img_task_results = { task: preds_list[i].detach() for task, preds_list in nms_results_per_task_batch.items() } # 3. Combine & Cross-task NMS # det shape: [N, 6] -> (x1, y1, x2, y2, conf, global_cls) det = self._combine_output(current_img_task_results) det = nms_between_tasks(det, self.config.categories_inds_map, iou_thres=iou_thres_between_tasks) # 4. Scale coords to original image if len(det) > 0 and original_shapes is not None: if isinstance(original_shapes, torch.Tensor): curr_shape = original_shapes[i].tolist() else: curr_shape = original_shapes[i] # Scale: tensor.shape[2:] is (H_net, W_net) of the current tensor det[:, :4] = scale_boxes(pixel_values.shape[2:], det[:, :4], curr_shape).round() # 5. Collect tensors for this image if len(det) > 0: # Append tensors (move to CPU/Int as required by output spec) batch_boxes.append(det[:, :4].int()) batch_scores.append(det[:, 4]) batch_labels.append(det[:, 5].int()) # Map global class ID back to Task ID current_task_ids = [self.config.global_cls_to_task_id_map.get(int(c.item()), -1) for c in det[:, 5]] batch_tasks.append(torch.tensor(current_task_ids, dtype=torch.int32)) else: # Empty tensors if no detection batch_boxes.append(torch.zeros((0, 4), dtype=torch.int32)) batch_scores.append(torch.zeros((0,), dtype=torch.float32)) batch_labels.append(torch.zeros((0,), dtype=torch.int32)) batch_tasks.append(torch.zeros((0,), dtype=torch.int32)) if not return_dict: return (nms_results_per_task_batch, batch_boxes, batch_scores, batch_labels, batch_tasks) return CerberusOutput( nms_results_per_task_batch=nms_results_per_task_batch, boxes=batch_boxes, scores=batch_scores, labels=batch_labels, tasks_ids=batch_tasks )