#!/usr/bin/env python3 """ Buleyean RL -- 70B Training Space Trains Llama 3.3 70B with QLoRA using Buleyean rejection loss. Runs on HuggingFace A100 GPU Space. The void doesn't need much compute. It just needs enough capacity to express the nuance. """ import os import sys import json import time import math import random from pathlib import Path import torch from datasets import load_dataset from peft import LoraConfig, TaskType, get_peft_model from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Trainer, TrainerCallback, TrainerControl, TrainerState, TrainingArguments, ) from huggingface_hub import HfApi # ============================================================================ # Config # ============================================================================ BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-14B-Instruct") MODEL_SHORT = os.environ.get("MODEL_SHORT", "qwen2.5-14b") OUTPUT_REPO = os.environ.get("OUTPUT_REPO", f"forkjoin-ai/buleyean-{MODEL_SHORT}") MAX_SAMPLES = int(os.environ.get("MAX_SAMPLES", "5000")) EPOCHS = int(os.environ.get("EPOCHS", "1")) ALPHA = float(os.environ.get("ALPHA", "0.7")) LORA_RANK = int(os.environ.get("LORA_RANK", "16")) MAX_SEQ_LEN = int(os.environ.get("MAX_SEQ_LEN", "512")) LR = float(os.environ.get("LR", "1e-4")) BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8")) DATA_REPO = "forkjoin-ai/buleyean-rl-data" # ============================================================================ # Buleyean Loss (inline -- no external deps needed) # ============================================================================ import torch.nn as nn import torch.nn.functional as F class SparseBuleyeanLoss(nn.Module): """Sparse-native Buleyean RL loss. Never materializes dense vocab tensor.""" def __init__(self, vocab_size, alpha=0.7, temperature=1.0): super().__init__() self.vocab_size = vocab_size self.alpha = alpha self.temperature = temperature self._log_vocab = math.log(vocab_size) def forward(self, logits, rejected_token_ids, rejected_token_counts, num_rejected_tokens, total_rounds, mask=None): batch, seq_len, max_rej = rejected_token_ids.shape device = logits.device log_probs = F.log_softmax(logits, dim=-1) arange = torch.arange(max_rej, device=device) valid = arange.unsqueeze(0).unsqueeze(0) < num_rejected_tokens.unsqueeze(-1) safe_ids = rejected_token_ids.clamp(0, self.vocab_size - 1) rej_log_probs = log_probs.gather(2, safe_ids) T = total_rounds.unsqueeze(-1) rej_weights = T - rejected_token_counts.clamp( max=T.squeeze(-1).unsqueeze(-1).expand_as(rejected_token_counts) ) + 1 non_rej_weight = T.squeeze(-1) + 1 K = num_rejected_tokens.float() rej_weight_sum = (rej_weights * valid.float()).sum(dim=-1) Z = rej_weight_sum + (self.vocab_size - K) * non_rej_weight rej_probs = rej_weights / Z.unsqueeze(-1).clamp(min=1e-8) non_rej_prob = non_rej_weight / Z.clamp(min=1e-8) rej_kl = rej_probs * (torch.log(rej_probs.clamp(min=1e-8)) - rej_log_probs) rej_kl = (rej_kl * valid.float()).sum(dim=-1) rej_model_prob_sum = (torch.exp(rej_log_probs).clamp(max=1.0) * valid.float()).sum(dim=-1) non_rej_model_prob_total = (1.0 - rej_model_prob_sum).clamp(min=1e-8) non_rej_count = (self.vocab_size - K).clamp(min=1) avg_non_rej_log_prob = torch.log(non_rej_model_prob_total / non_rej_count) non_rej_kl = non_rej_count * non_rej_prob * ( torch.log(non_rej_prob.clamp(min=1e-8)) - avg_non_rej_log_prob ) total_kl = rej_kl + non_rej_kl rej_rate = rejected_token_counts / T.clamp(min=1) contrast = -(rej_rate * rej_log_probs * valid.float()).sum(dim=-1) loss = self.alpha * total_kl + (1 - self.alpha) * contrast optimality_gap = total_kl / max(self._log_vocab, 1e-8) if mask is not None: loss = loss * mask total_kl = total_kl * mask contrast = contrast * mask optimality_gap = optimality_gap * mask denom = mask.sum().clamp(min=1) else: denom = torch.tensor(batch * seq_len, dtype=torch.float, device=device) return { "loss": loss.sum() / denom, "buleyean": (total_kl.sum() / denom).detach(), "contrast": (contrast.sum() / denom).detach(), "optimality_gap": (optimality_gap.sum() / denom).detach(), } # ============================================================================ # Dataset # ============================================================================ class SparseRejectionDataset(torch.utils.data.Dataset): def __init__(self, records, tokenizer, max_seq_len=512, max_rej_per_pos=256): self.records = records self.tokenizer = tokenizer self.max_seq_len = max_seq_len self.max_rej = max_rej_per_pos self.vocab_size = tokenizer.vocab_size def __len__(self): return len(self.records) def __getitem__(self, idx): record = self.records[idx] prompt = record["prompt"] rejected = record.get("rejected_responses", []) counts = record.get("rejection_counts", []) prompt_enc = self.tokenizer(prompt, truncation=True, max_length=self.max_seq_len, return_tensors="pt") prompt_ids = prompt_enc["input_ids"].squeeze(0) prompt_len = prompt_ids.size(0) max_resp = self.max_seq_len - prompt_len if max_resp <= 0: pad_id = self.tokenizer.pad_token_id or 0 return { "input_ids": prompt_ids[:self.max_seq_len], "attention_mask": torch.ones(min(prompt_len, self.max_seq_len), dtype=torch.long), "total_rounds": torch.zeros(self.max_seq_len), "rejected_token_ids": torch.zeros(self.max_seq_len, self.max_rej, dtype=torch.long), "rejected_token_counts": torch.zeros(self.max_seq_len, self.max_rej), "num_rejected_tokens": torch.zeros(self.max_seq_len, dtype=torch.long), } pos_rej = {} pos_tot = {} for i, resp in enumerate(rejected[:50]): c = counts[i] if i < len(counts) else 1 enc = self.tokenizer(resp, truncation=True, max_length=max_resp, return_tensors="pt") ids = enc["input_ids"].squeeze(0) for p in range(ids.size(0)): ap = prompt_len + p if ap >= self.max_seq_len: break if ap not in pos_rej: pos_rej[ap] = {} pos_tot[ap] = 0 tid = ids[p].item() pos_rej[ap][tid] = pos_rej[ap].get(tid, 0) + c pos_tot[ap] += c rej_ids = torch.zeros(self.max_seq_len, self.max_rej, dtype=torch.long) rej_counts = torch.zeros(self.max_seq_len, self.max_rej) num_rej = torch.zeros(self.max_seq_len, dtype=torch.long) total_rounds = torch.zeros(self.max_seq_len) for pos, tc in pos_rej.items(): if pos >= self.max_seq_len: continue total_rounds[pos] = pos_tot.get(pos, 0) sorted_e = sorted(tc.items(), key=lambda x: x[1], reverse=True) n = min(len(sorted_e), self.max_rej) num_rej[pos] = n for j in range(n): rej_ids[pos, j] = sorted_e[j][0] rej_counts[pos, j] = sorted_e[j][1] pad_id = self.tokenizer.pad_token_id or 0 input_ids = torch.full((self.max_seq_len,), pad_id, dtype=torch.long) input_ids[:prompt_len] = prompt_ids attn = torch.zeros(self.max_seq_len, dtype=torch.long) attn[:prompt_len] = 1 for pos in pos_rej: if pos < self.max_seq_len: attn[pos] = 1 return { "input_ids": input_ids, "attention_mask": attn, "total_rounds": total_rounds, "rejected_token_ids": rej_ids, "rejected_token_counts": rej_counts, "num_rejected_tokens": num_rej, } # ============================================================================ # Trainer # ============================================================================ class ProgressCallback(TrainerCallback): def __init__(self): self._start = None def on_train_begin(self, args, state, control, **kwargs): self._start = time.time() print(f"[buleyean-rl] Training started | {state.max_steps} steps", flush=True) def on_log(self, args, state, control, logs=None, **kwargs): if not logs or not self._start: return elapsed = time.time() - self._start step = state.global_step total = state.max_steps pct = step / total * 100 if total else 0 parts = [f"[buleyean-rl] step {step}/{total} ({pct:.1f}%)"] parts.append(f"elapsed={elapsed:.0f}s") for k in ["loss", "buleyean_kl", "contrast_loss", "optimality_gap"]: if k in logs: parts.append(f"{k}={logs[k]:.4f}") if step > 0: eta = (total - step) / (step / elapsed) parts.append(f"eta={eta/3600:.1f}h" if eta > 3600 else f"eta={eta/60:.0f}m") print(" | ".join(parts), flush=True) def on_train_end(self, args, state, control, **kwargs): elapsed = time.time() - self._start if self._start else 0 print(f"[buleyean-rl] Training complete | {state.global_step} steps | {elapsed:.0f}s", flush=True) class BuleyeanTrainer(Trainer): _KEYS = {"total_rounds", "rejected_token_ids", "rejected_token_counts", "num_rejected_tokens"} def __init__(self, *args, **kwargs): self._loss_fn = None self._metrics = {"buleyean_kl": [], "contrast_loss": [], "optimality_gap": []} super().__init__(*args, **kwargs) def _prepare_inputs(self, inputs): stashed = {k: inputs.pop(k) for k in list(inputs.keys()) if k in self._KEYS} inputs = super()._prepare_inputs(inputs) device = inputs["input_ids"].device for k, v in stashed.items(): inputs[k] = v.to(device) return inputs def compute_loss(self, model, inputs, return_outputs=False, **kwargs): if self._loss_fn is None: self._loss_fn = SparseBuleyeanLoss( vocab_size=model.config.vocab_size, alpha=ALPHA ).to(inputs["input_ids"].device) total_rounds = inputs["total_rounds"] outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask")) mask = (total_rounds > 0).float() if inputs.get("attention_mask") is not None: mask = mask * inputs["attention_mask"].float() loss_dict = self._loss_fn( logits=outputs.logits, rejected_token_ids=inputs["rejected_token_ids"], rejected_token_counts=inputs["rejected_token_counts"], num_rejected_tokens=inputs["num_rejected_tokens"], total_rounds=total_rounds, mask=mask, ) self._metrics["buleyean_kl"].append(loss_dict["buleyean"].item()) self._metrics["contrast_loss"].append(loss_dict["contrast"].item()) self._metrics["optimality_gap"].append(loss_dict["optimality_gap"].item()) if return_outputs: return loss_dict["loss"], outputs return loss_dict["loss"] def log(self, logs, *args, **kwargs): for k, vals in self._metrics.items(): if vals: logs[k] = sum(vals) / len(vals) self._metrics = {"buleyean_kl": [], "contrast_loss": [], "optimality_gap": []} super().log(logs, *args, **kwargs) # ============================================================================ # Main # ============================================================================ def main(): print("=" * 60) print(" Buleyean RL -- 70B Training") print(f" Model: {BASE_MODEL}") print(f" Samples: {MAX_SAMPLES} | Epochs: {EPOCHS} | Alpha: {ALPHA}") print(f" Output: {OUTPUT_REPO}") print("=" * 60) sys.stdout.flush() # Load data from HF dataset print("Loading rejection data from HuggingFace...") from huggingface_hub import hf_hub_download data_path = hf_hub_download( repo_id=DATA_REPO, filename="rejections.jsonl", repo_type="dataset", cache_dir="/tmp/hf_cache", ) records = [] with open(data_path) as f: for line in f: if line.strip(): records.append(json.loads(line)) if MAX_SAMPLES > 0 and len(records) > MAX_SAMPLES: random.seed(42) records = random.sample(records, MAX_SAMPLES) print(f"Loaded {len(records)} rejection records") # Tokenizer tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Dataset dataset = SparseRejectionDataset(records, tokenizer, MAX_SEQ_LEN) train_size = int(0.9 * len(dataset)) train_ds, eval_ds = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size]) print(f"Train: {train_size}, Eval: {len(dataset) - train_size}") # Model (QLoRA 4-bit) print(f"Loading {BASE_MODEL} with QLoRA...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, ) model.config.use_cache = False lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=LORA_RANK, lora_alpha=LORA_RANK * 2, lora_dropout=0.05, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], bias="none", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Training output_dir = "/tmp/buleyean-70b-output" training_args = TrainingArguments( output_dir=output_dir, remove_unused_columns=False, num_train_epochs=EPOCHS, per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LR, weight_decay=0.01, warmup_steps=50, lr_scheduler_type="cosine", logging_steps=5, logging_first_step=True, eval_strategy="no", save_strategy="steps", save_steps=200, save_total_limit=2, bf16=True, report_to="none", dataloader_num_workers=2, dataloader_pin_memory=True, ) def collator(batch): return {k: torch.stack([s[k] for s in batch]) for k in batch[0]} trainer = BuleyeanTrainer( model=model, args=training_args, train_dataset=train_ds, processing_class=tokenizer, data_collator=collator, callbacks=[ProgressCallback()], ) print("\nStarting Buleyean RL training on 70B...") print(f" Alpha: {ALPHA} | LR: {LR} | LoRA rank: {LORA_RANK}") sys.stdout.flush() trainer.train() # Save print("Saving LoRA adapter...") lora_dir = f"{output_dir}/lora" model.save_pretrained(lora_dir) tokenizer.save_pretrained(lora_dir) # Upload to HuggingFace print(f"Uploading to {OUTPUT_REPO}...") api = HfApi() api.create_repo(OUTPUT_REPO, repo_type="model", exist_ok=True) api.upload_folder( folder_path=lora_dir, repo_id=OUTPUT_REPO, commit_message=f"Buleyean RL LoRA adapter for {BASE_MODEL}", ) print(f"\nDone! https://huggingface.co/{OUTPUT_REPO}") if __name__ == "__main__": main()