video-to-colmap-for-tttlrm / video_to_colmap.py
notaneimu's picture
Auto-infer keyframes and improve preview
f870955
from __future__ import annotations
import json
import math
import re
import shutil
import struct
import subprocess
import time
import uuid
import zipfile
from dataclasses import dataclass
from fractions import Fraction
from pathlib import Path
from typing import Final
import cv2
import numpy as np
APP_DIR: Final[Path] = Path(__file__).resolve().parent
WORK_DIR: Final[Path] = APP_DIR / "work"
OUTPUTS_DIR: Final[Path] = APP_DIR / "outputs"
THUMB_SIZE: Final[tuple[int, int]] = (96, 96)
JPEG_QUALITY: Final[int] = 95
FONT = cv2.FONT_HERSHEY_SIMPLEX
WORK_DIR.mkdir(parents=True, exist_ok=True)
OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
@dataclass(frozen=True)
class ProfileConfig:
candidate_multiplier: int
cut_threshold: float
min_blur_percentile: float
sequential_overlap: int
min_segment_frames: int
PROFILES: Final[dict[str, ProfileConfig]] = {
"balanced": ProfileConfig(
candidate_multiplier=6,
cut_threshold=0.42,
min_blur_percentile=35.0,
sequential_overlap=8,
min_segment_frames=14,
),
"dense": ProfileConfig(
candidate_multiplier=8,
cut_threshold=0.38,
min_blur_percentile=30.0,
sequential_overlap=12,
min_segment_frames=18,
),
"sparse": ProfileConfig(
candidate_multiplier=5,
cut_threshold=0.48,
min_blur_percentile=40.0,
sequential_overlap=6,
min_segment_frames=12,
),
}
AUTO_TARGET_FRAME_OPTIONS: Final[tuple[int, ...]] = (16, 24, 32, 48)
@dataclass(frozen=True)
class VideoMetadata:
fps: float
frame_count: int
duration_seconds: float
width: int
height: int
@dataclass(frozen=True)
class FrameCandidate:
candidate_index: int
frame_index: int
timestamp_seconds: float
path: Path
blur_score: float
motion_score: float
cut_score: float
thumb: np.ndarray
@dataclass(frozen=True)
class ConversionOutputs:
archive_path: Path
report_path: Path
contact_sheet_path: Path
scene_name: str
selected_frames: int
registered_frames: int
duration_seconds: float
quality_label: str
def infer_target_frames(metadata: VideoMetadata) -> int:
duration_seconds = metadata.duration_seconds
if duration_seconds <= 6.0:
return AUTO_TARGET_FRAME_OPTIONS[0]
if duration_seconds <= 12.0:
return AUTO_TARGET_FRAME_OPTIONS[1]
if duration_seconds <= 20.0:
return AUTO_TARGET_FRAME_OPTIONS[2]
return AUTO_TARGET_FRAME_OPTIONS[3]
def _now_ms() -> int:
return int(time.time() * 1000)
def _ensure_dir(path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path
def _unique_dir(parent: Path, prefix: str) -> Path:
path = parent / f"{prefix}-{_now_ms()}-{uuid.uuid4().hex[:8]}"
path.mkdir(parents=True, exist_ok=True)
return path
def _slugify(value: str) -> str:
slug = re.sub(r"[^a-zA-Z0-9]+", "-", value).strip("-").lower()
return slug or "scene"
def _run(cmd: list[str], cwd: Path | None = None) -> None:
result = subprocess.run(
cmd,
cwd=str(cwd) if cwd else None,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
check=False,
)
if result.returncode != 0:
raise RuntimeError(
f"Command failed ({result.returncode}): {' '.join(cmd)}\n{result.stdout.strip()}"
)
def _require_binary(binary_name: str) -> None:
if shutil.which(binary_name) is None:
raise RuntimeError(f"Required executable not found: {binary_name}")
def _read_video_metadata_ffprobe(video_path: Path) -> VideoMetadata | None:
if shutil.which("ffprobe") is None:
return None
result = subprocess.run(
[
"ffprobe",
"-v",
"error",
"-print_format",
"json",
"-show_streams",
"-show_format",
str(video_path),
],
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False,
)
if result.returncode != 0 or not result.stdout.strip():
return None
try:
payload = json.loads(result.stdout)
except json.JSONDecodeError:
return None
video_stream = next(
(stream for stream in payload.get("streams", []) if stream.get("codec_type") == "video"),
None,
)
if not video_stream:
return None
width = int(video_stream.get("width") or 0)
height = int(video_stream.get("height") or 0)
fps_value = video_stream.get("avg_frame_rate") or video_stream.get("r_frame_rate") or "0/1"
try:
fps = float(Fraction(fps_value))
except (ValueError, ZeroDivisionError):
fps = 0.0
duration_value = video_stream.get("duration") or payload.get("format", {}).get("duration") or 0.0
try:
duration_seconds = float(duration_value)
except (TypeError, ValueError):
duration_seconds = 0.0
frame_count_value = video_stream.get("nb_frames")
try:
frame_count = int(frame_count_value) if frame_count_value is not None else 0
except (TypeError, ValueError):
frame_count = 0
if frame_count <= 0 and fps > 0 and duration_seconds > 0:
frame_count = max(1, int(round(fps * duration_seconds)))
if fps <= 0 and frame_count > 0 and duration_seconds > 0:
fps = frame_count / duration_seconds
if width <= 0 or height <= 0 or duration_seconds <= 0:
return None
if fps <= 0:
fps = 24.0
return VideoMetadata(
fps=fps,
frame_count=frame_count,
duration_seconds=duration_seconds,
width=width,
height=height,
)
def normalize_video_input(video_path: Path, work_dir: Path) -> Path:
_require_binary("ffmpeg")
normalized_path = work_dir / "normalized.mp4"
_run(
[
"ffmpeg",
"-y",
"-i",
str(video_path),
"-an",
"-movflags",
"+faststart",
"-pix_fmt",
"yuv420p",
"-c:v",
"libx264",
str(normalized_path),
],
cwd=work_dir,
)
return normalized_path
def read_video_metadata(video_path: Path) -> VideoMetadata:
ffprobe_metadata = _read_video_metadata_ffprobe(video_path)
if ffprobe_metadata is not None:
return ffprobe_metadata
capture = cv2.VideoCapture(str(video_path))
if not capture.isOpened():
raise RuntimeError(f"Failed to open video: {video_path}")
fps = float(capture.get(cv2.CAP_PROP_FPS) or 0.0)
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
capture.release()
if frame_count <= 0 or width <= 0 or height <= 0:
raise RuntimeError("Video metadata could not be read from the uploaded file.")
if fps <= 0:
fps = 24.0
return VideoMetadata(
fps=fps,
frame_count=frame_count,
duration_seconds=frame_count / fps,
width=width,
height=height,
)
def _resize_max_edge(frame: np.ndarray, max_edge: int) -> np.ndarray:
height, width = frame.shape[:2]
current_max = max(height, width)
if current_max <= max_edge:
return frame
scale = max_edge / current_max
new_size = (max(2, int(round(width * scale))), max(2, int(round(height * scale))))
return cv2.resize(frame, new_size, interpolation=cv2.INTER_AREA)
def _compute_histogram(frame: np.ndarray) -> np.ndarray:
hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
hist = cv2.calcHist([hsv], [0, 1], None, [16, 16], [0, 180, 0, 256])
cv2.normalize(hist, hist)
return hist
def _compute_thumb(gray_frame: np.ndarray) -> np.ndarray:
thumb = cv2.resize(gray_frame, THUMB_SIZE, interpolation=cv2.INTER_AREA)
return thumb.astype(np.float32) / 255.0
def extract_candidates(
video_path: Path,
metadata: VideoMetadata,
candidates_dir: Path,
target_frames: int,
max_image_edge: int,
profile: ProfileConfig,
) -> list[FrameCandidate]:
desired_candidates = min(max(target_frames * profile.candidate_multiplier, target_frames + 8), 240)
stride = max(1, metadata.frame_count // desired_candidates)
capture = cv2.VideoCapture(str(video_path))
if not capture.isOpened():
raise RuntimeError(f"Failed to open video for frame extraction: {video_path}")
candidates: list[FrameCandidate] = []
frame_index = 0
candidate_index = 0
previous_hist: np.ndarray | None = None
previous_thumb: np.ndarray | None = None
while True:
ok, frame = capture.read()
if not ok:
break
if frame_index % stride != 0:
frame_index += 1
continue
frame = _resize_max_edge(frame, max_image_edge)
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
blur_score = float(cv2.Laplacian(gray, cv2.CV_32F).var())
thumb = _compute_thumb(gray)
hist = _compute_histogram(frame)
motion_score = float(np.mean(np.abs(thumb - previous_thumb))) if previous_thumb is not None else 0.0
cut_score = (
float(cv2.compareHist(previous_hist, hist, cv2.HISTCMP_BHATTACHARYYA))
if previous_hist is not None
else 0.0
)
output_path = candidates_dir / f"candidate_{candidate_index:04d}.jpg"
cv2.imwrite(str(output_path), frame, [int(cv2.IMWRITE_JPEG_QUALITY), JPEG_QUALITY])
candidates.append(
FrameCandidate(
candidate_index=candidate_index,
frame_index=frame_index,
timestamp_seconds=frame_index / metadata.fps,
path=output_path,
blur_score=blur_score,
motion_score=motion_score,
cut_score=cut_score,
thumb=thumb,
)
)
previous_hist = hist
previous_thumb = thumb
candidate_index += 1
frame_index += 1
capture.release()
if len(candidates) < max(8, target_frames // 2):
raise RuntimeError(
f"Video yielded only {len(candidates)} usable candidates; upload a longer or slower video."
)
return candidates
def segment_candidates(candidates: list[FrameCandidate], profile: ProfileConfig) -> list[list[FrameCandidate]]:
if not candidates:
return []
segments: list[list[FrameCandidate]] = []
start = 0
for index in range(1, len(candidates)):
if candidates[index].cut_score >= profile.cut_threshold:
segments.append(candidates[start:index])
start = index
segments.append(candidates[start:])
return [segment for segment in segments if segment]
def choose_best_segment(
segments: list[list[FrameCandidate]],
target_frames: int,
profile: ProfileConfig,
) -> list[FrameCandidate]:
if not segments:
raise RuntimeError("No coherent video segment was found for reconstruction.")
scored_segments: list[tuple[float, list[FrameCandidate]]] = []
for segment in segments:
duration = segment[-1].timestamp_seconds - segment[0].timestamp_seconds if len(segment) > 1 else 0.0
median_blur = float(np.median([candidate.blur_score for candidate in segment]))
coverage_bonus = min(len(segment) / max(target_frames, 1), 1.5)
segment_penalty = 0.0 if len(segment) >= profile.min_segment_frames else 0.6
score = (duration + len(segment) * 0.12) * coverage_bonus * math.log1p(max(median_blur, 1.0)) - segment_penalty
scored_segments.append((score, segment))
scored_segments.sort(key=lambda item: item[0], reverse=True)
return scored_segments[0][1]
def select_keyframes(
segment: list[FrameCandidate],
target_frames: int,
profile: ProfileConfig,
) -> list[FrameCandidate]:
if len(segment) <= target_frames:
return segment
blur_scores = np.array([candidate.blur_score for candidate in segment], dtype=np.float32)
blur_threshold = float(np.percentile(blur_scores, profile.min_blur_percentile))
normalized_blur = blur_scores / max(float(blur_scores.max()), 1e-6)
motion = np.array([0.0] + [max(candidate.motion_score, 1e-6) for candidate in segment[1:]], dtype=np.float32)
cumulative_motion = np.cumsum(motion)
selected_indices: list[int] = []
neighborhood = max(2, len(segment) // max(target_frames * 2, 1))
if float(cumulative_motion[-1]) <= 1e-5:
marks = np.linspace(0, len(segment) - 1, target_frames)
mark_distances = np.arange(len(segment), dtype=np.float32)
else:
marks = np.linspace(float(cumulative_motion[0]), float(cumulative_motion[-1]), target_frames)
mark_distances = cumulative_motion
for mark in marks:
center = int(np.searchsorted(mark_distances, mark))
best_index: int | None = None
best_score = float("inf")
min_allowed = selected_indices[-1] + 1 if selected_indices else 0
lower = max(min_allowed, center - neighborhood)
upper = min(len(segment), center + neighborhood + 1)
search_ranges = [(lower, upper), (min_allowed, len(segment))]
for range_start, range_end in search_ranges:
for idx in range(range_start, range_end):
candidate = segment[idx]
mark_penalty = abs(float(mark_distances[idx]) - float(mark))
blur_penalty = 0.25 if candidate.blur_score < blur_threshold else 0.0
spacing_penalty = 0.15 if selected_indices and idx - selected_indices[-1] < 2 else 0.0
sharpness_bonus = 0.08 * float(normalized_blur[idx])
score = mark_penalty + blur_penalty + spacing_penalty - sharpness_bonus
if score < best_score:
best_score = score
best_index = idx
if best_index is not None:
break
if best_index is not None and (not selected_indices or best_index > selected_indices[-1]):
selected_indices.append(best_index)
selected_indices = sorted(set(selected_indices))
if len(selected_indices) < target_frames:
remaining = [idx for idx in range(len(segment)) if idx not in selected_indices]
remaining.sort(
key=lambda idx: (
-segment[idx].blur_score,
-(min(abs(idx - chosen) for chosen in selected_indices) if selected_indices else float("inf")),
)
)
for idx in remaining:
if len(selected_indices) >= target_frames:
break
selected_indices.append(idx)
selected_indices.sort()
trimmed = selected_indices[:target_frames]
return [segment[idx] for idx in trimmed]
def export_selected_images(scene_dir: Path, selected_frames: list[FrameCandidate]) -> list[Path]:
images_dir = _ensure_dir(scene_dir / "images")
exported: list[Path] = []
for index, candidate in enumerate(selected_frames):
destination = images_dir / f"frame_{index:04d}.jpg"
shutil.copy2(candidate.path, destination)
exported.append(destination)
return exported
def run_colmap(scene_dir: Path, selected_count: int, profile: ProfileConfig, max_image_edge: int) -> Path:
_require_binary("colmap")
database_path = scene_dir / "database.db"
images_dir = scene_dir / "images"
sparse_dir = _ensure_dir(scene_dir / "sparse")
_run(
[
"colmap",
"feature_extractor",
"--database_path",
str(database_path),
"--image_path",
str(images_dir),
"--ImageReader.single_camera",
"1",
"--ImageReader.camera_model",
"SIMPLE_RADIAL",
"--SiftExtraction.use_gpu",
"0",
"--SiftExtraction.max_image_size",
str(max_image_edge),
],
cwd=scene_dir,
)
_run(
[
"colmap",
"sequential_matcher",
"--database_path",
str(database_path),
"--SiftMatching.use_gpu",
"0",
"--SequentialMatching.overlap",
str(min(profile.sequential_overlap, max(selected_count - 1, 1))),
"--SequentialMatching.quadratic_overlap",
"1",
"--SequentialMatching.loop_detection",
"0",
],
cwd=scene_dir,
)
_run(
[
"colmap",
"mapper",
"--database_path",
str(database_path),
"--image_path",
str(images_dir),
"--output_path",
str(sparse_dir),
"--Mapper.multiple_models",
"0",
"--Mapper.extract_colors",
"0",
"--Mapper.min_model_size",
str(min(8, max(selected_count // 3, 4))),
],
cwd=scene_dir,
)
model_dirs = sorted(path for path in sparse_dir.iterdir() if path.is_dir())
if not model_dirs:
raise RuntimeError("COLMAP did not produce a sparse reconstruction.")
return model_dirs[0]
def count_registered_images(model_dir: Path) -> int:
image_bin = model_dir / "images.bin"
image_txt = model_dir / "images.txt"
if image_bin.exists():
with image_bin.open("rb") as handle:
header = handle.read(8)
return int(struct.unpack("<Q", header)[0]) if header else 0
if image_txt.exists():
lines = [line.strip() for line in image_txt.read_text(encoding="utf-8").splitlines()]
payload = [line for line in lines if line and not line.startswith("#")]
return len(payload) // 2
return 0
def quality_label(registered_frames: int, selected_frames: int) -> str:
if selected_frames <= 0:
return "unknown"
ratio = registered_frames / selected_frames
if ratio >= 0.85:
return "strong"
if ratio >= 0.6:
return "usable"
return "weak"
def create_contact_sheet(selected_frames: list[FrameCandidate], output_path: Path) -> Path:
if not selected_frames:
raise RuntimeError("No selected frames were available for the contact sheet.")
thumbs: list[np.ndarray] = []
for candidate in selected_frames:
image = cv2.imread(str(candidate.path), cv2.IMREAD_COLOR)
if image is None:
continue
image = _resize_max_edge(image, 320)
overlay = image.copy()
label = f"{candidate.timestamp_seconds:0.2f}s | blur {candidate.blur_score:0.0f}"
cv2.rectangle(overlay, (0, 0), (image.shape[1], 32), (12, 18, 28), -1)
image = cv2.addWeighted(overlay, 0.72, image, 0.28, 0.0)
cv2.putText(image, label, (10, 22), FONT, 0.55, (230, 235, 240), 1, cv2.LINE_AA)
thumbs.append(image)
cols = min(4, len(thumbs))
rows = int(math.ceil(len(thumbs) / cols))
cell_height = max(image.shape[0] for image in thumbs)
cell_width = max(image.shape[1] for image in thumbs)
canvas = np.full((rows * cell_height, cols * cell_width, 3), 18, dtype=np.uint8)
for index, image in enumerate(thumbs):
row = index // cols
col = index % cols
y = row * cell_height
x = col * cell_width
canvas[y : y + image.shape[0], x : x + image.shape[1]] = image
cv2.imwrite(str(output_path), canvas, [int(cv2.IMWRITE_JPEG_QUALITY), 92])
return output_path
def write_report(
scene_dir: Path,
metadata: VideoMetadata,
selected_frames: list[FrameCandidate],
registered_frames: int,
profile_key: str,
max_image_edge: int,
) -> Path:
report = {
"scene_name": scene_dir.name,
"video": {
"fps": metadata.fps,
"frame_count": metadata.frame_count,
"duration_seconds": metadata.duration_seconds,
"width": metadata.width,
"height": metadata.height,
},
"selection": {
"profile": profile_key,
"max_image_edge": max_image_edge,
"selected_frames": len(selected_frames),
"registered_frames": registered_frames,
"quality_label": quality_label(registered_frames, len(selected_frames)),
},
"frames": [
{
"filename": f"images/frame_{index:04d}.jpg",
"timestamp_seconds": candidate.timestamp_seconds,
"source_frame_index": candidate.frame_index,
"blur_score": candidate.blur_score,
"motion_score": candidate.motion_score,
"cut_score": candidate.cut_score,
}
for index, candidate in enumerate(selected_frames)
],
}
report_path = scene_dir / "report.json"
report_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
return report_path
def build_archive(scene_dir: Path, output_archive: Path) -> Path:
package_dir = _unique_dir(WORK_DIR, "package")
scene_package = _ensure_dir(package_dir / scene_dir.name)
shutil.copytree(scene_dir / "images", scene_package / "images")
shutil.copytree(scene_dir / "sparse", scene_package / "sparse")
report_path = scene_dir / "report.json"
if report_path.exists():
shutil.copy2(report_path, scene_package / "report.json")
with zipfile.ZipFile(output_archive, "w", compression=zipfile.ZIP_DEFLATED) as archive:
for path in sorted(scene_package.rglob("*")):
if path.is_file():
archive.write(path, path.relative_to(package_dir))
return output_archive
def convert_video_to_colmap_archive(
video_path: str | Path,
target_frames: int,
profile_key: str,
max_image_edge: int,
) -> ConversionOutputs:
if profile_key not in PROFILES:
raise ValueError(f"Unknown sampling profile: {profile_key}")
source_path = Path(video_path)
if not source_path.exists():
raise FileNotFoundError(f"Input video not found: {source_path}")
job_dir = _unique_dir(WORK_DIR, "video-job")
normalized_path = normalize_video_input(source_path, job_dir)
metadata = read_video_metadata(normalized_path)
profile = PROFILES[profile_key]
candidates_dir = _ensure_dir(job_dir / "candidates")
candidates = extract_candidates(
video_path=normalized_path,
metadata=metadata,
candidates_dir=candidates_dir,
target_frames=target_frames,
max_image_edge=max_image_edge,
profile=profile,
)
segment = choose_best_segment(segment_candidates(candidates, profile), target_frames, profile)
selected = select_keyframes(segment, target_frames, profile)
scene_name = f"{_slugify(source_path.stem)}-{_now_ms()}"
scene_dir = _ensure_dir(job_dir / scene_name)
export_selected_images(scene_dir, selected)
model_dir = run_colmap(scene_dir, len(selected), profile, max_image_edge)
registered_frames = count_registered_images(model_dir)
report_path = write_report(scene_dir, metadata, selected, registered_frames, profile_key, max_image_edge)
output_stem = f"{scene_name}-{profile_key}-{len(selected)}"
contact_sheet_path = create_contact_sheet(selected, OUTPUTS_DIR / f"{output_stem}.jpg")
archive_path = build_archive(scene_dir, OUTPUTS_DIR / f"{output_stem}.zip")
output_report_path = OUTPUTS_DIR / f"{output_stem}.report.json"
shutil.copy2(report_path, output_report_path)
return ConversionOutputs(
archive_path=archive_path,
report_path=output_report_path,
contact_sheet_path=contact_sheet_path,
scene_name=scene_name,
selected_frames=len(selected),
registered_frames=registered_frames,
duration_seconds=metadata.duration_seconds,
quality_label=quality_label(registered_frames, len(selected)),
)