import os import math import copy import json import torch import torch.nn.functional as F from flask import Flask, request, jsonify, Response from transformers import AutoTokenizer, AutoModelForMaskedLM app = Flask(__name__) model = None tokenizer = None device = None def add_gumbel_noise(logits, temperature): if temperature == 0: return logits logits = logits.to(torch.float64) noise = torch.rand_like(logits, dtype=torch.float64) g = (-torch.log(noise)) ** temperature return logits.exp() / g def get_num_transfer_tokens(mask_index, steps): mask_num = mask_index.sum(dim=1, keepdim=True) base = mask_num // steps rem = mask_num % steps out = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base for i in range(mask_num.size(0)): out[i, : rem[i]] += 1 return out def build_staircase_attention_mask(x, block_size, pad_id): B, T = x.shape device = x.device valid = x != pad_id pos_raw = torch.cumsum(valid.long(), dim=-1) position_ids = torch.where(valid, pos_raw - 1, torch.zeros_like(pos_raw)).long() col = torch.arange(T, device=device) block_ids = (col // block_size).view(1, T).expand(B, T) block_ids = torch.where(valid, block_ids, torch.full_like(block_ids, -1)) q = block_ids.view(B, 1, T, 1) k = block_ids.view(B, 1, 1, T) attn = (k <= q) & (q >= 0) & (k >= 0) return attn, position_ids def diffusion_step_block(logits, x_block, mask_block, num_transfer, temperature, remasking): B, L, _ = logits.shape if not mask_block.any(): return x_block noisy = add_gumbel_noise(logits, temperature) x0 = noisy.argmax(dim=-1) if remasking == "low_confidence": p = F.softmax(logits, dim=-1) conf = p.gather(-1, x0.unsqueeze(-1)).squeeze(-1) elif remasking == "random": conf = torch.rand((B, L), device=logits.device) else: raise ValueError(remasking) x0 = torch.where(mask_block, x0, x_block) neg_inf = torch.full_like(conf, -float("inf")) conf = torch.where(mask_block, conf, neg_inf) commit = torch.zeros_like(x_block, dtype=torch.bool) for i in range(B): k = int(num_transfer[i].item()) if k > 0: valid = (conf[i] > -float("inf")).sum().item() k = min(k, valid) _, idx = torch.topk(conf[i], k) commit[i, idx] = True out = x_block.clone() out[commit] = x0[commit] return out @torch.no_grad() def generate( model, tokenizer, prompt, steps=128, max_new_tokens=128, block_size=32, temperature=0.0, cfg_scale=0.0, remasking="low_confidence", capture_interval=0, ): device = model.device mask_id = tokenizer.mask_token_id pad_id = tokenizer.pad_token_id if pad_id is None: pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.mask_token_id if isinstance(prompt, torch.Tensor): x = prompt.to(device).long() else: if isinstance(prompt[0], (list, tuple)): max_len = max(len(p) for p in prompt) x = torch.full((len(prompt), max_len), pad_id, device=device, dtype=torch.long) for i, p in enumerate(prompt): x[i, : len(p)] = torch.tensor(p, device=device) else: x = torch.tensor(prompt, device=device).long() if x.dim() == 1: x = x.unsqueeze(0) B = x.size(0) finished = torch.zeros(B, dtype=torch.bool, device=device) num_blocks = math.ceil(max_new_tokens / block_size) steps_per_block = math.ceil(steps / num_blocks) generated = 0 intermediates = [] total_step = 0 while generated < max_new_tokens: if finished.all(): break T_prefix = x.size(1) offset = T_prefix % block_size room = block_size if offset == 0 else block_size - offset cur_len = min(room, max_new_tokens - generated) if cur_len <= 0: break attn_pfx, pos_pfx = build_staircase_attention_mask(x, block_size, pad_id) out = model(x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True) cond_past = out.past_key_values if cfg_scale > 0: un_x = x.clone() un_x[:] = mask_id out_un = model(un_x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True) uncond_past = out_un.past_key_values else: uncond_past = None block = torch.full((B, cur_len), mask_id, device=device, dtype=torch.long) block[finished] = pad_id x = torch.cat([x, block], dim=1) T_total = x.size(1) block_mask = x[:, -cur_len:] == mask_id num_transfer = get_num_transfer_tokens(block_mask, steps_per_block) eff_steps = num_transfer.size(1) full_attn, full_pos = build_staircase_attention_mask(x, block_size, pad_id) attn_blk = full_attn[:, :, T_prefix:T_total, :] pos_blk = full_pos[:, T_prefix:T_total] for t in range(eff_steps): x_blk = x[:, T_prefix:T_total] m_blk = x_blk == mask_id cond_logits = model( x_blk, attention_mask=attn_blk, position_ids=pos_blk, past_key_values=copy.deepcopy(cond_past), use_cache=False ).logits logits = cond_logits if cfg_scale > 0: un_logits = model( x_blk, attention_mask=attn_blk, position_ids=pos_blk, past_key_values=copy.deepcopy(uncond_past), use_cache=False ).logits logits = un_logits + (cfg_scale + 1.0) * (cond_logits - un_logits) x_blk_new = diffusion_step_block( logits, x_blk, m_blk, num_transfer[:, t], temperature, remasking ) x[:, T_prefix:T_total] = x_blk_new if capture_interval > 0 and total_step % capture_interval == 0: intermediates.append(x.clone()) total_step += 1 if tokenizer.eos_token_id is not None: finished |= (x_blk_new == tokenizer.eos_token_id).any(dim=1) if finished.all(): break generated += cur_len if finished.all(): break if capture_interval > 0: return x, intermediates return x @torch.no_grad() def generate_stream( model, tokenizer, prompt, steps=128, max_new_tokens=128, block_size=32, temperature=0.0, cfg_scale=0.0, remasking="low_confidence", capture_interval=10, ): device = model.device mask_id = tokenizer.mask_token_id pad_id = tokenizer.pad_token_id if pad_id is None: pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.mask_token_id if isinstance(prompt, torch.Tensor): x = prompt.to(device).long() else: if isinstance(prompt[0], (list, tuple)): max_len = max(len(p) for p in prompt) x = torch.full((len(prompt), max_len), pad_id, device=device, dtype=torch.long) for i, p in enumerate(prompt): x[i, : len(p)] = torch.tensor(p, device=device) else: x = torch.tensor(prompt, device=device).long() if x.dim() == 1: x = x.unsqueeze(0) B = x.size(0) finished = torch.zeros(B, dtype=torch.bool, device=device) num_blocks = math.ceil(max_new_tokens / block_size) steps_per_block = math.ceil(steps / num_blocks) generated = 0 total_step = 0 prompt_len = x.size(1) while generated < max_new_tokens: if finished.all(): break T_prefix = x.size(1) offset = T_prefix % block_size room = block_size if offset == 0 else block_size - offset cur_len = min(room, max_new_tokens - generated) if cur_len <= 0: break attn_pfx, pos_pfx = build_staircase_attention_mask(x, block_size, pad_id) out = model(x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True) cond_past = out.past_key_values if cfg_scale > 0: un_x = x.clone() un_x[:] = mask_id out_un = model(un_x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True) uncond_past = out_un.past_key_values else: uncond_past = None block = torch.full((B, cur_len), mask_id, device=device, dtype=torch.long) block[finished] = pad_id x = torch.cat([x, block], dim=1) T_total = x.size(1) block_mask = x[:, -cur_len:] == mask_id num_transfer = get_num_transfer_tokens(block_mask, steps_per_block) eff_steps = num_transfer.size(1) full_attn, full_pos = build_staircase_attention_mask(x, block_size, pad_id) attn_blk = full_attn[:, :, T_prefix:T_total, :] pos_blk = full_pos[:, T_prefix:T_total] for t in range(eff_steps): x_blk = x[:, T_prefix:T_total] m_blk = x_blk == mask_id cond_logits = model( x_blk, attention_mask=attn_blk, position_ids=pos_blk, past_key_values=copy.deepcopy(cond_past), use_cache=False ).logits logits = cond_logits if cfg_scale > 0: un_logits = model( x_blk, attention_mask=attn_blk, position_ids=pos_blk, past_key_values=copy.deepcopy(uncond_past), use_cache=False ).logits logits = un_logits + (cfg_scale + 1.0) * (cond_logits - un_logits) x_blk_new = diffusion_step_block( logits, x_blk, m_blk, num_transfer[:, t], temperature, remasking ) x[:, T_prefix:T_total] = x_blk_new if total_step % capture_interval == 0: new_tokens = x[0, prompt_len:prompt_len + max_new_tokens].tolist() text = tokenizer.decode(new_tokens, skip_special_tokens=True) yield { "type": "intermediate", "step": total_step, "text": text, "total_steps": steps } total_step += 1 if tokenizer.eos_token_id is not None: finished |= (x_blk_new == tokenizer.eos_token_id).any(dim=1) if finished.all(): break generated += cur_len if finished.all(): break new_tokens = x[0, prompt_len:prompt_len + max_new_tokens].tolist() final_text = tokenizer.decode(new_tokens, skip_special_tokens=True) yield { "type": "final", "text": final_text, "total_steps": total_step } def load_model(): global model, tokenizer, device device = "cuda" if torch.cuda.is_available() else "cpu" model_name = os.getenv("MODEL_NAME", "dllm-hub/Qwen3-0.6B-diffusion-bd3lm-v0.1") print(f"Loading model {model_name} on {device}...") model = AutoModelForMaskedLM.from_pretrained( model_name, dtype=torch.bfloat16, trust_remote_code=True ).to(device).eval() tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) print("Model loaded successfully!") @app.route('/health', methods=['GET']) def health(): return jsonify({"status": "healthy", "model_loaded": model is not None}) @app.route('/generate', methods=['POST']) def generate_text(): if model is None or tokenizer is None: return jsonify({"error": "Model not loaded"}), 503 data = request.get_json() if not data or 'prompt' not in data: return jsonify({"error": "Missing 'prompt' field"}), 400 prompt = data['prompt'] steps = data.get('steps', 256) max_new_tokens = data.get('max_new_tokens', 256) block_size = data.get('block_size', 32) temperature = data.get('temperature', 0.0) cfg_scale = data.get('cfg_scale', 0.0) remasking = data.get('remasking', 'low_confidence') system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.') messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt} ] encoded = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, enable_thinking=False ) input_ids = torch.tensor([encoded], dtype=torch.long, device=device) output = generate( model, tokenizer, input_ids, steps=steps, max_new_tokens=max_new_tokens, block_size=block_size, temperature=temperature, cfg_scale=cfg_scale, remasking=remasking, ) prompt_len = len(encoded) new_tokens = output[0, prompt_len:prompt_len + max_new_tokens].tolist() generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) return jsonify({ "prompt": prompt, "generated_text": generated_text, "parameters": { "steps": steps, "max_new_tokens": max_new_tokens, "block_size": block_size, "temperature": temperature, "cfg_scale": cfg_scale, "remasking": remasking } }) @app.route('/generate_stream', methods=['POST']) def generate_text_stream(): if model is None or tokenizer is None: return jsonify({"error": "Model not loaded"}), 503 data = request.get_json() if not data or 'prompt' not in data: return jsonify({"error": "Missing 'prompt' field"}), 400 prompt = data['prompt'] steps = data.get('steps', 256) max_new_tokens = data.get('max_new_tokens', 256) block_size = data.get('block_size', 32) temperature = data.get('temperature', 0.0) cfg_scale = data.get('cfg_scale', 0.0) remasking = data.get('remasking', 'low_confidence') system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.') capture_interval = data.get('capture_interval', 10) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt} ] encoded = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, enable_thinking=False ) input_ids = torch.tensor([encoded], dtype=torch.long, device=device) output, intermediates = generate( model, tokenizer, input_ids, steps=steps, max_new_tokens=max_new_tokens, block_size=block_size, temperature=temperature, cfg_scale=cfg_scale, remasking=remasking, capture_interval=capture_interval, ) prompt_len = len(encoded) intermediate_states = [] for i, intermediate in enumerate(intermediates): new_tokens = intermediate[0, prompt_len:prompt_len + max_new_tokens].tolist() text = tokenizer.decode(new_tokens, skip_special_tokens=True) intermediate_states.append({ "step": i * capture_interval, "text": text }) new_tokens = output[0, prompt_len:prompt_len + max_new_tokens].tolist() generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) return jsonify({ "prompt": prompt, "generated_text": generated_text, "intermediate_states": intermediate_states, "parameters": { "steps": steps, "max_new_tokens": max_new_tokens, "block_size": block_size, "temperature": temperature, "cfg_scale": cfg_scale, "remasking": remasking, "capture_interval": capture_interval } }) @app.route('/generate_sse', methods=['POST']) def generate_text_sse(): if model is None or tokenizer is None: return jsonify({"error": "Model not loaded"}), 503 data = request.get_json() if not data or 'prompt' not in data: return jsonify({"error": "Missing 'prompt' field"}), 400 prompt = data['prompt'] steps = data.get('steps', 256) max_new_tokens = data.get('max_new_tokens', 256) block_size = data.get('block_size', 32) temperature = data.get('temperature', 0.0) cfg_scale = data.get('cfg_scale', 0.0) remasking = data.get('remasking', 'low_confidence') system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.') capture_interval = data.get('capture_interval', 10) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt} ] encoded = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, enable_thinking=False ) input_ids = torch.tensor([encoded], dtype=torch.long, device=device) def stream(): for state in generate_stream( model, tokenizer, input_ids, steps=steps, max_new_tokens=max_new_tokens, block_size=block_size, temperature=temperature, cfg_scale=cfg_scale, remasking=remasking, capture_interval=capture_interval, ): yield f"data: {json.dumps(state)}\n\n" return Response( stream(), mimetype='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no', } ) if __name__ == '__main__': load_model() app.run(host='0.0.0.0', port=int(os.getenv('PORT', 5000)))