Coverage for src / lilbee / cli / tui / screens / catalog_utils.py: 100%
141 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Catalog data types, row builders, and formatting helpers.
3The catalog renders two distinct row shapes side by side: locally
4installed / installable GGUFs (``LocalCatalogRow``) and cloud chat
5models accessed through a provider's API (``FrontierCatalogRow``).
6They share enough surface area that grouping and search reuse the
7same helpers, but they carry different metadata and pull from
8different sources, so they're separate types under a sealed
9``CatalogRow`` union rather than a single optional-fields dataclass.
10"""
12from __future__ import annotations
14import re
15from collections.abc import Callable
16from dataclasses import dataclass, field
17from enum import StrEnum
18from typing import Any, Literal
20from lilbee.catalog import PARAM_COUNT_RE, CatalogModel, ModelFamily, ModelVariant, extract_quant
21from lilbee.catalog.types import KeyStatus, ModelCompat, ModelTask
22from lilbee.modelhub.model_manager import RemoteModel
23from lilbee.providers.model_ref import format_remote_ref
24from lilbee.runtime.hardware import FitChip
27class CatalogRowKind(StrEnum):
28 """Discriminator for the sealed CatalogRow union."""
30 LOCAL = "local"
31 FRONTIER = "frontier"
34# Tab IDs for the 6-tab catalog shell. Discover is the curated landing,
35# the four task tabs each render a single per-task grid, Library is the
36# personal-encyclopedia view of installed local + activated cloud APIs.
37TAB_DISCOVER = "discover"
38TAB_CHAT = "chat"
39TAB_EMBED = "embed"
40TAB_VISION = "vision"
41TAB_RERANK = "rerank"
42TAB_LIBRARY = "library"
44# Order matters: numbered shortcuts 1-6 follow this sequence.
45ALL_TAB_IDS: tuple[str, ...] = (
46 TAB_DISCOVER,
47 TAB_CHAT,
48 TAB_EMBED,
49 TAB_VISION,
50 TAB_RERANK,
51 TAB_LIBRARY,
52)
53TASK_TAB_IDS: tuple[str, ...] = (TAB_CHAT, TAB_EMBED, TAB_VISION, TAB_RERANK)
55# Maps ModelTask -> the per-task tab id that renders its rows. Featured
56# items still appear pinned at the top of their task tab; cross-task
57# discovery happens on the Discover landing.
58TASK_TO_TAB_ID: dict[ModelTask, str] = {
59 ModelTask.CHAT: TAB_CHAT,
60 ModelTask.EMBEDDING: TAB_EMBED,
61 ModelTask.VISION: TAB_VISION,
62 ModelTask.RERANK: TAB_RERANK,
63}
64TAB_ID_TO_TASK: dict[str, ModelTask] = {v: k for k, v in TASK_TO_TAB_ID.items()}
67class SourceMode(StrEnum):
68 """Per-task tab filter for which row source backs the visible cards.
70 LOCAL hides frontier rows (the default; mirrors the legacy mega-grid
71 behavior). CLOUD shows only frontier rows for the active task. BOTH
72 unions them so users can compare a local Llama with a cloud Llama
73 side by side. Cycled via the ``c`` keybinding.
74 """
76 LOCAL = "local"
77 CLOUD = "cloud"
78 BOTH = "both"
81_SOURCE_MODE_CYCLE: tuple[SourceMode, ...] = (SourceMode.LOCAL, SourceMode.CLOUD, SourceMode.BOTH)
84def next_source_mode(current: SourceMode) -> SourceMode:
85 """Return the next SourceMode in the LOCAL -> CLOUD -> BOTH -> LOCAL cycle."""
86 idx = _SOURCE_MODE_CYCLE.index(current)
87 return _SOURCE_MODE_CYCLE[(idx + 1) % len(_SOURCE_MODE_CYCLE)]
90def task_to_tab_id(task: ModelTask | str) -> str:
91 """Return the per-task tab id for a ModelTask or its string value.
93 Accepts the string form because catalog rows carry ``task`` as a
94 raw string (matching how HF API and the row builders return it),
95 while the routing tables are keyed on the enum.
96 """
97 if isinstance(task, ModelTask):
98 return TASK_TO_TAB_ID[task]
99 try:
100 return TASK_TO_TAB_ID[ModelTask(task)]
101 except (KeyError, ValueError) as exc:
102 raise KeyError(f"unknown task: {task!r}") from exc
105# SI thresholds for short download counts ("12.3M" / "456K") and binary
106# thresholds for sizes ("4.2 GB" / "768 MB").
107_DOWNLOADS_PER_M = 1_000_000
108_DOWNLOADS_PER_K = 1_000
109_MB_PER_GB = 1024
112@dataclass(frozen=True)
113class SizeVariant:
114 """One size/quant variant of a model family for the family-as-card strip.
116 ``label`` renders inline on the card (e.g. "8B Q4_K_M"). ``ref`` is
117 the canonical pull target for this specific variant. ``fit`` is the
118 fit chip computed against the host's available memory; ``None``
119 when the hardware probe has not yet run.
120 """
122 label: str
123 quant: str
124 size_gb: float
125 ref: str
126 fit: FitChip | None = None
129@dataclass
130class LocalCatalogRow:
131 """A row in the catalog backed by a local GGUF (installable or installed).
133 ``name`` is the human-readable display label (e.g. "Qwen3 0.6B").
134 ``ref`` is the canonical identifier used for config persistence:
135 ``hf_repo`` for catalog rows, ``hf_repo/filename`` for installed
136 native models, and the provider's ref shape for remote/API rows.
137 ``size_variants`` carries every quant for a family-aggregated row,
138 so the card can render an inline chip strip and the detail drawer
139 can list all sizes. ``fit`` is the chip for the row's primary
140 variant.
141 """
143 name: str
144 task: str
145 params: str
146 size: str
147 quant: str
148 downloads: str
149 featured: bool
150 installed: bool
151 sort_downloads: int
152 sort_size: float
153 ref: str = ""
154 backend: str = ""
155 variant: ModelVariant | None = None
156 family: ModelFamily | None = None
157 catalog_model: CatalogModel | None = None
158 remote_model: RemoteModel | None = None
159 size_variants: list[SizeVariant] = field(default_factory=list)
160 fit: FitChip | None = None
161 compat: ModelCompat = ModelCompat.UNKNOWN
162 kind: Literal[CatalogRowKind.LOCAL] = CatalogRowKind.LOCAL
165@dataclass
166class FrontierCatalogRow:
167 """A row in the catalog backed by a cloud provider's chat API.
169 Frontier rows skip the local-model fields (size on disk, quant,
170 GGUF filename) because they don't apply: the model lives on the
171 provider's infrastructure.
172 """
174 name: str
175 ref: str
176 task: str
177 provider: str # Display label, e.g. "Gemini" / "OpenAI" / "Anthropic".
178 provider_id: str # Canonical id used for the API key field, e.g. "gemini".
179 key_status: KeyStatus
180 kind: Literal[CatalogRowKind.FRONTIER] = CatalogRowKind.FRONTIER
183# Sealed union discriminated on .kind. Pattern-match (or compare) on row.kind
184# to dispatch instead of isinstance, so adding a new row type is one place.
185CatalogRow = LocalCatalogRow | FrontierCatalogRow
188def parse_param_label(name: str) -> str:
189 """Extract parameter count label from model name (e.g. '8B', '0.6B')."""
190 from lilbee.catalog import PARAM_COUNT_RE
192 match = PARAM_COUNT_RE.search(name)
193 return match.group(1).upper() if match else "--"
196def _format_downloads(n: int) -> str:
197 if n >= _DOWNLOADS_PER_M:
198 return f"{n / _DOWNLOADS_PER_M:.1f}M"
199 if n >= _DOWNLOADS_PER_K:
200 return f"{n / _DOWNLOADS_PER_K:.0f}K"
201 return str(n)
204def _format_size_mb(size_mb: int) -> str:
205 """Format size in MB to a human-readable string."""
206 if size_mb == 0:
207 return "--"
208 if size_mb >= _MB_PER_GB:
209 return f"{size_mb / _MB_PER_GB:.1f} GB"
210 return f"{size_mb} MB"
213def format_size_gb(size_gb: float) -> str:
214 """Format size in GB to a human-readable string."""
215 if size_gb <= 0:
216 return "--"
217 return f"{size_gb:.1f} GB"
220def _is_param_count(label: str) -> bool:
221 """True when label looks like a parameter count (e.g. '8B', '0.6B')."""
222 return bool(PARAM_COUNT_RE.fullmatch(label))
225def family_to_size_variants(family: ModelFamily) -> list[SizeVariant]:
226 """Build the size-chip strip for a featured ModelFamily.
228 Variants are returned in increasing size order so the chip strip
229 reads compact-to-large left-to-right. ``fit`` is left ``None``;
230 the catalog screen fills it in once the hardware probe has run.
231 """
232 variants = sorted(family.variants, key=lambda v: v.size_mb)
233 return [
234 SizeVariant(
235 label=_size_variant_label(v),
236 quant=v.quant or "--",
237 size_gb=v.size_mb / 1024,
238 ref=v.hf_repo,
239 fit=None,
240 )
241 for v in variants
242 ]
245def _size_variant_label(v: ModelVariant) -> str:
246 """Render a compact label for a ModelVariant chip (e.g. '8B Q4_K_M')."""
247 pieces = [p for p in (v.param_count, v.quant) if p]
248 return " ".join(pieces) if pieces else "--"
251def variant_to_row(v: ModelVariant, f: ModelFamily, installed: bool) -> LocalCatalogRow:
252 """Convert a ModelVariant + family to a LocalCatalogRow."""
253 # Avoid duplicating the param count when the family name already ends with it.
254 if v.param_count and not f.name.endswith(v.param_count):
255 label = f"{f.name} {v.param_count}"
256 else:
257 label = f.name
258 params = v.param_count if _is_param_count(v.param_count) else "--"
259 return LocalCatalogRow(
260 name=label,
261 task=f.task,
262 params=params,
263 size=_format_size_mb(v.size_mb),
264 quant=v.quant or "--",
265 downloads="--",
266 featured=True,
267 installed=installed,
268 sort_downloads=0,
269 sort_size=v.size_mb / 1024,
270 ref=v.hf_repo,
271 backend="native",
272 variant=v,
273 family=f,
274 )
277def catalog_to_row(m: CatalogModel, installed: bool) -> LocalCatalogRow:
278 """Convert a CatalogModel to a LocalCatalogRow."""
279 quant = extract_quant(m.gguf_filename)
280 return LocalCatalogRow(
281 name=m.display_name,
282 task=m.task,
283 params=parse_param_label(m.display_name),
284 size=format_size_gb(m.size_gb),
285 quant=quant or "--",
286 downloads=_format_downloads(m.downloads) if m.downloads > 0 else "--",
287 featured=m.featured,
288 installed=installed,
289 sort_downloads=m.downloads,
290 sort_size=m.size_gb,
291 ref=m.ref,
292 backend="native",
293 catalog_model=m,
294 compat=m.compat,
295 )
298def remote_to_row(rm: RemoteModel) -> LocalCatalogRow:
299 """Convert a RemoteModel to a LocalCatalogRow.
301 ``ref`` is the canonical ``provider/name`` form so it round-trips
302 through ``Config.chat_model``'s validator without a per-call-site
303 fixup.
304 """
305 return LocalCatalogRow(
306 name=rm.name,
307 task=rm.task,
308 params=rm.parameter_size or "--",
309 size="--",
310 quant="--",
311 downloads="--",
312 featured=False,
313 installed=True,
314 sort_downloads=0,
315 sort_size=0.0,
316 ref=format_remote_ref(rm.name, rm.provider),
317 backend=rm.provider.lower(),
318 remote_model=rm,
319 )
322def frontier_row_from_remote(
323 rm: RemoteModel, *, provider_id: str, key_status: KeyStatus
324) -> FrontierCatalogRow:
325 """Convert a discovered cloud chat model to a FrontierCatalogRow.
327 ``ref`` is the canonical ``provider/name`` form so callers pass it
328 straight to ``Config.chat_model`` without re-prefixing.
329 """
330 return FrontierCatalogRow(
331 name=rm.name,
332 ref=format_remote_ref(rm.name, rm.provider),
333 task=rm.task,
334 provider=rm.provider,
335 provider_id=provider_id,
336 key_status=key_status,
337 )
340# Column sort key extractors. Local-only because every column except
341# Name reads a field FrontierCatalogRow doesn't carry, and the catalog
342# screen sorts local and frontier rows independently before concat.
343SORT_KEYS: dict[str, Callable[[LocalCatalogRow], Any]] = {
344 "Name": lambda r: r.name.lower(),
345 "Task": lambda r: r.task,
346 "Backend": lambda r: r.backend.lower(),
347 "Params": lambda r: _param_sort_value(r.params),
348 "Size": lambda r: r.sort_size,
349 "Quant": lambda r: r.quant,
350 "Downloads": lambda r: r.sort_downloads,
351}
354def _param_sort_value(params: str) -> float:
355 """Convert param label to sortable float (e.g. '8B' -> 8.0)."""
356 match = re.search(r"(\d+\.?\d*)", params)
357 return float(match.group(1)) if match else 0.0
360def row_delete_id(row: CatalogRow) -> str | None:
361 """Return the model_manager-compatible identifier for *row*.
363 Remote rows hand back the bare ``RemoteModel.name`` because the
364 Ollama HTTP API keys models by bare name, while ``ref`` carries the
365 canonical ``ollama/<name>`` chat_model form.
366 """
367 if row.kind == CatalogRowKind.FRONTIER:
368 return row.ref or None
369 if row.remote_model is not None:
370 return row.remote_model.name or None
371 return row.ref or None
374def matches_search(row: CatalogRow, search: str) -> bool:
375 """Return True if the row matches the search text (hyphen/underscore-insensitive).
377 Local rows match against name/task/params/quant/backend; frontier
378 rows match against name + provider so users can type "gemini" and
379 see every Gemini model regardless of suffix.
380 """
381 if not search:
382 return True
383 needle = _normalize_for_search(search)
384 if row.kind == CatalogRowKind.FRONTIER:
385 return any(
386 needle in _normalize_for_search(field)
387 for field in (row.name, row.provider, row.provider_id)
388 )
389 return any(
390 needle in _normalize_for_search(field)
391 for field in (row.name, row.task, row.params, row.quant, row.backend)
392 )
395def _normalize_for_search(value: str) -> str:
396 return value.lower().replace("-", " ").replace("_", " ")