| from transformers.modeling_outputs import ModelOutput |
| from transformers import PreTrainedModel |
| from transformers.utils import logging |
|
|
| from typing import Optional, Tuple, Union, List, Dict |
| from dataclasses import dataclass |
|
|
| import torch.nn as nn |
| import torch |
| from copy import deepcopy |
|
|
| from cerberusdet.models.cerberus import CerberusDet |
| from cerberusdet.utils.general import ( |
| nms_between_tasks, |
| non_max_suppression, |
| scale_boxes, |
| ) |
|
|
| from .configuration_cerberus import CerberusDetConfig |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class CerberusOutput(ModelOutput): |
| nms_results_per_task_batch: Dict[str, List[torch.Tensor]] = None |
| boxes: List[torch.IntTensor] = None |
| scores: List[torch.FloatTensor] = None |
| labels: List[torch.IntTensor] = None |
| tasks_ids: List[torch.IntTensor] = None |
|
|
|
|
| class CerberusDetPreTrainedModel(PreTrainedModel): |
| config_class = CerberusDetConfig |
| base_model_prefix = 'model' |
| _no_split_modules = ['model'] |
|
|
| def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: |
| """Initializes the weights of the model layers.""" |
|
|
| if module is nn.Conv2d: |
| pass |
| elif module is nn.BatchNorm2d: |
| module.eps = 1e-3 |
| module.momentum = 0.03 |
| elif module in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: |
| module.inplace = True |
|
|
|
|
| class CerberusDetForObjectDetection(CerberusDetPreTrainedModel): |
| def __init__(self, config: CerberusDetConfig): |
| """ |
| Initializes the CerberusDet object detection model based on the provided configuration. |
| |
| Args: |
| config (CerberusDetConfig): The configuration object containing model parameters. |
| """ |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.model = CerberusDet( |
| task_ids=config.task_ids, |
| nc=config.tasks_nc, |
| cfg=config.cfg, |
| ch=3, |
| verbose=config.verbose, |
| stride=config.strides, |
| backbone_verbose=False, |
| ) |
| self.model.names = config.names |
| self.cerber_schedule = config.cfg["cerber"] |
| self.model.sequential_split(deepcopy(self.cerber_schedule), device=None) |
|
|
| self.half = config.torch_dtype is torch.float16 |
|
|
| |
| self.post_init() |
|
|
| def _combine_output(self, output_per_task: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Combines task results for a SINGLE image. |
| """ |
| output = torch.zeros((0, 6)) |
| for task, bboxes in output_per_task.items(): |
| bboxes = bboxes.cpu() |
| |
| if bboxes.shape[0] > 0: |
| bboxes[:, 5].apply_(lambda cat: self.config.categories_inds_map[task][int(cat)]) |
| output = torch.cat((output, bboxes), 0) |
| return output |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| original_shapes: Optional[Union[List[Tuple[int, int]], torch.Tensor]] = None, |
| max_det: int = 300, |
| agnostic_nms: bool = False, |
| conf_thres: float = None, |
| iou_thres: float = None, |
| iou_thres_between_tasks: float = None, |
| return_dict: Optional[bool] = None, |
| postprocess: bool = True, |
| ) -> Union[Tuple, CerberusOutput]: |
| """ |
| Performs the forward pass of the model and the full post-processing pipeline, including |
| multi-task inference, Non-Maximum Suppression (NMS), and coordinate rescaling. |
| |
| Args: |
| pixel_values (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): |
| Pixel values. |
| original_shapes (`List[Tuple[int, int]]` or `torch.Tensor`, *optional*): |
| The original (height, width) of each image in the batch. If provided, bounding box |
| coordinates will be rescaled from the model input size back to these original dimensions. |
| max_det (`int`, *optional*, defaults to 300): |
| The maximum number of detections for each task to return per image. |
| agnostic_nms (`bool`, *optional*, defaults to `False`): |
| If `True`, performs class-agnostic NMS (overlapping boxes of different classes can be suppressed). |
| conf_thres (`float`, *optional*): |
| The confidence threshold for filtering detections. If `None`, defaults to the value in the config. |
| iou_thres (`float`, *optional*): |
| The IoU threshold for NMS within a single task (head). If `None`, defaults to the value in the config. |
| iou_thres_between_tasks (`float`, *optional*): |
| The IoU threshold for NMS applied to detections overlapping between different tasks. |
| If `None`, defaults to the value in the config. |
| return_dict (`bool`, *optional*): |
| Whether to return a [`~utils.ModelOutput`] object instead of a plain tuple. |
| |
| Returns: |
| `Union[Tuple, CerberusOutput]`: |
| An instance of [`CerberusOutput`] containing detected boxes, scores, labels, and task IDs, |
| or a tuple of these tensors if `return_dict=False`. |
| """ |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| agnostic_nms = agnostic_nms if agnostic_nms is not None else self.config.agnostic_nms |
| conf_thres = conf_thres if conf_thres is not None else self.config.conf_thres |
| iou_thres = iou_thres if iou_thres is not None else self.config.iou_thres |
| iou_thres_between_tasks = iou_thres_between_tasks if iou_thres_between_tasks is not None else self.config.iou_thres_between_tasks |
|
|
| |
| |
| all_out = self.model(pixel_values) |
|
|
| batch_size = pixel_values.shape[0] |
|
|
| |
| |
| nms_results_per_task_batch: Dict[str, List[torch.Tensor]] = {} |
|
|
| for task in all_out.keys(): |
| task_pred, _ = all_out[task] |
| |
| preds_batch = non_max_suppression( |
| task_pred, |
| conf_thres, |
| iou_thres, |
| agnostic=agnostic_nms, |
| max_det=max_det |
| ) |
| nms_results_per_task_batch[task] = preds_batch |
|
|
| if not postprocess: |
| if not return_dict: |
| return (nms_results_per_task_batch, None, None, None, None) |
| else: |
| return CerberusOutput(nms_results_per_task_batch=nms_results_per_task_batch) |
|
|
| |
| batch_boxes = [] |
| batch_scores = [] |
| batch_labels = [] |
| batch_tasks = [] |
|
|
| |
| for i in range(batch_size): |
| |
| |
| current_img_task_results = { |
| task: preds_list[i].detach() |
| for task, preds_list in nms_results_per_task_batch.items() |
| } |
|
|
| |
| |
| det = self._combine_output(current_img_task_results) |
| det = nms_between_tasks(det, self.config.categories_inds_map, iou_thres=iou_thres_between_tasks) |
|
|
| |
| if len(det) > 0 and original_shapes is not None: |
| if isinstance(original_shapes, torch.Tensor): |
| curr_shape = original_shapes[i].tolist() |
| else: |
| curr_shape = original_shapes[i] |
|
|
| |
| det[:, :4] = scale_boxes(pixel_values.shape[2:], det[:, :4], curr_shape).round() |
|
|
| |
| if len(det) > 0: |
| |
| batch_boxes.append(det[:, :4].int()) |
| batch_scores.append(det[:, 4]) |
| batch_labels.append(det[:, 5].int()) |
|
|
| |
| current_task_ids = [self.config.global_cls_to_task_id_map.get(int(c.item()), -1) for c in det[:, 5]] |
| batch_tasks.append(torch.tensor(current_task_ids, dtype=torch.int32)) |
| else: |
| |
| batch_boxes.append(torch.zeros((0, 4), dtype=torch.int32)) |
| batch_scores.append(torch.zeros((0,), dtype=torch.float32)) |
| batch_labels.append(torch.zeros((0,), dtype=torch.int32)) |
| batch_tasks.append(torch.zeros((0,), dtype=torch.int32)) |
|
|
| if not return_dict: |
| return (nms_results_per_task_batch, batch_boxes, batch_scores, batch_labels, batch_tasks) |
|
|
| return CerberusOutput( |
| nms_results_per_task_batch=nms_results_per_task_batch, |
| boxes=batch_boxes, |
| scores=batch_scores, |
| labels=batch_labels, |
| tasks_ids=batch_tasks |
| ) |
|
|