Coverage for src / lilbee / server / handlers / models.py: 100%
200 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Model catalog, role assignment, install/delete, and external listing handlers."""
3from __future__ import annotations
5import asyncio
6import logging
7import time
8from collections.abc import AsyncGenerator
9from typing import TYPE_CHECKING, Any, Literal
11from pydantic import BaseModel
13from lilbee.app.services import get_services
14from lilbee.catalog import (
15 FEATURED_CHAT,
16 FEATURED_EMBEDDING,
17 FEATURED_RERANK,
18 FEATURED_VISION,
19 ModelFamily,
20 enrich_catalog,
21 find_catalog_entry,
22 get_catalog,
23 get_families,
24)
25from lilbee.catalog.refs import hf_repo_from_ref
26from lilbee.catalog.types import ModelSource, ModelTask
27from lilbee.core import settings
28from lilbee.core.config import cfg, validate_model_task_assignment
29from lilbee.core.config.validators import _MODEL_FIELD_TO_TASK
30from lilbee.providers.model_ref import parse_model_ref
31from lilbee.runtime.hardware import (
32 FitLevel,
33 SizeVariantInfo,
34 available_memory_for_fit,
35 compute_fit,
36 family_size_variants,
37)
38from lilbee.runtime.progress import SseEvent
39from lilbee.server.handlers.sse import SseStream, sse_error, sse_event
40from lilbee.server.models import (
41 CatalogEntryResponse,
42 ExternalModelsResponse,
43 InstalledModelEntry,
44 ModelsCatalogResponse,
45 ModelsDeleteResponse,
46 ModelsInstalledResponse,
47 ModelsShowResponse,
48 SetModelResponse,
49)
51if TYPE_CHECKING:
52 from lilbee.catalog import CatalogModel
53 from lilbee.catalog.formatting import EnrichedModel
55log = logging.getLogger(__name__)
58class ModelCatalogEntry(BaseModel):
59 """A single model in the catalog."""
61 name: str
62 size_gb: float
63 min_ram_gb: float
64 description: str
65 installed: bool
68class ModelCatalogSection(BaseModel):
69 """A single-role catalog section with active model and installed list."""
71 active: str
72 catalog: list[ModelCatalogEntry]
73 installed: list[str]
76class ModelsResponse(BaseModel):
77 """Response for GET /api/models: one catalog section per role."""
79 chat: ModelCatalogSection
80 embedding: ModelCatalogSection
81 vision: ModelCatalogSection
82 reranker: ModelCatalogSection
85# ``ModelTask.RERANK.value`` is ``"rerank"`` but the route is ``/api/models/reranker``,
86# so this mapping is needed to build correct redirect URLs in 422 responses.
87TASK_ENDPOINT_PATH: dict[ModelTask, str] = {
88 ModelTask.CHAT: "chat",
89 ModelTask.EMBEDDING: "embedding",
90 ModelTask.VISION: "vision",
91 ModelTask.RERANK: "reranker",
92}
95def format_task_mismatch(ref: str, entry_task: ModelTask, expected_task: ModelTask) -> str:
96 """Build the 422 body when a role slot is assigned a model of the wrong task."""
97 endpoint = TASK_ENDPOINT_PATH[entry_task]
98 return (
99 f"Model '{ref}' is a {entry_task} model, not {expected_task}. "
100 f"Set it via PUT /api/models/{endpoint} instead."
101 )
104def _catalog_section(
105 featured: tuple[CatalogModel, ...],
106 active: str,
107 installed: set[str],
108) -> ModelCatalogSection:
109 """Build a ModelCatalogSection from a featured-catalog tuple.
111 A featured row is "installed" when at least one quant of its
112 ``hf_repo`` has a manifest. Installed refs are full
113 ``hf_repo/filename`` strings, so the membership test compares the
114 leading ``hf_repo`` segment. Bare ``hf_repo`` entries are accepted
115 too (e.g. older clients that report just the repo).
116 """
117 installed_repos = {hf_repo_from_ref(ref) for ref in installed}
118 return ModelCatalogSection(
119 active=active,
120 catalog=[
121 ModelCatalogEntry(
122 name=m.display_name,
123 size_gb=m.size_gb,
124 min_ram_gb=m.min_ram_gb,
125 description=m.description,
126 installed=m.hf_repo in installed_repos,
127 )
128 for m in featured
129 ],
130 installed=sorted(installed),
131 )
134async def list_models() -> ModelsResponse:
135 """Return per-role catalogs (chat, embedding, vision, reranker) with active selections.
137 Uses the unfiltered installed set so a single ref lights up in every
138 catalog section it legitimately matches.
139 """
140 installed = set(get_services().model_manager.list_installed())
142 return ModelsResponse(
143 chat=_catalog_section(FEATURED_CHAT, cfg.chat_model, installed),
144 embedding=_catalog_section(FEATURED_EMBEDDING, cfg.embedding_model, installed),
145 vision=_catalog_section(FEATURED_VISION, cfg.vision_model, installed),
146 reranker=_catalog_section(FEATURED_RERANK, cfg.reranker_model, installed),
147 )
150async def _set_model(
151 field: Literal["chat_model", "embedding_model", "vision_model", "reranker_model"],
152 model: str,
153) -> SetModelResponse:
154 """Shared helper for switching a model field."""
155 setattr(cfg, field, model)
156 settings.set_value(cfg.data_root, field, model)
157 return SetModelResponse(model=model)
160def _resolve_via_catalog(model: str, available: set[str]) -> str | None:
161 """Resolve a bare ``hf_repo`` to whichever quant of it is in *available*."""
162 entry = find_catalog_entry(model)
163 if entry is None:
164 return None
165 return next((ref for ref in available if ref.startswith(f"{entry.hf_repo}/")), None)
168def _resolve_via_parse(model: str, available: set[str]) -> str | None:
169 """Resolve a provider-prefixed ref to its bare provider name in *available*."""
170 try:
171 parsed = parse_model_ref(model)
172 except ValueError:
173 return None
174 return parsed.name if parsed.name in available else None
177def _require_model_available(model: str) -> str:
178 """Return the installed-and-routable form of *model*, or raise."""
179 not_available = ValueError(
180 f"Model '{model}' is not available. Pull it first or check the name."
181 )
182 if not model:
183 raise not_available
184 available = set(get_services().provider.list_models())
185 if model in available:
186 return model
187 hit = _resolve_via_catalog(model, available) or _resolve_via_parse(model, available)
188 if hit is None:
189 raise not_available
190 return hit
193def _build_task_to_field() -> dict[ModelTask, str]:
194 """Invert config's ``_MODEL_FIELD_TO_TASK`` so the two maps stay in sync."""
195 return {ModelTask(task): field for field, task in _MODEL_FIELD_TO_TASK.items()}
198_TASK_TO_FIELD: dict[ModelTask, str] = _build_task_to_field()
201def _require_model_for_task(model: str, expected: ModelTask, *, allow_empty: bool = False) -> str:
202 """Validate *model* is installed locally AND passes the catalog task check.
204 Empty string unsets the role when *allow_empty* is True. Catalog +
205 task validation delegates to ``validate_model_task_assignment`` so
206 the handler and config paths share a single implementation.
207 """
208 if allow_empty and not model.strip():
209 return ""
210 normalized = _require_model_available(model)
211 return validate_model_task_assignment(_TASK_TO_FIELD[expected], normalized, allow_bypass=False)
214async def set_chat_model(model: str) -> SetModelResponse:
215 """Switch active chat model. Validates installation and catalog task."""
216 normalized = _require_model_for_task(model, ModelTask.CHAT)
217 return await _set_model("chat_model", normalized)
220async def set_embedding_model(model: str) -> SetModelResponse:
221 """Switch embedding model. Validates installation and catalog task.
223 Returns ``reindex_required=True`` when the new model differs from the
224 embedding model that built the persisted vector store. The caller is
225 expected to trigger a rebuild (``lilbee rebuild`` or ``POST /api/sync``
226 with ``force_rebuild=true``). Search and ingest will refuse to operate
227 until that happens.
229 Pins a legacy store's identity to the OLD cfg before mutating it. Without
230 this step, a pre-upgrade store with chunks but no ``_meta`` row would have
231 its meta lazy-initialized from the NEW cfg on the next read, hiding the
232 drift the caller just introduced.
233 """
234 from lilbee.data.store.lance_helpers import refs_compatible
236 normalized = _require_model_for_task(model, ModelTask.EMBEDDING)
237 store = get_services().store
238 store.initialize_meta_if_legacy()
239 await _set_model("embedding_model", normalized)
240 store.canonicalize_meta_if_legacy()
241 meta = store.get_meta()
242 reindex_required = meta is not None and not refs_compatible(
243 meta["embedding_model"], normalized, meta["embedding_dim"], meta["embedding_dim"]
244 )
245 return SetModelResponse(model=normalized, reindex_required=reindex_required)
248async def set_vision_model(model: str) -> SetModelResponse:
249 """Switch vision OCR model. Empty string unsets it (vision OCR disabled)."""
250 normalized = _require_model_for_task(model, ModelTask.VISION, allow_empty=True)
251 return await _set_model("vision_model", normalized)
254async def set_reranker_model(model: str) -> SetModelResponse:
255 """Switch reranker model. Empty string unsets it (reranking disabled)."""
256 normalized = _require_model_for_task(model, ModelTask.RERANK, allow_empty=True)
257 return await _set_model("reranker_model", normalized)
260async def models_show(model: str) -> ModelsShowResponse:
261 """Return model metadata/parameters. Returns empty model if unavailable."""
262 provider = get_services().provider
263 result = provider.show_model(model)
264 return ModelsShowResponse(**(result or {}))
267def _parse_source(source: str) -> ModelSource:
268 """Convert a source string to ModelSource enum."""
269 return ModelSource(source)
272_BYTES_PER_GB = 1024**3
275def _row_fit(enriched: EnrichedModel, available_bytes: int | None) -> FitLevel | None:
276 """Fit level for *enriched*, or None when host memory or row size can't be measured."""
277 if available_bytes is None:
278 return None
279 if enriched.source != ModelSource.NATIVE.value:
280 return None
281 if enriched.size_gb <= 0:
282 return None
283 return compute_fit(int(enriched.size_gb * _BYTES_PER_GB), available_bytes).level
286def _families_by_repo() -> dict[str, ModelFamily]:
287 """Index featured ModelFamilies by every variant's ``hf_repo`` for size-variant lookup."""
288 index: dict[str, ModelFamily] = {}
289 for family in get_families():
290 for variant in family.variants:
291 index[variant.hf_repo] = family
292 return index
295def _row_size_variants(
296 enriched: EnrichedModel, families_by_repo: dict[str, ModelFamily]
297) -> list[SizeVariantInfo]:
298 """Size-variant strip for *enriched*; empty when the row isn't part of a family."""
299 family = families_by_repo.get(enriched.hf_repo)
300 if family is None:
301 return []
302 return family_size_variants(family)
305def _build_catalog_entry(
306 enriched: EnrichedModel,
307 *,
308 available_bytes: int | None,
309 families_by_repo: dict[str, ModelFamily],
310) -> CatalogEntryResponse:
311 """Translate one enriched catalog model into its HTTP response row."""
312 return CatalogEntryResponse(
313 hf_repo=enriched.hf_repo,
314 gguf_filename=enriched.gguf_filename,
315 task=enriched.task,
316 display_name=enriched.display_name,
317 param_count=enriched.param_count,
318 size_gb=enriched.size_gb,
319 min_ram_gb=enriched.min_ram_gb,
320 description=enriched.description,
321 quality_tier=enriched.quality_tier,
322 featured=enriched.featured,
323 downloads=enriched.downloads,
324 installed=enriched.installed,
325 source=enriched.source,
326 fit=_row_fit(enriched, available_bytes),
327 size_variants=_row_size_variants(enriched, families_by_repo),
328 )
331async def models_catalog(
332 task: str | None = None,
333 search: str = "",
334 size: str | None = None,
335 installed: bool | None = None,
336 featured: bool | None = None,
337 sort: str = "featured",
338 limit: int = 20,
339 offset: int = 0,
340) -> ModelsCatalogResponse:
341 """Return paginated model catalog with installed status."""
342 # task arrives as a raw query-string; validate against the closed enum
343 # at the HTTP boundary instead of letting an unknown value silently
344 # short-circuit the catalog filter inside.
345 parsed_task = ModelTask(task) if task else None
346 result = get_catalog(
347 task=parsed_task,
348 search=search,
349 size=size,
350 installed=installed,
351 featured=featured,
352 sort=sort,
353 limit=limit,
354 offset=offset,
355 )
357 registry = get_services().registry
358 installed_refs = {m.ref for m in registry.list_installed()}
359 enriched = enrich_catalog(result, installed_refs)
361 available_bytes = available_memory_for_fit()
362 families_by_repo = _families_by_repo()
364 return ModelsCatalogResponse(
365 total=result.total,
366 limit=result.limit,
367 offset=result.offset,
368 has_more=result.has_more,
369 models=[
370 _build_catalog_entry(
371 e, available_bytes=available_bytes, families_by_repo=families_by_repo
372 )
373 for e in enriched
374 ],
375 )
378async def models_installed() -> ModelsInstalledResponse:
379 """Return list of installed models with their source."""
380 manager = get_services().model_manager
381 names = manager.list_installed()
382 models = []
383 for name in names:
384 src = manager.get_source(name)
385 models.append(InstalledModelEntry(name=name, source=src or ModelSource.REMOTE))
386 return ModelsInstalledResponse(models=models)
389async def models_pull(model: str, *, source: str = "native") -> AsyncGenerator[str, None]:
390 """Yield SSE progress events while pulling a model in real time.
391 Sets a cancel event on client disconnect so the pull stops.
392 """
393 manager = get_services().model_manager
394 src = _parse_source(source)
395 sse = SseStream()
397 def _pull_blocking() -> None:
398 def _on_progress(data: dict[str, Any]) -> None:
399 if sse.cancel.is_set():
400 return
401 payload = sse_event(SseEvent.PROGRESS, data)
402 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, payload)
404 def _on_bytes(downloaded: int, total: int) -> None:
405 if sse.cancel.is_set():
406 return
407 payload = sse_event(SseEvent.PROGRESS, {"current": downloaded, "total": total})
408 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, payload)
410 try:
411 manager.pull(model, src, on_progress=_on_progress, on_bytes=_on_bytes)
412 except Exception as exc:
413 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, sse_error(str(exc)))
414 finally:
415 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, None)
417 task = asyncio.ensure_future(asyncio.to_thread(_pull_blocking))
418 async for event in sse.drain(task, "Model pull stream"):
419 yield event
422async def models_delete(model: str, *, source: str = "native") -> ModelsDeleteResponse:
423 """Delete a model. Returns deletion status, model name, and freed space."""
424 manager = get_services().model_manager
425 src = _parse_source(source)
426 deleted = manager.remove(model, src)
427 return ModelsDeleteResponse(deleted=deleted, model=model, freed_gb=0.0)
430_EXTERNAL_MODELS_TTL = 60
433class _ExternalModelsCache:
434 """TTL cache for external model listings (no module-level mutable global)."""
436 def __init__(self) -> None:
437 self._time: float = 0.0
438 self._key: str = ""
439 self._result: ExternalModelsResponse | None = None
441 def get(self, key: str) -> ExternalModelsResponse | None:
442 now = time.monotonic()
443 if self._result and key == self._key and (now - self._time) < _EXTERNAL_MODELS_TTL:
444 return self._result
445 return None
447 def set(self, key: str, result: ExternalModelsResponse) -> None:
448 self._time = time.monotonic()
449 self._key = key
450 self._result = result
453_external_cache = _ExternalModelsCache()
456async def list_external_models() -> ExternalModelsResponse:
457 """Query the provider for available models via its list_models() API."""
458 key = f"{cfg.remote_base_url}:{cfg.llm_api_key or ''}"
459 cached = _external_cache.get(key)
460 if cached:
461 return cached
463 try:
464 models = await asyncio.to_thread(get_services().provider.list_models)
465 result = ExternalModelsResponse(models=models)
466 _external_cache.set(key, result)
467 return result
468 except Exception as exc:
469 log.warning("Failed to list external models: %s", exc)
470 return ExternalModelsResponse(models=[], error=str(exc))