RefCheck / src /space_service.py
voidful's picture
Add RefCheck Gradio Space
11a28db verified
Raw
History Blame
12.4 kB
"""
Non-interactive RefCheck workflow for Hugging Face Spaces.
"""
from __future__ import annotations
import tempfile
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Any
from concurrent.futures import ThreadPoolExecutor, as_completed
from main import (
apply_fix,
apply_local_fix,
get_default_workflow,
validate_entry,
)
from src.comparator import EntryReport, MetadataComparator
from src.fetcher import (
ArxivFetcher,
CrossRefFetcher,
DBLPFetcher,
OpenAlexFetcher,
ScholarFetcher,
SemanticScholarFetcher,
)
from src.local_db import LocalConferenceDB
from src.parser import BibEntry, BibParser
from src.sanitizer import BibSanitizer, SanitizeFix
@dataclass
class RefCheckOptions:
"""Options for a non-interactive RefCheck run."""
remove_unverified: bool = True
enable_google_scholar: bool = False
max_workers: int = 4
@dataclass
class RefCheckResult:
"""Artifacts and summary produced by a Space run."""
total_input: int = 0
total_output: int = 0
verified: int = 0
issues: int = 0
not_found: int = 0
fixed_details: dict[str, list[str]] = field(default_factory=dict)
removed_details: list[tuple[str, str, str]] = field(default_factory=list)
review_details: list[dict[str, Any]] = field(default_factory=list)
duplicate_details: dict[str, list[str]] = field(default_factory=dict)
sanitize_fixes: dict[str, list[SanitizeFix]] = field(default_factory=dict)
local_matches: int = 0
local_db_loaded: bool = False
fixed_bib_path: str = ""
report_path: str = ""
report_markdown: str = ""
def run_refcheck_file(file_path: str | Path, options: RefCheckOptions | None = None) -> RefCheckResult:
"""Validate and fix an uploaded BibTeX file without interactive prompts."""
options = options or RefCheckOptions()
source_path = Path(file_path)
parser = BibParser()
entries = parser.parse_file(str(source_path))
result = RefCheckResult(total_input=len(entries))
if not entries:
result.report_markdown = "## RefCheck Report\n\nNo BibTeX entries were found."
result.report_path = _write_report(result.report_markdown)
result.fixed_bib_path = _write_bib(parser, [], source_path.stem)
return result
sanitizer = BibSanitizer()
result.sanitize_fixes = sanitizer.sanitize_all(entries)
_record_sanitize_fixes(result.fixed_details, result.sanitize_fixes)
result.duplicate_details = sanitizer.find_duplicates(entries)
result.local_db_loaded, api_entries, result.local_matches = _apply_local_db(entries, result.fixed_details)
fetchers = _build_fetchers()
workflow = get_default_workflow()
for step in workflow.steps:
if step.name == "google_scholar":
step.enabled = options.enable_google_scholar
comparator = MetadataComparator()
analysis = _analyze_entries(api_entries, workflow, fetchers, comparator, options.max_workers)
actions: dict[str, tuple[str, Any, list[Any]]] = {}
for entry, best_result, candidates in analysis:
if not best_result:
actions[entry.key] = ("keep", None, [])
elif getattr(entry, "_force_api_lookup", False) and best_result.fetched_data:
actions[entry.key] = ("fix", best_result, candidates)
elif best_result.confidence > 0.85 and best_result.fetched_data:
actions[entry.key] = ("fix", best_result, candidates)
elif best_result.is_match:
actions[entry.key] = ("keep", best_result, candidates)
elif candidates:
actions[entry.key] = ("review", best_result, candidates)
else:
actions[entry.key] = ("remove", best_result, candidates)
updated_entries: list[BibEntry] = []
for entry in entries:
action, best_result, candidates = actions.get(entry.key, ("keep", None, []))
if action == "fix":
changes = apply_fix(entry, best_result.fetched_data, all_candidates=candidates)
if changes:
result.fixed_details.setdefault(entry.key, []).extend(changes)
updated_entries.append(entry)
elif action == "review":
result.review_details.append(_review_payload(entry, best_result, candidates))
updated_entries.append(entry)
elif action == "remove":
if options.remove_unverified:
result.removed_details.append((entry.key, entry.title, "No matching metadata found in any source"))
else:
result.review_details.append(
{
"key": entry.key,
"title": entry.title,
"reason": "No matching metadata found in any source",
"candidates": [],
}
)
updated_entries.append(entry)
else:
updated_entries.append(entry)
result.total_output = len(updated_entries)
fixed_path = _write_bib(parser, updated_entries, source_path.stem)
result.fixed_bib_path = fixed_path
verified_entries = parser.parse_file(fixed_path)
verification_reports = _verify_entries(
verified_entries,
workflow,
fetchers,
comparator,
options.max_workers,
)
result.verified = sum(1 for r in verification_reports if r.comparison and r.comparison.is_match)
result.issues = sum(1 for r in verification_reports if r.comparison and r.comparison.has_issues)
result.not_found = sum(
1
for r in verification_reports
if r.comparison and not r.comparison.is_match and not r.comparison.has_issues
)
result.report_markdown = _build_report(result, verification_reports)
result.report_path = _write_report(result.report_markdown)
return result
def _build_fetchers() -> dict[str, Any]:
return {
"arxiv": ArxivFetcher(),
"crossref": CrossRefFetcher(),
"scholar": ScholarFetcher(),
"semantic": SemanticScholarFetcher(),
"openalex": OpenAlexFetcher(),
"dblp": DBLPFetcher(),
}
def _analyze_entries(
entries: list[BibEntry],
workflow: Any,
fetchers: dict[str, Any],
comparator: MetadataComparator,
max_workers: int,
) -> list[tuple[BibEntry, Any, list[Any]]]:
if not entries:
return []
analysis: list[tuple[BibEntry, Any, list[Any]]] = []
worker_count = min(max(1, max_workers), len(entries))
with ThreadPoolExecutor(max_workers=worker_count) as executor:
futures = {
executor.submit(validate_entry, entry, workflow, fetchers, comparator): entry
for entry in entries
}
for future in as_completed(futures):
entry = futures[future]
try:
best_result, candidates = future.result()
except Exception:
best_result, candidates = None, []
analysis.append((entry, best_result, candidates))
return analysis
def _verify_entries(
entries: list[BibEntry],
workflow: Any,
fetchers: dict[str, Any],
comparator: MetadataComparator,
max_workers: int,
) -> list[EntryReport]:
reports: list[EntryReport] = []
for entry, best_result, _ in _analyze_entries(entries, workflow, fetchers, comparator, max_workers):
reports.append(EntryReport(entry=entry, comparison=best_result))
return reports
def _record_sanitize_fixes(
fixed_details: dict[str, list[str]],
sanitize_fixes: dict[str, list[SanitizeFix]],
) -> None:
for key, fixes in sanitize_fixes.items():
fixed_details.setdefault(key, [])
fixed_details[key].extend(fix.description for fix in fixes)
def _apply_local_db(
entries: list[BibEntry],
fixed_details: dict[str, list[str]],
) -> tuple[bool, list[BibEntry], int]:
local_db = _load_local_db()
if not local_db.is_loaded:
return False, entries, 0
api_entries = []
match_count = 0
for entry in entries:
official = local_db.lookup(entry.title)
if not official:
api_entries.append(entry)
continue
changes = apply_local_fix(entry, official)
match_count += 1
if changes:
fixed_details.setdefault(entry.key, []).extend(changes)
return True, api_entries, match_count
@lru_cache(maxsize=1)
def _load_local_db() -> LocalConferenceDB:
local_db = LocalConferenceDB()
local_db.load()
return local_db
def _review_payload(entry: BibEntry, best_result: Any, candidates: list[Any]) -> dict[str, Any]:
return {
"key": entry.key,
"title": entry.title,
"reason": "; ".join(best_result.issues) if best_result and best_result.issues else "Ambiguous match",
"candidates": [
{
"source": candidate.source,
"confidence": candidate.confidence,
"title": getattr(candidate.fetched_data, "title", ""),
"year": getattr(candidate.fetched_data, "year", ""),
"doi": getattr(candidate.fetched_data, "doi", ""),
}
for candidate in candidates[:5]
],
}
def _write_bib(parser: BibParser, entries: list[BibEntry], original_stem: str) -> str:
out_dir = Path(tempfile.mkdtemp(prefix="refcheck_"))
out_path = out_dir / f"{original_stem or 'references'}_refcheck_fixed.bib"
parser.save_entries(str(out_path), entries)
return str(out_path)
def _write_report(markdown: str) -> str:
out_dir = Path(tempfile.mkdtemp(prefix="refcheck_report_"))
out_path = out_dir / "refcheck_report.md"
out_path.write_text(markdown, encoding="utf-8")
return str(out_path)
def _build_report(result: RefCheckResult, reports: list[EntryReport]) -> str:
lines = [
"## RefCheck Report",
"",
"### Summary",
"",
f"- Input entries: {result.total_input}",
f"- Output entries: {result.total_output}",
f"- Verified after fix: {result.verified}",
f"- Remaining issues: {result.issues}",
f"- Not found after fix: {result.not_found}",
f"- Local DB loaded: {'yes' if result.local_db_loaded else 'no'}",
f"- Local DB matches: {result.local_matches}",
"",
]
if result.removed_details:
lines.extend(["### Removed", ""])
for key, title, reason in result.removed_details:
lines.append(f"- `{key}`: {title} ({reason})")
lines.append("")
if result.fixed_details:
lines.extend(["### Fixed", ""])
for key, changes in sorted(result.fixed_details.items()):
lines.append(f"- `{key}`")
for change in changes:
lines.append(f" - {change}")
lines.append("")
if result.duplicate_details:
lines.extend(["### Duplicate Titles", ""])
for title, keys in result.duplicate_details.items():
lines.append(f"- `{', '.join(keys)}`: {title}")
lines.append("")
if result.review_details:
lines.extend(["### Needs Review", ""])
for item in result.review_details:
lines.append(f"- `{item['key']}`: {item['title']}")
lines.append(f" - Reason: {item['reason']}")
for candidate in item["candidates"]:
lines.append(
" - Candidate: "
f"{candidate['source']} "
f"(confidence {candidate['confidence']:.2f}) "
f"{candidate['title']} "
f"{candidate['year']} "
f"{candidate['doi']}".strip()
)
lines.append("")
remaining = [
report
for report in reports
if report.comparison and not report.comparison.is_match
]
if remaining:
lines.extend(["### Verification Issues", ""])
for report in remaining:
comparison = report.comparison
issues = "; ".join(comparison.issues) if comparison.issues else "Not matched"
lines.append(
f"- `{report.entry.key}` via {comparison.source} "
f"(confidence {comparison.confidence:.2f}): {issues}"
)
lines.append("")
return "\n".join(lines).strip() + "\n"