import os import yaml import json import argparse import logging from datetime import datetime from pathlib import Path import torch from torch.utils.data import DataLoader from train.train import train_model from utils.logging_utils import init_logging from dataset.preprocess_multi_processing import run_parallel from train.train import evaluate_on_val from train.test import run_test_inference from dataset.lidar_dataset import LidarFusionDataset from model.model import SimpleMLP def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=Path, default="config.yaml", help="Path to YAML configuration.") parser.add_argument("--mode", choices=["preprocess", "train", "val", "test"], default="train") parser.add_argument("--weights_path", default="best_model.pth", help="Path to model weights.") args = parser.parse_args() cfg = yaml.safe_load(args.config.read_text()) now = datetime.now() out_dir = Path(cfg["logging"]["save_dir"] + now.strftime("_%m_%d_%H_%M_%S")) os.makedirs(out_dir, exist_ok=True) init_logging(out_dir / "run.log") mode = args.mode if mode == "preprocess": split_file = os.path.join(cfg["dataset"]["root"], cfg["dataset"]["split_file"]) with open(split_file) as f: split_dict = json.load(f) all_zones = split_dict["train"] + split_dict["val"] + split_dict["test"] run_parallel(cfg, all_zones, num_workers=cfg["dataset"]["pre-processing_num_workers"]) elif mode == "train": train_model(cfg, out_dir) elif mode == "val": model = SimpleMLP( input_dim=cfg['dataset']['n_classes'] * 2, hidden_dims=cfg['model']['hidden_dims'], n_classes=cfg['dataset']['n_classes']).to(cfg['training']['device']) model.load_state_dict(torch.load(args.weights_path)) val_set = LidarFusionDataset(cfg, split="val", shuffle_zones=False) val_loader = DataLoader(val_set, batch_size=1, num_workers=0, shuffle=False) miou, mf1, ious = evaluate_on_val(model, val_loader, cfg) elif mode == "test": model = SimpleMLP( input_dim=cfg['dataset']['n_classes'] * 2, hidden_dims=cfg['model']['hidden_dims'], n_classes=cfg['dataset']['n_classes']).to(cfg['training']['device']) model.load_state_dict(torch.load(args.weights_path)) test_set = LidarFusionDataset(cfg, split="test", shuffle_zones=False) test_loader = DataLoader(test_set, batch_size=1, num_workers=0, shuffle=False) run_test_inference(model, test_loader, cfg, out_dir) if __name__ == "__main__": main()