{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4", "machine_shape": "hm" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 🚀 Convertir Gemma 4 E2B Uncensored-MAX a LiteRT-LM\n", "\n", "Convierte el modelo a formato `.litertlm` para **Google AI Edge Gallery** en Android.\n", "\n", "**⚠️ IMPORTANTE:** Usa runtime con **GPU + RAM Alta**: Entorno de ejecución → Cambiar tipo → T4 + RAM Alta (hm)\n", "\n", "### Instrucciones:\n", "1. Ejecuta celda **1️⃣** → pon tu token\n", "2. Ejecuta celda **2️⃣** → instala dependencias. **El runtime se reiniciará, es normal.**\n", "3. Tras el reinicio, ejecuta **3️⃣**, **4️⃣** y **5️⃣** en orden\n", "\n", "**Tiempo:** ~30-45 min" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title 1️⃣ Configuración\n", "HF_TOKEN = \"\" #@param {type:\"string\"}\n", "OUTPUT_REPO = \"RedSparkie/gemma-4-E2B-it-Uncensored-MAX-litert-lm\" #@param {type:\"string\"}\n", "SOURCE_MODEL = \"prithivMLmods/gemma-4-E2B-it-Uncensored-MAX\" #@param {type:\"string\"}\n", "\n", "import json, os\n", "os.makedirs('/content/cfg', exist_ok=True)\n", "with open('/content/cfg/config.json', 'w') as f:\n", " json.dump({'HF_TOKEN': HF_TOKEN, 'OUTPUT_REPO': OUTPUT_REPO, 'SOURCE_MODEL': SOURCE_MODEL}, f)\n", "print('✅ Config guardada')\n", "assert HF_TOKEN, '❌ ¡Pon tu token de HuggingFace arriba!'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title 2️⃣ Instalar dependencias (reinicia el runtime)\n", "# Colab trae torch/torchao/transformers viejos que son incompatibles.\n", "# Necesitamos versiones exactas que funcionen juntas.\n", "!pip install -q --upgrade \\\n", " \"transformers>=5.7.0\" \\\n", " \"torchao>=0.17.0\" \\\n", " litert-torch \\\n", " litert-lm \\\n", " huggingface_hub \\\n", " sentencepiece \\\n", " protobuf \\\n", " safetensors \\\n", " psutil\n", "\n", "# Verificar\n", "import torch, torchao, transformers\n", "print(f'torch {torch.__version__} | torchao {torchao.__version__} | transformers {transformers.__version__}')\n", "\n", "# Test rápido: ¿funciona torchao.quantization.pt2e?\n", "try:\n", " import torchao.quantization.pt2e.quantize_pt2e\n", " print('✅ torchao.quantization.pt2e OK')\n", "except ImportError:\n", " print('⚠️ torchao.quantization.pt2e no disponible, forzando reinstalación...')\n", " import subprocess\n", " subprocess.check_call(['pip', 'install', '-q', '--force-reinstall', 'torchao>=0.17.0'])\n", "\n", "# Test: ¿Gemma4 disponible?\n", "try:\n", " from transformers import Gemma4Config\n", " print('✅ Gemma4 disponible')\n", "except ImportError:\n", " print('⚠️ Gemma4 no disponible, forzando reinstalación...')\n", " import subprocess\n", " subprocess.check_call(['pip', 'install', '-q', '--force-reinstall', 'transformers>=5.7.0'])\n", "\n", "# Reiniciar runtime para cargar todo limpio\n", "print('\\n🔄 Reiniciando runtime...')\n", "print(' Después del reinicio, ejecuta desde la celda 3️⃣')\n", "import IPython\n", "IPython.Application.instance().kernel.do_shutdown(True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title 3️⃣ Preparar modelo (extraer solo texto)\n", "# Recuperar config\n", "import json, os\n", "with open('/content/cfg/config.json') as f:\n", " cfg = json.load(f)\n", "HF_TOKEN = cfg['HF_TOKEN']\n", "OUTPUT_REPO = cfg['OUTPUT_REPO']\n", "SOURCE_MODEL = cfg['SOURCE_MODEL']\n", "\n", "# Verificar versiones\n", "import torch, torchao, transformers\n", "print(f'torch {torch.__version__} | torchao {torchao.__version__} | transformers {transformers.__version__}')\n", "from transformers import Gemma4Config\n", "import torchao.quantization.pt2e.quantize_pt2e\n", "print('✅ Todo OK')\n", "\n", "import sys, gc, shutil, time\n", "from huggingface_hub import hf_hub_download\n", "from safetensors import safe_open\n", "from safetensors.torch import save_file\n", "import psutil\n", "\n", "def memlog(l=''):\n", " m = psutil.virtual_memory()\n", " print(f' [{l}] RAM: {m.available/(1024**3):.1f}/{m.total/(1024**3):.1f} GB')\n", "\n", "MODEL_DIR = '/content/model'\n", "OUTPUT_DIR = '/content/output'\n", "os.makedirs(MODEL_DIR, exist_ok=True)\n", "os.makedirs(OUTPUT_DIR, exist_ok=True)\n", "start_time = time.time()\n", "memlog('inicio')\n", "\n", "print('📥 Descargando índice...')\n", "idx_path = hf_hub_download(SOURCE_MODEL, 'model.safetensors.index.json', token=HF_TOKEN)\n", "with open(idx_path) as f:\n", " index = json.load(f)\n", "\n", "shard_lm = {}\n", "for key, shard in index['weight_map'].items():\n", " if key.startswith('model.language_model.'):\n", " shard_lm.setdefault(shard, []).append(key)\n", "\n", "total_shards = len(shard_lm)\n", "print(f' {sum(len(v) for v in shard_lm.values())} tensores en {total_shards} shards')\n", "\n", "weight_map = {}\n", "for i, sn in enumerate(sorted(shard_lm)):\n", " keys = shard_lm[sn]\n", " out_name = f'model-{i+1:05d}-of-{total_shards:05d}.safetensors'\n", " out_path = os.path.join(MODEL_DIR, out_name)\n", " \n", " if os.path.exists(out_path) and os.path.getsize(out_path) > 100:\n", " print(f' {out_name} ya existe, skip')\n", " with safe_open(out_path, framework='pt') as f:\n", " for k in f.keys(): weight_map[k] = out_name\n", " continue\n", " \n", " print(f' 📦 {sn} → {out_name} ({len(keys)} tensores)')\n", " shard_path = hf_hub_download(SOURCE_MODEL, sn, token=HF_TOKEN)\n", " \n", " tensors = {}\n", " with safe_open(shard_path, framework='pt') as f:\n", " for key in keys:\n", " tensors[key] = f.get_tensor(key)\n", " \n", " save_file(tensors, out_path)\n", " for k in tensors: weight_map[k] = out_name\n", " print(f' 💾 {os.path.getsize(out_path)/(1024**2):.0f} MB')\n", " del tensors; gc.collect()\n", " memlog(f'shard {i+1}')\n", "\n", "with open(os.path.join(MODEL_DIR, 'model.safetensors.index.json'), 'w') as f:\n", " json.dump({'metadata': {}, 'weight_map': weight_map}, f)\n", "\n", "print('\\n📝 Config...')\n", "config = transformers.AutoConfig.from_pretrained(SOURCE_MODEL, token=HF_TOKEN)\n", "cd = config.to_dict()\n", "cd['vision_config'] = None\n", "cd['audio_config'] = None\n", "for k in ['vision_soft_tokens_per_image','image_token_id','boi_token_id',\n", " 'eoi_token_id','audio_token_id','boa_token_id','eoa_token_id',\n", " 'eoa_token_index','video_token_id']:\n", " cd.pop(k, None)\n", "with open(os.path.join(MODEL_DIR, 'config.json'), 'w') as f:\n", " json.dump(cd, f, indent=2)\n", "\n", "for fn in ['tokenizer.json','tokenizer_config.json','chat_template.jinja','generation_config.json']:\n", " try:\n", " shutil.copy(hf_hub_download(SOURCE_MODEL, fn, token=HF_TOKEN), os.path.join(MODEL_DIR, fn))\n", " print(f' ✓ {fn}')\n", " except: pass\n", "\n", "del config; gc.collect()\n", "cache_dir = os.path.expanduser('~/.cache/huggingface/hub')\n", "if os.path.exists(cache_dir):\n", " for d in os.listdir(cache_dir):\n", " if d.startswith('models--'):\n", " shutil.rmtree(os.path.join(cache_dir, d), ignore_errors=True)\n", "gc.collect()\n", "print(f'\\n✅ Modelo preparado')\n", "memlog('listo')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title 4️⃣ Convertir a .litertlm\n", "from litert_torch.generative.export_hf import export as export_lib\n", "\n", "print('🚀 Convirtiendo a LiteRT-LM...')\n", "print(' Esto tarda 15-30 min.')\n", "memlog('pre-export')\n", "conversion_start = time.time()\n", "\n", "export_lib.export(\n", " model=MODEL_DIR,\n", " output_dir=OUTPUT_DIR,\n", " task='text_generation',\n", " bundle_litert_lm=True,\n", " quantization_recipe='dynamic_wi8_afp32',\n", " cache_length=4096,\n", " prefill_lengths=[256],\n", " use_jinja_template=True,\n", " keep_temporary_files=True,\n", " trust_remote_code=False,\n", " experimental_lightweight_conversion=True,\n", " externalize_embedder=True,\n", ")\n", "\n", "print(f'\\n✅ Conversión en {(time.time()-conversion_start)/60:.1f} min')\n", "memlog('post-export')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title 5️⃣ Verificar y subir\n", "litertlm = os.path.join(OUTPUT_DIR, 'model.litertlm')\n", "\n", "if not os.path.exists(litertlm):\n", " print('❌ model.litertlm no encontrado. Archivos:')\n", " for r,d,fs in os.walk(OUTPUT_DIR):\n", " for f in fs:\n", " fp = os.path.join(r,f)\n", " print(f' {os.path.relpath(fp,OUTPUT_DIR)}: {os.path.getsize(fp)/(1024**2):.1f} MB')\n", "else:\n", " size_gb = os.path.getsize(litertlm) / (1024**3)\n", " print(f'📊 model.litertlm: {size_gb:.2f} GB')\n", " if size_gb <= 2.0: print('✅ ¡Cabe en 2 GB!')\n", " else: print(f'⚠️ {size_gb:.2f} GB — Cambia a dynamic_wi4_afp32 en celda 4')\n", " \n", " print(f'\\n📤 Subiendo a {OUTPUT_REPO}...')\n", " from huggingface_hub import HfApi\n", " api = HfApi(token=HF_TOKEN)\n", " try: api.create_repo(OUTPUT_REPO, exist_ok=True)\n", " except: pass\n", " \n", " api.upload_file(path_or_fileobj=litertlm,\n", " path_in_repo='gemma-4-E2B-it-Uncensored-MAX.litertlm',\n", " repo_id=OUTPUT_REPO, commit_message='Add LiteRT-LM model')\n", " \n", " readme = f\"\"\"---\\nlicense: apache-2.0\\nbase_model:\\n- {SOURCE_MODEL}\\ntags:\\n - litert-lm\\n - uncensored\\n - edge-gallery\\nlanguage:\\n- en\\n---\\n\\n# gemma-4-E2B-it-Uncensored-MAX (LiteRT-LM)\\n\\nLiteRT-LM conversion for **Google AI Edge Gallery**.\\n\\n| | |\\n|---|---|\\n| **Base** | [{SOURCE_MODEL}](https://huggingface.co/{SOURCE_MODEL}) |\\n| **Format** | `.litertlm` |\\n| **Quant** | INT8 |\\n| **Context** | 4096 |\\n| **Size** | {size_gb:.2f} GB |\\n\\n## Usage\\n1. Install [Edge Gallery](https://play.google.com/store/apps/details?id=com.google.ai.edge.gallery)\\n2. Add model via HF URL\\n3. Chat!\\n\\n⚠️ Uncensored. Use responsibly.\\n\"\"\"\n", " api.upload_file(path_or_fileobj=readme.encode(), path_in_repo='README.md',\n", " repo_id=OUTPUT_REPO, commit_message='README')\n", " \n", " print(f'\\n🎉 ¡LISTO!')\n", " print(f'📱 https://huggingface.co/{OUTPUT_REPO}')\n", " print(f'📊 {size_gb:.2f} GB')\n", " print(f'⏱️ {(time.time()-start_time)/60:.0f} min total')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🔧 Troubleshooting\n", "\n", "| Error | Solución |\n", "|---|---|\n", "| `KeyError: 'gemma4'` | `transformers` viejo. Re-ejecuta celda 2️⃣ y reinicia runtime |\n", "| `No module 'torchao.quantization.pt2e'` | `torchao` viejo. Re-ejecuta celda 2️⃣ y reinicia runtime |\n", "| OOM / Se queda sin memoria | Usa runtime **RAM Alta** (hm) |\n", "| Modelo > 2 GB | Cambia `dynamic_wi8_afp32` → `dynamic_wi4_afp32` en celda 4️⃣ |\n", "| `External embedder required` | Ya solucionado con `externalize_embedder=True` |" ] } ] }