Coverage for src / lilbee / cli / tui / screens / catalog_utils.py: 100%
143 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Catalog data types, row builders, and formatting helpers.
3The catalog renders two distinct row shapes side by side: locally
4installed / installable GGUFs (``LocalCatalogRow``) and cloud chat
5models accessed through a provider's API (``FrontierCatalogRow``).
6They share enough surface area that grouping and search reuse the
7same helpers, but they carry different metadata and pull from
8different sources, so they're separate types under a sealed
9``CatalogRow`` union rather than a single optional-fields dataclass.
10"""
12from __future__ import annotations
14import re
15from collections.abc import Callable
16from dataclasses import dataclass, field
17from enum import Enum, StrEnum
18from typing import Any, Literal
20from lilbee.catalog import PARAM_COUNT_RE, CatalogModel, ModelFamily, ModelVariant, extract_quant
21from lilbee.catalog.types import ModelTask
22from lilbee.modelhub.model_manager import RemoteModel
23from lilbee.providers.model_ref import format_remote_ref
24from lilbee.runtime.hardware import FitChip
27class CatalogRowKind(StrEnum):
28 """Discriminator for the sealed CatalogRow union."""
30 LOCAL = "local"
31 FRONTIER = "frontier"
34# Tab IDs for the 6-tab catalog shell. Discover is the curated landing,
35# the four task tabs each render a single per-task grid, Library is the
36# personal-encyclopedia view of installed local + activated cloud APIs.
37TAB_DISCOVER = "discover"
38TAB_CHAT = "chat"
39TAB_EMBED = "embed"
40TAB_VISION = "vision"
41TAB_RERANK = "rerank"
42TAB_LIBRARY = "library"
44# Order matters: numbered shortcuts 1-6 follow this sequence.
45ALL_TAB_IDS: tuple[str, ...] = (
46 TAB_DISCOVER,
47 TAB_CHAT,
48 TAB_EMBED,
49 TAB_VISION,
50 TAB_RERANK,
51 TAB_LIBRARY,
52)
53TASK_TAB_IDS: tuple[str, ...] = (TAB_CHAT, TAB_EMBED, TAB_VISION, TAB_RERANK)
55# Maps ModelTask -> the per-task tab id that renders its rows. Featured
56# items still appear pinned at the top of their task tab; cross-task
57# discovery happens on the Discover landing.
58TASK_TO_TAB_ID: dict[ModelTask, str] = {
59 ModelTask.CHAT: TAB_CHAT,
60 ModelTask.EMBEDDING: TAB_EMBED,
61 ModelTask.VISION: TAB_VISION,
62 ModelTask.RERANK: TAB_RERANK,
63}
64TAB_ID_TO_TASK: dict[str, ModelTask] = {v: k for k, v in TASK_TO_TAB_ID.items()}
67class SourceMode(StrEnum):
68 """Per-task tab filter for which row source backs the visible cards.
70 LOCAL hides frontier rows (the default; mirrors the legacy mega-grid
71 behavior). CLOUD shows only frontier rows for the active task. BOTH
72 unions them so users can compare a local Llama with a cloud Llama
73 side by side. Cycled via the ``c`` keybinding.
74 """
76 LOCAL = "local"
77 CLOUD = "cloud"
78 BOTH = "both"
81_SOURCE_MODE_CYCLE: tuple[SourceMode, ...] = (SourceMode.LOCAL, SourceMode.CLOUD, SourceMode.BOTH)
84def next_source_mode(current: SourceMode) -> SourceMode:
85 """Return the next SourceMode in the LOCAL -> CLOUD -> BOTH -> LOCAL cycle."""
86 idx = _SOURCE_MODE_CYCLE.index(current)
87 return _SOURCE_MODE_CYCLE[(idx + 1) % len(_SOURCE_MODE_CYCLE)]
90def task_to_tab_id(task: ModelTask | str) -> str:
91 """Return the per-task tab id for a ModelTask or its string value.
93 Accepts the string form because catalog rows carry ``task`` as a
94 raw string (matching how HF API and the row builders return it),
95 while the routing tables are keyed on the enum.
96 """
97 if isinstance(task, ModelTask):
98 return TASK_TO_TAB_ID[task]
99 try:
100 return TASK_TO_TAB_ID[ModelTask(task)]
101 except (KeyError, ValueError) as exc:
102 raise KeyError(f"unknown task: {task!r}") from exc
105# SI thresholds for short download counts ("12.3M" / "456K") and binary
106# thresholds for sizes ("4.2 GB" / "768 MB").
107_DOWNLOADS_PER_M = 1_000_000
108_DOWNLOADS_PER_K = 1_000
109_MB_PER_GB = 1024
112@dataclass(frozen=True)
113class SizeVariant:
114 """One size/quant variant of a model family for the family-as-card strip.
116 ``label`` renders inline on the card (e.g. "8B Q4_K_M"). ``ref`` is
117 the canonical pull target for this specific variant. ``fit`` is the
118 fit chip computed against the host's available memory; ``None``
119 when the hardware probe has not yet run.
120 """
122 label: str
123 quant: str
124 size_gb: float
125 ref: str
126 fit: FitChip | None = None
129@dataclass
130class LocalCatalogRow:
131 """A row in the catalog backed by a local GGUF (installable or installed).
133 ``name`` is the human-readable display label (e.g. "Qwen3 0.6B").
134 ``ref`` is the canonical identifier used for config persistence:
135 ``hf_repo`` for catalog rows, ``hf_repo/filename`` for installed
136 native models, and the provider's ref shape for remote/API rows.
137 ``size_variants`` carries every quant for a family-aggregated row,
138 so the card can render an inline chip strip and the detail drawer
139 can list all sizes. ``fit`` is the chip for the row's primary
140 variant.
141 """
143 name: str
144 task: str
145 params: str
146 size: str
147 quant: str
148 downloads: str
149 featured: bool
150 installed: bool
151 sort_downloads: int
152 sort_size: float
153 ref: str = ""
154 backend: str = ""
155 variant: ModelVariant | None = None
156 family: ModelFamily | None = None
157 catalog_model: CatalogModel | None = None
158 remote_model: RemoteModel | None = None
159 size_variants: list[SizeVariant] = field(default_factory=list)
160 fit: FitChip | None = None
161 kind: Literal[CatalogRowKind.LOCAL] = CatalogRowKind.LOCAL
164class KeyStatus(Enum):
165 """Whether the user has the API key needed to use a frontier model."""
167 READY = "ready"
168 MISSING_KEY = "missing_key"
171@dataclass
172class FrontierCatalogRow:
173 """A row in the catalog backed by a cloud provider's chat API.
175 Frontier rows skip the local-model fields (size on disk, quant,
176 GGUF filename) because they don't apply: the model lives on the
177 provider's infrastructure.
178 """
180 name: str
181 ref: str
182 task: str
183 provider: str # Display label, e.g. "Gemini" / "OpenAI" / "Anthropic".
184 provider_id: str # Canonical id used for the API key field, e.g. "gemini".
185 key_status: KeyStatus
186 kind: Literal[CatalogRowKind.FRONTIER] = CatalogRowKind.FRONTIER
189# Sealed union discriminated on .kind. Pattern-match (or compare) on row.kind
190# to dispatch instead of isinstance, so adding a new row type is one place.
191CatalogRow = LocalCatalogRow | FrontierCatalogRow
194def parse_param_label(name: str) -> str:
195 """Extract parameter count label from model name (e.g. '8B', '0.6B')."""
196 from lilbee.catalog import PARAM_COUNT_RE
198 match = PARAM_COUNT_RE.search(name)
199 return match.group(1).upper() if match else "--"
202def _format_downloads(n: int) -> str:
203 if n >= _DOWNLOADS_PER_M:
204 return f"{n / _DOWNLOADS_PER_M:.1f}M"
205 if n >= _DOWNLOADS_PER_K:
206 return f"{n / _DOWNLOADS_PER_K:.0f}K"
207 return str(n)
210def _format_size_mb(size_mb: int) -> str:
211 """Format size in MB to a human-readable string."""
212 if size_mb == 0:
213 return "--"
214 if size_mb >= _MB_PER_GB:
215 return f"{size_mb / _MB_PER_GB:.1f} GB"
216 return f"{size_mb} MB"
219def format_size_gb(size_gb: float) -> str:
220 """Format size in GB to a human-readable string."""
221 if size_gb <= 0:
222 return "--"
223 return f"{size_gb:.1f} GB"
226def _is_param_count(label: str) -> bool:
227 """True when label looks like a parameter count (e.g. '8B', '0.6B')."""
228 return bool(PARAM_COUNT_RE.fullmatch(label))
231def family_to_size_variants(family: ModelFamily) -> list[SizeVariant]:
232 """Build the size-chip strip for a featured ModelFamily.
234 Variants are returned in increasing size order so the chip strip
235 reads compact-to-large left-to-right. ``fit`` is left ``None``;
236 the catalog screen fills it in once the hardware probe has run.
237 """
238 variants = sorted(family.variants, key=lambda v: v.size_mb)
239 return [
240 SizeVariant(
241 label=_size_variant_label(v),
242 quant=v.quant or "--",
243 size_gb=v.size_mb / 1024,
244 ref=v.hf_repo,
245 fit=None,
246 )
247 for v in variants
248 ]
251def _size_variant_label(v: ModelVariant) -> str:
252 """Render a compact label for a ModelVariant chip (e.g. '8B Q4_K_M')."""
253 pieces = [p for p in (v.param_count, v.quant) if p]
254 return " ".join(pieces) if pieces else "--"
257def variant_to_row(v: ModelVariant, f: ModelFamily, installed: bool) -> LocalCatalogRow:
258 """Convert a ModelVariant + family to a LocalCatalogRow."""
259 # Avoid duplicating the param count when the family name already ends with it.
260 if v.param_count and not f.name.endswith(v.param_count):
261 label = f"{f.name} {v.param_count}"
262 else:
263 label = f.name
264 params = v.param_count if _is_param_count(v.param_count) else "--"
265 return LocalCatalogRow(
266 name=label,
267 task=f.task,
268 params=params,
269 size=_format_size_mb(v.size_mb),
270 quant=v.quant or "--",
271 downloads="--",
272 featured=True,
273 installed=installed,
274 sort_downloads=0,
275 sort_size=v.size_mb / 1024,
276 ref=v.hf_repo,
277 backend="native",
278 variant=v,
279 family=f,
280 )
283def catalog_to_row(m: CatalogModel, installed: bool) -> LocalCatalogRow:
284 """Convert a CatalogModel to a LocalCatalogRow."""
285 quant = extract_quant(m.gguf_filename)
286 return LocalCatalogRow(
287 name=m.display_name,
288 task=m.task,
289 params=parse_param_label(m.display_name),
290 size=format_size_gb(m.size_gb),
291 quant=quant or "--",
292 downloads=_format_downloads(m.downloads) if m.downloads > 0 else "--",
293 featured=m.featured,
294 installed=installed,
295 sort_downloads=m.downloads,
296 sort_size=m.size_gb,
297 ref=m.ref,
298 backend="native",
299 catalog_model=m,
300 )
303def remote_to_row(rm: RemoteModel) -> LocalCatalogRow:
304 """Convert a RemoteModel to a LocalCatalogRow.
306 ``ref`` is the canonical ``provider/name`` form so it round-trips
307 through ``Config.chat_model``'s validator without a per-call-site
308 fixup.
309 """
310 return LocalCatalogRow(
311 name=rm.name,
312 task=rm.task,
313 params=rm.parameter_size or "--",
314 size="--",
315 quant="--",
316 downloads="--",
317 featured=False,
318 installed=True,
319 sort_downloads=0,
320 sort_size=0.0,
321 ref=format_remote_ref(rm.name, rm.provider),
322 backend=rm.provider.lower(),
323 remote_model=rm,
324 )
327def frontier_row_from_remote(
328 rm: RemoteModel, *, provider_id: str, key_status: KeyStatus
329) -> FrontierCatalogRow:
330 """Convert a discovered cloud chat model to a FrontierCatalogRow.
332 ``ref`` is the canonical ``provider/name`` form so callers pass it
333 straight to ``Config.chat_model`` without re-prefixing.
334 """
335 return FrontierCatalogRow(
336 name=rm.name,
337 ref=format_remote_ref(rm.name, rm.provider),
338 task=rm.task,
339 provider=rm.provider,
340 provider_id=provider_id,
341 key_status=key_status,
342 )
345# Column sort key extractors. Local-only because every column except
346# Name reads a field FrontierCatalogRow doesn't carry, and the catalog
347# screen sorts local and frontier rows independently before concat.
348SORT_KEYS: dict[str, Callable[[LocalCatalogRow], Any]] = {
349 "Name": lambda r: r.name.lower(),
350 "Task": lambda r: r.task,
351 "Backend": lambda r: r.backend.lower(),
352 "Params": lambda r: _param_sort_value(r.params),
353 "Size": lambda r: r.sort_size,
354 "Quant": lambda r: r.quant,
355 "Downloads": lambda r: r.sort_downloads,
356}
359def _param_sort_value(params: str) -> float:
360 """Convert param label to sortable float (e.g. '8B' -> 8.0)."""
361 match = re.search(r"(\d+\.?\d*)", params)
362 return float(match.group(1)) if match else 0.0
365def row_delete_id(row: CatalogRow) -> str | None:
366 """Return the model_manager-compatible identifier for *row*.
368 Remote rows hand back the bare ``RemoteModel.name`` because the
369 Ollama HTTP API keys models by bare name, while ``ref`` carries the
370 canonical ``ollama/<name>`` chat_model form.
371 """
372 if row.kind == CatalogRowKind.FRONTIER:
373 return row.ref or None
374 if row.remote_model is not None:
375 return row.remote_model.name or None
376 return row.ref or None
379def matches_search(row: CatalogRow, search: str) -> bool:
380 """Return True if the row matches the search text (hyphen/underscore-insensitive).
382 Local rows match against name/task/params/quant/backend; frontier
383 rows match against name + provider so users can type "gemini" and
384 see every Gemini model regardless of suffix.
385 """
386 if not search:
387 return True
388 needle = _normalize_for_search(search)
389 if row.kind == CatalogRowKind.FRONTIER:
390 return any(
391 needle in _normalize_for_search(field)
392 for field in (row.name, row.provider, row.provider_id)
393 )
394 return any(
395 needle in _normalize_for_search(field)
396 for field in (row.name, row.task, row.params, row.quant, row.backend)
397 )
400def _normalize_for_search(value: str) -> str:
401 return value.lower().replace("-", " ").replace("_", " ")