Coverage for src / lilbee / catalog / download.py: 100%

155 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""GGUF download, mmproj resolution, post-download hooks.""" 

2 

3import fnmatch 

4import logging 

5from collections.abc import Callable 

6from http import HTTPStatus 

7from pathlib import Path 

8from typing import Any 

9 

10import httpx 

11from pydantic import BaseModel 

12 

13from lilbee.catalog.download_progress import ProgressCallback, _ProgressTracker 

14from lilbee.catalog.featured import DEFAULT_MMPROJ_PATTERN, VISION_MMPROJ_FILES 

15from lilbee.catalog.hf_client import DEFAULT_TIMEOUT, hf_headers, hf_token 

16from lilbee.catalog.models import CatalogModel 

17from lilbee.catalog.types import ModelTask 

18from lilbee.core.config.model import cfg 

19from lilbee.runtime.cancellation import TaskCancelledError 

20 

21CompleteCallback = Callable[[CatalogModel, Path], None] 

22 

23log = logging.getLogger(__name__) 

24 

25 

26class DownloadConfig(BaseModel): 

27 model_config = {"arbitrary_types_allowed": True} 

28 

29 repo_id: str 

30 filename: str 

31 token: str | None 

32 force_download: bool = False 

33 cache_dir: str | None = None 

34 tqdm_class: Any = None 

35 

36 

37def _hf_download_or_translate(entry: CatalogModel, config: DownloadConfig) -> Path: 

38 """Run the HF download and translate every error class into a clean exception.""" 

39 from huggingface_hub import hf_hub_download 

40 from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError 

41 

42 try: 

43 # HF_HUB_DISABLE_XET is set in lilbee/__init__.py at import time. 

44 # Setting it here is too late: huggingface_hub.constants already 

45 # captured the value when this module first imported it. 

46 return Path(hf_hub_download(**config.model_dump(exclude_none=True))) 

47 except TaskCancelledError: 

48 raise 

49 except GatedRepoError: 

50 raise PermissionError( 

51 f"{entry.hf_repo} requires HuggingFace authentication. " 

52 "Set HF_TOKEN env var or visit the repo page to request access." 

53 ) from None 

54 except RepositoryNotFoundError: 

55 raise RuntimeError(f"Repository {entry.hf_repo!r} not found on HuggingFace.") from None 

56 except (httpx.TimeoutException, httpx.ConnectError) as exc: 

57 raise RuntimeError(f"Network error downloading {entry.hf_repo}: {exc}") from None 

58 except OSError as exc: 

59 raise RuntimeError(f"I/O error downloading {entry.hf_repo}: {exc}") from None 

60 except Exception as exc: 

61 raise RuntimeError( 

62 f"Failed to download {entry.hf_repo}: {type(exc).__name__}: {exc}" 

63 ) from None 

64 

65 

66def download_model( 

67 entry: CatalogModel, 

68 *, 

69 on_progress: ProgressCallback | None = None, 

70 on_complete: CompleteCallback | None = None, 

71) -> Path: 

72 """Download a GGUF model from HuggingFace to cfg.models_dir. 

73 Uses huggingface_hub for resumable downloads, caching, and auth. 

74 The optional *on_progress(downloaded, total)* callback receives byte counts. 

75 The optional *on_complete(entry, file_path)* callback runs after the file 

76 is on disk; modelhub uses it to write a registry manifest. For vision 

77 models, also downloads the mmproj (CLIP projection) file. 

78 

79 Raises: 

80 PermissionError: gated repo requiring authentication 

81 RuntimeError: repo not found or download failure with details 

82 """ 

83 cfg.models_dir.mkdir(parents=True, exist_ok=True) 

84 

85 filename = resolve_filename(entry) 

86 dest = cfg.models_dir / filename 

87 if dest.exists(): 

88 log.info("Model already downloaded: %s", dest) 

89 if on_progress is not None: 

90 size = dest.stat().st_size 

91 on_progress(size, size) # Report 100% immediately 

92 return _finalize_download(entry, dest, on_progress=on_progress, on_complete=on_complete) 

93 

94 log.info("Downloading %s/%s → %s", entry.hf_repo, filename, cfg.models_dir) 

95 tracker = _ProgressTracker(on_progress) if on_progress else None 

96 config = DownloadConfig( 

97 repo_id=entry.hf_repo, 

98 filename=filename, 

99 token=hf_token(), 

100 cache_dir=str(cfg.models_dir), 

101 tqdm_class=tracker.make_tqdm_class() if tracker else None, 

102 ) 

103 

104 cached = _hf_download_or_translate(entry, config) 

105 

106 if on_progress: 

107 actual_size = cached.stat().st_size 

108 if not tracker or not tracker.was_used: 

109 log.info("Model found in HuggingFace cache: %s", cached) 

110 on_progress(actual_size, actual_size) 

111 return _finalize_download(entry, cached, on_progress=on_progress, on_complete=on_complete) 

112 

113 

114def _finalize_download( 

115 entry: CatalogModel, 

116 dest: Path, 

117 *, 

118 on_progress: ProgressCallback | None = None, 

119 on_complete: CompleteCallback | None = None, 

120) -> Path: 

121 """Run post-download hooks: registry write (via on_complete) + mmproj fetch.""" 

122 if on_complete is not None: 

123 on_complete(entry, dest) 

124 if entry.task == ModelTask.VISION: 

125 _download_mmproj(entry, on_progress=on_progress) 

126 return dest 

127 

128 

129def _download_mmproj( 

130 entry: CatalogModel, 

131 *, 

132 on_progress: ProgressCallback | None = None, 

133) -> Path | None: 

134 """Download the mmproj (CLIP projection) file for a vision model. 

135 Returns the path to the downloaded file, or None if no mmproj is configured. 

136 The optional ``on_progress`` callback receives ``(downloaded, total)`` byte 

137 counts and is wired through the same tqdm hook used by the main download. 

138 """ 

139 mmproj_pattern = VISION_MMPROJ_FILES.get(entry.hf_repo, DEFAULT_MMPROJ_PATTERN) 

140 

141 mmproj_filename = _resolve_mmproj_filename(entry.hf_repo, mmproj_pattern) 

142 if not mmproj_filename: 

143 log.warning("Could not resolve mmproj file for %s", entry.hf_repo) 

144 return None 

145 

146 from huggingface_hub import hf_hub_download 

147 

148 tracker = _ProgressTracker(on_progress) if on_progress else None 

149 log.info("Downloading mmproj %s/%s → %s", entry.hf_repo, mmproj_filename, cfg.models_dir) 

150 path = Path( 

151 hf_hub_download( 

152 repo_id=entry.hf_repo, 

153 filename=mmproj_filename, 

154 cache_dir=str(cfg.models_dir), 

155 token=hf_token(), 

156 tqdm_class=tracker.make_tqdm_class() if tracker else None, 

157 ) 

158 ) 

159 if on_progress is not None and (not tracker or not tracker.was_used): 

160 # Cache hit: HF returned the cached path without invoking tqdm. 

161 size = path.stat().st_size 

162 on_progress(size, size) 

163 return path 

164 

165 

166def _resolve_mmproj_filename(hf_repo: str, pattern: str) -> str | None: 

167 """Resolve an mmproj filename pattern to a concrete filename via the HF API.""" 

168 if "*" not in pattern: 

169 return pattern 

170 

171 try: 

172 resp = httpx.get( 

173 f"https://huggingface.co/api/models/{hf_repo}", 

174 timeout=DEFAULT_TIMEOUT, 

175 headers=hf_headers(), 

176 ) 

177 resp.raise_for_status() 

178 siblings = resp.json().get("siblings", []) 

179 except Exception as exc: 

180 log.warning("Cannot query mmproj files for %s: %s", hf_repo, exc) 

181 return None 

182 

183 mmproj_files: list[str] = [ 

184 s.get("rfilename", "") for s in siblings if fnmatch.fnmatch(s.get("rfilename", ""), pattern) 

185 ] 

186 if not mmproj_files: 

187 return None 

188 

189 # Prefer F16 over F32 (smaller), and any over BF16 

190 for preference in ("f16", "F16"): 

191 for f in mmproj_files: 

192 if preference in f: 

193 return f 

194 return mmproj_files[0] 

195 

196 

197def _mmproj_in_models_dir_matching(pattern: str) -> Path | None: 

198 """Return the first ``*.gguf`` under ``cfg.models_dir`` that matches.""" 

199 models_dir: Path = cfg.models_dir 

200 for p in models_dir.rglob("*.gguf"): 

201 if fnmatch.fnmatch(p.name, pattern) or "mmproj" in p.name.lower(): 

202 return p 

203 return None 

204 

205 

206def find_mmproj_file(model_ref: str) -> Path | None: 

207 """Find the mmproj for a ``FEATURED_VISION`` entry under ``cfg.models_dir``. 

208 

209 *model_ref* is matched against each featured vision entry's 

210 ``hf_repo``. Returns ``None`` when nothing matches. Never falls back 

211 to an arbitrary mmproj: that cross-contaminates non-vision chat 

212 models (e.g. a chat model would inherit a vision model's mmproj and 

213 be misreported as vision-capable). 

214 """ 

215 # Local import to avoid pulling featured.py into hf_client/ etc. 

216 from lilbee.catalog.featured import FEATURED_VISION 

217 

218 if not cfg.models_dir.exists(): 

219 return None 

220 for entry in FEATURED_VISION: 

221 if model_ref not in entry.hf_repo and entry.hf_repo not in model_ref: 

222 continue 

223 pattern = VISION_MMPROJ_FILES.get(entry.hf_repo, DEFAULT_MMPROJ_PATTERN) 

224 match = _mmproj_in_models_dir_matching(pattern) 

225 if match is not None: 

226 return match 

227 return None 

228 

229 

230_QUANT_PREFERENCE = ("Q4_K_M", "Q4_K_S", "Q5_K_M", "Q5_K_S", "Q8_0", "Q6_K", "Q3_K_M") 

231 

232 

233def resolve_filename(entry: CatalogModel) -> str: 

234 """Resolve a GGUF filename pattern to the best concrete filename. 

235 For exact filenames, return as-is. For wildcards, query the HF API 

236 and pick the best quantization (prefer Q4_K_M for balance of size/quality). 

237 """ 

238 if "*" not in entry.gguf_filename: 

239 return entry.gguf_filename 

240 

241 try: 

242 resp = httpx.get( 

243 f"https://huggingface.co/api/models/{entry.hf_repo}", 

244 timeout=DEFAULT_TIMEOUT, 

245 headers=hf_headers(), 

246 ) 

247 if resp.status_code == HTTPStatus.UNAUTHORIZED: 

248 raise PermissionError( 

249 f"{entry.hf_repo} requires HuggingFace authentication. " 

250 "Set HF_TOKEN env var or visit the repo page to request access." 

251 ) 

252 resp.raise_for_status() 

253 siblings = resp.json().get("siblings", []) 

254 except PermissionError: 

255 raise 

256 except Exception as exc: 

257 raise RuntimeError(f"Cannot query files for {entry.hf_repo}: {exc}") from exc 

258 

259 gguf_files = [ 

260 s.get("rfilename", "") for s in siblings if s.get("rfilename", "").endswith(".gguf") 

261 ] 

262 if not gguf_files: 

263 raise RuntimeError(f"No GGUF files found in {entry.hf_repo}") 

264 

265 return _pick_best_gguf(gguf_files) 

266 

267 

268def _pick_best_gguf(filenames: list[str]) -> str: 

269 """Pick the best GGUF file by quantization preference.""" 

270 for quant in _QUANT_PREFERENCE: 

271 for f in filenames: 

272 if quant in f: 

273 return f 

274 return filenames[0] 

275 

276 

277def fetch_model_file_size(hf_repo: str) -> float: 

278 """Fetch the best GGUF file size from HuggingFace tree API. 

279 Returns size in GB, or 0.0 if unavailable. 

280 """ 

281 try: 

282 resp = httpx.get( 

283 f"https://huggingface.co/api/models/{hf_repo}/tree/main", 

284 timeout=DEFAULT_TIMEOUT, 

285 headers=hf_headers(), 

286 ) 

287 resp.raise_for_status() 

288 files = resp.json() 

289 except Exception: 

290 return 0.0 

291 

292 gguf_files = [ 

293 (f.get("path", ""), f.get("size", 0) or f.get("lfs", {}).get("size", 0)) 

294 for f in files 

295 if isinstance(f, dict) and f.get("path", "").endswith(".gguf") 

296 ] 

297 if not gguf_files: 

298 return 0.0 

299 

300 best_name = _pick_best_gguf([name for name, _ in gguf_files]) 

301 size_bytes = next((s for n, s in gguf_files if n == best_name), 0) 

302 return round(size_bytes / (1024**3), 1) if size_bytes else 0.0