#!/usr/bin/env python3 """ GPT-SoVITS Daniya trainer for Hugging Face Spaces. This Space is CPU-oriented. It prepares the dataset with the current GPT-SoVITS pipeline and can export fresh SoVITS and GPT checkpoints through the Gradio UI. """ import csv import fcntl import json import logging import os import shutil import signal import subprocess import sys import threading import time from dataclasses import dataclass from datetime import datetime from pathlib import Path import gradio as gr import yaml from huggingface_hub import hf_hub_download, snapshot_download HF_TOKEN = os.environ.get("HF_TOKEN", "") DATASET_REPO = "huanx/daniya-voice-gptsovits" GPT_SOVITS_REPO = "https://github.com/RVC-Boss/GPT-SoVITS.git" PERSISTENT_ROOT = Path(os.environ.get("PERSISTENT_ROOT", "/data")) EPHEMERAL_ROOT = Path("/tmp") ACTIVE_STORAGE_ROOT = ( PERSISTENT_ROOT if PERSISTENT_ROOT.is_dir() and os.access(PERSISTENT_ROOT, os.W_OK) else EPHEMERAL_ROOT ) STORAGE_MODE = "persistent" if ACTIVE_STORAGE_ROOT == PERSISTENT_ROOT else "ephemeral" WORK_DIR = ACTIVE_STORAGE_ROOT / "daniya_trainer" TMP_ROOT = WORK_DIR / "tmp" HF_HOME = WORK_DIR / "hf_home" GPT_SOVITS_DIR = WORK_DIR / "GPT-SoVITS" DATASET_DIR = WORK_DIR / "dataset" AUDIO_DIR = DATASET_DIR / "audio" LANGUAGE = "zh" SPEAKER = "daniya" EXP_NAME = "daniya" EXP_ROOT = WORK_DIR / "logs" OUTPUT_ROOT = WORK_DIR / "trained_models" PRETRAINED_DIR = GPT_SOVITS_DIR / "GPT_SoVITS" / "pretrained_models" BERT_DIR = PRETRAINED_DIR / "chinese-roberta-wwm-ext-large" CNHUBERT_DIR = PRETRAINED_DIR / "chinese-hubert-base" PRETRAINED_REPO = "lj1995/GPT-SoVITS" SV_PRETRAINED_REL = "sv/pretrained_eres2netv2w24s4ep4.ckpt" BERT_REPO = "hfl/chinese-roberta-wwm-ext-large" CNHUBERT_REPO = "TencentGameMate/chinese-hubert-base" MODEL_PATTERNS = [ "*.json", "*.txt", "*.bin", "*.safetensors", "*.model", ] UPSTREAM_PATCH_VERSION = "2026-05-26-s2-cpu-singleproc-v6" logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") log = logging.getLogger(__name__) os.environ.setdefault("HF_HOME", str(HF_HOME)) os.environ.setdefault("TRANSFORMERS_CACHE", str(HF_HOME / "transformers")) os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") os.environ.setdefault("TMPDIR", str(TMP_ROOT)) os.environ.setdefault("TMP", str(TMP_ROOT)) os.environ.setdefault("TEMP", str(TMP_ROOT)) @dataclass(frozen=True) class VersionSpec: version: str label: str stability: str pretrained_s1_rel: str pretrained_s2g_rel: str pretrained_s2d_rel: str sovits_config_name: str uses_sv: bool = False @dataclass(frozen=True) class VersionContext: spec: VersionSpec run_name: str input_list: Path exp_dir: Path text_path: Path semantic_path: Path prep_live_log: Path sovits_live_log: Path gpt_live_log: Path prep_status_path: Path sovits_status_path: Path gpt_status_path: Path gpt_train_log: Path sovits_ckpt_dir: Path gpt_log_dir: Path sovits_output_dir: Path gpt_output_dir: Path pretrained_s1: Path pretrained_s2g: Path pretrained_s2d: Path sv_pretrained: Path | None SUPPORTED_VERSIONS = { "v2": VersionSpec( version="v2", label="v2", stability="默认推荐。当前 Space 上最稳,训练和调试成本最低。", pretrained_s1_rel="gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", pretrained_s2g_rel="gsv-v2final-pretrained/s2G2333k.pth", pretrained_s2d_rel="gsv-v2final-pretrained/s2D2333k.pth", sovits_config_name="s2.json", ), "v2Pro": VersionSpec( version="v2Pro", label="v2Pro", stability="官方 main 已支持。比 v2 稍重,但链路和 v2 家族一致。", pretrained_s1_rel="s1v3.ckpt", pretrained_s2g_rel="v2Pro/s2Gv2Pro.pth", pretrained_s2d_rel="v2Pro/s2Dv2Pro.pth", sovits_config_name="s2v2Pro.json", uses_sv=True, ), "v2ProPlus": VersionSpec( version="v2ProPlus", label="v2ProPlus", stability="官方 main 已支持,效果更强,但在 cpu-basic Space 上更慢,建议作为可切换实验版本。", pretrained_s1_rel="s1v3.ckpt", pretrained_s2g_rel="v2Pro/s2Gv2ProPlus.pth", pretrained_s2d_rel="v2Pro/s2Dv2ProPlus.pth", sovits_config_name="s2v2ProPlus.json", uses_sv=True, ), } DEFAULT_VERSION = "v2" def get_version_spec(version: str): if version not in SUPPORTED_VERSIONS: raise ValueError(f"不支持的版本: {version}") return SUPPORTED_VERSIONS[version] def get_version_context(version: str): spec = get_version_spec(version) if version == DEFAULT_VERSION: run_name = EXP_NAME exp_dir = EXP_ROOT / EXP_NAME input_list = WORK_DIR / f"{EXP_NAME}.list" else: run_name = f"{EXP_NAME}-{version}" exp_dir = EXP_ROOT / f"{EXP_NAME}_{version}" input_list = WORK_DIR / f"{EXP_NAME}.{version}.list" return VersionContext( spec=spec, run_name=run_name, input_list=input_list, exp_dir=exp_dir, text_path=exp_dir / "2-name2text.txt", semantic_path=exp_dir / "6-name2semantic.tsv", prep_live_log=exp_dir / "_live_prepare.log", sovits_live_log=exp_dir / "_live_sovits.log", gpt_live_log=exp_dir / "_live_gpt.log", prep_status_path=exp_dir / "_status_prepare.json", sovits_status_path=exp_dir / "_status_sovits.json", gpt_status_path=exp_dir / "_status_gpt.json", gpt_train_log=exp_dir / "train.log", sovits_ckpt_dir=exp_dir / f"logs_s2_{version}", gpt_log_dir=exp_dir / f"logs_s1_{version}", sovits_output_dir=OUTPUT_ROOT / f"SoVITS_weights_{version}", gpt_output_dir=OUTPUT_ROOT / f"GPT_weights_{version}", pretrained_s1=PRETRAINED_DIR / spec.pretrained_s1_rel, pretrained_s2g=PRETRAINED_DIR / spec.pretrained_s2g_rel, pretrained_s2d=PRETRAINED_DIR / spec.pretrained_s2d_rel, sv_pretrained=PRETRAINED_DIR / SV_PRETRAINED_REL if spec.uses_sv else None, ) def ensure_dirs(): base_dirs = [ WORK_DIR, TMP_ROOT, HF_HOME, DATASET_DIR, EXP_ROOT, OUTPUT_ROOT, PRETRAINED_DIR, ] for path in base_dirs: path.mkdir(parents=True, exist_ok=True) for version in SUPPORTED_VERSIONS: ctx = get_version_context(version) for path in [ ctx.exp_dir, ctx.sovits_ckpt_dir, ctx.gpt_log_dir, ctx.sovits_output_dir, ctx.gpt_output_dir, ]: path.mkdir(parents=True, exist_ok=True) def hf_kwargs(): return {"token": HF_TOKEN} if HF_TOKEN else {} def append_live_log(path: Path, message: str): path.parent.mkdir(parents=True, exist_ok=True) with path.open("a", encoding="utf-8") as handle: handle.write(message + "\n") def clear_live_log(path: Path): path.parent.mkdir(parents=True, exist_ok=True) path.write_text("", encoding="utf-8") def clear_status_file(path: Path): path.parent.mkdir(parents=True, exist_ok=True) if path.exists(): path.unlink() def read_text_tail(path: Path, max_lines=200): if not path.exists(): return "" lines = path.read_text(encoding="utf-8", errors="ignore").splitlines() return "\n".join(lines[-max_lines:]) def read_live_log(path: Path): return read_text_tail(path, max_lines=200) def write_status_file(path: Path, payload): path.parent.mkdir(parents=True, exist_ok=True) tmp_path = path.with_suffix(path.suffix + ".tmp") tmp_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") tmp_path.replace(path) def read_status_file(path: Path): if not path.exists(): return None try: return json.loads(path.read_text(encoding="utf-8")) except Exception: return None def now_iso(): return datetime.now().isoformat(timespec="seconds") def child_pids(pid: int): children_path = Path(f"/proc/{pid}/task/{pid}/children") if not children_path.exists(): return [] try: direct_children = [int(item) for item in children_path.read_text(encoding="utf-8").split() if item.strip()] except Exception: return [] descendants = [] seen = set() queue = list(direct_children) while queue: current = queue.pop(0) if current in seen: continue seen.add(current) descendants.append(current) queue.extend(child_pids(current)) return descendants def process_metrics(pid: int, previous=None): pids = [pid, *child_pids(pid)] live_pids = [] cpu_ticks = 0 rss_pages = 0 for current_pid in pids: stat_path = Path(f"/proc/{current_pid}/stat") statm_path = Path(f"/proc/{current_pid}/statm") if not stat_path.exists() or not statm_path.exists(): continue stat_fields = stat_path.read_text(encoding="utf-8").split() cpu_ticks += int(stat_fields[13]) + int(stat_fields[14]) rss_pages += int(statm_path.read_text(encoding="utf-8").split()[1]) live_pids.append(current_pid) if not live_pids: return {"alive": False, "cpu_percent": 0.0, "rss_mb": 0.0, "cpu_time_seconds": 0.0} clock_ticks = os.sysconf(os.sysconf_names["SC_CLK_TCK"]) cpu_time_seconds = cpu_ticks / float(clock_ticks) page_size = os.sysconf("SC_PAGE_SIZE") rss_mb = rss_pages * page_size / (1024 * 1024) cpu_percent = 0.0 if previous is not None: prev_ts, prev_cpu_seconds = previous elapsed = max(time.time() - prev_ts, 1e-6) cpu_percent = max(0.0, ((cpu_time_seconds - prev_cpu_seconds) / elapsed) * 100.0) return { "alive": True, "cpu_percent": round(cpu_percent, 2), "rss_mb": round(rss_mb, 2), "cpu_time_seconds": round(cpu_time_seconds, 2), "process_count": len(live_pids), } def refresh_task_status(task): if not task: return None pid = task.get("pid") if pid is None: return task metrics = process_metrics(int(pid)) refreshed = dict(task) refreshed.update(metrics) refreshed["refreshed_at"] = now_iso() if not metrics.get("alive"): if refreshed.get("state") in {"starting", "running"} and refreshed.get("exit_code") is None: refreshed["state"] = "stale" return refreshed def status_summary(status): if not status: return "" parts = [ f"阶段: {status.get('state', 'unknown')}", f"更新时间: {status.get('updated_at', '-')}", ] if status.get("pid") is not None: parts.append(f"pid={status['pid']}") if status.get("process_count") is not None: parts.append(f"proc={status['process_count']}") if status.get("cpu_percent") is not None: parts.append(f"cpu={status['cpu_percent']}%") if status.get("rss_mb") is not None: parts.append(f"rss={status['rss_mb']} MB") if status.get("silent_for_seconds") is not None: parts.append(f"静默={status['silent_for_seconds']}s") if status.get("last_output_line"): parts.append(f"最后输出: {status['last_output_line'][:160]}") return "[状态]\n" + " | ".join(parts) def combined_log_text(primary_path: Path, status_path: Path | None = None, extra_paths=None): sections = [] status = read_status_file(status_path) if status_path else None summary = status_summary(status) if summary: sections.append(summary) primary = read_text_tail(primary_path, max_lines=200) if primary: sections.append(primary) for label, path in extra_paths or []: content = read_text_tail(path, max_lines=120) if content: sections.append(f"[{label}]\n{content}") return "\n\n".join(section for section in sections if section) def push(logs, message, live_path: Path | None = None): logs.append(message) if live_path is not None: append_live_log(live_path, message) return "\n".join(logs[-200:]) def run_cmd(command, cwd=None, env=None, status_path: Path | None = None, live_path: Path | None = None, status_label=None): proc = subprocess.Popen( command, cwd=str(cwd) if cwd else None, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, ) command_text = " ".join(command) shared_state = { "started_at": now_iso(), "last_output_at_ts": time.time(), "last_output_line": f"$ {command_text}", "last_heartbeat_at_ts": 0.0, } stop_event = threading.Event() def write_runtime_status(state, exit_code=None): if status_path is None: return payload = { "label": status_label or "任务", "state": state, "pid": proc.pid, "command": command_text, "started_at": shared_state["started_at"], "updated_at": now_iso(), "last_output_at": datetime.fromtimestamp(shared_state["last_output_at_ts"]).isoformat(timespec="seconds"), "last_output_line": shared_state["last_output_line"], "silent_for_seconds": round(max(time.time() - shared_state["last_output_at_ts"], 0.0), 1), "exit_code": exit_code, } payload.update(process_metrics(proc.pid)) write_status_file(status_path, payload) def monitor_process(): previous = None while not stop_event.is_set() and proc.poll() is None: metrics = process_metrics(proc.pid, previous) previous = (time.time(), metrics["cpu_time_seconds"]) if status_path is not None: payload = { "label": status_label or "任务", "state": "running", "pid": proc.pid, "command": command_text, "started_at": shared_state["started_at"], "updated_at": now_iso(), "last_output_at": datetime.fromtimestamp(shared_state["last_output_at_ts"]).isoformat(timespec="seconds"), "last_output_line": shared_state["last_output_line"], "silent_for_seconds": round(max(time.time() - shared_state["last_output_at_ts"], 0.0), 1), "exit_code": None, } payload.update(metrics) write_status_file(status_path, payload) silent_for = time.time() - shared_state["last_output_at_ts"] if live_path is not None and silent_for >= 30 and time.time() - shared_state["last_heartbeat_at_ts"] >= 30: append_live_log( live_path, ( f"[状态] {status_label or '任务'} 仍在运行 | pid={proc.pid} | " f"cpu={metrics['cpu_percent']}% | rss={metrics['rss_mb']} MB | 静默 {int(silent_for)}s" ), ) shared_state["last_heartbeat_at_ts"] = time.time() stop_event.wait(10) write_runtime_status("starting") monitor_thread = threading.Thread(target=monitor_process, daemon=True) monitor_thread.start() yield f"$ {command_text}" for raw in proc.stdout: line = raw.rstrip() if line: shared_state["last_output_at_ts"] = time.time() shared_state["last_output_line"] = line write_runtime_status("running") yield line code = proc.wait() stop_event.set() monitor_thread.join(timeout=1) write_runtime_status("completed" if code == 0 else "failed", exit_code=code) if code != 0: raise RuntimeError(f"命令失败 (exit={code}): {command_text}") def python_script_command(script_path, *script_args, file_system_sharing=False): if not file_system_sharing: return [sys.executable, "-s", script_path, *[str(arg) for arg in script_args]] launcher = ( "import runpy, sys;" "import torch.multiprocessing as mp;" "mp.set_sharing_strategy('file_system');" "script = sys.argv[1];" "sys.argv = sys.argv[1:];" "runpy.run_path(script, run_name='__main__')" ) return [sys.executable, "-s", "-c", launcher, script_path, *[str(arg) for arg in script_args]] def has_transformers_model(path: Path): return path.exists() and (path / "config.json").exists() and ( any(path.glob("*.bin")) or any(path.glob("*.safetensors")) ) def metadata_rows(): metadata = DATASET_DIR / "metadata.csv" if not metadata.exists(): return [] with metadata.open("r", encoding="utf-8", newline="") as handle: return list(csv.DictReader(handle)) def latest_file(directory: Path, suffix: str): files = sorted(directory.glob(f"*{suffix}"), key=lambda item: item.stat().st_mtime) return str(files[-1]) if files else None def list_files(directory: Path, suffix: str): return sorted(directory.glob(f"*{suffix}"), key=lambda item: item.stat().st_mtime, reverse=True) def format_size(size_bytes: int): value = float(size_bytes) units = ["B", "KB", "MB", "GB"] for unit in units: if value < 1024 or unit == units[-1]: if unit == "B": return f"{int(value)} {unit}" return f"{value:.1f} {unit}" value /= 1024 return f"{int(size_bytes)} B" def format_mtime(path: Path): return datetime.fromtimestamp(path.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S") def directory_overview(): ctx = get_version_context(DEFAULT_VERSION) return directory_overview_for(ctx) def directory_overview_for(ctx: VersionContext): return "\n".join( [ f"存储模式: {STORAGE_MODE}", f"存储根目录: {ACTIVE_STORAGE_ROOT}", f"当前版本: {ctx.spec.version}", f"工作目录: {WORK_DIR}", f"临时目录: {TMP_ROOT}", f"数据集目录: {DATASET_DIR}", f"日志目录: {ctx.exp_dir}", f"SoVITS 导出目录: {ctx.sovits_output_dir}", f"GPT 导出目录: {ctx.gpt_output_dir}", ] ) def artifact_lines(label: str, files): if not files: return [f"{label}: 暂无"] lines = [f"{label}: 共 {len(files)} 个"] for index, path in enumerate(files[:10], start=1): lines.append(f"{index}. {path.name} | {format_size(path.stat().st_size)} | {format_mtime(path)} | {path}") return lines def artifacts_summary(version=DEFAULT_VERSION): ctx = get_version_context(version) sovits_files = list_files(ctx.sovits_output_dir, ".pth") gpt_files = list_files(ctx.gpt_output_dir, ".ckpt") lines = [ "训练结果总览", f"版本: {ctx.spec.version}", f"稳定性: {ctx.spec.stability}", *artifact_lines("SoVITS", sovits_files), *artifact_lines("GPT", gpt_files), ] return ("\n".join(lines), directory_overview_for(ctx)) def path_is_under(path: Path, root: Path): try: path.resolve().relative_to(root.resolve()) return True except ValueError: return False def managed_roots(ctx: VersionContext): return [ ("版本日志目录", ctx.exp_dir), ("SoVITS 输出目录", ctx.sovits_output_dir), ("GPT 输出目录", ctx.gpt_output_dir), ("数据音频目录", AUDIO_DIR), ("临时目录", TMP_ROOT), ] def managed_single_paths(ctx: VersionContext): return [ ("metadata", DATASET_DIR / "metadata.csv"), ("训练清单", ctx.input_list), ("文本特征总表", ctx.text_path), ("语义特征总表", ctx.semantic_path), ("预处理实时日志", ctx.prep_live_log), ("预处理状态", ctx.prep_status_path), ("SoVITS 实时日志", ctx.sovits_live_log), ("SoVITS 状态", ctx.sovits_status_path), ("GPT 实时日志", ctx.gpt_live_log), ("GPT 状态", ctx.gpt_status_path), ("GPT train.log", ctx.gpt_train_log), ] def managed_path_choices(ctx: VersionContext): items = set() for _, root in managed_roots(ctx): if not root.exists(): continue for child in sorted(root.iterdir(), key=lambda item: item.name): rel = child.relative_to(WORK_DIR).as_posix() if child.is_dir(): rel = f"{rel}/" items.add(rel) for _, path in managed_single_paths(ctx): if path.exists(): items.add(path.relative_to(WORK_DIR).as_posix()) return sorted(items) def describe_path(path: Path): if not path.exists(): return f"不存在: {path}" lines = [ f"路径: {path}", f"类型: {'目录' if path.is_dir() else '文件'}", f"修改时间: {format_mtime(path)}", ] if path.is_file(): lines.append(f"大小: {format_size(path.stat().st_size)}") else: children = list(path.iterdir()) lines.append(f"直接子项: {len(children)}") return "\n".join(lines) def file_manager_summary(version=DEFAULT_VERSION): ctx = get_version_context(version) lines = [ "文件管理", f"版本: {ctx.spec.version}", "说明: 这里只暴露安全清理范围,不包含 GPT-SoVITS 仓库和预训练缓存。", ] for label, root in managed_roots(ctx): if not root.exists(): lines.append(f"{label}: 不存在") continue children = sorted(root.iterdir(), key=lambda item: item.name) lines.append(f"{label}: {root} | 直接子项 {len(children)}") for child in children[:12]: rel = child.relative_to(WORK_DIR).as_posix() kind = "D" if child.is_dir() else "F" suffix = "/" if child.is_dir() else "" if child.is_dir(): extra = f"{len(list(child.iterdir()))} 项" else: extra = format_size(child.stat().st_size) lines.append(f"[{kind}] {rel}{suffix} | {extra} | {format_mtime(child)}") if len(children) > 12: lines.append(f"... 其余 {len(children) - 12} 项省略") for label, path in managed_single_paths(ctx): if path.exists(): lines.append(f"{label}: {path.relative_to(WORK_DIR).as_posix()} | {format_size(path.stat().st_size)}") return "\n".join(lines) def file_manager_state(version=DEFAULT_VERSION, selected_path=None): ctx = get_version_context(version) choices = managed_path_choices(ctx) if selected_path not in choices: selected_path = choices[0] if choices else None detail = describe_managed_path(version, selected_path) return ( file_manager_summary(version), gr.update(choices=choices, value=selected_path), detail, ) def describe_managed_path(version=DEFAULT_VERSION, selected_path=None): if not selected_path: return "未选择路径。" target = (WORK_DIR / selected_path.rstrip("/")).resolve() return describe_path(target) def allowed_delete_roots(ctx: VersionContext): return [ ctx.exp_dir, ctx.sovits_output_dir, ctx.gpt_output_dir, AUDIO_DIR, TMP_ROOT, DATASET_DIR / "metadata.csv", ctx.input_list, ctx.text_path, ctx.semantic_path, ctx.prep_live_log, ctx.prep_status_path, ctx.sovits_live_log, ctx.sovits_status_path, ctx.gpt_live_log, ctx.gpt_status_path, ctx.gpt_train_log, ] def is_allowed_delete_target(ctx: VersionContext, target: Path): return any(path_is_under(target, root) for root in allowed_delete_roots(ctx)) def manager_action_result(version, status, selected_path=None): return ( status, *file_manager_state(version, selected_path), *sync_live_state(version), ) def delete_managed_path(version=DEFAULT_VERSION, selected_path=None): ctx = get_version_context(version) if not selected_path: return manager_action_result(version, "未选择要删除的路径。", selected_path) target = (WORK_DIR / selected_path.rstrip("/")).resolve() if not target.exists(): return manager_action_result(version, f"路径不存在: {selected_path}", None) if not is_allowed_delete_target(ctx, target): return manager_action_result(version, f"拒绝删除: {selected_path} 不在安全范围内。", selected_path) if target.is_dir(): shutil.rmtree(target) status = f"已删除目录: {selected_path}" else: target.unlink() status = f"已删除文件: {selected_path}" return manager_action_result(version, status, None) def cleanup_preprocess(version=DEFAULT_VERSION, selected_path=None): ctx = get_version_context(version) reset_preprocess_outputs(ctx) clear_live_log(ctx.prep_live_log) clear_status_file(ctx.prep_status_path) status = f"已清理 {ctx.spec.version} 的预处理产物。" return manager_action_result(version, status, selected_path) def cleanup_model_outputs(version=DEFAULT_VERSION, selected_path=None): ctx = get_version_context(version) for path in ctx.sovits_output_dir.glob("*.pth"): path.unlink() for path in ctx.gpt_output_dir.glob("*.ckpt"): path.unlink() if ctx.sovits_ckpt_dir.exists(): shutil.rmtree(ctx.sovits_ckpt_dir) ctx.sovits_ckpt_dir.mkdir(parents=True, exist_ok=True) if ctx.gpt_log_dir.exists(): shutil.rmtree(ctx.gpt_log_dir) ctx.gpt_log_dir.mkdir(parents=True, exist_ok=True) status = f"已清理 {ctx.spec.version} 的模型输出和训练断点。" return manager_action_result(version, status, selected_path) def cleanup_sovits_outputs(version=DEFAULT_VERSION): ctx = get_version_context(version) sovits_status = read_status_file(ctx.sovits_status_path) if sovits_status and sovits_status.get("alive") and sovits_status.get("pid") is not None: terminate_pid(int(sovits_status["pid"])) for path in ctx.sovits_output_dir.glob("*.pth"): path.unlink() removed_dirs = [] for directory in sorted({ctx.sovits_ckpt_dir, *ctx.exp_dir.glob("logs_s2_*")}): if directory.exists(): shutil.rmtree(directory) removed_dirs.append(directory.name) ctx.sovits_ckpt_dir.mkdir(parents=True, exist_ok=True) clear_live_log(ctx.sovits_live_log) clear_status_file(ctx.sovits_status_path) suffix = f" 已删除目录: {', '.join(removed_dirs)}" if removed_dirs else "" status = f"已清理 {ctx.spec.version} 的 SoVITS 输出和训练断点。{suffix}" return manager_action_result(version, status, None) def cleanup_live_logs(version=DEFAULT_VERSION, selected_path=None): ctx = get_version_context(version) clear_live_log(ctx.prep_live_log) clear_live_log(ctx.sovits_live_log) clear_live_log(ctx.gpt_live_log) clear_status_file(ctx.prep_status_path) clear_status_file(ctx.sovits_status_path) clear_status_file(ctx.gpt_status_path) status = f"已清理 {ctx.spec.version} 的实时日志。" return manager_action_result(version, status, selected_path) def terminate_pid(pid: int, grace_seconds: float = 5.0): try: os.kill(pid, signal.SIGTERM) except ProcessLookupError: return True, "进程不存在" deadline = time.time() + grace_seconds while time.time() < deadline: if not Path(f"/proc/{pid}").exists(): return True, "SIGTERM" time.sleep(0.2) try: os.kill(pid, signal.SIGKILL) except ProcessLookupError: return True, "SIGTERM" deadline = time.time() + 2.0 while time.time() < deadline: if not Path(f"/proc/{pid}").exists(): return True, "SIGKILL" time.sleep(0.2) return False, "进程仍存活" def stop_task_by_kind(version=DEFAULT_VERSION, kind="sovits"): ctx = get_version_context(version) mapping = { "prepare": ("预处理", ctx.prep_status_path, ctx.prep_live_log), "sovits": ("SoVITS", ctx.sovits_status_path, ctx.sovits_live_log), "gpt": ("GPT", ctx.gpt_status_path, ctx.gpt_live_log), } label, status_path, live_path = mapping[kind] status = read_status_file(status_path) if not status or status.get("pid") is None: return (f"{label} 当前没有可停止的任务。", *sync_live_state(version)) pid = int(status["pid"]) ok, method = terminate_pid(pid) status["updated_at"] = now_iso() status["alive"] = Path(f"/proc/{pid}").exists() status["state"] = "stopped" if ok else "stop_failed" status["exit_code"] = status.get("exit_code") if status.get("exit_code") is not None else -15 write_status_file(status_path, status) message = f"{label} 停止结果: {method} | pid={pid}" append_live_log(live_path, f"[控制] {message}") return (message, *sync_live_state(version)) def stop_sovits_task(version=DEFAULT_VERSION): return stop_task_by_kind(version, "sovits") def stop_gpt_task(version=DEFAULT_VERSION): return stop_task_by_kind(version, "gpt") def dataset_prepared(ctx: VersionContext): prepared = ( ctx.text_path.exists() and ctx.semantic_path.exists() and (ctx.exp_dir / "3-bert").exists() and (ctx.exp_dir / "4-cnhubert").exists() and (ctx.exp_dir / "5-wav32k").exists() ) if ctx.spec.uses_sv: prepared = prepared and (ctx.exp_dir / "7-sv_cn").exists() return prepared def build_process_env(ctx: VersionContext): env = os.environ.copy() python_path_parts = [str(GPT_SOVITS_DIR), str(GPT_SOVITS_DIR / "GPT_SoVITS")] if env.get("PYTHONPATH"): python_path_parts.append(env["PYTHONPATH"]) env.update( { "PYTHONPATH": os.pathsep.join(python_path_parts), "inp_text": str(ctx.input_list), "inp_wav_dir": str(AUDIO_DIR), "exp_name": ctx.run_name, "opt_dir": str(ctx.exp_dir), "i_part": "0", "all_parts": "1", "_CUDA_VISIBLE_DEVICES": "0", "is_half": "False", "version": ctx.spec.version, "hz": "25hz", "bert_pretrained_dir": str(BERT_DIR), "bert_path": str(BERT_DIR), "cnhubert_base_dir": str(CNHUBERT_DIR), "pretrained_s2G": str(ctx.pretrained_s2g), "s2config_path": f"GPT_SoVITS/configs/{ctx.spec.sovits_config_name}", "sv_path": str(ctx.sv_pretrained) if ctx.sv_pretrained else "", } ) return env def ensure_upstream_repo(): if (GPT_SOVITS_DIR / "webui.py").exists(): return if GPT_SOVITS_DIR.exists(): shutil.rmtree(GPT_SOVITS_DIR) subprocess.run( ["git", "clone", "--depth", "1", GPT_SOVITS_REPO, str(GPT_SOVITS_DIR)], check=True, capture_output=True, text=True, timeout=900, ) def patch_upstream_repo(): patch_marker = GPT_SOVITS_DIR / ".hf_space_patch_applied" if patch_marker.exists(): applied_version = patch_marker.read_text(encoding="utf-8").strip() if applied_version == UPSTREAM_PATCH_VERSION: return chinese2 = GPT_SOVITS_DIR / "GPT_SoVITS" / "text" / "chinese2.py" content = chinese2.read_text(encoding="utf-8") old = "is_g2pw = True # True if is_g2pw_str.lower() == 'true' else False" new = "is_g2pw = False # patched for CPU Space training; avoids extra G2PW asset" if old in content: content = content.replace(old, new, 1) chinese2.write_text(content, encoding="utf-8") sv_script = GPT_SOVITS_DIR / "GPT_SoVITS" / "prepare_datasets" / "2-get-sv.py" sv_content = sv_script.read_text(encoding="utf-8") if "from scipy.io import wavfile" not in sv_content: sv_content = sv_content.replace( "import torch\n", "import torch\nimport numpy as np\nfrom scipy.io import wavfile\n", 1, ) old_load = """ wav32k, sr0 = torchaudio.load(wav_path) assert sr0 == 32000 wav32k = wav32k.to(device) """ new_load = """ sr0, wav32k_np = wavfile.read(wav_path) assert sr0 == 32000 if wav32k_np.ndim == 1: wav32k_np = wav32k_np[None, :] else: wav32k_np = wav32k_np.T if np.issubdtype(wav32k_np.dtype, np.integer): wav32k_np = wav32k_np.astype("float32") / float(np.iinfo(wav32k_np.dtype).max) else: wav32k_np = wav32k_np.astype("float32") wav32k = torch.from_numpy(wav32k_np).to(device) """ if old_load in sv_content: sv_content = sv_content.replace(old_load, new_load, 1) sv_script.write_text(sv_content, encoding="utf-8") process_ckpt = GPT_SOVITS_DIR / "GPT_SoVITS" / "process_ckpt.py" process_ckpt_content = process_ckpt.read_text(encoding="utf-8") if 'os.makedirs(dir, exist_ok=True)' not in process_ckpt_content: process_ckpt_content = process_ckpt_content.replace( ' name = os.path.basename(path)\n', ' name = os.path.basename(path)\n os.makedirs(dir, exist_ok=True)\n', 1, ) process_ckpt_content = process_ckpt_content.replace( 'def my_save2(fea, path, model_version):\n bio = BytesIO()\n', 'def my_save2(fea, path, model_version):\n os.makedirs(os.path.dirname(path), exist_ok=True)\n bio = BytesIO()\n', 1, ) process_ckpt.write_text(process_ckpt_content, encoding="utf-8") utils_py = GPT_SOVITS_DIR / "GPT_SoVITS" / "utils.py" utils_content = utils_py.read_text(encoding="utf-8") if 'os.makedirs(dir, exist_ok=True)' not in utils_content: utils_content = utils_content.replace( ' name = os.path.basename(path)\n', ' name = os.path.basename(path)\n os.makedirs(dir, exist_ok=True)\n', 1, ) utils_py.write_text(utils_content, encoding="utf-8") data_module = GPT_SOVITS_DIR / "GPT_SoVITS" / "AR" / "data" / "data_module.py" data_module.write_text( """# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py # reference: https://github.com/lifeiteng/vall-e from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from AR.data.bucket_sampler import DistributedBucketSampler from AR.data.dataset import Text2SemanticDataset class Text2SemanticDataModule(LightningDataModule): def __init__( self, config, train_semantic_path, train_phoneme_path, dev_semantic_path=None, dev_phoneme_path=None, ): super().__init__() self.config = config self.train_semantic_path = train_semantic_path self.train_phoneme_path = train_phoneme_path self.dev_semantic_path = dev_semantic_path self.dev_phoneme_path = dev_phoneme_path self.num_workers = self.config["data"]["num_workers"] def prepare_data(self): pass def setup(self, stage=None, output_logs=False): self._train_dataset = Text2SemanticDataset( phoneme_path=self.train_phoneme_path, semantic_path=self.train_semantic_path, max_sec=self.config["data"]["max_sec"], pad_val=self.config["data"]["pad_val"], ) self._dev_dataset = self._train_dataset # self._dev_dataset = Text2SemanticDataset( # phoneme_path=self.dev_phoneme_path, # semantic_path=self.dev_semantic_path, # max_sample=self.config['data']['max_eval_sample'], # max_sec=self.config['data']['max_sec'], # pad_val=self.config['data']['pad_val']) def train_dataloader(self): batch_size = ( self.config["train"]["batch_size"] // 2 if self.config["train"].get("if_dpo", False) is True else self.config["train"]["batch_size"] ) batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存 sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) loader_kwargs = dict( batch_size=batch_size, sampler=sampler, collate_fn=self._train_dataset.collate, num_workers=self.num_workers, ) if self.num_workers > 0: loader_kwargs["persistent_workers"] = True loader_kwargs["prefetch_factor"] = 16 return DataLoader(self._train_dataset, **loader_kwargs) def val_dataloader(self): num_workers = self.num_workers loader_kwargs = dict( batch_size=1, shuffle=False, collate_fn=self._train_dataset.collate, num_workers=num_workers, ) if num_workers > 0: loader_kwargs["persistent_workers"] = True loader_kwargs["prefetch_factor"] = 16 return DataLoader(self._dev_dataset, **loader_kwargs) # 这个会使用到嘛? def test_dataloader(self): return DataLoader( self._dev_dataset, batch_size=1, shuffle=False, collate_fn=self._train_dataset.collate, ) """, encoding="utf-8", ) s1_train = GPT_SOVITS_DIR / "GPT_SoVITS" / "s1_train.py" s1_train_content = s1_train.read_text(encoding="utf-8") trainer_anchor = " callbacks=[ckpt_callback],\n use_distributed_sampler=False," if "log_every_n_steps=1" not in s1_train_content and trainer_anchor in s1_train_content: s1_train_content = s1_train_content.replace( trainer_anchor, " callbacks=[ckpt_callback],\n log_every_n_steps=1,\n use_distributed_sampler=False,", 1, ) s1_train.write_text(s1_train_content, encoding="utf-8") s2_train = GPT_SOVITS_DIR / "GPT_SoVITS" / "s2_train.py" s2_train_content = s2_train.read_text(encoding="utf-8") old_main = """def main(): if torch.cuda.is_available(): n_gpus = torch.cuda.device_count() else: n_gpus = 1 os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) mp.spawn( run, nprocs=n_gpus, args=( n_gpus, hps, ), ) """ new_main = """def main(): if torch.cuda.is_available(): n_gpus = torch.cuda.device_count() else: n_gpus = 1 os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) if not torch.cuda.is_available() and n_gpus == 1: run(0, n_gpus, hps) return mp.spawn( run, nprocs=n_gpus, args=( n_gpus, hps, ), ) """ if old_main in s2_train_content: s2_train_content = s2_train_content.replace(old_main, new_main, 1) old_loader = """ train_loader = DataLoader( train_dataset, num_workers=5, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler, persistent_workers=True, prefetch_factor=3, ) """ new_loader = """ train_loader_kwargs = dict( shuffle=False, collate_fn=collate_fn, batch_sampler=train_sampler, num_workers=0 if not torch.cuda.is_available() else 5, pin_memory=torch.cuda.is_available(), ) if train_loader_kwargs[\"num_workers\"] > 0: train_loader_kwargs[\"persistent_workers\"] = True train_loader_kwargs[\"prefetch_factor\"] = 3 train_loader = DataLoader(train_dataset, **train_loader_kwargs) """ if old_loader in s2_train_content: s2_train_content = s2_train_content.replace(old_loader, new_loader, 1) s2_train.write_text(s2_train_content, encoding="utf-8") patch_marker.write_text(f"{UPSTREAM_PATCH_VERSION}\n", encoding="utf-8") def ensure_base_assets(version=DEFAULT_VERSION): ctx = get_version_context(version) if not has_transformers_model(BERT_DIR): snapshot_download( repo_id=BERT_REPO, local_dir=str(BERT_DIR), allow_patterns=MODEL_PATTERNS, **hf_kwargs(), ) if not has_transformers_model(CNHUBERT_DIR): snapshot_download( repo_id=CNHUBERT_REPO, local_dir=str(CNHUBERT_DIR), allow_patterns=MODEL_PATTERNS, **hf_kwargs(), ) if not ctx.pretrained_s1.exists(): hf_hub_download( repo_id=PRETRAINED_REPO, filename=ctx.spec.pretrained_s1_rel, local_dir=str(PRETRAINED_DIR), **hf_kwargs(), ) if not ctx.pretrained_s2g.exists(): hf_hub_download( repo_id=PRETRAINED_REPO, filename=ctx.spec.pretrained_s2g_rel, local_dir=str(PRETRAINED_DIR), **hf_kwargs(), ) if not ctx.pretrained_s2d.exists(): try: hf_hub_download( repo_id=PRETRAINED_REPO, filename=ctx.spec.pretrained_s2d_rel, local_dir=str(PRETRAINED_DIR), **hf_kwargs(), ) except Exception: log.warning("Optional pretrained discriminator not found: %s", ctx.spec.pretrained_s2d_rel) if ctx.sv_pretrained and not ctx.sv_pretrained.exists(): hf_hub_download( repo_id=PRETRAINED_REPO, filename=SV_PRETRAINED_REL, local_dir=str(PRETRAINED_DIR), **hf_kwargs(), ) def reset_preprocess_outputs(ctx: VersionContext): for path in [ ctx.input_list, ctx.text_path, ctx.semantic_path, ctx.exp_dir / "2-name2text-0.txt", ctx.exp_dir / "6-name2semantic-0.tsv", ]: if path.exists(): path.unlink() directories = [ctx.exp_dir / "3-bert", ctx.exp_dir / "4-cnhubert", ctx.exp_dir / "5-wav32k"] if ctx.spec.uses_sv: directories.append(ctx.exp_dir / "7-sv_cn") for directory in directories: if directory.exists(): shutil.rmtree(directory) def ensure_preprocess_dirs(ctx: VersionContext): directories = [ctx.exp_dir / "3-bert", ctx.exp_dir / "4-cnhubert", ctx.exp_dir / "5-wav32k"] if ctx.spec.uses_sv: directories.append(ctx.exp_dir / "7-sv_cn") for directory in directories: directory.mkdir(parents=True, exist_ok=True) def reset_gpt_training_outputs(ctx: VersionContext): if ctx.gpt_log_dir.exists(): shutil.rmtree(ctx.gpt_log_dir) ctx.gpt_log_dir.mkdir(parents=True, exist_ok=True) for path in ctx.gpt_output_dir.glob("*.ckpt"): path.unlink() def build_manifest(ctx: VersionContext): rows = metadata_rows() audio_files = {item.name for item in AUDIO_DIR.glob("*.wav")} listed = set() output = [] for row in rows: wav_name = (row.get("file") or "").strip() text = (row.get("text") or "").strip() if not wav_name or not text or wav_name not in audio_files: continue listed.add(wav_name) output.append(f"{wav_name}|{SPEAKER}|{LANGUAGE}|{text}") if not output: raise RuntimeError("metadata.csv 里没有可用训练样本") ctx.input_list.write_text("\n".join(output) + "\n", encoding="utf-8") unlisted = sorted(audio_files - listed) return len(output), len(audio_files), unlisted def create_sovits_config(ctx: VersionContext, epochs, batch_size, save_every_epoch, learning_rate): config_path = GPT_SOVITS_DIR / "GPT_SoVITS" / "configs" / ctx.spec.sovits_config_name with config_path.open("r", encoding="utf-8") as handle: data = json.load(handle) data["train"]["fp16_run"] = False data["train"]["batch_size"] = int(batch_size) data["train"]["epochs"] = int(epochs) data["train"]["learning_rate"] = float(learning_rate) data["train"]["pretrained_s2G"] = str(ctx.pretrained_s2g) data["train"]["pretrained_s2D"] = str(ctx.pretrained_s2d) if ctx.pretrained_s2d.exists() else "" data["train"]["if_save_latest"] = False data["train"]["if_save_every_weights"] = True data["train"]["save_every_epoch"] = int(save_every_epoch) data["train"]["gpu_numbers"] = "0" data["train"]["grad_ckpt"] = False data["data"]["exp_dir"] = str(ctx.exp_dir) data["s2_ckpt_dir"] = str(ctx.exp_dir) data["save_weight_dir"] = str(ctx.sovits_output_dir) data["name"] = ctx.run_name data["version"] = ctx.spec.version data["model"]["version"] = ctx.spec.version tmp_config = WORK_DIR / f"tmp_s2_{ctx.spec.version}.json" tmp_config.write_text(json.dumps(data), encoding="utf-8") return tmp_config def create_gpt_config(ctx: VersionContext, epochs, batch_size, save_every_epoch): config_path = GPT_SOVITS_DIR / "GPT_SoVITS" / "configs" / "s1longer-v2.yaml" with config_path.open("r", encoding="utf-8") as handle: data = yaml.safe_load(handle) data["train"]["batch_size"] = int(batch_size) data["train"]["epochs"] = int(epochs) data["train"]["precision"] = "32" data["train"]["save_every_n_epoch"] = int(save_every_epoch) data["train"]["if_save_every_weights"] = True data["train"]["if_save_latest"] = False data["train"]["if_dpo"] = False data["train"]["exp_name"] = ctx.run_name data["train"]["half_weights_save_dir"] = str(ctx.gpt_output_dir) # cpu-basic Space 上 worker>0 会走共享内存路径,batch 4 时不稳定。 data["data"]["num_workers"] = 0 data["pretrained_s1"] = str(ctx.pretrained_s1) data["train_semantic_path"] = str(ctx.semantic_path) data["train_phoneme_path"] = str(ctx.text_path) data["output_dir"] = str(ctx.gpt_log_dir) tmp_config = WORK_DIR / f"tmp_s1_{ctx.spec.version}.yaml" tmp_config.write_text(yaml.safe_dump(data, allow_unicode=True, sort_keys=False), encoding="utf-8") return tmp_config def setup_environment_steps(logs, version, live_path: Path | None = None): ctx = get_version_context(version) ensure_dirs() yield push(logs, f"目标版本:{ctx.spec.version}({ctx.spec.stability})", live_path) if (GPT_SOVITS_DIR / "webui.py").exists(): yield push(logs, "GPT-SoVITS 仓库已存在,跳过克隆。", live_path) else: yield push(logs, "克隆 GPT-SoVITS 仓库...", live_path) ensure_upstream_repo() yield push(logs, "✅ GPT-SoVITS 仓库已就绪。", live_path) patch_upstream_repo() yield push(logs, "✅ 已应用 Space 兼容补丁。", live_path) if not has_transformers_model(BERT_DIR): yield push(logs, "下载中文 BERT 特征模型...", live_path) if not has_transformers_model(CNHUBERT_DIR): yield push(logs, "下载 CN-HuBERT 特征模型...", live_path) if not ctx.pretrained_s1.exists() or not ctx.pretrained_s2g.exists(): yield push(logs, f"下载 GPT-SoVITS {ctx.spec.version} 底模...", live_path) if ctx.sv_pretrained and not ctx.sv_pretrained.exists(): yield push(logs, f"下载 {ctx.spec.version} 的 speaker embedding 底模...", live_path) ensure_base_assets(version) yield push( logs, f"✅ 环境就绪:GPT-SoVITS 仓库、中文特征模型和 {ctx.spec.version} 底模均已准备完成。", live_path, ) def download_dataset_steps(logs, version, live_path: Path | None = None): ensure_dirs() yield from setup_environment_steps(logs, version, live_path) yield push(logs, "下载 Daniya 数据集...", live_path) snapshot_download( repo_id=DATASET_REPO, repo_type="dataset", local_dir=str(DATASET_DIR), **hf_kwargs(), ) rows = metadata_rows() audio_count = len(list(AUDIO_DIR.glob("*.wav"))) yield push( logs, f"✅ 数据集已下载:音频 {audio_count} 个,metadata {len(rows)} 条。", live_path, ) def prepare_data_steps(logs, version, live_path: Path | None = None): ctx = get_version_context(version) ensure_dirs() lock_path = ctx.exp_dir / ".preprocess.lock" wait_notice_at = 0.0 with lock_path.open("a+", encoding="utf-8") as lock_handle: while True: try: fcntl.flock(lock_handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) break except BlockingIOError: now = time.time() if now >= wait_notice_at: yield push(logs, f"{ctx.spec.version} 预处理中,等待已有任务完成...", live_path) wait_notice_at = now + 10 time.sleep(2) try: if dataset_prepared(ctx): return f"✅ 预处理已就绪({ctx.spec.version}),无需重复执行。" yield from download_dataset_steps(logs, version, live_path) reset_preprocess_outputs(ctx) sample_count, audio_count, unlisted = build_manifest(ctx) ensure_preprocess_dirs(ctx) yield push( logs, f"训练清单已生成:metadata 可用样本 {sample_count} 条,音频总数 {audio_count} 个,未标注音频 {len(unlisted)} 个。", live_path, ) env = build_process_env(ctx) for line in run_cmd( python_script_command("GPT_SoVITS/prepare_datasets/1-get-text.py"), cwd=GPT_SOVITS_DIR, env=env, ): yield push(logs, line, live_path) part_text = ctx.exp_dir / "2-name2text-0.txt" if not part_text.exists(): raise RuntimeError("文本特征提取完成后未生成 2-name2text-0.txt") part_text.replace(ctx.text_path) yield push(logs, "✅ 文本分词与 BERT 特征提取完成。", live_path) for line in run_cmd( python_script_command("GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py"), cwd=GPT_SOVITS_DIR, env=env, ): yield push(logs, line, live_path) yield push(logs, "✅ CN-HuBERT 特征与 32k wav 已生成。", live_path) if ctx.spec.uses_sv: for line in run_cmd( python_script_command("GPT_SoVITS/prepare_datasets/2-get-sv.py"), cwd=GPT_SOVITS_DIR, env=env, ): yield push(logs, line, live_path) yield push(logs, "✅ speaker embedding 特征已生成。", live_path) for line in run_cmd( python_script_command("GPT_SoVITS/prepare_datasets/3-get-semantic.py"), cwd=GPT_SOVITS_DIR, env=env, ): yield push(logs, line, live_path) part_semantic = ctx.exp_dir / "6-name2semantic-0.tsv" if not part_semantic.exists(): raise RuntimeError("语义 token 提取完成后未生成 6-name2semantic-0.tsv") semantic_rows = part_semantic.read_text(encoding="utf-8").strip() ctx.semantic_path.write_text( "item_name\tsemantic_audio\n" + semantic_rows + ("\n" if semantic_rows else ""), encoding="utf-8", ) part_semantic.unlink() yield push(logs, "✅ 语义 token 提取完成。", live_path) return f"✅ 预处理完成({ctx.spec.version}),可用于训练的样本 {sample_count} 条。" finally: fcntl.flock(lock_handle.fileno(), fcntl.LOCK_UN) def check_environment(version=DEFAULT_VERSION): logs = [] ctx = get_version_context(version) clear_live_log(ctx.prep_live_log) clear_status_file(ctx.prep_status_path) try: final = None final = yield from setup_environment_steps(logs, version, ctx.prep_live_log) if final: yield final except Exception as exc: log.exception("check_environment") yield push(logs, f"❌ 环境准备失败: {exc}", ctx.prep_live_log) def download_dataset(version=DEFAULT_VERSION): logs = [] ctx = get_version_context(version) clear_live_log(ctx.prep_live_log) clear_status_file(ctx.prep_status_path) try: final = None final = yield from download_dataset_steps(logs, version, ctx.prep_live_log) if final: yield final except Exception as exc: log.exception("download_dataset") yield push(logs, f"❌ 数据集下载失败: {exc}", ctx.prep_live_log) def prepare_data(version=DEFAULT_VERSION): logs = [] ctx = get_version_context(version) clear_live_log(ctx.prep_live_log) clear_status_file(ctx.prep_status_path) try: final = yield from prepare_data_steps(logs, version, ctx.prep_live_log) yield push(logs, final, ctx.prep_live_log) except Exception as exc: log.exception("prepare_data") yield push(logs, f"❌ 预处理失败: {exc}", ctx.prep_live_log) def start_training(version=DEFAULT_VERSION, epochs=2, batch_size=1, save_every_epoch=1, lr=0.0001): logs = [] ctx = get_version_context(version) clear_live_log(ctx.sovits_live_log) clear_status_file(ctx.sovits_status_path) try: yield push(logs, f"当前版本:{ctx.spec.version}({ctx.spec.stability})", ctx.sovits_live_log), None for update in setup_environment_steps(logs, version, ctx.sovits_live_log): yield update, None if not dataset_prepared(ctx): yield push(logs, f"{ctx.spec.version} 缺少预处理产物,开始自动补齐...", ctx.sovits_live_log), None for update in prepare_data_steps(logs, version, ctx.sovits_live_log): yield update, None config_path = create_sovits_config(ctx, epochs, batch_size, save_every_epoch, lr) env = build_process_env(ctx) yield push(logs, f"开始 SoVITS 训练({ctx.spec.version})...", ctx.sovits_live_log), None for line in run_cmd( python_script_command( "GPT_SoVITS/s2_train.py", "--config", str(config_path), file_system_sharing=True, ), cwd=GPT_SOVITS_DIR, env=env, status_path=ctx.sovits_status_path, live_path=ctx.sovits_live_log, status_label=f"{ctx.spec.version} SoVITS", ): yield push(logs, line, ctx.sovits_live_log), None latest = latest_file(ctx.sovits_output_dir, ".pth") if not latest: raise RuntimeError("训练结束后没有找到导出的 SoVITS 权重文件") yield push(logs, f"✅ SoVITS 训练完成,最新权重:{latest}", ctx.sovits_live_log), latest except Exception as exc: log.exception("start_training") yield push(logs, f"❌ SoVITS 训练失败: {exc}", ctx.sovits_live_log), None def start_gpt_training(version=DEFAULT_VERSION, epochs=1, batch_size=1, save_every_epoch=1): logs = [] ctx = get_version_context(version) clear_live_log(ctx.gpt_live_log) clear_status_file(ctx.gpt_status_path) try: yield push(logs, f"当前版本:{ctx.spec.version}({ctx.spec.stability})", ctx.gpt_live_log), None for update in setup_environment_steps(logs, version, ctx.gpt_live_log): yield update, None if not dataset_prepared(ctx): yield push(logs, f"{ctx.spec.version} 缺少预处理产物,开始自动补齐...", ctx.gpt_live_log), None for update in prepare_data_steps(logs, version, ctx.gpt_live_log): yield update, None reset_gpt_training_outputs(ctx) yield push(logs, "已清理旧 GPT 断点与导出文件,避免恢复到不兼容 checkpoint。", ctx.gpt_live_log), None config_path = create_gpt_config(ctx, epochs, batch_size, save_every_epoch) env = build_process_env(ctx) yield push(logs, f"开始 GPT 训练({ctx.spec.version})...", ctx.gpt_live_log), None for line in run_cmd( python_script_command( "GPT_SoVITS/s1_train.py", "--config_file", str(config_path), file_system_sharing=True, ), cwd=GPT_SOVITS_DIR, env=env, status_path=ctx.gpt_status_path, live_path=ctx.gpt_live_log, status_label=f"{ctx.spec.version} GPT", ): yield push(logs, line, ctx.gpt_live_log), None latest = latest_file(ctx.gpt_output_dir, ".ckpt") if not latest: raise RuntimeError("训练结束后没有找到导出的 GPT 权重文件") yield push(logs, f"✅ GPT 训练完成,最新权重:{latest}", ctx.gpt_live_log), latest except Exception as exc: log.exception("start_gpt_training") yield push(logs, f"❌ GPT 训练失败: {exc}", ctx.gpt_live_log), None def refresh_outputs(version=DEFAULT_VERSION): return artifacts_summary(version) def live_logs(version=DEFAULT_VERSION): ctx = get_version_context(version) return ( combined_log_text(ctx.prep_live_log, ctx.prep_status_path), combined_log_text(ctx.sovits_live_log, ctx.sovits_status_path), combined_log_text( ctx.gpt_live_log, ctx.gpt_status_path, extra_paths=[("train.log", ctx.gpt_train_log)], ), ) def sync_live_state(version=DEFAULT_VERSION): return ( runtime_status_text(version), *live_logs(version), *refresh_outputs(version), agent_status(version), ) def version_markdown(version=DEFAULT_VERSION): spec = get_version_spec(version) return ( f"### 0. 版本设置\n" f"- 当前选择:`{spec.version}`\n" f"- 稳定性:{spec.stability}\n" f"- 默认建议:`{DEFAULT_VERSION}` 作为常规训练;`v2Pro` / `v2ProPlus` 作为可切换版本。" ) def file_record(path: Path): return { "name": path.name, "path": str(path), "size_bytes": path.stat().st_size, "size_human": format_size(path.stat().st_size), "modified_at": format_mtime(path), } def agent_status(version=DEFAULT_VERSION): ctx = get_version_context(version) sovits_files = list_files(ctx.sovits_output_dir, ".pth") gpt_files = list_files(ctx.gpt_output_dir, ".ckpt") return { "version": ctx.spec.version, "recommended_default": DEFAULT_VERSION, "stability": ctx.spec.stability, "dataset_prepared": dataset_prepared(ctx), "storage": { "mode": STORAGE_MODE, "root": str(ACTIVE_STORAGE_ROOT), "persistent_root": str(PERSISTENT_ROOT), "tmp_root": str(TMP_ROOT), }, "paths": { "work_dir": str(WORK_DIR), "tmp_root": str(TMP_ROOT), "dataset_dir": str(DATASET_DIR), "exp_dir": str(ctx.exp_dir), "sovits_output_dir": str(ctx.sovits_output_dir), "gpt_output_dir": str(ctx.gpt_output_dir), }, "counts": { "sovits": len(sovits_files), "gpt": len(gpt_files), }, "latest": { "sovits": file_record(sovits_files[0]) if sovits_files else None, "gpt": file_record(gpt_files[0]) if gpt_files else None, }, "files": { "sovits": [file_record(path) for path in sovits_files[:10]], "gpt": [file_record(path) for path in gpt_files[:10]], }, "tasks": { "prepare": refresh_task_status(read_status_file(ctx.prep_status_path)), "sovits": refresh_task_status(read_status_file(ctx.sovits_status_path)), "gpt": refresh_task_status(read_status_file(ctx.gpt_status_path)), }, } def task_brief(label: str, task): if not task: return f"{label}: 空闲" parts = [f"{label}: {task.get('state', 'unknown')}"] if task.get("pid") is not None: parts.append(f"pid={task['pid']}") if task.get("cpu_percent") is not None: parts.append(f"cpu={task['cpu_percent']}%") if task.get("rss_mb") is not None: parts.append(f"rss={task['rss_mb']} MB") if task.get("silent_for_seconds") is not None: parts.append(f"静默={task['silent_for_seconds']}s") if task.get("exit_code") is not None: parts.append(f"exit={task['exit_code']}") if task.get("last_output_line"): parts.append(f"最后输出={task['last_output_line'][:120]}") return " | ".join(parts) def runtime_status_text(version=DEFAULT_VERSION): status = agent_status(version) lines = [ f"版本: {status['version']}", f"数据预处理: {'已就绪' if status['dataset_prepared'] else '未就绪'}", task_brief("预处理", status["tasks"]["prepare"]), task_brief("SoVITS", status["tasks"]["sovits"]), task_brief("GPT", status["tasks"]["gpt"]), ] latest_gpt = status["latest"]["gpt"] latest_sovits = status["latest"]["sovits"] if latest_gpt: lines.append(f"最新 GPT: {latest_gpt['name']} | {latest_gpt['size_human']} | {latest_gpt['path']}") if latest_sovits: lines.append(f"最新 SoVITS: {latest_sovits['name']} | {latest_sovits['size_human']} | {latest_sovits['path']}") return "\n".join(lines) def load_dashboard(): version = DEFAULT_VERSION return (version_markdown(version), *sync_live_state(version)) def create_ui(): with gr.Blocks(title="GPT-SoVITS 训练器 — 达妮娅", theme=gr.themes.Soft()) as demo: gr.Markdown( "# 🎤 GPT-SoVITS 训练器 — 达妮娅语音\n" "这个 Space 按当前 GPT-SoVITS 训练链路执行。默认推荐 v2,也支持切到 v2Pro / v2ProPlus。" ) gr.Markdown( "打开页面会自动加载当前输出。训练完成后,最新模型会直接出现在下载框里," "下面的“输出与目录”也会自动刷新,避免找不到文件。" ) gr.Markdown("日志会同步写入 `/data`,页面每 2 秒自动轮询;即使刷新页面,也会把当前训练日志重新拉回来。") version_select = gr.Dropdown( choices=list(SUPPORTED_VERSIONS), value=DEFAULT_VERSION, label="模型版本", info="v2 最稳;v2Pro / v2ProPlus 按官方 main 接法切换。", ) version_note = gr.Markdown(value=version_markdown(DEFAULT_VERSION)) runtime_status = gr.Textbox(label="运行状态", lines=8, interactive=False, autoscroll=True) control_status = gr.Textbox(label="任务控制", lines=2, interactive=False) with gr.Row(): stop_sovits_btn = gr.Button("停止 SoVITS", variant="stop") stop_gpt_btn = gr.Button("停止 GPT", variant="stop") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. 环境") env_btn = gr.Button("准备环境", variant="secondary") env_out = gr.Textbox(label="环境状态", lines=8, interactive=False, autoscroll=True) gr.Markdown("### 2. 数据集") dataset_btn = gr.Button("下载数据集", variant="secondary") dataset_out = gr.Textbox(label="数据状态", lines=8, interactive=False, autoscroll=True) gr.Markdown("### 3. 预处理") prep_btn = gr.Button("生成训练特征", variant="primary") prep_out = gr.Textbox(label="预处理日志", lines=16, interactive=False, autoscroll=True) with gr.Column(scale=1): gr.Markdown("### 4. SoVITS 训练") sovits_epochs = gr.Slider(1, 50, value=2, step=1, label="训练轮数") sovits_batch = gr.Slider(1, 8, value=1, step=1, label="批次大小") sovits_save_every = gr.Slider(1, 5, value=1, step=1, label="每隔多少轮导出") sovits_lr = gr.Slider(1e-5, 5e-4, value=1e-4, step=1e-5, label="学习率") sovits_btn = gr.Button("开始 SoVITS 训练", variant="primary", size="lg") gr.Markdown("### 5. GPT 训练") gpt_epochs = gr.Slider(1, 50, value=1, step=1, label="训练轮数") gpt_batch = gr.Slider(1, 8, value=1, step=1, label="批次大小") gpt_save_every = gr.Slider(1, 5, value=1, step=1, label="每隔多少轮导出") gpt_btn = gr.Button("开始 GPT 训练", variant="secondary") gr.Markdown("### 6. SoVITS 实时日志与下载") sovits_log = gr.Textbox(label="SoVITS 训练日志", lines=22, interactive=False, autoscroll=True) sovits_file = gr.File(label="最新 SoVITS 权重", interactive=False) gr.Markdown("### 7. GPT 实时日志与下载") gpt_log = gr.Textbox(label="GPT 训练日志", lines=22, interactive=False, autoscroll=True) gpt_file = gr.File(label="最新 GPT 权重", interactive=False) gr.Markdown("### 8. 输出与目录") refresh_btn = gr.Button("刷新输出与目录", variant="secondary") refresh_text = gr.Textbox(label="模型列表与状态", lines=14, interactive=False, autoscroll=True) output_dirs = gr.Textbox(label="工作目录与输出目录", lines=6, interactive=False) gr.Markdown("### 9. Agent API") gr.Markdown( "`/gradio_api/info` 会公开 schema。主要 endpoint:" "`/check_environment(version)`、`/download_dataset(version)`、`/prepare_data(version)`、" "`/start_training(version, epochs, batch_size, save_every_epoch, lr)`、" "`/start_gpt_training(version, epochs, batch_size, save_every_epoch)`、" "`/refresh_outputs(version)`、`/agent_status(version)`、" "`/live_logs(version)`、`/sync_live_state(version)`。" ) agent_status_out = gr.JSON(label="Agent 状态", value=agent_status(DEFAULT_VERSION)) agent_status_btn = gr.Button("刷新 Agent 状态", variant="secondary") sync_timer = gr.Timer(value=2, active=True) refresh_api_btn = gr.Button(visible=False) live_logs_api_btn = gr.Button(visible=False) sync_state_api_btn = gr.Button(visible=False) gr.Markdown("### 10. 文件管理") gr.Markdown("支持安全清理当前版本的预处理、模型输出和实时日志,也可以删除选中的关键路径。") manager_status = gr.Textbox(label="文件管理操作结果", lines=3, interactive=False) manager_text = gr.Textbox(label="文件管理概览", lines=18, interactive=False, autoscroll=True) manager_select = gr.Dropdown(choices=[], value=None, label="可管理路径") manager_detail = gr.Textbox(label="所选路径详情", lines=5, interactive=False) with gr.Row(): manager_refresh_btn = gr.Button("刷新文件列表", variant="secondary") manager_delete_btn = gr.Button("删除所选路径", variant="stop") cleanup_prep_btn = gr.Button("清理本版本预处理", variant="secondary") cleanup_sovits_btn = gr.Button("只清 SoVITS 输出", variant="secondary") cleanup_models_btn = gr.Button("清理本版本模型输出", variant="secondary") cleanup_live_logs_btn = gr.Button("清理本版本实时日志", variant="secondary") refresh_outputs_targets = [refresh_text, output_dirs] sync_targets = [runtime_status, prep_out, sovits_log, gpt_log, *refresh_outputs_targets, agent_status_out] control_targets = [control_status, *sync_targets] manager_targets = [manager_text, manager_select, manager_detail] manager_action_targets = [manager_status, *manager_targets, *sync_targets] env_btn.click(check_environment, inputs=[version_select], outputs=env_out, api_name="check_environment") dataset_btn.click(download_dataset, inputs=[version_select], outputs=dataset_out, api_name="download_dataset") prep_btn.click(prepare_data, inputs=[version_select], outputs=prep_out, api_name="prepare_data") sovits_event = sovits_btn.click( start_training, inputs=[version_select, sovits_epochs, sovits_batch, sovits_save_every, sovits_lr], outputs=[sovits_log, sovits_file], api_name="start_training", ) gpt_event = gpt_btn.click( start_gpt_training, inputs=[version_select, gpt_epochs, gpt_batch, gpt_save_every], outputs=[gpt_log, gpt_file], api_name="start_gpt_training", ) refresh_btn.click(sync_live_state, inputs=[version_select], outputs=sync_targets, api_name=False) refresh_api_btn.click(refresh_outputs, inputs=[version_select], outputs=refresh_outputs_targets, api_name="refresh_outputs") live_logs_api_btn.click(live_logs, inputs=[version_select], outputs=[prep_out, sovits_log, gpt_log], api_name="live_logs") sync_state_api_btn.click(sync_live_state, inputs=[version_select], outputs=sync_targets, api_name="sync_live_state") agent_status_btn.click(agent_status, inputs=[version_select], outputs=agent_status_out, api_name="agent_status") stop_sovits_btn.click(stop_sovits_task, inputs=[version_select], outputs=control_targets, api_name="stop_sovits_task") stop_gpt_btn.click(stop_gpt_task, inputs=[version_select], outputs=control_targets, api_name="stop_gpt_task") manager_refresh_btn.click(file_manager_state, inputs=[version_select, manager_select], outputs=manager_targets, api_name="file_manager_state") manager_select.change(describe_managed_path, inputs=[version_select, manager_select], outputs=manager_detail, api_name=False) manager_delete_btn.click( delete_managed_path, inputs=[version_select, manager_select], outputs=manager_action_targets, api_name="delete_managed_path", ) cleanup_prep_btn.click( cleanup_preprocess, inputs=[version_select, manager_select], outputs=manager_action_targets, api_name="cleanup_preprocess", ) cleanup_sovits_btn.click( cleanup_sovits_outputs, inputs=[version_select], outputs=manager_action_targets, api_name="cleanup_sovits_outputs", ) cleanup_models_btn.click( cleanup_model_outputs, inputs=[version_select, manager_select], outputs=manager_action_targets, api_name="cleanup_model_outputs", ) cleanup_live_logs_btn.click( cleanup_live_logs, inputs=[version_select, manager_select], outputs=manager_action_targets, api_name="cleanup_live_logs", ) version_select.change(version_markdown, inputs=[version_select], outputs=version_note, api_name=False) version_select.change( sync_live_state, inputs=[version_select], outputs=sync_targets, api_name=False, ) version_select.change( file_manager_state, inputs=[version_select, manager_select], outputs=manager_targets, api_name=False, ) sovits_event.then( sync_live_state, inputs=[version_select], outputs=sync_targets, api_name=False, ) gpt_event.then( sync_live_state, inputs=[version_select], outputs=sync_targets, api_name=False, ) sync_timer.tick( sync_live_state, inputs=[version_select], outputs=sync_targets, api_name=False, queue=False, show_progress="hidden", ) demo.load( load_dashboard, outputs=[version_note, *sync_targets], api_name=False, ) return demo if __name__ == "__main__": ensure_dirs() demo = create_ui() demo.launch(server_name="0.0.0.0", server_port=7860)