Coverage for src / lilbee / server / handlers / models.py: 100%

200 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Model catalog, role assignment, install/delete, and external listing handlers.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7import time 

8from collections.abc import AsyncGenerator 

9from typing import TYPE_CHECKING, Any, Literal 

10 

11from pydantic import BaseModel 

12 

13from lilbee.app.services import get_services 

14from lilbee.catalog import ( 

15 FEATURED_CHAT, 

16 FEATURED_EMBEDDING, 

17 FEATURED_RERANK, 

18 FEATURED_VISION, 

19 ModelFamily, 

20 enrich_catalog, 

21 find_catalog_entry, 

22 get_catalog, 

23 get_families, 

24) 

25from lilbee.catalog.refs import hf_repo_from_ref 

26from lilbee.catalog.types import ModelSource, ModelTask 

27from lilbee.core import settings 

28from lilbee.core.config import cfg, validate_model_task_assignment 

29from lilbee.core.config.validators import _MODEL_FIELD_TO_TASK 

30from lilbee.providers.model_ref import parse_model_ref 

31from lilbee.runtime.hardware import ( 

32 FitLevel, 

33 SizeVariantInfo, 

34 available_memory_for_fit, 

35 compute_fit, 

36 family_size_variants, 

37) 

38from lilbee.runtime.progress import SseEvent 

39from lilbee.server.handlers.sse import SseStream, sse_error, sse_event 

40from lilbee.server.models import ( 

41 CatalogEntryResponse, 

42 ExternalModelsResponse, 

43 InstalledModelEntry, 

44 ModelsCatalogResponse, 

45 ModelsDeleteResponse, 

46 ModelsInstalledResponse, 

47 ModelsShowResponse, 

48 SetModelResponse, 

49) 

50 

51if TYPE_CHECKING: 

52 from lilbee.catalog import CatalogModel 

53 from lilbee.catalog.formatting import EnrichedModel 

54 

55log = logging.getLogger(__name__) 

56 

57 

58class ModelCatalogEntry(BaseModel): 

59 """A single model in the catalog.""" 

60 

61 name: str 

62 size_gb: float 

63 min_ram_gb: float 

64 description: str 

65 installed: bool 

66 

67 

68class ModelCatalogSection(BaseModel): 

69 """A single-role catalog section with active model and installed list.""" 

70 

71 active: str 

72 catalog: list[ModelCatalogEntry] 

73 installed: list[str] 

74 

75 

76class ModelsResponse(BaseModel): 

77 """Response for GET /api/models: one catalog section per role.""" 

78 

79 chat: ModelCatalogSection 

80 embedding: ModelCatalogSection 

81 vision: ModelCatalogSection 

82 reranker: ModelCatalogSection 

83 

84 

85# ``ModelTask.RERANK.value`` is ``"rerank"`` but the route is ``/api/models/reranker``, 

86# so this mapping is needed to build correct redirect URLs in 422 responses. 

87TASK_ENDPOINT_PATH: dict[ModelTask, str] = { 

88 ModelTask.CHAT: "chat", 

89 ModelTask.EMBEDDING: "embedding", 

90 ModelTask.VISION: "vision", 

91 ModelTask.RERANK: "reranker", 

92} 

93 

94 

95def format_task_mismatch(ref: str, entry_task: ModelTask, expected_task: ModelTask) -> str: 

96 """Build the 422 body when a role slot is assigned a model of the wrong task.""" 

97 endpoint = TASK_ENDPOINT_PATH[entry_task] 

98 return ( 

99 f"Model '{ref}' is a {entry_task} model, not {expected_task}. " 

100 f"Set it via PUT /api/models/{endpoint} instead." 

101 ) 

102 

103 

104def _catalog_section( 

105 featured: tuple[CatalogModel, ...], 

106 active: str, 

107 installed: set[str], 

108) -> ModelCatalogSection: 

109 """Build a ModelCatalogSection from a featured-catalog tuple. 

110 

111 A featured row is "installed" when at least one quant of its 

112 ``hf_repo`` has a manifest. Installed refs are full 

113 ``hf_repo/filename`` strings, so the membership test compares the 

114 leading ``hf_repo`` segment. Bare ``hf_repo`` entries are accepted 

115 too (e.g. older clients that report just the repo). 

116 """ 

117 installed_repos = {hf_repo_from_ref(ref) for ref in installed} 

118 return ModelCatalogSection( 

119 active=active, 

120 catalog=[ 

121 ModelCatalogEntry( 

122 name=m.display_name, 

123 size_gb=m.size_gb, 

124 min_ram_gb=m.min_ram_gb, 

125 description=m.description, 

126 installed=m.hf_repo in installed_repos, 

127 ) 

128 for m in featured 

129 ], 

130 installed=sorted(installed), 

131 ) 

132 

133 

134async def list_models() -> ModelsResponse: 

135 """Return per-role catalogs (chat, embedding, vision, reranker) with active selections. 

136 

137 Uses the unfiltered installed set so a single ref lights up in every 

138 catalog section it legitimately matches. 

139 """ 

140 installed = set(get_services().model_manager.list_installed()) 

141 

142 return ModelsResponse( 

143 chat=_catalog_section(FEATURED_CHAT, cfg.chat_model, installed), 

144 embedding=_catalog_section(FEATURED_EMBEDDING, cfg.embedding_model, installed), 

145 vision=_catalog_section(FEATURED_VISION, cfg.vision_model, installed), 

146 reranker=_catalog_section(FEATURED_RERANK, cfg.reranker_model, installed), 

147 ) 

148 

149 

150async def _set_model( 

151 field: Literal["chat_model", "embedding_model", "vision_model", "reranker_model"], 

152 model: str, 

153) -> SetModelResponse: 

154 """Shared helper for switching a model field.""" 

155 setattr(cfg, field, model) 

156 settings.set_value(cfg.data_root, field, model) 

157 return SetModelResponse(model=model) 

158 

159 

160def _resolve_via_catalog(model: str, available: set[str]) -> str | None: 

161 """Resolve a bare ``hf_repo`` to whichever quant of it is in *available*.""" 

162 entry = find_catalog_entry(model) 

163 if entry is None: 

164 return None 

165 return next((ref for ref in available if ref.startswith(f"{entry.hf_repo}/")), None) 

166 

167 

168def _resolve_via_parse(model: str, available: set[str]) -> str | None: 

169 """Resolve a provider-prefixed ref to its bare provider name in *available*.""" 

170 try: 

171 parsed = parse_model_ref(model) 

172 except ValueError: 

173 return None 

174 return parsed.name if parsed.name in available else None 

175 

176 

177def _require_model_available(model: str) -> str: 

178 """Return the installed-and-routable form of *model*, or raise.""" 

179 not_available = ValueError( 

180 f"Model '{model}' is not available. Pull it first or check the name." 

181 ) 

182 if not model: 

183 raise not_available 

184 available = set(get_services().provider.list_models()) 

185 if model in available: 

186 return model 

187 hit = _resolve_via_catalog(model, available) or _resolve_via_parse(model, available) 

188 if hit is None: 

189 raise not_available 

190 return hit 

191 

192 

193def _build_task_to_field() -> dict[ModelTask, str]: 

194 """Invert config's ``_MODEL_FIELD_TO_TASK`` so the two maps stay in sync.""" 

195 return {ModelTask(task): field for field, task in _MODEL_FIELD_TO_TASK.items()} 

196 

197 

198_TASK_TO_FIELD: dict[ModelTask, str] = _build_task_to_field() 

199 

200 

201def _require_model_for_task(model: str, expected: ModelTask, *, allow_empty: bool = False) -> str: 

202 """Validate *model* is installed locally AND passes the catalog task check. 

203 

204 Empty string unsets the role when *allow_empty* is True. Catalog + 

205 task validation delegates to ``validate_model_task_assignment`` so 

206 the handler and config paths share a single implementation. 

207 """ 

208 if allow_empty and not model.strip(): 

209 return "" 

210 normalized = _require_model_available(model) 

211 return validate_model_task_assignment(_TASK_TO_FIELD[expected], normalized, allow_bypass=False) 

212 

213 

214async def set_chat_model(model: str) -> SetModelResponse: 

215 """Switch active chat model. Validates installation and catalog task.""" 

216 normalized = _require_model_for_task(model, ModelTask.CHAT) 

217 return await _set_model("chat_model", normalized) 

218 

219 

220async def set_embedding_model(model: str) -> SetModelResponse: 

221 """Switch embedding model. Validates installation and catalog task. 

222 

223 Returns ``reindex_required=True`` when the new model differs from the 

224 embedding model that built the persisted vector store. The caller is 

225 expected to trigger a rebuild (``lilbee rebuild`` or ``POST /api/sync`` 

226 with ``force_rebuild=true``). Search and ingest will refuse to operate 

227 until that happens. 

228 

229 Pins a legacy store's identity to the OLD cfg before mutating it. Without 

230 this step, a pre-upgrade store with chunks but no ``_meta`` row would have 

231 its meta lazy-initialized from the NEW cfg on the next read, hiding the 

232 drift the caller just introduced. 

233 """ 

234 from lilbee.data.store.lance_helpers import refs_compatible 

235 

236 normalized = _require_model_for_task(model, ModelTask.EMBEDDING) 

237 store = get_services().store 

238 store.initialize_meta_if_legacy() 

239 await _set_model("embedding_model", normalized) 

240 store.canonicalize_meta_if_legacy() 

241 meta = store.get_meta() 

242 reindex_required = meta is not None and not refs_compatible( 

243 meta["embedding_model"], normalized, meta["embedding_dim"], meta["embedding_dim"] 

244 ) 

245 return SetModelResponse(model=normalized, reindex_required=reindex_required) 

246 

247 

248async def set_vision_model(model: str) -> SetModelResponse: 

249 """Switch vision OCR model. Empty string unsets it (vision OCR disabled).""" 

250 normalized = _require_model_for_task(model, ModelTask.VISION, allow_empty=True) 

251 return await _set_model("vision_model", normalized) 

252 

253 

254async def set_reranker_model(model: str) -> SetModelResponse: 

255 """Switch reranker model. Empty string unsets it (reranking disabled).""" 

256 normalized = _require_model_for_task(model, ModelTask.RERANK, allow_empty=True) 

257 return await _set_model("reranker_model", normalized) 

258 

259 

260async def models_show(model: str) -> ModelsShowResponse: 

261 """Return model metadata/parameters. Returns empty model if unavailable.""" 

262 provider = get_services().provider 

263 result = provider.show_model(model) 

264 return ModelsShowResponse(**(result or {})) 

265 

266 

267def _parse_source(source: str) -> ModelSource: 

268 """Convert a source string to ModelSource enum.""" 

269 return ModelSource(source) 

270 

271 

272_BYTES_PER_GB = 1024**3 

273 

274 

275def _row_fit(enriched: EnrichedModel, available_bytes: int | None) -> FitLevel | None: 

276 """Fit level for *enriched*, or None when host memory or row size can't be measured.""" 

277 if available_bytes is None: 

278 return None 

279 if enriched.source != ModelSource.NATIVE.value: 

280 return None 

281 if enriched.size_gb <= 0: 

282 return None 

283 return compute_fit(int(enriched.size_gb * _BYTES_PER_GB), available_bytes).level 

284 

285 

286def _families_by_repo() -> dict[str, ModelFamily]: 

287 """Index featured ModelFamilies by every variant's ``hf_repo`` for size-variant lookup.""" 

288 index: dict[str, ModelFamily] = {} 

289 for family in get_families(): 

290 for variant in family.variants: 

291 index[variant.hf_repo] = family 

292 return index 

293 

294 

295def _row_size_variants( 

296 enriched: EnrichedModel, families_by_repo: dict[str, ModelFamily] 

297) -> list[SizeVariantInfo]: 

298 """Size-variant strip for *enriched*; empty when the row isn't part of a family.""" 

299 family = families_by_repo.get(enriched.hf_repo) 

300 if family is None: 

301 return [] 

302 return family_size_variants(family) 

303 

304 

305def _build_catalog_entry( 

306 enriched: EnrichedModel, 

307 *, 

308 available_bytes: int | None, 

309 families_by_repo: dict[str, ModelFamily], 

310) -> CatalogEntryResponse: 

311 """Translate one enriched catalog model into its HTTP response row.""" 

312 return CatalogEntryResponse( 

313 hf_repo=enriched.hf_repo, 

314 gguf_filename=enriched.gguf_filename, 

315 task=enriched.task, 

316 display_name=enriched.display_name, 

317 param_count=enriched.param_count, 

318 size_gb=enriched.size_gb, 

319 min_ram_gb=enriched.min_ram_gb, 

320 description=enriched.description, 

321 quality_tier=enriched.quality_tier, 

322 featured=enriched.featured, 

323 downloads=enriched.downloads, 

324 installed=enriched.installed, 

325 source=enriched.source, 

326 fit=_row_fit(enriched, available_bytes), 

327 size_variants=_row_size_variants(enriched, families_by_repo), 

328 ) 

329 

330 

331async def models_catalog( 

332 task: str | None = None, 

333 search: str = "", 

334 size: str | None = None, 

335 installed: bool | None = None, 

336 featured: bool | None = None, 

337 sort: str = "featured", 

338 limit: int = 20, 

339 offset: int = 0, 

340) -> ModelsCatalogResponse: 

341 """Return paginated model catalog with installed status.""" 

342 # task arrives as a raw query-string; validate against the closed enum 

343 # at the HTTP boundary instead of letting an unknown value silently 

344 # short-circuit the catalog filter inside. 

345 parsed_task = ModelTask(task) if task else None 

346 result = get_catalog( 

347 task=parsed_task, 

348 search=search, 

349 size=size, 

350 installed=installed, 

351 featured=featured, 

352 sort=sort, 

353 limit=limit, 

354 offset=offset, 

355 ) 

356 

357 registry = get_services().registry 

358 installed_refs = {m.ref for m in registry.list_installed()} 

359 enriched = enrich_catalog(result, installed_refs) 

360 

361 available_bytes = available_memory_for_fit() 

362 families_by_repo = _families_by_repo() 

363 

364 return ModelsCatalogResponse( 

365 total=result.total, 

366 limit=result.limit, 

367 offset=result.offset, 

368 has_more=result.has_more, 

369 models=[ 

370 _build_catalog_entry( 

371 e, available_bytes=available_bytes, families_by_repo=families_by_repo 

372 ) 

373 for e in enriched 

374 ], 

375 ) 

376 

377 

378async def models_installed() -> ModelsInstalledResponse: 

379 """Return list of installed models with their source.""" 

380 manager = get_services().model_manager 

381 names = manager.list_installed() 

382 models = [] 

383 for name in names: 

384 src = manager.get_source(name) 

385 models.append(InstalledModelEntry(name=name, source=src or ModelSource.REMOTE)) 

386 return ModelsInstalledResponse(models=models) 

387 

388 

389async def models_pull(model: str, *, source: str = "native") -> AsyncGenerator[str, None]: 

390 """Yield SSE progress events while pulling a model in real time. 

391 Sets a cancel event on client disconnect so the pull stops. 

392 """ 

393 manager = get_services().model_manager 

394 src = _parse_source(source) 

395 sse = SseStream() 

396 

397 def _pull_blocking() -> None: 

398 def _on_progress(data: dict[str, Any]) -> None: 

399 if sse.cancel.is_set(): 

400 return 

401 payload = sse_event(SseEvent.PROGRESS, data) 

402 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, payload) 

403 

404 def _on_bytes(downloaded: int, total: int) -> None: 

405 if sse.cancel.is_set(): 

406 return 

407 payload = sse_event(SseEvent.PROGRESS, {"current": downloaded, "total": total}) 

408 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, payload) 

409 

410 try: 

411 manager.pull(model, src, on_progress=_on_progress, on_bytes=_on_bytes) 

412 except Exception as exc: 

413 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, sse_error(str(exc))) 

414 finally: 

415 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, None) 

416 

417 task = asyncio.ensure_future(asyncio.to_thread(_pull_blocking)) 

418 async for event in sse.drain(task, "Model pull stream"): 

419 yield event 

420 

421 

422async def models_delete(model: str, *, source: str = "native") -> ModelsDeleteResponse: 

423 """Delete a model. Returns deletion status, model name, and freed space.""" 

424 manager = get_services().model_manager 

425 src = _parse_source(source) 

426 deleted = manager.remove(model, src) 

427 return ModelsDeleteResponse(deleted=deleted, model=model, freed_gb=0.0) 

428 

429 

430_EXTERNAL_MODELS_TTL = 60 

431 

432 

433class _ExternalModelsCache: 

434 """TTL cache for external model listings (no module-level mutable global).""" 

435 

436 def __init__(self) -> None: 

437 self._time: float = 0.0 

438 self._key: str = "" 

439 self._result: ExternalModelsResponse | None = None 

440 

441 def get(self, key: str) -> ExternalModelsResponse | None: 

442 now = time.monotonic() 

443 if self._result and key == self._key and (now - self._time) < _EXTERNAL_MODELS_TTL: 

444 return self._result 

445 return None 

446 

447 def set(self, key: str, result: ExternalModelsResponse) -> None: 

448 self._time = time.monotonic() 

449 self._key = key 

450 self._result = result 

451 

452 

453_external_cache = _ExternalModelsCache() 

454 

455 

456async def list_external_models() -> ExternalModelsResponse: 

457 """Query the provider for available models via its list_models() API.""" 

458 key = f"{cfg.remote_base_url}:{cfg.llm_api_key or ''}" 

459 cached = _external_cache.get(key) 

460 if cached: 

461 return cached 

462 

463 try: 

464 models = await asyncio.to_thread(get_services().provider.list_models) 

465 result = ExternalModelsResponse(models=models) 

466 _external_cache.set(key, result) 

467 return result 

468 except Exception as exc: 

469 log.warning("Failed to list external models: %s", exc) 

470 return ExternalModelsResponse(models=[], error=str(exc))