Coverage for src / lilbee / catalog / download.py: 100%
155 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""GGUF download, mmproj resolution, post-download hooks."""
3import fnmatch
4import logging
5from collections.abc import Callable
6from http import HTTPStatus
7from pathlib import Path
8from typing import Any
10import httpx
11from pydantic import BaseModel
13from lilbee.catalog.download_progress import ProgressCallback, _ProgressTracker
14from lilbee.catalog.featured import DEFAULT_MMPROJ_PATTERN, VISION_MMPROJ_FILES
15from lilbee.catalog.hf_client import DEFAULT_TIMEOUT, hf_headers, hf_token
16from lilbee.catalog.models import CatalogModel
17from lilbee.catalog.types import ModelTask
18from lilbee.core.config.model import cfg
19from lilbee.runtime.cancellation import TaskCancelledError
21CompleteCallback = Callable[[CatalogModel, Path], None]
23log = logging.getLogger(__name__)
26class DownloadConfig(BaseModel):
27 model_config = {"arbitrary_types_allowed": True}
29 repo_id: str
30 filename: str
31 token: str | None
32 force_download: bool = False
33 cache_dir: str | None = None
34 tqdm_class: Any = None
37def _hf_download_or_translate(entry: CatalogModel, config: DownloadConfig) -> Path:
38 """Run the HF download and translate every error class into a clean exception."""
39 from huggingface_hub import hf_hub_download
40 from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
42 try:
43 # HF_HUB_DISABLE_XET is set in lilbee/__init__.py at import time.
44 # Setting it here is too late: huggingface_hub.constants already
45 # captured the value when this module first imported it.
46 return Path(hf_hub_download(**config.model_dump(exclude_none=True)))
47 except TaskCancelledError:
48 raise
49 except GatedRepoError:
50 raise PermissionError(
51 f"{entry.hf_repo} requires HuggingFace authentication. "
52 "Set HF_TOKEN env var or visit the repo page to request access."
53 ) from None
54 except RepositoryNotFoundError:
55 raise RuntimeError(f"Repository {entry.hf_repo!r} not found on HuggingFace.") from None
56 except (httpx.TimeoutException, httpx.ConnectError) as exc:
57 raise RuntimeError(f"Network error downloading {entry.hf_repo}: {exc}") from None
58 except OSError as exc:
59 raise RuntimeError(f"I/O error downloading {entry.hf_repo}: {exc}") from None
60 except Exception as exc:
61 raise RuntimeError(
62 f"Failed to download {entry.hf_repo}: {type(exc).__name__}: {exc}"
63 ) from None
66def download_model(
67 entry: CatalogModel,
68 *,
69 on_progress: ProgressCallback | None = None,
70 on_complete: CompleteCallback | None = None,
71) -> Path:
72 """Download a GGUF model from HuggingFace to cfg.models_dir.
73 Uses huggingface_hub for resumable downloads, caching, and auth.
74 The optional *on_progress(downloaded, total)* callback receives byte counts.
75 The optional *on_complete(entry, file_path)* callback runs after the file
76 is on disk; modelhub uses it to write a registry manifest. For vision
77 models, also downloads the mmproj (CLIP projection) file.
79 Raises:
80 PermissionError: gated repo requiring authentication
81 RuntimeError: repo not found or download failure with details
82 """
83 cfg.models_dir.mkdir(parents=True, exist_ok=True)
85 filename = resolve_filename(entry)
86 dest = cfg.models_dir / filename
87 if dest.exists():
88 log.info("Model already downloaded: %s", dest)
89 if on_progress is not None:
90 size = dest.stat().st_size
91 on_progress(size, size) # Report 100% immediately
92 return _finalize_download(entry, dest, on_progress=on_progress, on_complete=on_complete)
94 log.info("Downloading %s/%s → %s", entry.hf_repo, filename, cfg.models_dir)
95 tracker = _ProgressTracker(on_progress) if on_progress else None
96 config = DownloadConfig(
97 repo_id=entry.hf_repo,
98 filename=filename,
99 token=hf_token(),
100 cache_dir=str(cfg.models_dir),
101 tqdm_class=tracker.make_tqdm_class() if tracker else None,
102 )
104 cached = _hf_download_or_translate(entry, config)
106 if on_progress:
107 actual_size = cached.stat().st_size
108 if not tracker or not tracker.was_used:
109 log.info("Model found in HuggingFace cache: %s", cached)
110 on_progress(actual_size, actual_size)
111 return _finalize_download(entry, cached, on_progress=on_progress, on_complete=on_complete)
114def _finalize_download(
115 entry: CatalogModel,
116 dest: Path,
117 *,
118 on_progress: ProgressCallback | None = None,
119 on_complete: CompleteCallback | None = None,
120) -> Path:
121 """Run post-download hooks: registry write (via on_complete) + mmproj fetch."""
122 if on_complete is not None:
123 on_complete(entry, dest)
124 if entry.task == ModelTask.VISION:
125 _download_mmproj(entry, on_progress=on_progress)
126 return dest
129def _download_mmproj(
130 entry: CatalogModel,
131 *,
132 on_progress: ProgressCallback | None = None,
133) -> Path | None:
134 """Download the mmproj (CLIP projection) file for a vision model.
135 Returns the path to the downloaded file, or None if no mmproj is configured.
136 The optional ``on_progress`` callback receives ``(downloaded, total)`` byte
137 counts and is wired through the same tqdm hook used by the main download.
138 """
139 mmproj_pattern = VISION_MMPROJ_FILES.get(entry.hf_repo, DEFAULT_MMPROJ_PATTERN)
141 mmproj_filename = _resolve_mmproj_filename(entry.hf_repo, mmproj_pattern)
142 if not mmproj_filename:
143 log.warning("Could not resolve mmproj file for %s", entry.hf_repo)
144 return None
146 from huggingface_hub import hf_hub_download
148 tracker = _ProgressTracker(on_progress) if on_progress else None
149 log.info("Downloading mmproj %s/%s → %s", entry.hf_repo, mmproj_filename, cfg.models_dir)
150 path = Path(
151 hf_hub_download(
152 repo_id=entry.hf_repo,
153 filename=mmproj_filename,
154 cache_dir=str(cfg.models_dir),
155 token=hf_token(),
156 tqdm_class=tracker.make_tqdm_class() if tracker else None,
157 )
158 )
159 if on_progress is not None and (not tracker or not tracker.was_used):
160 # Cache hit: HF returned the cached path without invoking tqdm.
161 size = path.stat().st_size
162 on_progress(size, size)
163 return path
166def _resolve_mmproj_filename(hf_repo: str, pattern: str) -> str | None:
167 """Resolve an mmproj filename pattern to a concrete filename via the HF API."""
168 if "*" not in pattern:
169 return pattern
171 try:
172 resp = httpx.get(
173 f"https://huggingface.co/api/models/{hf_repo}",
174 timeout=DEFAULT_TIMEOUT,
175 headers=hf_headers(),
176 )
177 resp.raise_for_status()
178 siblings = resp.json().get("siblings", [])
179 except Exception as exc:
180 log.warning("Cannot query mmproj files for %s: %s", hf_repo, exc)
181 return None
183 mmproj_files: list[str] = [
184 s.get("rfilename", "") for s in siblings if fnmatch.fnmatch(s.get("rfilename", ""), pattern)
185 ]
186 if not mmproj_files:
187 return None
189 # Prefer F16 over F32 (smaller), and any over BF16
190 for preference in ("f16", "F16"):
191 for f in mmproj_files:
192 if preference in f:
193 return f
194 return mmproj_files[0]
197def _mmproj_in_models_dir_matching(pattern: str) -> Path | None:
198 """Return the first ``*.gguf`` under ``cfg.models_dir`` that matches."""
199 models_dir: Path = cfg.models_dir
200 for p in models_dir.rglob("*.gguf"):
201 if fnmatch.fnmatch(p.name, pattern) or "mmproj" in p.name.lower():
202 return p
203 return None
206def find_mmproj_file(model_ref: str) -> Path | None:
207 """Find the mmproj for a ``FEATURED_VISION`` entry under ``cfg.models_dir``.
209 *model_ref* is matched against each featured vision entry's
210 ``hf_repo``. Returns ``None`` when nothing matches. Never falls back
211 to an arbitrary mmproj: that cross-contaminates non-vision chat
212 models (e.g. a chat model would inherit a vision model's mmproj and
213 be misreported as vision-capable).
214 """
215 # Local import to avoid pulling featured.py into hf_client/ etc.
216 from lilbee.catalog.featured import FEATURED_VISION
218 if not cfg.models_dir.exists():
219 return None
220 for entry in FEATURED_VISION:
221 if model_ref not in entry.hf_repo and entry.hf_repo not in model_ref:
222 continue
223 pattern = VISION_MMPROJ_FILES.get(entry.hf_repo, DEFAULT_MMPROJ_PATTERN)
224 match = _mmproj_in_models_dir_matching(pattern)
225 if match is not None:
226 return match
227 return None
230_QUANT_PREFERENCE = ("Q4_K_M", "Q4_K_S", "Q5_K_M", "Q5_K_S", "Q8_0", "Q6_K", "Q3_K_M")
233def resolve_filename(entry: CatalogModel) -> str:
234 """Resolve a GGUF filename pattern to the best concrete filename.
235 For exact filenames, return as-is. For wildcards, query the HF API
236 and pick the best quantization (prefer Q4_K_M for balance of size/quality).
237 """
238 if "*" not in entry.gguf_filename:
239 return entry.gguf_filename
241 try:
242 resp = httpx.get(
243 f"https://huggingface.co/api/models/{entry.hf_repo}",
244 timeout=DEFAULT_TIMEOUT,
245 headers=hf_headers(),
246 )
247 if resp.status_code == HTTPStatus.UNAUTHORIZED:
248 raise PermissionError(
249 f"{entry.hf_repo} requires HuggingFace authentication. "
250 "Set HF_TOKEN env var or visit the repo page to request access."
251 )
252 resp.raise_for_status()
253 siblings = resp.json().get("siblings", [])
254 except PermissionError:
255 raise
256 except Exception as exc:
257 raise RuntimeError(f"Cannot query files for {entry.hf_repo}: {exc}") from exc
259 gguf_files = [
260 s.get("rfilename", "") for s in siblings if s.get("rfilename", "").endswith(".gguf")
261 ]
262 if not gguf_files:
263 raise RuntimeError(f"No GGUF files found in {entry.hf_repo}")
265 return _pick_best_gguf(gguf_files)
268def _pick_best_gguf(filenames: list[str]) -> str:
269 """Pick the best GGUF file by quantization preference."""
270 for quant in _QUANT_PREFERENCE:
271 for f in filenames:
272 if quant in f:
273 return f
274 return filenames[0]
277def fetch_model_file_size(hf_repo: str) -> float:
278 """Fetch the best GGUF file size from HuggingFace tree API.
279 Returns size in GB, or 0.0 if unavailable.
280 """
281 try:
282 resp = httpx.get(
283 f"https://huggingface.co/api/models/{hf_repo}/tree/main",
284 timeout=DEFAULT_TIMEOUT,
285 headers=hf_headers(),
286 )
287 resp.raise_for_status()
288 files = resp.json()
289 except Exception:
290 return 0.0
292 gguf_files = [
293 (f.get("path", ""), f.get("size", 0) or f.get("lfs", {}).get("size", 0))
294 for f in files
295 if isinstance(f, dict) and f.get("path", "").endswith(".gguf")
296 ]
297 if not gguf_files:
298 return 0.0
300 best_name = _pick_best_gguf([name for name, _ in gguf_files])
301 size_bytes = next((s for n, s in gguf_files if n == best_name), 0)
302 return round(size_bytes / (1024**3), 1) if size_bytes else 0.0