File size: 17,679 Bytes
a56eb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
"""
OmniGen-v1 LoRA Fine-Tuning for Thumbnail Generation

Model: Shitao/OmniGen-v1 (3.8B, Phi-3 based)
Method: LoRA (rank=8) fine-tuning via accelerate
Dataset: PosterCraft/Poster100K + synthetic thumbnail prompts
Output: Image generation model for thumbnails

Input modes supported:
  - Text only → Thumbnail image
  - Image only → Thumbnail image  
  - Text + Image → Thumbnail image

Based on OmniGen official fine-tuning recipe:
  https://github.com/VectorSpaceLab/OmniGen
"""

import os
import sys
import json
import math
import random
import logging
import argparse
from pathlib import Path
from typing import Optional, List, Dict, Any

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm

# OmniGen imports
from OmniGen import OmniGenPipeline
from OmniGen.model import OmniGen
from OmniGen.processor import OmniGenProcessor
from OmniGen.scheduler import OmniGenScheduler

from diffusers import AutoencoderKL
from transformers import get_cosine_schedule_with_warmup
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
from accelerate.utils import set_seed

import trackio

logger = logging.getLogger(__name__)


class ThumbnailDataset(Dataset):
    """Dataset for thumbnail generation training."""
    
    def __init__(
        self, 
        jsonl_path: str, 
        image_dir: str, 
        processor: OmniGenProcessor,
        max_image_size: int = 1024,
        max_input_length_limit: int = 18000,
        keep_raw_resolution: bool = True,
        condition_dropout_prob: float = 0.01,
    ):
        self.image_dir = image_dir
        self.processor = processor
        self.max_image_size = max_image_size
        self.max_input_length_limit = max_input_length_limit
        self.keep_raw_resolution = keep_raw_resolution
        self.condition_dropout_prob = condition_dropout_prob
        
        # Load JSONL entries
        self.entries = []
        with open(jsonl_path, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    self.entries.append(json.loads(line))
        
        logger.info(f"Loaded {len(self.entries)} training samples from {jsonl_path}")
    
    def __len__(self):
        return len(self.entries)
    
    def _load_image(self, filename: str) -> Optional[Image.Image]:
        """Load an image from the image directory."""
        path = os.path.join(self.image_dir, filename)
        if not os.path.exists(path):
            return None
        try:
            img = Image.open(path).convert("RGB")
            return img
        except Exception as e:
            logger.warning(f"Failed to load image {path}: {e}")
            return None
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        entry = self.entries[idx]
        instruction = entry["instruction"]
        output_image_name = entry["output_image"]
        input_image_names = entry.get("input_images", [])
        
        # Apply condition dropout for CFG training
        if random.random() < self.condition_dropout_prob:
            instruction = ""
        
        # Load output (target) image
        output_image = self._load_image(output_image_name)
        if output_image is None:
            # Return a random other sample if image missing
            return self.__getitem__(random.randint(0, len(self) - 1))
        
        # Load input images if any
        input_images = []
        for img_name in input_image_names:
            img = self._load_image(img_name)
            if img is not None:
                input_images.append(img)
        
        return {
            "instruction": instruction,
            "output_image": output_image,
            "input_images": input_images if input_images else None,
        }


def collate_fn(batch):
    """Custom collate that keeps PIL images."""
    instructions = [item["instruction"] for item in batch]
    output_images = [item["output_image"] for item in batch]
    input_images = [item["input_images"] for item in batch]
    return {
        "instructions": instructions,
        "output_images": output_images,
        "input_images": input_images,
    }


def parse_args():
    parser = argparse.ArgumentParser(description="OmniGen LoRA Fine-Tuning for Thumbnails")
    
    # Model
    parser.add_argument("--model_name_or_path", type=str, default="Shitao/OmniGen-v1")
    
    # Data
    parser.add_argument("--json_file", type=str, required=True, help="Path to JSONL training data")
    parser.add_argument("--image_path", type=str, required=True, help="Root dir for images")
    parser.add_argument("--max_image_size", type=int, default=1024)
    parser.add_argument("--max_input_length_limit", type=int, default=18000)
    parser.add_argument("--keep_raw_resolution", action="store_true")
    
    # Training
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--batch_size_per_device", type=int, default=1)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--warmup_steps", type=int, default=100)
    parser.add_argument("--condition_dropout_prob", type=float, default=0.01)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--bf16", action="store_true", default=True)
    
    # LoRA
    parser.add_argument("--use_lora", action="store_true", default=True)
    parser.add_argument("--lora_rank", type=int, default=8)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0.05)
    
    # Output
    parser.add_argument("--results_dir", type=str, default="./results/thumbnail_lora")
    parser.add_argument("--ckpt_every", type=int, default=500)
    parser.add_argument("--log_every", type=int, default=10)
    parser.add_argument("--push_to_hub", action="store_true", default=True)
    parser.add_argument("--hub_model_id", type=str, default="asats/thumbnail-vlm-omnigen-lora")
    
    return parser.parse_args()


def main():
    args = parse_args()
    
    # Initialize accelerator
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision="bf16" if args.bf16 else "no",
        log_with="all",
    )
    
    # Initialize trackio monitoring
    if accelerator.is_main_process:
        trackio.init(
            project="thumbnail-vlm",
            name="omnigen-lora-finetune",
        )
    
    set_seed(args.seed)
    
    logger.info(f"Loading OmniGen model from {args.model_name_or_path}...")
    
    # Load the OmniGen pipeline components
    pipe = OmniGenPipeline.from_pretrained(args.model_name_or_path)
    model = pipe.model
    processor = pipe.processor
    vae = pipe.vae
    
    # Freeze VAE
    vae.requires_grad_(False)
    vae.eval()
    
    if args.use_lora:
        logger.info(f"Applying LoRA (rank={args.lora_rank}, alpha={args.lora_alpha})...")
        # Apply LoRA to the transformer backbone
        lora_config = LoraConfig(
            r=args.lora_rank,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            target_modules=["qkv_proj", "o_proj", "gate_up_proj", "down_proj"],
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
    else:
        model.train()
    
    # Dataset
    logger.info(f"Loading dataset from {args.json_file}...")
    dataset = ThumbnailDataset(
        jsonl_path=args.json_file,
        image_dir=args.image_path,
        processor=processor,
        max_image_size=args.max_image_size,
        max_input_length_limit=args.max_input_length_limit,
        keep_raw_resolution=args.keep_raw_resolution,
        condition_dropout_prob=args.condition_dropout_prob,
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size_per_device,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True,
    )
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        weight_decay=args.weight_decay,
        betas=(0.9, 0.999),
    )
    
    # Scheduler
    num_training_steps = len(dataloader) * args.epochs // args.gradient_accumulation_steps
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=num_training_steps,
    )
    
    # Prepare with accelerator
    model, optimizer, dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, dataloader, lr_scheduler
    )
    
    # Move VAE to device
    vae = vae.to(accelerator.device, dtype=torch.bfloat16 if args.bf16 else torch.float32)
    
    os.makedirs(args.results_dir, exist_ok=True)
    
    logger.info("=" * 60)
    logger.info("Training Configuration:")
    logger.info(f"  Model: {args.model_name_or_path}")
    logger.info(f"  LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}")
    logger.info(f"  Dataset: {len(dataset)} samples")
    logger.info(f"  Epochs: {args.epochs}")
    logger.info(f"  Batch size: {args.batch_size_per_device}")
    logger.info(f"  Grad accum: {args.gradient_accumulation_steps}")
    logger.info(f"  Effective batch: {args.batch_size_per_device * args.gradient_accumulation_steps * accelerator.num_processes}")
    logger.info(f"  LR: {args.lr}")
    logger.info(f"  Total steps: {num_training_steps}")
    logger.info(f"  Hub model: {args.hub_model_id}")
    logger.info("=" * 60)
    
    # Training loop
    global_step = 0
    best_loss = float("inf")
    
    for epoch in range(args.epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        for step, batch in enumerate(dataloader):
            with accelerator.accumulate(model):
                instructions = batch["instructions"]
                output_images = batch["output_images"]
                input_images_list = batch["input_images"]
                
                # Process each sample in the batch
                total_loss = torch.tensor(0.0, device=accelerator.device)
                valid_samples = 0
                
                for i in range(len(instructions)):
                    try:
                        instruction = instructions[i]
                        output_img = output_images[i]
                        input_imgs = input_images_list[i]
                        
                        # Encode target image with VAE
                        from torchvision import transforms
                        transform = transforms.Compose([
                            transforms.Resize((args.max_image_size, args.max_image_size)),
                            transforms.ToTensor(),
                            transforms.Normalize([0.5], [0.5]),
                        ])
                        target_tensor = transform(output_img).unsqueeze(0).to(
                            accelerator.device, 
                            dtype=torch.bfloat16 if args.bf16 else torch.float32
                        )
                        
                        # Get VAE latents
                        with torch.no_grad():
                            latents = vae.encode(target_tensor).latent_dist.sample()
                            latents = latents * vae.config.scaling_factor
                        
                        # Add noise (flow matching)
                        noise = torch.randn_like(latents)
                        timesteps = torch.rand(1, device=accelerator.device)
                        noisy_latents = (1 - timesteps) * latents + timesteps * noise
                        
                        # Process input through model
                        # OmniGen processes text+images together through the Phi-3 backbone
                        input_data = processor(
                            instruction,
                            input_images=input_imgs,
                            height=args.max_image_size,
                            width=args.max_image_size,
                        )
                        
                        # Forward pass
                        model_output = model(
                            input_ids=input_data["input_ids"].to(accelerator.device),
                            input_img_latents=input_data.get("input_img_latents"),
                            input_image_sizes=input_data.get("input_image_sizes"),
                            attention_mask=input_data["attention_mask"].to(accelerator.device),
                            position_ids=input_data["position_ids"].to(accelerator.device),
                            x=noisy_latents,
                            t=timesteps,
                        )
                        
                        # Flow matching loss: MSE between predicted velocity and target
                        target = noise - latents  # velocity target for rectified flow
                        loss = F.mse_loss(model_output, target)
                        total_loss += loss
                        valid_samples += 1
                        
                    except Exception as e:
                        logger.warning(f"Error processing sample {i}: {e}")
                        continue
                
                if valid_samples > 0:
                    avg_loss = total_loss / valid_samples
                    accelerator.backward(avg_loss)
                    
                    if accelerator.sync_gradients:
                        accelerator.clip_grad_norm_(model.parameters(), 1.0)
                    
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    
                    epoch_loss += avg_loss.item()
                    num_batches += 1
                    global_step += 1
                    
                    # Logging
                    if global_step % args.log_every == 0 and accelerator.is_main_process:
                        avg_epoch_loss = epoch_loss / max(num_batches, 1)
                        current_lr = lr_scheduler.get_last_lr()[0]
                        print(f"step={global_step}, epoch={epoch+1}/{args.epochs}, "
                              f"loss={avg_loss.item():.4f}, avg_loss={avg_epoch_loss:.4f}, "
                              f"lr={current_lr:.2e}")
                        trackio.log({
                            "train/loss": avg_loss.item(),
                            "train/avg_loss": avg_epoch_loss,
                            "train/lr": current_lr,
                            "train/epoch": epoch + 1,
                            "train/step": global_step,
                        })
                    
                    # Checkpoint
                    if global_step % args.ckpt_every == 0 and accelerator.is_main_process:
                        ckpt_dir = os.path.join(args.results_dir, f"checkpoint-{global_step}")
                        os.makedirs(ckpt_dir, exist_ok=True)
                        unwrapped_model = accelerator.unwrap_model(model)
                        if args.use_lora:
                            unwrapped_model.save_pretrained(ckpt_dir)
                        else:
                            torch.save(unwrapped_model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
                        logger.info(f"Saved checkpoint to {ckpt_dir}")
        
        # End of epoch logging
        if accelerator.is_main_process:
            avg_epoch_loss = epoch_loss / max(num_batches, 1)
            print(f"\n{'='*60}")
            print(f"Epoch {epoch+1}/{args.epochs} complete. Avg loss: {avg_epoch_loss:.4f}")
            print(f"{'='*60}\n")
            
            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                best_dir = os.path.join(args.results_dir, "best")
                os.makedirs(best_dir, exist_ok=True)
                unwrapped_model = accelerator.unwrap_model(model)
                if args.use_lora:
                    unwrapped_model.save_pretrained(best_dir)
                else:
                    torch.save(unwrapped_model.state_dict(), os.path.join(best_dir, "model.pt"))
                logger.info(f"New best model saved (loss={best_loss:.4f})")
    
    # Final save and push to hub
    if accelerator.is_main_process:
        final_dir = os.path.join(args.results_dir, "final")
        os.makedirs(final_dir, exist_ok=True)
        unwrapped_model = accelerator.unwrap_model(model)
        
        if args.use_lora:
            unwrapped_model.save_pretrained(final_dir)
            if args.push_to_hub:
                logger.info(f"Pushing LoRA adapters to {args.hub_model_id}...")
                unwrapped_model.push_to_hub(args.hub_model_id, token=os.environ.get("HF_TOKEN"))
        else:
            torch.save(unwrapped_model.state_dict(), os.path.join(final_dir, "model.pt"))
        
        logger.info(f"Training complete! Final model saved to {final_dir}")
        logger.info(f"Best loss: {best_loss:.4f}")
        
        if args.push_to_hub:
            print(f"\nModel pushed to: https://huggingface.co/{args.hub_model_id}")
    
    trackio.finish()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
    main()