| """A HuggingFace-style model configuration.""" |
| from typing import Any, Dict, List |
| from transformers import PretrainedConfig |
|
|
|
|
| class CerberusDetConfig(PretrainedConfig): |
| model_type = 'cerberus_v8x' |
|
|
| def __init__( |
| self, |
| tasks_names: Dict[str, List[str]] = None, |
| tasks_nc: List[int] = None, |
| cfg: dict = None, |
| cerber_schedule = None, |
| stride: List[int] = None, |
| config_name: str = None, |
| agnostic_nms: bool=False, |
| conf_thres: float = 0.3, |
| iou_thres: float = 0.45, |
| iou_thres_between_tasks: float = 0.8, |
| **kwargs: Any |
| ): |
| super().__init__(**kwargs) |
|
|
| |
| self.tasks_nc = tasks_nc |
| self.tasks_names = tasks_names |
| self.cfg = cfg |
| self.cerber_schedule = cerber_schedule |
| self.config_name = config_name |
| self.strides = stride |
| self.stride = max(self.strides) if self.strides is not None else None |
|
|
| |
| self.agnostic_nms = agnostic_nms |
| self.conf_thres = conf_thres |
| self.iou_thres = iou_thres |
| self.iou_thres_between_tasks = iou_thres_between_tasks |
|
|
| |
| if tasks_names: |
| self.task_ids = [k for k, _ in tasks_names.items()] |
| self.names: Dict[str, List[str]] = self.tasks_names |
|
|
| |
| self.categories_inds_map, self.all_class_names = self._get_categories_map(self.names) |
|
|
| |
| self.global_cls_to_task_id_map = {} |
|
|
| for task_idx, task_name in enumerate(self.task_ids): |
| if task_name in self.categories_inds_map: |
| mapping = self.categories_inds_map[task_name] |
| |
| for global_id in mapping.values(): |
| self.global_cls_to_task_id_map[int(global_id)] = task_idx |
| else: |
| self.global_cls_to_task_id_map = {} |
| self.names = {} |
| self.task_ids = [] |
| self.all_class_names = [] |
| self.categories_inds_map = {} |
|
|
| def _get_categories_map(self, class_names: Dict[str, List[str]]): |
| categories_inds_map: Dict[str, Dict[int, int]] = {} |
| all_class_names: List[str] = [] |
| tmp_categories_ids: List[List[int]] = [] |
|
|
| for task_name, task_categories in class_names.items(): |
| last_ind = tmp_categories_ids[-1][-1] + 1 if len(tmp_categories_ids) != 0 else 0 |
| cur_categories_ids = list(range(len(task_categories))) |
| tmp_categories_ids.append([ind + last_ind for ind in cur_categories_ids]) |
| categories_inds_map[task_name] = { |
| prev_id: new_id for prev_id, new_id in zip(cur_categories_ids, tmp_categories_ids[-1]) |
| } |
| all_class_names.extend(task_categories) |
|
|
| return categories_inds_map, all_class_names |
|
|