Coverage for src / lilbee / server / handlers / models.py: 100%
264 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Model catalog, role assignment, install/delete, and external listing handlers."""
3from __future__ import annotations
5import asyncio
6import logging
7import time
8from collections.abc import AsyncGenerator
9from typing import TYPE_CHECKING, Literal
11from pydantic import BaseModel
13from lilbee.app.services import get_services
14from lilbee.app.settings import apply_settings_update
15from lilbee.catalog import (
16 FEATURED_CHAT,
17 FEATURED_EMBEDDING,
18 FEATURED_RERANK,
19 FEATURED_VISION,
20 ModelFamily,
21 enrich_catalog,
22 find_catalog_entry,
23 get_catalog,
24 get_families,
25)
26from lilbee.catalog.refs import hf_repo_from_ref, is_bare_hf_repo
27from lilbee.catalog.types import CatalogSize, CatalogSort, KeyStatus, ModelSource, ModelTask
28from lilbee.core.config import cfg
29from lilbee.modelhub.model_manager import classify_all_remote_models, discover_api_models
30from lilbee.modelhub.model_manager.types import RemoteModel
31from lilbee.modelhub.role_validator import _MODEL_FIELD_TO_TASK, validate_model_task_assignment
32from lilbee.providers.local_servers import canonical_local_ref, local_server_for_label
33from lilbee.providers.model_ref import format_remote_ref, parse_model_ref
34from lilbee.providers.sdk_backend import PROVIDER_KEYS, get_provider_api_key
35from lilbee.runtime.hardware import (
36 FitLevel,
37 SizeVariantInfo,
38 available_memory_for_fit,
39 compute_fit,
40 family_size_variants,
41)
42from lilbee.runtime.progress import SseEvent
43from lilbee.server.handlers.sse import SseStream, sse_error, sse_event
44from lilbee.server.models import (
45 CatalogEntryResponse,
46 ExternalModelsResponse,
47 InstalledModelEntry,
48 ModelsCatalogResponse,
49 ModelsDeleteResponse,
50 ModelsInstalledResponse,
51 ModelsShowResponse,
52 SetModelResponse,
53)
55if TYPE_CHECKING:
56 from lilbee.catalog import CatalogModel
57 from lilbee.catalog.formatting import EnrichedModel
59log = logging.getLogger(__name__)
62class ModelCatalogEntry(BaseModel):
63 """A single model in the catalog."""
65 name: str
66 size_gb: float
67 min_ram_gb: float
68 description: str
69 installed: bool
72class ModelCatalogSection(BaseModel):
73 """A single-role catalog section with active model and installed list."""
75 active: str
76 catalog: list[ModelCatalogEntry]
77 installed: list[str]
80class ModelsResponse(BaseModel):
81 """Response for GET /api/models: one catalog section per role."""
83 chat: ModelCatalogSection
84 embedding: ModelCatalogSection
85 vision: ModelCatalogSection
86 reranker: ModelCatalogSection
89# ``ModelTask.RERANK.value`` is ``"rerank"`` but the route is ``/api/models/reranker``,
90# so this mapping is needed to build correct redirect URLs in 422 responses.
91TASK_ENDPOINT_PATH: dict[ModelTask, str] = {
92 ModelTask.CHAT: "chat",
93 ModelTask.EMBEDDING: "embedding",
94 ModelTask.VISION: "vision",
95 ModelTask.RERANK: "reranker",
96}
99def format_task_mismatch(ref: str, entry_task: ModelTask, expected_task: ModelTask) -> str:
100 """Build the 422 body when a role slot is assigned a model of the wrong task."""
101 endpoint = TASK_ENDPOINT_PATH[entry_task]
102 return (
103 f"Model '{ref}' is a {entry_task} model, not {expected_task}. "
104 f"Set it via PUT /api/models/{endpoint} instead."
105 )
108def _catalog_section(
109 featured: tuple[CatalogModel, ...],
110 active: str,
111 installed: set[str],
112) -> ModelCatalogSection:
113 """Build a ModelCatalogSection from a featured-catalog tuple.
115 A featured row is "installed" when at least one quant of its
116 ``hf_repo`` has a manifest. Installed refs are full
117 ``hf_repo/filename`` strings, so the membership test compares the
118 leading ``hf_repo`` segment. Bare ``hf_repo`` entries are accepted
119 too (e.g. older clients that report just the repo).
120 """
121 installed_repos = {hf_repo_from_ref(ref) for ref in installed}
122 return ModelCatalogSection(
123 active=active,
124 catalog=[
125 ModelCatalogEntry(
126 name=m.display_name,
127 size_gb=m.size_gb,
128 min_ram_gb=m.min_ram_gb,
129 description=m.description,
130 installed=m.hf_repo in installed_repos,
131 )
132 for m in featured
133 ],
134 installed=sorted(installed),
135 )
138async def list_models() -> ModelsResponse:
139 """Return per-role catalogs (chat, embedding, vision, reranker) with active selections.
141 Uses the unfiltered installed set so a single ref lights up in every
142 catalog section it legitimately matches.
143 """
144 installed = set(get_services().model_manager.list_installed())
146 return ModelsResponse(
147 chat=_catalog_section(FEATURED_CHAT, cfg.chat_model, installed),
148 embedding=_catalog_section(FEATURED_EMBEDDING, cfg.embedding_model, installed),
149 vision=_catalog_section(FEATURED_VISION, cfg.vision_model, installed),
150 reranker=_catalog_section(FEATURED_RERANK, cfg.reranker_model, installed),
151 )
154async def _set_model(
155 field: Literal["chat_model", "embedding_model", "vision_model", "reranker_model"],
156 model: str,
157) -> SetModelResponse:
158 """Persist a model field through the shared write boundary."""
159 apply_settings_update({field: model})
160 return SetModelResponse(model=model)
163def _resolve_via_catalog(model: str, available: set[str]) -> str | None:
164 """Resolve a bare ``hf_repo`` to whichever quant of it is in *available*.
166 Sorted scan so the pick is deterministic when several quants are installed.
167 """
168 entry = find_catalog_entry(model)
169 if entry is None:
170 return None
171 return next((ref for ref in sorted(available) if ref.startswith(f"{entry.hf_repo}/")), None)
174def _resolve_via_installed_repo(model: str, available: set[str]) -> str | None:
175 """Resolve a bare ``hf_repo`` to its installed quant, featured or not.
177 Only refs the provider also lists are accepted, so remote-only
178 provider modes don't activate a model they can't serve.
179 """
180 if not is_bare_hf_repo(model):
181 return None
182 ref = get_services().registry.installed_ref_for_repo(model)
183 return ref if ref in available else None
186def _resolve_via_parse(model: str, available: set[str]) -> str | None:
187 """Resolve a provider-prefixed ref against *available*.
189 The backend lists hosted models under bare names while selections carry the
190 routing prefix. When the bare name is visible, return the prefixed ref so it
191 keeps provider routing instead of falling through to the native check.
192 """
193 try:
194 parsed = parse_model_ref(model)
195 except ValueError:
196 return None
197 return model if parsed.name in available else None
200def _resolve_via_provider_key(model: str) -> str | None:
201 """Accept an API-provider-prefixed ref when that provider's key is configured.
203 Frontier models surface through ``discover_api_models``, not the default
204 ``list_models()``, so they never appear in *available*. With the key set,
205 litellm routes the ref (and validates the model name at call time).
206 """
207 try:
208 parsed = parse_model_ref(model)
209 except ValueError:
210 return None
211 if parsed.is_api and get_provider_api_key(parsed.provider) is not None:
212 return model
213 return None
216def _require_model_available(model: str) -> str:
217 """Return the installed-and-routable form of *model*, or raise."""
218 not_available = ValueError(
219 f"Model '{model}' is not available. Pull it first or check the name."
220 )
221 if not model:
222 raise not_available
223 available = set(get_services().provider.list_models())
224 if model in available:
225 return model
226 hit = (
227 _resolve_via_catalog(model, available)
228 or _resolve_via_installed_repo(model, available)
229 or _resolve_via_parse(model, available)
230 or _resolve_via_provider_key(model)
231 )
232 if hit is None:
233 raise not_available
234 return hit
237def _build_task_to_field() -> dict[ModelTask, str]:
238 """Invert config's ``_MODEL_FIELD_TO_TASK`` so the two maps stay in sync."""
239 return {ModelTask(task): field for field, task in _MODEL_FIELD_TO_TASK.items()}
242_TASK_TO_FIELD: dict[ModelTask, str] = _build_task_to_field()
245def _require_model_for_task(model: str, expected: ModelTask, *, allow_empty: bool = False) -> str:
246 """Validate *model* is installed locally AND passes the catalog task check.
248 Empty string unsets the role when *allow_empty* is True. Catalog +
249 task validation delegates to ``validate_model_task_assignment`` so
250 the handler and config paths share a single implementation.
251 """
252 if allow_empty and not model.strip():
253 return ""
254 normalized = _require_model_available(model)
255 return validate_model_task_assignment(_TASK_TO_FIELD[expected], normalized, allow_bypass=False)
258async def set_chat_model(model: str) -> SetModelResponse:
259 """Switch active chat model. Validates installation and catalog task."""
260 normalized = _require_model_for_task(model, ModelTask.CHAT)
261 return await _set_model("chat_model", normalized)
264async def set_embedding_model(model: str) -> SetModelResponse:
265 """Switch embedding model. Validates installation and catalog task.
267 Returns ``reindex_required=True`` when the new model differs from the
268 embedding model that built the persisted vector store. The caller is
269 expected to trigger a rebuild (``lilbee rebuild`` or ``POST /api/sync``
270 with ``force_rebuild=true``). Search and ingest will refuse to operate
271 until that happens. The settings boundary pins legacy store meta to
272 the OLD ref before the write and computes ``reindex_required`` after.
273 """
274 normalized = _require_model_for_task(model, ModelTask.EMBEDDING)
275 result = apply_settings_update({"embedding_model": normalized})
276 return SetModelResponse(model=normalized, reindex_required=result.reindex_required)
279async def set_vision_model(model: str) -> SetModelResponse:
280 """Switch vision OCR model. Empty string unsets it (vision OCR disabled)."""
281 normalized = _require_model_for_task(model, ModelTask.VISION, allow_empty=True)
282 return await _set_model("vision_model", normalized)
285async def set_reranker_model(model: str) -> SetModelResponse:
286 """Switch reranker model. Empty string unsets it (reranking disabled)."""
287 normalized = _require_model_for_task(model, ModelTask.RERANK, allow_empty=True)
288 return await _set_model("reranker_model", normalized)
291async def models_show(model: str) -> ModelsShowResponse:
292 """Return model metadata/parameters. Returns empty model if unavailable."""
293 provider = get_services().provider
294 result = provider.show_model(model)
295 return ModelsShowResponse(**(result or {}))
298def _parse_source(source: str) -> ModelSource:
299 """Convert a source string to ModelSource enum."""
300 return ModelSource(source)
303_BYTES_PER_GB = 1024**3
306def _row_fit(enriched: EnrichedModel, available_bytes: int | None) -> FitLevel | None:
307 """Fit level for *enriched*, or None when host memory or row size can't be measured."""
308 if available_bytes is None:
309 return None
310 if enriched.source != ModelSource.NATIVE.value:
311 return None
312 if enriched.size_gb <= 0:
313 return None
314 return compute_fit(int(enriched.size_gb * _BYTES_PER_GB), available_bytes).level
317def _families_by_repo() -> dict[str, ModelFamily]:
318 """Index featured ModelFamilies by every variant's ``hf_repo`` for size-variant lookup."""
319 index: dict[str, ModelFamily] = {}
320 for family in get_families():
321 for variant in family.variants:
322 index[variant.hf_repo] = family
323 return index
326def _row_size_variants(
327 enriched: EnrichedModel, families_by_repo: dict[str, ModelFamily]
328) -> list[SizeVariantInfo]:
329 """Size-variant strip for *enriched*; empty when the row isn't part of a family."""
330 family = families_by_repo.get(enriched.hf_repo)
331 if family is None:
332 return []
333 return family_size_variants(family)
336def _build_catalog_entry(
337 enriched: EnrichedModel,
338 *,
339 available_bytes: int | None,
340 families_by_repo: dict[str, ModelFamily],
341) -> CatalogEntryResponse:
342 """Translate one enriched catalog model into its HTTP response row."""
343 return CatalogEntryResponse(
344 hf_repo=enriched.hf_repo,
345 gguf_filename=enriched.gguf_filename,
346 task=enriched.task,
347 display_name=enriched.display_name,
348 param_count=enriched.param_count,
349 size_gb=enriched.size_gb,
350 min_ram_gb=enriched.min_ram_gb,
351 description=enriched.description,
352 quality_tier=enriched.quality_tier,
353 featured=enriched.featured,
354 downloads=enriched.downloads,
355 installed=enriched.installed,
356 source=enriched.source,
357 fit=_row_fit(enriched, available_bytes),
358 size_variants=_row_size_variants(enriched, families_by_repo),
359 architecture=enriched.architecture,
360 compat=enriched.compat,
361 )
364def _hosted_entry(rm: RemoteModel, source: ModelSource) -> CatalogEntryResponse:
365 """Build a selectable, no-download catalog row for a discovered hosted model."""
366 return CatalogEntryResponse(
367 hf_repo=format_remote_ref(rm.name, rm.provider),
368 gguf_filename="",
369 task=rm.task,
370 display_name=rm.name,
371 param_count=rm.parameter_size,
372 size_gb=0,
373 min_ram_gb=0,
374 description="",
375 quality_tier="",
376 featured=False,
377 downloads=0,
378 installed=True,
379 source=source,
380 fit=None,
381 size_variants=[],
382 provider=rm.provider,
383 key_status=KeyStatus.READY if source is ModelSource.FRONTIER else None,
384 )
387_HOSTED_MODELS_TTL = 60
390class _HostedModelsCache:
391 """TTL cache for discovered hosted rows (no module-level mutable global)."""
393 def __init__(self) -> None:
394 self._time: float = 0.0
395 self._key: str = ""
396 self._result: list[CatalogEntryResponse] | None = None
398 def get(self, key: str) -> list[CatalogEntryResponse] | None:
399 now = time.monotonic()
400 fresh = (now - self._time) < _HOSTED_MODELS_TTL
401 if self._result is not None and key == self._key and fresh:
402 return self._result
403 return None
405 def set(self, key: str, result: list[CatalogEntryResponse]) -> None:
406 self._time = time.monotonic()
407 self._key = key
408 self._result = result
411_hosted_cache = _HostedModelsCache()
414def _discover_hosted_sync() -> list[CatalogEntryResponse]:
415 """All hosted rows (frontier + the configured local server), unfiltered.
417 Blocking; call via to_thread. Local rows take the detected server's source
418 (Ollama or LM Studio). Both discovery calls fail soft when no keys are set
419 or the endpoint is unreachable, so the catalog degrades to native-only.
420 """
421 rows: list[CatalogEntryResponse] = []
422 for models in discover_api_models().values():
423 rows.extend(_hosted_entry(rm, ModelSource.FRONTIER) for rm in models)
424 for rm in classify_all_remote_models():
425 spec = local_server_for_label(rm.provider)
426 source = ModelSource(spec.key) if spec is not None else ModelSource.REMOTE
427 rows.append(_hosted_entry(rm, source))
428 return rows
431def _hosted_cache_key() -> str:
432 """Cache key over the inputs that change discovery output.
434 Enumerates configured provider-key fields generically from
435 ``PROVIDER_KEYS`` so adding a provider does not silently reuse a
436 stale cache entry.
437 """
438 keys = ":".join(getattr(cfg, field) or "" for _, field, *_ in PROVIDER_KEYS)
439 return f"{cfg.ollama_base_url}:{cfg.lm_studio_base_url}:{keys}"
442async def _collect_hosted_entries(
443 *, task: ModelTask | None, search: str
444) -> list[CatalogEntryResponse]:
445 """Hosted catalog rows filtered by task/search, off the event loop + TTL-cached."""
446 key = _hosted_cache_key()
447 rows = _hosted_cache.get(key)
448 if rows is None:
449 rows = await asyncio.to_thread(_discover_hosted_sync)
450 _hosted_cache.set(key, rows)
451 if task is not None:
452 rows = [r for r in rows if r.task == task]
453 if search:
454 needle = search.lower()
455 rows = [r for r in rows if needle in r.display_name.lower()]
456 return rows
459async def models_catalog(
460 task: str | None = None,
461 search: str = "",
462 size: str | None = None,
463 installed: bool | None = None,
464 featured: bool | None = None,
465 sort: str = "featured",
466 limit: int = 20,
467 offset: int = 0,
468) -> ModelsCatalogResponse:
469 """Return paginated model catalog with installed status."""
470 # Validate every closed-set param at the HTTP boundary instead of
471 # letting unknown values silently short-circuit the filter inside.
472 parsed_task = ModelTask(task) if task else None
473 parsed_size = CatalogSize(size) if size else None
474 parsed_sort = CatalogSort(sort)
475 result = get_catalog(
476 task=parsed_task,
477 search=search,
478 size=parsed_size,
479 installed=installed,
480 featured=featured,
481 sort=parsed_sort,
482 limit=limit,
483 offset=offset,
484 )
486 registry = get_services().registry
487 installed_refs = {m.ref for m in registry.list_installed()}
488 enriched = enrich_catalog(result, installed_refs)
490 available_bytes = available_memory_for_fit()
491 families_by_repo = _families_by_repo()
493 native_rows = [
494 _build_catalog_entry(e, available_bytes=available_bytes, families_by_repo=families_by_repo)
495 for e in enriched
496 ]
497 # Hosted rows (frontier + ollama) are selectable and download-free, so
498 # they're shown on the first page only (mirrors the featured first-page
499 # convention), skipped for featured-only and installed=False filters, and
500 # counted toward ``total``.
501 hosted_rows: list[CatalogEntryResponse] = []
502 if offset == 0 and not featured and installed is not False:
503 hosted_rows = await _collect_hosted_entries(task=parsed_task, search=search)
505 return ModelsCatalogResponse(
506 total=result.total + len(hosted_rows),
507 limit=result.limit,
508 offset=result.offset,
509 has_more=result.has_more,
510 models=hosted_rows + native_rows,
511 )
514async def models_installed() -> ModelsInstalledResponse:
515 """Return installed models with their granular source and canonical ref."""
516 manager = get_services().model_manager
517 models = []
518 for name in manager.list_installed():
519 source = manager.get_source(name) or ModelSource.REMOTE
520 models.append(
521 InstalledModelEntry(name=canonical_local_ref(name, source.value), source=source)
522 )
523 return ModelsInstalledResponse(models=models)
526async def models_pull(
527 model: str, *, source: str = "native", allow_unsupported: bool = False
528) -> AsyncGenerator[str, None]:
529 """Yield SSE progress events while pulling a model in real time.
530 Sets a cancel event on client disconnect so the pull stops.
532 Pre-checks architecture compatibility BEFORE opening the SSE stream so
533 refusals surface as an HTTP 409 from the route, not an in-stream error.
534 """
535 from litestar.exceptions import HTTPException
537 from lilbee.catalog.compat import SUPPORTED_ARCHS, UnsupportedArchError
539 manager = get_services().model_manager
540 src = _parse_source(source)
542 if src is ModelSource.NATIVE and not allow_unsupported:
543 try:
544 await asyncio.to_thread(manager._enforce_arch_compat, model)
545 except UnsupportedArchError as exc:
546 raise HTTPException(
547 status_code=409,
548 detail="unsupported_arch",
549 extra={
550 "code": "unsupported_arch",
551 "arch": exc.architecture,
552 "ref": exc.ref,
553 "supported_examples": sorted(SUPPORTED_ARCHS)[:5],
554 "total_supported": len(SUPPORTED_ARCHS),
555 },
556 ) from exc
558 sse = SseStream()
560 def _pull_blocking() -> None:
561 def _on_bytes(downloaded: int, total: int) -> None:
562 if sse.cancel.is_set():
563 return
564 payload = sse_event(SseEvent.PROGRESS, {"current": downloaded, "total": total})
565 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, payload)
567 try:
568 manager.pull(
569 model,
570 src,
571 on_bytes=_on_bytes,
572 allow_unsupported=allow_unsupported,
573 )
574 except Exception as exc:
575 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, sse_error(str(exc)))
576 finally:
577 sse.loop.call_soon_threadsafe(sse.queue.put_nowait, None)
579 task = asyncio.ensure_future(asyncio.to_thread(_pull_blocking))
580 async for event in sse.drain(task, "Model pull stream"):
581 yield event
584async def models_delete(model: str, *, source: str = "native") -> ModelsDeleteResponse:
585 """Delete a model. Returns deletion status, model name, and freed space.
587 lilbee removes only native models it downloaded; removing a read-only
588 local-server model (Ollama, LM Studio) is refused with a 409.
589 """
590 from litestar.exceptions import HTTPException
592 manager = get_services().model_manager
593 src = _parse_source(source)
594 try:
595 deleted = manager.remove(model, src)
596 except ValueError as exc:
597 raise HTTPException(status_code=409, detail=str(exc)) from exc
598 return ModelsDeleteResponse(deleted=deleted, model=model, freed_gb=0.0)
601_EXTERNAL_MODELS_TTL = 60
604class _ExternalModelsCache:
605 """TTL cache for external model listings (no module-level mutable global)."""
607 def __init__(self) -> None:
608 self._time: float = 0.0
609 self._key: str = ""
610 self._result: ExternalModelsResponse | None = None
612 def get(self, key: str) -> ExternalModelsResponse | None:
613 now = time.monotonic()
614 if self._result and key == self._key and (now - self._time) < _EXTERNAL_MODELS_TTL:
615 return self._result
616 return None
618 def set(self, key: str, result: ExternalModelsResponse) -> None:
619 self._time = time.monotonic()
620 self._key = key
621 self._result = result
624_external_cache = _ExternalModelsCache()
627async def list_external_models() -> ExternalModelsResponse:
628 """Query the provider for available models via its list_models() API."""
629 key = f"{cfg.ollama_base_url}:{cfg.lm_studio_base_url}:{cfg.llm_api_key or ''}"
630 cached = _external_cache.get(key)
631 if cached:
632 return cached
634 try:
635 models = await asyncio.to_thread(get_services().provider.list_models)
636 result = ExternalModelsResponse(models=models)
637 _external_cache.set(key, result)
638 return result
639 except Exception as exc:
640 log.warning("Failed to list external models: %s", exc)
641 return ExternalModelsResponse(models=[], error=str(exc))