Coverage for src / lilbee / catalog / download.py: 100%

167 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""GGUF download, mmproj resolution, post-download hooks.""" 

2 

3import fnmatch 

4import logging 

5from collections.abc import Callable 

6from http import HTTPStatus 

7from pathlib import Path 

8from typing import Any 

9 

10import httpx 

11from pydantic import BaseModel 

12 

13from lilbee.catalog.download_progress import ProgressCallback, _ProgressTracker 

14from lilbee.catalog.featured import DEFAULT_MMPROJ_PATTERN, VISION_MMPROJ_FILES 

15from lilbee.catalog.hf_client import DEFAULT_TIMEOUT, HF_API_URL, hf_headers, hf_token 

16from lilbee.catalog.models import CatalogModel 

17from lilbee.catalog.refs import pick_best_gguf 

18from lilbee.catalog.types import ModelTask 

19from lilbee.core.config.model import cfg 

20from lilbee.runtime.cancellation import TaskCancelledError 

21 

22CompleteCallback = Callable[[CatalogModel, Path], None] 

23 

24log = logging.getLogger(__name__) 

25 

26 

27class DownloadConfig(BaseModel): 

28 model_config = {"arbitrary_types_allowed": True} 

29 

30 repo_id: str 

31 filename: str 

32 token: str | None 

33 force_download: bool = False 

34 cache_dir: str | None = None 

35 tqdm_class: Any = None 

36 

37 

38def _hf_download_or_translate(entry: CatalogModel, config: DownloadConfig) -> Path: 

39 """Run the HF download and translate every error class into a clean exception.""" 

40 from huggingface_hub import hf_hub_download 

41 from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError 

42 

43 try: 

44 # HF_HUB_DISABLE_XET is set in lilbee/__init__.py at import time. 

45 # Setting it here is too late: huggingface_hub.constants already 

46 # captured the value when this module first imported it. 

47 return Path(hf_hub_download(**config.model_dump(exclude_none=True))) 

48 except TaskCancelledError: 

49 raise 

50 except GatedRepoError: 

51 raise PermissionError( 

52 f"{entry.hf_repo} requires HuggingFace authentication. " 

53 "Set HF_TOKEN env var or visit the repo page to request access." 

54 ) from None 

55 except RepositoryNotFoundError: 

56 raise RuntimeError(f"Repository {entry.hf_repo!r} not found on HuggingFace.") from None 

57 except (httpx.TimeoutException, httpx.ConnectError) as exc: 

58 raise RuntimeError(f"Network error downloading {entry.hf_repo}: {exc}") from None 

59 except OSError as exc: 

60 raise RuntimeError(f"I/O error downloading {entry.hf_repo}: {exc}") from None 

61 except Exception as exc: 

62 raise RuntimeError( 

63 f"Failed to download {entry.hf_repo}: {type(exc).__name__}: {exc}" 

64 ) from None 

65 

66 

67def download_model( 

68 entry: CatalogModel, 

69 *, 

70 on_progress: ProgressCallback | None = None, 

71 on_complete: CompleteCallback | None = None, 

72) -> Path: 

73 """Download a GGUF model from HuggingFace to cfg.models_dir. 

74 Uses huggingface_hub for resumable downloads, caching, and auth. 

75 The optional *on_progress(downloaded, total)* callback receives byte counts. 

76 The optional *on_complete(entry, file_path)* callback runs after the file 

77 is on disk; modelhub uses it to write a registry manifest. For vision 

78 models, also downloads the mmproj (CLIP projection) file. 

79 

80 Raises: 

81 PermissionError: gated repo requiring authentication 

82 RuntimeError: repo not found or download failure with details 

83 """ 

84 cfg.models_dir.mkdir(parents=True, exist_ok=True) 

85 

86 filename = resolve_filename(entry) 

87 dest = cfg.models_dir / filename 

88 if dest.exists() and _cached_file_is_complete(entry.hf_repo, filename, dest): 

89 log.info("Model already downloaded: %s", dest) 

90 if on_progress is not None: 

91 size = dest.stat().st_size 

92 on_progress(size, size) # Report 100% immediately 

93 return _finalize_download(entry, dest, on_progress=on_progress, on_complete=on_complete) 

94 

95 log.info("Downloading %s/%s → %s", entry.hf_repo, filename, cfg.models_dir) 

96 tracker = _ProgressTracker(on_progress) if on_progress else None 

97 config = DownloadConfig( 

98 repo_id=entry.hf_repo, 

99 filename=filename, 

100 token=hf_token(), 

101 cache_dir=str(cfg.models_dir), 

102 tqdm_class=tracker.make_tqdm_class() if tracker else None, 

103 ) 

104 

105 cached = _hf_download_or_translate(entry, config) 

106 

107 if on_progress: 

108 actual_size = cached.stat().st_size 

109 if not tracker or not tracker.was_used: 

110 log.info("Model found in HuggingFace cache: %s", cached) 

111 on_progress(actual_size, actual_size) 

112 return _finalize_download(entry, cached, on_progress=on_progress, on_complete=on_complete) 

113 

114 

115def _finalize_download( 

116 entry: CatalogModel, 

117 dest: Path, 

118 *, 

119 on_progress: ProgressCallback | None = None, 

120 on_complete: CompleteCallback | None = None, 

121) -> Path: 

122 """Run post-download hooks: registry write (via on_complete) + mmproj fetch.""" 

123 if on_complete is not None: 

124 on_complete(entry, dest) 

125 if entry.task == ModelTask.VISION: 

126 _download_mmproj(entry, on_progress=on_progress) 

127 return dest 

128 

129 

130def _download_mmproj( 

131 entry: CatalogModel, 

132 *, 

133 on_progress: ProgressCallback | None = None, 

134) -> Path | None: 

135 """Download the mmproj (CLIP projection) file for a vision model. 

136 Returns the path to the downloaded file, or None if no mmproj is configured. 

137 The optional ``on_progress`` callback receives ``(downloaded, total)`` byte 

138 counts and is wired through the same tqdm hook used by the main download. 

139 """ 

140 mmproj_pattern = VISION_MMPROJ_FILES.get(entry.hf_repo, DEFAULT_MMPROJ_PATTERN) 

141 

142 mmproj_filename = _resolve_mmproj_filename(entry.hf_repo, mmproj_pattern) 

143 if not mmproj_filename: 

144 log.warning("Could not resolve mmproj file for %s", entry.hf_repo) 

145 return None 

146 

147 from huggingface_hub import hf_hub_download 

148 

149 tracker = _ProgressTracker(on_progress) if on_progress else None 

150 log.info("Downloading mmproj %s/%s → %s", entry.hf_repo, mmproj_filename, cfg.models_dir) 

151 path = Path( 

152 hf_hub_download( 

153 repo_id=entry.hf_repo, 

154 filename=mmproj_filename, 

155 cache_dir=str(cfg.models_dir), 

156 token=hf_token(), 

157 tqdm_class=tracker.make_tqdm_class() if tracker else None, 

158 ) 

159 ) 

160 if on_progress is not None and (not tracker or not tracker.was_used): 

161 # Cache hit: HF returned the cached path without invoking tqdm. 

162 size = path.stat().st_size 

163 on_progress(size, size) 

164 return path 

165 

166 

167def _resolve_mmproj_filename(hf_repo: str, pattern: str) -> str | None: 

168 """Resolve an mmproj filename pattern to a concrete filename via the HF API.""" 

169 if "*" not in pattern: 

170 return pattern 

171 

172 try: 

173 resp = httpx.get( 

174 f"{HF_API_URL}/{hf_repo}", 

175 timeout=DEFAULT_TIMEOUT, 

176 headers=hf_headers(), 

177 ) 

178 resp.raise_for_status() 

179 siblings = resp.json().get("siblings", []) 

180 except Exception as exc: 

181 log.warning("Cannot query mmproj files for %s: %s", hf_repo, exc) 

182 return None 

183 

184 mmproj_files: list[str] = [ 

185 s.get("rfilename", "") for s in siblings if fnmatch.fnmatch(s.get("rfilename", ""), pattern) 

186 ] 

187 if not mmproj_files: 

188 return None 

189 

190 # Prefer F16 over F32 (smaller), and any over BF16 

191 for preference in ("f16", "F16"): 

192 for f in mmproj_files: 

193 if preference in f: 

194 return f 

195 return mmproj_files[0] 

196 

197 

198def _mmproj_in_models_dir_matching(pattern: str) -> Path | None: 

199 """Return the first ``*.gguf`` under ``cfg.models_dir`` that matches.""" 

200 models_dir: Path = cfg.models_dir 

201 for p in models_dir.rglob("*.gguf"): 

202 if fnmatch.fnmatch(p.name, pattern) or "mmproj" in p.name.lower(): 

203 return p 

204 return None 

205 

206 

207def find_mmproj_file(model_ref: str) -> Path | None: 

208 """Find the mmproj for a ``FEATURED_VISION`` entry under ``cfg.models_dir``. 

209 

210 *model_ref* is matched against each featured vision entry's 

211 ``hf_repo``. Returns ``None`` when nothing matches. Never falls back 

212 to an arbitrary mmproj: that cross-contaminates non-vision chat 

213 models (e.g. a chat model would inherit a vision model's mmproj and 

214 be misreported as vision-capable). 

215 """ 

216 # Local import to avoid pulling featured.py into hf_client/ etc. 

217 from lilbee.catalog.featured import FEATURED_VISION 

218 

219 if not cfg.models_dir.exists(): 

220 return None 

221 for entry in FEATURED_VISION: 

222 if model_ref not in entry.hf_repo and entry.hf_repo not in model_ref: 

223 continue 

224 pattern = VISION_MMPROJ_FILES.get(entry.hf_repo, DEFAULT_MMPROJ_PATTERN) 

225 match = _mmproj_in_models_dir_matching(pattern) 

226 if match is not None: 

227 return match 

228 return None 

229 

230 

231def resolve_filename(entry: CatalogModel) -> str: 

232 """Resolve a GGUF filename pattern to the best concrete filename. 

233 For exact filenames, return as-is. For wildcards, query the HF API 

234 and pick the best quantization (prefer Q4_K_M for balance of size/quality). 

235 """ 

236 if "*" not in entry.gguf_filename: 

237 return entry.gguf_filename 

238 

239 try: 

240 resp = httpx.get( 

241 f"{HF_API_URL}/{entry.hf_repo}", 

242 timeout=DEFAULT_TIMEOUT, 

243 headers=hf_headers(), 

244 ) 

245 if resp.status_code == HTTPStatus.UNAUTHORIZED: 

246 raise PermissionError( 

247 f"{entry.hf_repo} requires HuggingFace authentication. " 

248 "Set HF_TOKEN env var or visit the repo page to request access." 

249 ) 

250 resp.raise_for_status() 

251 siblings = resp.json().get("siblings", []) 

252 except PermissionError: 

253 raise 

254 except Exception as exc: 

255 raise RuntimeError(f"Cannot query files for {entry.hf_repo}: {exc}") from exc 

256 

257 gguf_files = [ 

258 s.get("rfilename", "") for s in siblings if s.get("rfilename", "").endswith(".gguf") 

259 ] 

260 if not gguf_files: 

261 raise RuntimeError(f"No GGUF files found in {entry.hf_repo}") 

262 

263 return pick_best_gguf(gguf_files) 

264 

265 

266_SIZE_UNKNOWN = 0 

267 

268 

269def _cached_file_is_complete(hf_repo: str, filename: str, dest: Path) -> bool: 

270 """Decide whether an existing cached file may be accepted as complete. 

271 

272 Verifies the on-disk byte size against the size HuggingFace reports for 

273 *filename*. A mismatch means a truncated / corrupt download, so the file 

274 is rejected and re-fetched. When the size can't be fetched (offline, API 

275 error) it stays unknown and the cached file is accepted: there's nothing 

276 to verify against and refusing would block all offline reuse. 

277 """ 

278 expected = fetch_expected_file_size(hf_repo, filename) 

279 if expected == _SIZE_UNKNOWN: 

280 return True 

281 actual = dest.stat().st_size 

282 if actual == expected: 

283 return True 

284 log.warning( 

285 "Cached %s is %d bytes but HuggingFace reports %d; re-downloading", 

286 dest, 

287 actual, 

288 expected, 

289 ) 

290 return False 

291 

292 

293def _hf_file_size(hf_repo: str, filename: str) -> int | None: 

294 """Byte size huggingface_hub resolves for *filename* (None if unreported).""" 

295 from huggingface_hub import get_hf_file_metadata, hf_hub_url 

296 

297 return get_hf_file_metadata(hf_hub_url(hf_repo, filename), token=hf_token()).size 

298 

299 

300def fetch_expected_file_size(hf_repo: str, filename: str) -> int: 

301 """Return the byte size huggingface_hub reports for *filename*, or _SIZE_UNKNOWN. 

302 

303 Resolves via hf_hub's own file metadata (correct revision, redirects, and 

304 LFS/Xet handled uniformly) instead of scraping the repo tree. Returns 0 when 

305 offline or unresolvable, in which case the caller keeps the cached file. 

306 """ 

307 try: 

308 return _hf_file_size(hf_repo, filename) or _SIZE_UNKNOWN 

309 except Exception: 

310 return _SIZE_UNKNOWN 

311 

312 

313def fetch_model_file_size(hf_repo: str) -> float: 

314 """Fetch the best GGUF file size from HuggingFace tree API. 

315 Returns size in GB, or 0.0 if unavailable. 

316 """ 

317 try: 

318 resp = httpx.get( 

319 f"{HF_API_URL}/{hf_repo}/tree/main", 

320 timeout=DEFAULT_TIMEOUT, 

321 headers=hf_headers(), 

322 ) 

323 resp.raise_for_status() 

324 files = resp.json() 

325 except Exception: 

326 return 0.0 

327 

328 gguf_files = [ 

329 (f.get("path", ""), f.get("size", 0) or f.get("lfs", {}).get("size", 0)) 

330 for f in files 

331 if isinstance(f, dict) and f.get("path", "").endswith(".gguf") 

332 ] 

333 if not gguf_files: 

334 return 0.0 

335 

336 best_name = pick_best_gguf([name for name, _ in gguf_files]) 

337 size_bytes = next((s for n, s in gguf_files if n == best_name), 0) 

338 return round(size_bytes / (1024**3), 1) if size_bytes else 0.0