cerberusdet-yolov8x-voc-o365-full / configuration_cerberus.py
iitolstykh's picture
Upload 8 files
56b6e17 verified
Raw
History Blame Contribute Delete
3.01 kB
"""A HuggingFace-style model configuration."""
from typing import Any, Dict, List
from transformers import PretrainedConfig
class CerberusDetConfig(PretrainedConfig):
model_type = 'cerberus_v8x'
def __init__(
self,
tasks_names: Dict[str, List[str]] = None,
tasks_nc: List[int] = None,
cfg: dict = None,
cerber_schedule = None,
stride: List[int] = None,
config_name: str = None,
agnostic_nms: bool=False,
conf_thres: float = 0.3,
iou_thres: float = 0.45,
iou_thres_between_tasks: float = 0.8,
**kwargs: Any
):
super().__init__(**kwargs)
# model configuration
self.tasks_nc = tasks_nc
self.tasks_names = tasks_names
self.cfg = cfg
self.cerber_schedule = cerber_schedule
self.config_name = config_name
self.strides = stride
self.stride = max(self.strides) if self.strides is not None else None
# model inference
self.agnostic_nms = agnostic_nms
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.iou_thres_between_tasks = iou_thres_between_tasks
# additional
if tasks_names:
self.task_ids = [k for k, _ in tasks_names.items()]
self.names: Dict[str, List[str]] = self.tasks_names
# 1. Create category mapping (Task Name -> {Local ID -> Global ID}
self.categories_inds_map, self.all_class_names = self._get_categories_map(self.names)
# 2. Create Global Class ID -> Task Index map
self.global_cls_to_task_id_map = {}
for task_idx, task_name in enumerate(self.task_ids):
if task_name in self.categories_inds_map:
mapping = self.categories_inds_map[task_name]
# global_class_id -> task_idx
for global_id in mapping.values():
self.global_cls_to_task_id_map[int(global_id)] = task_idx
else:
self.global_cls_to_task_id_map = {}
self.names = {}
self.task_ids = []
self.all_class_names = []
self.categories_inds_map = {}
def _get_categories_map(self, class_names: Dict[str, List[str]]):
categories_inds_map: Dict[str, Dict[int, int]] = {}
all_class_names: List[str] = []
tmp_categories_ids: List[List[int]] = []
for task_name, task_categories in class_names.items():
last_ind = tmp_categories_ids[-1][-1] + 1 if len(tmp_categories_ids) != 0 else 0
cur_categories_ids = list(range(len(task_categories)))
tmp_categories_ids.append([ind + last_ind for ind in cur_categories_ids])
categories_inds_map[task_name] = {
prev_id: new_id for prev_id, new_id in zip(cur_categories_ids, tmp_categories_ids[-1])
}
all_class_names.extend(task_categories)
return categories_inds_map, all_class_names