Spaces:
Running
Running
| import json | |
| import os | |
| from typing import Any, Dict | |
| import pandas as pd | |
| from huggingface_hub import HfApi, hf_hub_download, metadata_load | |
| from .dataset_handler import DEPRECATED_VIDORE_2_DATASETS_KEYWORDS, DEPRECATED_VIDORE_DATASETS_KEYWORDS, deprecated_get_datasets_nickname | |
| BLOCKLIST = ["impactframes"] | |
| class DeprecatedModelHandler: | |
| def __init__(self, model_infos_path="model_infos.json"): | |
| self.api = HfApi() | |
| self.model_infos_path = model_infos_path | |
| self.model_infos = self._load_model_infos() | |
| def _load_model_infos(self) -> Dict: | |
| if os.path.exists(self.model_infos_path): | |
| with open(self.model_infos_path) as f: | |
| return json.load(f) | |
| return {} | |
| def _save_model_infos(self): | |
| with open(self.model_infos_path, "w") as f: | |
| json.dump(self.model_infos, f) | |
| def _are_results_in_new_vidore_format(self, results: Dict[str, Any]) -> bool: | |
| return "metadata" in results and "metrics" in results | |
| def _is_baseline_repo(self, repo_id: str) -> bool: | |
| return repo_id == "vidore/baseline-results" | |
| def sanitize_model_name(self, model_name): | |
| return model_name.replace("/", "_").replace(".", "-thisisapoint-") | |
| def fuze_model_infos(self, model_name, results): | |
| for dataset, metrics in results.items(): | |
| if dataset not in self.model_infos[model_name]["results"].keys(): | |
| self.model_infos[model_name]["results"][dataset] = metrics | |
| else: | |
| continue | |
| def get_vidore_data(self, metric="ndcg_at_5"): | |
| models = self.api.list_models(filter="vidore") | |
| repositories = [model.modelId for model in models] # type: ignore | |
| # Sort repositories to process non-baseline repos first (to prioritize their results) | |
| repositories.sort(key=lambda x: self._is_baseline_repo(x)) | |
| for repo_id in repositories: | |
| org_name = repo_id.split("/")[0] | |
| if org_name in BLOCKLIST: | |
| continue | |
| files = [f for f in self.api.list_repo_files(repo_id) if f.endswith("_metrics.json") or f == "results.json"] | |
| if len(files) == 0: | |
| continue | |
| else: | |
| for file in files: | |
| if file.endswith("results.json"): | |
| model_name = repo_id.replace("/", "_").replace(".", "-thisisapoint-") | |
| else: | |
| model_name = file.split("_metrics.json")[0] | |
| model_name = model_name.replace("/", "_").replace(".", "-thisisapoint-") | |
| # Skip if the model is from baseline and we already have results | |
| readme_path = hf_hub_download(repo_id, filename="README.md") | |
| meta = metadata_load(readme_path) | |
| try: | |
| result_path = hf_hub_download(repo_id, filename=file) | |
| with open(result_path) as f: | |
| results = json.load(f) | |
| if self._are_results_in_new_vidore_format(results): | |
| metadata = results["metadata"] | |
| results = results["metrics"] | |
| # Handles the case where the model is both in baseline and outside of it | |
| # (prioritizes the non-baseline results) | |
| if self._is_baseline_repo(repo_id) and self.sanitize_model_name(model_name) in self.model_infos: | |
| self.fuze_model_infos(model_name, results) | |
| self.model_infos[model_name] = {"meta": meta, "results": results} | |
| except Exception as e: | |
| print(f"Error loading {model_name} - {e}") | |
| continue | |
| # In order to keep only models relevant to a benchmark | |
| def filter_models_by_benchmark(self, benchmark_version=1): | |
| filtered_model_infos = {} | |
| keywords = DEPRECATED_VIDORE_DATASETS_KEYWORDS if benchmark_version == 1 else DEPRECATED_VIDORE_2_DATASETS_KEYWORDS | |
| for model, info in self.model_infos.items(): | |
| results = info["results"] | |
| if any(any(keyword in dataset for keyword in keywords) for dataset in results.keys()): | |
| filtered_model_infos[model] = info | |
| return filtered_model_infos | |
| # Compute the average of a metric for each model, | |
| def render_df(self, metric="ndcg_at_5", benchmark_version=1): | |
| model_res = {} | |
| filtered_model_infos = self.filter_models_by_benchmark(benchmark_version) | |
| if len(filtered_model_infos) > 0: | |
| for model in filtered_model_infos.keys(): | |
| res = filtered_model_infos[model]["results"] | |
| dataset_res = {} | |
| keywords = DEPRECATED_VIDORE_DATASETS_KEYWORDS if benchmark_version == 1 else DEPRECATED_VIDORE_2_DATASETS_KEYWORDS | |
| for dataset in res.keys(): | |
| if not any(keyword in dataset for keyword in keywords): | |
| continue | |
| dataset_nickname = deprecated_get_datasets_nickname(dataset) | |
| dataset_res[dataset_nickname] = res[dataset][metric] | |
| model_res[model] = dataset_res | |
| df = pd.DataFrame(model_res).T | |
| return df | |
| return pd.DataFrame() | |