using UnityEngine; using Unity.InferenceEngine; using FF = Unity.InferenceEngine.Functional; using System.IO; using System.Collections.Generic; using System.Diagnostics; using Debug = UnityEngine.Debug; using System.Threading.Tasks; using System; public class ModelVLM : MonoBehaviour { [Header("Model Settings")] public BackendType BACKEND = BackendType.GPUCompute; private const int MAX_GENERATE_TOKENS = 256; [Header("Model Assets")] [SerializeField] private ModelAsset visionEncoderAsset; [SerializeField] private ModelAsset embedTokensAsset; [SerializeField] private ModelAsset decoderAsset; private Worker _visionEncoder; private Worker _embedTokens; private Worker _decoder; private Worker _greedyDecoder; private Worker _concatEmbeddings; private Qwen2Tokenizer _tokenizer; private int _concatEmbeddingDimension = -1; private const int MAX_LAYERS = 24; private const int NUM_KEY_VALUE_HEADS = 2; private const int HEAD_DIM = 64; private const int VOCAB_SIZE = 151646; private Tensor[] _pastKeys = new Tensor[MAX_LAYERS]; private Tensor[] _pastValues = new Tensor[MAX_LAYERS]; private List _outputTokens = new List(); public bool IsInitialized { get; private set; } public bool IsGenerating { get; private set; } public event Action OnTokenGenerated; public event Action OnGenerationComplete; public event Action OnGenerationError; public Task Initialize() { IsInitialized = false; try { if (visionEncoderAsset == null || embedTokensAsset == null || decoderAsset == null) { Debug.LogError("One or more ModelAssets are not assigned in Inspector!"); return Task.CompletedTask; } DisposeWorkers(); string tokenizerPath = Path.Combine(Application.streamingAssetsPath, "fastvlm"); string vocabPath = Path.Combine(tokenizerPath, "vocab.json"); string mergesPath = Path.Combine(tokenizerPath, "merges.txt"); string configPath = Path.Combine(tokenizerPath, "tokenizer_config.json"); if (!File.Exists(vocabPath) || !File.Exists(mergesPath) || !File.Exists(configPath)) { Debug.LogError($"Tokenizer files not found at: {tokenizerPath}"); return Task.CompletedTask; } _tokenizer = new Qwen2Tokenizer( File.ReadAllText(vocabPath), File.ReadAllText(mergesPath), File.ReadAllText(configPath) ); Model visionModel = ModelLoader.Load(visionEncoderAsset); _visionEncoder = new Worker(visionModel, BACKEND); Model embedModel = ModelLoader.Load(embedTokensAsset); _embedTokens = new Worker(embedModel, BACKEND); Model decoderModel = ModelLoader.Load(decoderAsset); _decoder = new Worker(decoderModel, BACKEND); FunctionalGraph graph = new FunctionalGraph(); FunctionalTensor logitsInput = graph.AddInput(new DynamicTensorShape(1, -1, VOCAB_SIZE)); FunctionalTensor argMax = FF.ArgMax(logitsInput, 2, false); Model greedyModel = graph.Compile(argMax); _greedyDecoder = new Worker(greedyModel, BACKEND); IsInitialized = true; } catch (Exception e) { Debug.LogError($"Failed to initialize models: {e.Message}\n{e.StackTrace}"); DisposeWorkers(); } return Task.CompletedTask; } public async Task GenerateFromPrompt(string prompt, Texture image = null, int maxTokens = MAX_GENERATE_TOKENS) { if (string.IsNullOrEmpty(prompt)) return; if (!IsInitialized) { OnGenerationError?.Invoke("Models not initialized."); return; } if (IsGenerating) { OnGenerationError?.Invoke("Generation already in progress."); return; } maxTokens = Math.Max(1, maxTokens); IsGenerating = true; _outputTokens.Clear(); ClearKVCache(); Stopwatch sw = Stopwatch.StartNew(); Tensor visionEmbeddings = null; Tensor mergedEmbeddings = null; try { mergedEmbeddings = BuildPromptEmbeddings(prompt, image, out visionEmbeddings); int mergedSeqLen = mergedEmbeddings.shape[1]; int maxKvSequenceLength = mergedSeqLen + maxTokens; int nextToken = DecoderPrefill(mergedEmbeddings, mergedSeqLen, maxKvSequenceLength); int currentPos = mergedSeqLen; int generatedCount = 0; do { _outputTokens.Add(nextToken); generatedCount++; string decodedText = _tokenizer.Decode(_outputTokens); OnTokenGenerated?.Invoke(decodedText); nextToken = DecoderDecode(nextToken, currentPos); UpdateKVCache(); currentPos++; await Task.Yield(); } while (nextToken != _tokenizer.EosTokenId && nextToken != _tokenizer.PadTokenId && _outputTokens.Count < maxTokens); sw.Stop(); string finalText = _tokenizer.Decode(_outputTokens); OnGenerationComplete?.Invoke(finalText, generatedCount, sw.ElapsedMilliseconds); } catch (Exception e) { Debug.LogError($"Generation error: {e.Message}\n{e.StackTrace}"); OnGenerationError?.Invoke(e.Message); } finally { visionEmbeddings?.Dispose(); mergedEmbeddings?.Dispose(); IsGenerating = false; } } private Tensor BuildPromptEmbeddings(string prompt, Texture image, out Tensor visionEmbeddings) { visionEmbeddings = image != null ? EncodeVision(image) : null; if (visionEmbeddings == null) { string fullPrompt = $"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"; var tokenIds = _tokenizer.Encode(fullPrompt); return EmbedTokens(tokenIds.ToArray()); } string promptPrefix = "<|im_start|>user\n"; string promptSuffix = $"\n{prompt}<|im_end|>\n<|im_start|>assistant\n"; var prefixIds = _tokenizer.Encode(promptPrefix); var suffixIds = _tokenizer.Encode(promptSuffix); using var prefixEmbeddings = EmbedTokens(prefixIds.ToArray()); using var suffixEmbeddings = EmbedTokens(suffixIds.ToArray()); return ConcatenateEmbeddings(prefixEmbeddings, visionEmbeddings, suffixEmbeddings); } private Tensor EncodeVision(Texture image) { int targetSize = 256; using var imageTensor = TextureConverter.ToTensor(image, width: targetSize, height: targetSize, channels: 3); _visionEncoder.SetInput(0, imageTensor); _visionEncoder.Schedule(); Tensor copiedOutput = null; _visionEncoder.CopyOutput(0, ref copiedOutput); return copiedOutput as Tensor; } private Tensor EmbedTokens(int[] tokenIds) { using var inputTensor = new Tensor(new TensorShape(1, tokenIds.Length), tokenIds); _embedTokens.SetInput(0, inputTensor); _embedTokens.Schedule(); Tensor copiedOutput = null; _embedTokens.CopyOutput(0, ref copiedOutput); return copiedOutput as Tensor; } private Tensor ConcatenateEmbeddings(Tensor t1, Tensor t2, Tensor t3) { int embeddingDimension = t1.shape[2]; if (t2.shape[2] != embeddingDimension || t3.shape[2] != embeddingDimension) throw new InvalidOperationException("Embedding dimensions must match for concatenation."); EnsureConcatWorker(embeddingDimension); _concatEmbeddings.SetInput(0, t1); _concatEmbeddings.SetInput(1, t2); _concatEmbeddings.SetInput(2, t3); _concatEmbeddings.Schedule(); Tensor copiedOutput = null; _concatEmbeddings.CopyOutput(0, ref copiedOutput); return copiedOutput as Tensor; } private void EnsureConcatWorker(int embeddingDimension) { if (_concatEmbeddings != null && _concatEmbeddingDimension == embeddingDimension) return; _concatEmbeddings?.Dispose(); _concatEmbeddings = null; var funcGraph = new FunctionalGraph(); var inputShape = new DynamicTensorShape(1, -1, embeddingDimension); var input1 = funcGraph.AddInput(inputShape); var input2 = funcGraph.AddInput(inputShape); var input3 = funcGraph.AddInput(inputShape); var concatenated = FF.Concat(new[] { input1, input2, input3 }, 1); var model = funcGraph.Compile(concatenated); _concatEmbeddings = new Worker(model, BACKEND); _concatEmbeddingDimension = embeddingDimension; } private int DecoderPrefill(Tensor embeddings, int sequenceLength, int maxKvSequenceLength) { _decoder.SetInput("inputs_embeds", embeddings); using var positionIds = new Tensor(new TensorShape(1, sequenceLength), BuildRangeArray(sequenceLength, 0)); _decoder.SetInput("position_ids", positionIds); using var attentionMask = new Tensor(new TensorShape(1, sequenceLength), BuildFilledArray(sequenceLength, 1)); _decoder.SetInput("attention_mask", attentionMask); SetEmptyKVCache(maxKvSequenceLength); _decoder.Schedule(); var logits = _decoder.PeekOutput("logits") as Tensor; var firstToken = ProcessLogits(logits, sequenceLength - 1); UpdateKVCache(); return firstToken; } private int DecoderDecode(int tokenId, int position) { using var embeddings = EmbedTokens(new[] { tokenId }); _decoder.SetInput("inputs_embeds", embeddings); using var positionIds = new Tensor(new TensorShape(1, 1), new[] { position }); _decoder.SetInput("position_ids", positionIds); using var attentionMask = new Tensor(new TensorShape(1, position + 1), BuildFilledArray(position + 1, 1)); _decoder.SetInput("attention_mask", attentionMask); _decoder.Schedule(); var logits = _decoder.PeekOutput("logits") as Tensor; return ProcessLogits(logits, 0); } private int ProcessLogits(Tensor logits, int index) { _greedyDecoder.SetInput(0, logits); _greedyDecoder.Schedule(); var argMaxTensor = _greedyDecoder.PeekOutput() as Tensor; using var resultTensor = argMaxTensor.ReadbackAndClone(); return resultTensor[index]; } private void SetEmptyKVCache(int maxKvSequenceLength) { var shape = new TensorShape(1, NUM_KEY_VALUE_HEADS, 0, HEAD_DIM); int maxTensorLength = Math.Max(0, NUM_KEY_VALUE_HEADS * maxKvSequenceLength * HEAD_DIM); for (int i = 0; i < MAX_LAYERS; i++) { _pastKeys[i]?.Dispose(); _pastValues[i]?.Dispose(); if (BACKEND == BackendType.GPUCompute && maxTensorLength > 0) { _pastKeys[i] = new Tensor(shape, new ComputeTensorData(maxTensorLength, clearOnInit: false)); _pastValues[i] = new Tensor(shape, new ComputeTensorData(maxTensorLength, clearOnInit: false)); } else { var preallocatedShape = new TensorShape(1, NUM_KEY_VALUE_HEADS, maxKvSequenceLength, HEAD_DIM); _pastKeys[i] = new Tensor(preallocatedShape, clearOnInit: false); _pastValues[i] = new Tensor(preallocatedShape, clearOnInit: false); _pastKeys[i].Reshape(shape); _pastValues[i].Reshape(shape); } _decoder.SetInput($"past_key_values.{i}.key", _pastKeys[i]); _decoder.SetInput($"past_key_values.{i}.value", _pastValues[i]); } } private void UpdateKVCache() { for (int i = 0; i < MAX_LAYERS; i++) { string keyName = $"present.{i}.key"; string valueName = $"present.{i}.value"; Tensor previousKey = _pastKeys[i]; Tensor previousValue = _pastValues[i]; Tensor copiedKey = previousKey; Tensor copiedValue = previousValue; _decoder.CopyOutput(keyName, ref copiedKey); _decoder.CopyOutput(valueName, ref copiedValue); _pastKeys[i] = copiedKey as Tensor; _pastValues[i] = copiedValue as Tensor; if (!ReferenceEquals(previousKey, _pastKeys[i])) previousKey?.Dispose(); if (!ReferenceEquals(previousValue, _pastValues[i])) previousValue?.Dispose(); _decoder.SetInput($"past_key_values.{i}.key", _pastKeys[i]); _decoder.SetInput($"past_key_values.{i}.value", _pastValues[i]); } } private void ClearKVCache() { for (int i = 0; i < MAX_LAYERS; i++) { _pastKeys[i]?.Dispose(); _pastValues[i]?.Dispose(); _pastKeys[i] = null; _pastValues[i] = null; } } private static int[] BuildRangeArray(int length, int start) { var values = new int[length]; for (int i = 0; i < length; i++) values[i] = start + i; return values; } private static int[] BuildFilledArray(int length, int value) { var values = new int[length]; if (value == 0) return values; for (int i = 0; i < length; i++) values[i] = value; return values; } private void DisposeWorkers() { _visionEncoder?.Dispose(); _visionEncoder = null; _embedTokens?.Dispose(); _embedTokens = null; _decoder?.Dispose(); _decoder = null; _greedyDecoder?.Dispose(); _greedyDecoder = null; _concatEmbeddings?.Dispose(); _concatEmbeddings = null; _concatEmbeddingDimension = -1; ClearKVCache(); } private void OnDestroy() { DisposeWorkers(); } }