| import io |
| import re |
| import struct |
| from enum import IntEnum |
| from math import floor |
|
|
| import requests |
|
|
| import gradio as gr |
|
|
|
|
| class GGUFValueType(IntEnum): |
| UINT8 = 0 |
| INT8 = 1 |
| UINT16 = 2 |
| INT16 = 3 |
| UINT32 = 4 |
| INT32 = 5 |
| FLOAT32 = 6 |
| BOOL = 7 |
| STRING = 8 |
| ARRAY = 9 |
| UINT64 = 10 |
| INT64 = 11 |
| FLOAT64 = 12 |
|
|
|
|
| _simple_value_packing = { |
| GGUFValueType.UINT8: "<B", |
| GGUFValueType.INT8: "<b", |
| GGUFValueType.UINT16: "<H", |
| GGUFValueType.INT16: "<h", |
| GGUFValueType.UINT32: "<I", |
| GGUFValueType.INT32: "<i", |
| GGUFValueType.FLOAT32: "<f", |
| GGUFValueType.UINT64: "<Q", |
| GGUFValueType.INT64: "<q", |
| GGUFValueType.FLOAT64: "<d", |
| GGUFValueType.BOOL: "?", |
| } |
|
|
| value_type_info = { |
| GGUFValueType.UINT8: 1, |
| GGUFValueType.INT8: 1, |
| GGUFValueType.UINT16: 2, |
| GGUFValueType.INT16: 2, |
| GGUFValueType.UINT32: 4, |
| GGUFValueType.INT32: 4, |
| GGUFValueType.FLOAT32: 4, |
| GGUFValueType.UINT64: 8, |
| GGUFValueType.INT64: 8, |
| GGUFValueType.FLOAT64: 8, |
| GGUFValueType.BOOL: 1, |
| } |
|
|
|
|
| def get_single(value_type, file): |
| if value_type == GGUFValueType.STRING: |
| value_length = struct.unpack("<Q", file.read(8))[0] |
| value = file.read(value_length) |
| try: |
| value = value.decode('utf-8') |
| except: |
| pass |
| else: |
| type_str = _simple_value_packing.get(value_type) |
| bytes_length = value_type_info.get(value_type) |
| value = struct.unpack(type_str, file.read(bytes_length))[0] |
|
|
| return value |
|
|
|
|
| def load_metadata_from_file(file_obj): |
| """Load metadata from a file-like object""" |
| metadata = {} |
|
|
| GGUF_MAGIC = struct.unpack("<I", file_obj.read(4))[0] |
| GGUF_VERSION = struct.unpack("<I", file_obj.read(4))[0] |
| ti_data_count = struct.unpack("<Q", file_obj.read(8))[0] |
| kv_data_count = struct.unpack("<Q", file_obj.read(8))[0] |
|
|
| if GGUF_VERSION == 1: |
| raise Exception('You are using an outdated GGUF, please download a new one.') |
|
|
| for i in range(kv_data_count): |
| key_length = struct.unpack("<Q", file_obj.read(8))[0] |
| key = file_obj.read(key_length) |
|
|
| value_type = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0]) |
| if value_type == GGUFValueType.ARRAY: |
| ltype = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0]) |
| length = struct.unpack("<Q", file_obj.read(8))[0] |
|
|
| arr = [get_single(ltype, file_obj) for _ in range(length)] |
| metadata[key.decode()] = arr |
| else: |
| value = get_single(value_type, file_obj) |
| metadata[key.decode()] = value |
|
|
| |
| extracted_fields = {} |
| for key, value in metadata.items(): |
| if key.endswith('.block_count'): |
| extracted_fields['n_layers'] = value |
| elif key.endswith('.attention.head_count_kv'): |
| extracted_fields['n_kv_heads'] = max(value) if isinstance(value, list) else value |
| elif key.endswith('.embedding_length'): |
| extracted_fields['embedding_dim'] = value |
| elif key.endswith('.context_length'): |
| extracted_fields['context_length'] = value |
| elif key.endswith('.feed_forward_length'): |
| extracted_fields['feed_forward_dim'] = value |
|
|
| |
| metadata.update(extracted_fields) |
| return metadata |
|
|
|
|
| def download_gguf_partial(url, max_bytes=25 * 1024 * 1024): |
| """Download the first max_bytes from a GGUF URL""" |
| try: |
| |
| headers = {'Range': f'bytes=0-{max_bytes-1}'} |
|
|
| |
| response = requests.get(url, headers=headers, stream=True) |
| response.raise_for_status() |
|
|
| |
| content = response.content |
|
|
| |
| return io.BytesIO(content) |
|
|
| except Exception as e: |
| raise Exception(f"Failed to download GGUF file: {str(e)}") |
|
|
|
|
| def load_metadata(model_url, current_metadata): |
| """Load metadata from model URL and return updated metadata dict""" |
| if not model_url or model_url.strip() == "": |
| return {}, gr.update(), "Please enter a model URL" |
|
|
| try: |
| |
| model_size_mb = get_model_size_mb_from_url(model_url) |
|
|
| |
| normalized_url = normalize_huggingface_url(model_url) |
|
|
| |
| file_obj = download_gguf_partial(normalized_url) |
|
|
| |
| metadata = load_metadata_from_file(file_obj) |
|
|
| |
| gguf_filename = model_url.split('/')[-1].split('?')[0] |
|
|
| |
| model_name = model_url |
| if "huggingface.co/" in model_url: |
| try: |
| |
| parts = model_url.split("huggingface.co/")[1].split("/") |
| if len(parts) >= 2: |
| model_name = f"{parts[0]}/{parts[1]}" |
| except: |
| model_name = model_url |
|
|
| |
| metadata['url'] = model_url |
| metadata['model_name'] = model_name |
| metadata['model_size_mb'] = model_size_mb |
| metadata['loaded'] = True |
|
|
| return metadata, gr.update(value=metadata["n_layers"], maximum=metadata["n_layers"]), f"Metadata loaded successfully for: {gguf_filename}" |
|
|
| except Exception as e: |
| error_msg = f"Error loading metadata: {str(e)}" |
| return {}, gr.update(), error_msg |
|
|
|
|
| def normalize_huggingface_url(url: str) -> str: |
| """Normalize HuggingFace URL to resolve format for direct access""" |
| if 'huggingface.co' not in url: |
| return url |
|
|
| |
| base_url = url.split('?')[0] |
|
|
| |
| if '/blob/' in base_url: |
| base_url = base_url.replace('/blob/', '/resolve/') |
|
|
| return base_url |
|
|
|
|
| def get_model_size_mb_from_url(model_url: str) -> float: |
| """Get model size in MB from URL without downloading, handling multi-part files""" |
| try: |
| |
| normalized_url = normalize_huggingface_url(model_url) |
|
|
| |
| response = requests.head(normalized_url, allow_redirects=True) |
| response.raise_for_status() |
| main_file_size = int(response.headers.get('content-length', 0)) |
|
|
| |
| filename = normalized_url.split('/')[-1] |
|
|
| |
| match = re.match(r'(.+)-(\d+)-of-(\d+)\.gguf$', filename) |
|
|
| if match: |
| base_pattern = match.group(1) |
| total_parts = int(match.group(3)) |
|
|
| total_size = 0 |
| base_url = '/'.join(normalized_url.split('/')[:-1]) + '/' |
|
|
| |
| for part_num in range(1, total_parts + 1): |
| part_filename = f"{base_pattern}-{part_num:05d}-of-{total_parts:05d}.gguf" |
| part_url = base_url + part_filename |
|
|
| try: |
| part_response = requests.head(part_url, allow_redirects=True) |
| part_response.raise_for_status() |
| part_size = int(part_response.headers.get('content-length', 0)) |
| total_size += part_size |
| except requests.RequestException as e: |
| print(f"Warning: Could not get size of {part_filename}, estimating...") |
| |
| if total_size > 0: |
| avg_size = total_size / (part_num - 1) |
| remaining_parts = total_parts - (part_num - 1) |
| total_size += avg_size * remaining_parts |
| else: |
| |
| total_size = main_file_size * total_parts |
| break |
|
|
| return total_size / (1024 ** 2) |
| else: |
| |
| return main_file_size / (1024 ** 2) |
|
|
| except Exception as e: |
| print(f"Error getting model size: {e}") |
| return 0.0 |
|
|
|
|
| def estimate_memory(metadata, gpu_layers, ctx_size, cache_type): |
| """Calculate memory usage using the actual formula""" |
| try: |
| |
| n_layers = metadata.get('n_layers') |
| n_kv_heads = metadata.get('n_kv_heads') |
| embedding_dim = metadata.get('embedding_dim') |
| context_length = metadata.get('context_length') |
| size_in_mb = metadata.get('model_size_mb', 0) |
|
|
| |
| required_fields = [n_layers, n_kv_heads, embedding_dim, context_length] |
| if any(field is None for field in required_fields): |
| missing = [name for name, field in zip( |
| ['n_layers', 'n_kv_heads', 'embedding_dim', 'context_length'], |
| required_fields) if field is None] |
| raise ValueError(f"Missing required metadata fields: {missing}") |
|
|
| |
| if gpu_layers > n_layers: |
| gpu_layers = n_layers |
|
|
| |
| if cache_type == 'q4_0': |
| cache_type = 4 |
| elif cache_type == 'q8_0': |
| cache_type = 8 |
| else: |
| cache_type = 16 |
|
|
| |
| size_per_layer = size_in_mb / max(n_layers, 1e-6) |
| kv_cache_factor = n_kv_heads * cache_type * ctx_size |
| embedding_per_context = embedding_dim / ctx_size |
|
|
| |
| |
| memory = ( |
| (size_per_layer - 17.99552795246051 + 3.148552680382576e-05 * kv_cache_factor) |
| * (gpu_layers + max(0.9690636483914102, cache_type - (floor(50.77817218646521 * embedding_per_context) + 9.987899908205632))) |
| + 1516.522943869404 |
| ) |
|
|
| return memory |
|
|
| except Exception as e: |
| print(f"Error in memory calculation: {e}") |
| raise |
|
|
|
|
| def estimate_memory_wrapper(model_metadata, gpu_layers, ctx_size, cache_type): |
| """Wrapper function to estimate memory usage""" |
| if not model_metadata or 'model_name' not in model_metadata: |
| return "<div id=\"memory-info\">Estimated memory usage:</div>" |
|
|
| try: |
| result = estimate_memory(model_metadata, gpu_layers, ctx_size, cache_type) |
| conservative = result + 577 |
| return f"""<div id="memory-info"> |
| <div>Estimated memory usage: <span class="value">{conservative:.0f} MiB</span></div> |
| </div>""" |
| except Exception as e: |
| return f"<div id=\"memory-info\">Estimated memory usage: <span class=\"value\">Error: {str(e)}</span></div>" |
|
|
|
|
| def create_ui(): |
| """Create the simplified UI""" |
| css = """ |
| .gradio-container { |
| max-width: 810px !important; |
| margin: 0 auto !important; |
| } |
| |
| #memory-info { |
| padding: 10px; |
| border-radius: 4px; |
| background-color: var(--background-fill-secondary); |
| } |
| |
| #memory-info .value { |
| font-weight: bold; |
| color: var(--primary-500); |
| } |
| """ |
|
|
| with gr.Blocks(css=css) as demo: |
| |
| model_metadata = gr.State(value={}) |
|
|
| gr.Markdown("# Accurate GGUF Memory Calculator\n\nEstimate memory usage for GGUF models based on GPU layers, context length, and cache type.\n\nThe formula was discovered through [symbolic regression](https://en.wikipedia.org/wiki/Symbolic_regression) using [TuringBot](https://turingbotsoftware.com/), evaluating over a billion candidate formulas against 19,517 real measurements. For details, see this [blog post](https://oobabooga.github.io/blog/posts/gguf-vram-formula/).") |
| with gr.Row(): |
| with gr.Column(): |
| |
| model_url = gr.Textbox( |
| label="GGUF Model URL", |
| value="https://huggingface.co/unsloth/Qwen3-235B-A22B-GGUF/blob/main/UD-Q2_K_XL/Qwen3-235B-A22B-UD-Q2_K_XL-00001-of-00002.gguf" |
| ) |
|
|
| |
| load_metadata_btn = gr.Button("Load metadata", elem_classes='refresh-button') |
|
|
| |
| gpu_layers = gr.Slider( |
| label="GPU Layers", |
| minimum=0, |
| maximum=256, |
| value=256, |
| info='`--gpu-layers` in llama.cpp.' |
| ) |
|
|
| |
| ctx_size = gr.Slider( |
| label='Context Length', |
| minimum=512, |
| maximum=262144, |
| step=256, |
| value=8192, |
| info='`--ctx-size` in llama.cpp.' |
| ) |
|
|
| |
| cache_type = gr.Radio( |
| choices=['fp16', 'q8_0', 'q4_0'], |
| value='fp16', |
| label="Cache Type", |
| info='Cache quantization.' |
| ) |
|
|
| |
| memory_info = gr.HTML( |
| value="<div id=\"memory-info\">Estimated memory usage:</div>" |
| ) |
|
|
| |
| status = gr.Textbox( |
| label="Status", |
| value="No model loaded", |
| interactive=False |
| ) |
|
|
| |
| load_metadata_btn.click( |
| load_metadata, |
| inputs=[model_url, model_metadata], |
| outputs=[model_metadata, gpu_layers, status], |
| show_progress=True |
| ).then( |
| estimate_memory_wrapper, |
| inputs=[model_metadata, gpu_layers, ctx_size, cache_type], |
| outputs=[memory_info], |
| show_progress=False |
| ) |
|
|
| |
| for component in [gpu_layers, ctx_size, cache_type]: |
| component.change( |
| estimate_memory_wrapper, |
| inputs=[model_metadata, gpu_layers, ctx_size, cache_type], |
| outputs=[memory_info], |
| show_progress=False |
| ) |
|
|
| |
| model_metadata.change( |
| estimate_memory_wrapper, |
| inputs=[model_metadata, gpu_layers, ctx_size, cache_type], |
| outputs=[memory_info], |
| show_progress=False |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| demo = create_ui() |
| demo.launch() |
|
|