Coverage for src / lilbee / providers / llama_cpp / gguf_meta.py: 100%
85 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""GGUF metadata helpers: header reads, mmproj sidecar lookup, projector type."""
3from __future__ import annotations
5import logging
6from pathlib import Path
7from typing import Any
9from gguf import GGUFReader, GGUFValueType
11from lilbee.providers.base import ProviderError
12from lilbee.providers.llama_cpp.abort_signal import abort_callback, clear_abort
13from lilbee.providers.llama_cpp.log_dispatch import (
14 import_llama_cpp,
15 install_llama_log_handler,
16 suppress_native_stderr,
17)
19log = logging.getLogger(__name__)
21_HF_BLOBS_DIR_NAME = "blobs"
22_HF_SNAPSHOTS_DIR_NAME = "snapshots"
23_CLIP_PROJECTOR_TYPE_KEY = "clip.projector_type"
26def train_ctx_from_meta(
27 meta: dict[str, str] | None,
28 *,
29 fallback: int,
30 model_path: Path,
31) -> int:
32 """Resolve ``<arch>.context_length`` from GGUF metadata, clamping junk to ``fallback``.
34 Some published GGUFs (nomic-embed, certain Qwen3 and vision builds)
35 report ``context_length=0`` in their headers. Passing zero into
36 ``Llama(n_ctx=...)`` cascades into ``n_batch=0`` / ``n_ubatch=0``,
37 which trips ggml's Vulkan dispatch into undefined behaviour and
38 surfaces as STATUS_HEAP_CORRUPTION on Windows. Unparseable values
39 and non-positive integers both route to ``fallback``.
40 """
41 if not meta:
42 return fallback
43 raw = meta.get("context_length", str(fallback))
44 try:
45 value = int(raw)
46 except (TypeError, ValueError):
47 log.warning(
48 "GGUF %s has unparseable context_length=%r; using %d",
49 model_path.name,
50 raw,
51 fallback,
52 )
53 return fallback
54 if value <= 0:
55 log.warning(
56 "GGUF %s reports context_length=%d; using %d to avoid n_batch=0 crash",
57 model_path.name,
58 value,
59 fallback,
60 )
61 return fallback
62 return value
65def read_gguf_metadata(model_path: Path) -> dict[str, str] | None:
66 """Read metadata from a GGUF file's headers via llama-cpp-python.
68 Returns a dict with keys like ``architecture``, ``context_length``,
69 ``embedding_length``, ``chat_template``, ``file_type``, plus the
70 KV-cache-shape fields (``block_count``, ``head_count_kv``,
71 ``head_count``, ``key_length``, ``value_length``) used to size n_ctx
72 against host memory.
73 """
74 Llama = import_llama_cpp().Llama # noqa: N806
76 # Fresh abort flag: a prior request_abort() must not latch and break
77 # this metadata read, which is on the path of every model swap.
78 clear_abort()
79 install_llama_log_handler()
80 kwargs: dict[str, Any] = {
81 "model_path": str(model_path),
82 "vocab_only": True,
83 "verbose": False,
84 "n_gpu_layers": 0,
85 }
86 kwargs.setdefault("abort_callback", abort_callback)
87 llm = suppress_native_stderr(Llama, **kwargs)
88 try:
89 raw = llm.metadata or {}
90 result: dict[str, str] = {}
91 if "general.architecture" in raw:
92 result["architecture"] = str(raw["general.architecture"])
93 arch = raw.get("general.architecture", "llama")
94 ctx_key = f"{arch}.context_length"
95 if ctx_key in raw:
96 result["context_length"] = str(raw[ctx_key])
97 emb_key = f"{arch}.embedding_length"
98 if emb_key in raw:
99 result["embedding_length"] = str(raw[emb_key])
100 for arch_key, out_key in (
101 (f"{arch}.block_count", "block_count"),
102 (f"{arch}.attention.head_count_kv", "head_count_kv"),
103 (f"{arch}.attention.head_count", "head_count"),
104 (f"{arch}.attention.key_length", "key_length"),
105 (f"{arch}.attention.value_length", "value_length"),
106 ):
107 if arch_key in raw:
108 result[out_key] = str(raw[arch_key])
109 if "tokenizer.chat_template" in raw:
110 result["chat_template"] = str(raw["tokenizer.chat_template"])
111 if "general.file_type" in raw:
112 result["file_type"] = str(raw["general.file_type"])
113 if "general.name" in raw:
114 result["name"] = str(raw["general.name"])
115 return result or None
116 finally:
117 llm.close()
120def _find_mmproj_in_hf_snapshots(model_dir: Path) -> Path | None:
121 """Walk an HF-cache ``blobs/`` dir up to its sibling ``snapshots/`` tree."""
122 if model_dir.name != _HF_BLOBS_DIR_NAME:
123 return None
124 snapshots_dir = model_dir.parent / _HF_SNAPSHOTS_DIR_NAME
125 if not snapshots_dir.is_dir():
126 return None
127 for snapshot in snapshots_dir.iterdir():
128 candidates = sorted(snapshot.glob("*mmproj*.gguf"))
129 if candidates:
130 return candidates[0]
131 return None
134def _find_mmproj_in_flat_dir(model_dir: Path) -> Path | None:
135 """Glob ``*mmproj*.gguf`` siblings of a model GGUF (sideloaded layout)."""
136 candidates = sorted(model_dir.glob("*mmproj*.gguf"))
137 return candidates[0] if candidates else None
140def find_mmproj_for_model(model_path: Path) -> Path:
141 """Find the mmproj (CLIP projection) file for a vision model.
143 Resolution order: (1) catalog lookup scoped to ``FEATURED_VISION``,
144 (2) HuggingFace-cache ``snapshots/`` sibling of ``blobs/``,
145 (3) same-directory glob for flat sideloaded layouts.
146 Raises ``ProviderError`` if none find a file.
147 """
148 from lilbee.catalog import find_mmproj_file
150 found = (
151 find_mmproj_file(model_path.stem)
152 or _find_mmproj_in_hf_snapshots(model_path.parent)
153 or _find_mmproj_in_flat_dir(model_path.parent)
154 )
155 if found is not None:
156 return found
158 raise ProviderError(
159 f"No mmproj (CLIP projection) file found for vision model {model_path.name}. "
160 f"Download the mmproj file to {model_path.parent} or re-download the vision "
161 "model through the catalog to get both files.",
162 provider="llama-cpp",
163 )
166def read_mmproj_projector_type(mmproj_path: Path) -> str | None:
167 """Read ``clip.projector_type`` from a GGUF mmproj without loading the model."""
168 try:
169 reader = GGUFReader(str(mmproj_path))
170 field = reader.get_field(_CLIP_PROJECTOR_TYPE_KEY)
171 except Exception:
172 log.debug("Failed to read mmproj metadata from %s", mmproj_path, exc_info=True)
173 return None
174 if field is None or field.types[-1] != GGUFValueType.STRING:
175 return None
176 return bytes(field.parts[field.data[0]]).decode("utf-8", errors="replace")