import os import sys sys.path.insert(0, r'D:\AI_Training\spock_lora') try: import spock_compat_shim # noqa: F401 # Spock fork compat shim except Exception as _e: print(f"[spock_compat_shim] warn: {_e}") try: import torchao_compat # noqa: F401 # backport UIntXWeightOnlyConfig for torchao 0.17+ except Exception as _e: print(f"[torchao_compat] warn: {_e}") from dotenv import load_dotenv # Load the .env file if it exists load_dotenv() os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = os.getenv("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" seed = None if "SEED" in os.environ: try: seed = int(os.environ["SEED"]) except ValueError: print(f"Invalid SEED value: {os.environ['SEED']}. SEED must be an integer.") sys.path.insert(0, os.getcwd()) # must come before ANY torch or fastai imports # import toolkit.cuda_malloc # turn off diffusers telemetry until I can figure out how to make it opt-in os.environ['DISABLE_TELEMETRY'] = 'YES' # set torch to trace mode import torch # Spock fork: pin CPU thread count to match 16-core Ryzen 9 9950X. # Without this, PyTorch defaults to the OS's view which can underutilize. torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "16"))) torch.set_num_interop_threads(int(os.environ.get("TORCH_NUM_INTEROP_THREADS", "16"))) print(f"[spock_fork] torch threads: {torch.get_num_threads()} intraop, {torch.get_num_interop_threads()} interop") # check if we have DEBUG_TOOLKIT in env if os.environ.get("DEBUG_TOOLKIT", "0") == "1": torch.autograd.set_detect_anomaly(True) if seed is not None: import random import numpy as np random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) import argparse from toolkit.job import get_job from toolkit.accelerator import get_accelerator from toolkit.print import print_acc, setup_log_to_file accelerator = get_accelerator() def print_end_message(jobs_completed, jobs_failed): if not accelerator.is_main_process: return failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" print_acc("") print_acc("========================================") print_acc("Result:") if len(completed_string) > 0: print_acc(f" - {completed_string}") if len(failure_string) > 0: print_acc(f" - {failure_string}") print_acc("========================================") def main(): parser = argparse.ArgumentParser() # require at lease one config file parser.add_argument( 'config_file_list', nargs='+', type=str, help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' ) # flag to continue if failed job parser.add_argument( '-r', '--recover', action='store_true', help='Continue running additional jobs even if a job fails' ) # flag to continue if failed job parser.add_argument( '-n', '--name', type=str, default=None, help='Name to replace [name] tag in config file, useful for shared config file' ) parser.add_argument( '-l', '--log', type=str, default=None, help='Log file to write output to' ) args = parser.parse_args() if args.log is not None: setup_log_to_file(args.log) config_file_list = args.config_file_list if len(config_file_list) == 0: raise Exception("You must provide at least one config file") jobs_completed = 0 jobs_failed = 0 if accelerator.is_main_process: print_acc(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") for config_file in config_file_list: try: job = get_job(config_file, args.name) job.run() job.cleanup() jobs_completed += 1 except Exception as e: print_acc(f"Error running job: {e}") jobs_failed += 1 try: job.process[0].on_error(e) except Exception as e2: print_acc(f"Error running on_error: {e2}") if not args.recover: print_end_message(jobs_completed, jobs_failed) raise e except KeyboardInterrupt as e: try: job.process[0].on_error(e) except Exception as e2: print_acc(f"Error running on_error: {e2}") if not args.recover: print_end_message(jobs_completed, jobs_failed) raise e if __name__ == '__main__': main()