Coverage for src / lilbee / cli / tui / screens / catalog_utils.py: 100%

143 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Catalog data types, row builders, and formatting helpers. 

2 

3The catalog renders two distinct row shapes side by side: locally 

4installed / installable GGUFs (``LocalCatalogRow``) and cloud chat 

5models accessed through a provider's API (``FrontierCatalogRow``). 

6They share enough surface area that grouping and search reuse the 

7same helpers, but they carry different metadata and pull from 

8different sources, so they're separate types under a sealed 

9``CatalogRow`` union rather than a single optional-fields dataclass. 

10""" 

11 

12from __future__ import annotations 

13 

14import re 

15from collections.abc import Callable 

16from dataclasses import dataclass, field 

17from enum import Enum, StrEnum 

18from typing import Any, Literal 

19 

20from lilbee.catalog import PARAM_COUNT_RE, CatalogModel, ModelFamily, ModelVariant, extract_quant 

21from lilbee.catalog.types import ModelTask 

22from lilbee.modelhub.model_manager import RemoteModel 

23from lilbee.providers.model_ref import format_remote_ref 

24from lilbee.runtime.hardware import FitChip 

25 

26 

27class CatalogRowKind(StrEnum): 

28 """Discriminator for the sealed CatalogRow union.""" 

29 

30 LOCAL = "local" 

31 FRONTIER = "frontier" 

32 

33 

34# Tab IDs for the 6-tab catalog shell. Discover is the curated landing, 

35# the four task tabs each render a single per-task grid, Library is the 

36# personal-encyclopedia view of installed local + activated cloud APIs. 

37TAB_DISCOVER = "discover" 

38TAB_CHAT = "chat" 

39TAB_EMBED = "embed" 

40TAB_VISION = "vision" 

41TAB_RERANK = "rerank" 

42TAB_LIBRARY = "library" 

43 

44# Order matters: numbered shortcuts 1-6 follow this sequence. 

45ALL_TAB_IDS: tuple[str, ...] = ( 

46 TAB_DISCOVER, 

47 TAB_CHAT, 

48 TAB_EMBED, 

49 TAB_VISION, 

50 TAB_RERANK, 

51 TAB_LIBRARY, 

52) 

53TASK_TAB_IDS: tuple[str, ...] = (TAB_CHAT, TAB_EMBED, TAB_VISION, TAB_RERANK) 

54 

55# Maps ModelTask -> the per-task tab id that renders its rows. Featured 

56# items still appear pinned at the top of their task tab; cross-task 

57# discovery happens on the Discover landing. 

58TASK_TO_TAB_ID: dict[ModelTask, str] = { 

59 ModelTask.CHAT: TAB_CHAT, 

60 ModelTask.EMBEDDING: TAB_EMBED, 

61 ModelTask.VISION: TAB_VISION, 

62 ModelTask.RERANK: TAB_RERANK, 

63} 

64TAB_ID_TO_TASK: dict[str, ModelTask] = {v: k for k, v in TASK_TO_TAB_ID.items()} 

65 

66 

67class SourceMode(StrEnum): 

68 """Per-task tab filter for which row source backs the visible cards. 

69 

70 LOCAL hides frontier rows (the default; mirrors the legacy mega-grid 

71 behavior). CLOUD shows only frontier rows for the active task. BOTH 

72 unions them so users can compare a local Llama with a cloud Llama 

73 side by side. Cycled via the ``c`` keybinding. 

74 """ 

75 

76 LOCAL = "local" 

77 CLOUD = "cloud" 

78 BOTH = "both" 

79 

80 

81_SOURCE_MODE_CYCLE: tuple[SourceMode, ...] = (SourceMode.LOCAL, SourceMode.CLOUD, SourceMode.BOTH) 

82 

83 

84def next_source_mode(current: SourceMode) -> SourceMode: 

85 """Return the next SourceMode in the LOCAL -> CLOUD -> BOTH -> LOCAL cycle.""" 

86 idx = _SOURCE_MODE_CYCLE.index(current) 

87 return _SOURCE_MODE_CYCLE[(idx + 1) % len(_SOURCE_MODE_CYCLE)] 

88 

89 

90def task_to_tab_id(task: ModelTask | str) -> str: 

91 """Return the per-task tab id for a ModelTask or its string value. 

92 

93 Accepts the string form because catalog rows carry ``task`` as a 

94 raw string (matching how HF API and the row builders return it), 

95 while the routing tables are keyed on the enum. 

96 """ 

97 if isinstance(task, ModelTask): 

98 return TASK_TO_TAB_ID[task] 

99 try: 

100 return TASK_TO_TAB_ID[ModelTask(task)] 

101 except (KeyError, ValueError) as exc: 

102 raise KeyError(f"unknown task: {task!r}") from exc 

103 

104 

105# SI thresholds for short download counts ("12.3M" / "456K") and binary 

106# thresholds for sizes ("4.2 GB" / "768 MB"). 

107_DOWNLOADS_PER_M = 1_000_000 

108_DOWNLOADS_PER_K = 1_000 

109_MB_PER_GB = 1024 

110 

111 

112@dataclass(frozen=True) 

113class SizeVariant: 

114 """One size/quant variant of a model family for the family-as-card strip. 

115 

116 ``label`` renders inline on the card (e.g. "8B Q4_K_M"). ``ref`` is 

117 the canonical pull target for this specific variant. ``fit`` is the 

118 fit chip computed against the host's available memory; ``None`` 

119 when the hardware probe has not yet run. 

120 """ 

121 

122 label: str 

123 quant: str 

124 size_gb: float 

125 ref: str 

126 fit: FitChip | None = None 

127 

128 

129@dataclass 

130class LocalCatalogRow: 

131 """A row in the catalog backed by a local GGUF (installable or installed). 

132 

133 ``name`` is the human-readable display label (e.g. "Qwen3 0.6B"). 

134 ``ref`` is the canonical identifier used for config persistence: 

135 ``hf_repo`` for catalog rows, ``hf_repo/filename`` for installed 

136 native models, and the provider's ref shape for remote/API rows. 

137 ``size_variants`` carries every quant for a family-aggregated row, 

138 so the card can render an inline chip strip and the detail drawer 

139 can list all sizes. ``fit`` is the chip for the row's primary 

140 variant. 

141 """ 

142 

143 name: str 

144 task: str 

145 params: str 

146 size: str 

147 quant: str 

148 downloads: str 

149 featured: bool 

150 installed: bool 

151 sort_downloads: int 

152 sort_size: float 

153 ref: str = "" 

154 backend: str = "" 

155 variant: ModelVariant | None = None 

156 family: ModelFamily | None = None 

157 catalog_model: CatalogModel | None = None 

158 remote_model: RemoteModel | None = None 

159 size_variants: list[SizeVariant] = field(default_factory=list) 

160 fit: FitChip | None = None 

161 kind: Literal[CatalogRowKind.LOCAL] = CatalogRowKind.LOCAL 

162 

163 

164class KeyStatus(Enum): 

165 """Whether the user has the API key needed to use a frontier model.""" 

166 

167 READY = "ready" 

168 MISSING_KEY = "missing_key" 

169 

170 

171@dataclass 

172class FrontierCatalogRow: 

173 """A row in the catalog backed by a cloud provider's chat API. 

174 

175 Frontier rows skip the local-model fields (size on disk, quant, 

176 GGUF filename) because they don't apply: the model lives on the 

177 provider's infrastructure. 

178 """ 

179 

180 name: str 

181 ref: str 

182 task: str 

183 provider: str # Display label, e.g. "Gemini" / "OpenAI" / "Anthropic". 

184 provider_id: str # Canonical id used for the API key field, e.g. "gemini". 

185 key_status: KeyStatus 

186 kind: Literal[CatalogRowKind.FRONTIER] = CatalogRowKind.FRONTIER 

187 

188 

189# Sealed union discriminated on .kind. Pattern-match (or compare) on row.kind 

190# to dispatch instead of isinstance, so adding a new row type is one place. 

191CatalogRow = LocalCatalogRow | FrontierCatalogRow 

192 

193 

194def parse_param_label(name: str) -> str: 

195 """Extract parameter count label from model name (e.g. '8B', '0.6B').""" 

196 from lilbee.catalog import PARAM_COUNT_RE 

197 

198 match = PARAM_COUNT_RE.search(name) 

199 return match.group(1).upper() if match else "--" 

200 

201 

202def _format_downloads(n: int) -> str: 

203 if n >= _DOWNLOADS_PER_M: 

204 return f"{n / _DOWNLOADS_PER_M:.1f}M" 

205 if n >= _DOWNLOADS_PER_K: 

206 return f"{n / _DOWNLOADS_PER_K:.0f}K" 

207 return str(n) 

208 

209 

210def _format_size_mb(size_mb: int) -> str: 

211 """Format size in MB to a human-readable string.""" 

212 if size_mb == 0: 

213 return "--" 

214 if size_mb >= _MB_PER_GB: 

215 return f"{size_mb / _MB_PER_GB:.1f} GB" 

216 return f"{size_mb} MB" 

217 

218 

219def format_size_gb(size_gb: float) -> str: 

220 """Format size in GB to a human-readable string.""" 

221 if size_gb <= 0: 

222 return "--" 

223 return f"{size_gb:.1f} GB" 

224 

225 

226def _is_param_count(label: str) -> bool: 

227 """True when label looks like a parameter count (e.g. '8B', '0.6B').""" 

228 return bool(PARAM_COUNT_RE.fullmatch(label)) 

229 

230 

231def family_to_size_variants(family: ModelFamily) -> list[SizeVariant]: 

232 """Build the size-chip strip for a featured ModelFamily. 

233 

234 Variants are returned in increasing size order so the chip strip 

235 reads compact-to-large left-to-right. ``fit`` is left ``None``; 

236 the catalog screen fills it in once the hardware probe has run. 

237 """ 

238 variants = sorted(family.variants, key=lambda v: v.size_mb) 

239 return [ 

240 SizeVariant( 

241 label=_size_variant_label(v), 

242 quant=v.quant or "--", 

243 size_gb=v.size_mb / 1024, 

244 ref=v.hf_repo, 

245 fit=None, 

246 ) 

247 for v in variants 

248 ] 

249 

250 

251def _size_variant_label(v: ModelVariant) -> str: 

252 """Render a compact label for a ModelVariant chip (e.g. '8B Q4_K_M').""" 

253 pieces = [p for p in (v.param_count, v.quant) if p] 

254 return " ".join(pieces) if pieces else "--" 

255 

256 

257def variant_to_row(v: ModelVariant, f: ModelFamily, installed: bool) -> LocalCatalogRow: 

258 """Convert a ModelVariant + family to a LocalCatalogRow.""" 

259 # Avoid duplicating the param count when the family name already ends with it. 

260 if v.param_count and not f.name.endswith(v.param_count): 

261 label = f"{f.name} {v.param_count}" 

262 else: 

263 label = f.name 

264 params = v.param_count if _is_param_count(v.param_count) else "--" 

265 return LocalCatalogRow( 

266 name=label, 

267 task=f.task, 

268 params=params, 

269 size=_format_size_mb(v.size_mb), 

270 quant=v.quant or "--", 

271 downloads="--", 

272 featured=True, 

273 installed=installed, 

274 sort_downloads=0, 

275 sort_size=v.size_mb / 1024, 

276 ref=v.hf_repo, 

277 backend="native", 

278 variant=v, 

279 family=f, 

280 ) 

281 

282 

283def catalog_to_row(m: CatalogModel, installed: bool) -> LocalCatalogRow: 

284 """Convert a CatalogModel to a LocalCatalogRow.""" 

285 quant = extract_quant(m.gguf_filename) 

286 return LocalCatalogRow( 

287 name=m.display_name, 

288 task=m.task, 

289 params=parse_param_label(m.display_name), 

290 size=format_size_gb(m.size_gb), 

291 quant=quant or "--", 

292 downloads=_format_downloads(m.downloads) if m.downloads > 0 else "--", 

293 featured=m.featured, 

294 installed=installed, 

295 sort_downloads=m.downloads, 

296 sort_size=m.size_gb, 

297 ref=m.ref, 

298 backend="native", 

299 catalog_model=m, 

300 ) 

301 

302 

303def remote_to_row(rm: RemoteModel) -> LocalCatalogRow: 

304 """Convert a RemoteModel to a LocalCatalogRow. 

305 

306 ``ref`` is the canonical ``provider/name`` form so it round-trips 

307 through ``Config.chat_model``'s validator without a per-call-site 

308 fixup. 

309 """ 

310 return LocalCatalogRow( 

311 name=rm.name, 

312 task=rm.task, 

313 params=rm.parameter_size or "--", 

314 size="--", 

315 quant="--", 

316 downloads="--", 

317 featured=False, 

318 installed=True, 

319 sort_downloads=0, 

320 sort_size=0.0, 

321 ref=format_remote_ref(rm.name, rm.provider), 

322 backend=rm.provider.lower(), 

323 remote_model=rm, 

324 ) 

325 

326 

327def frontier_row_from_remote( 

328 rm: RemoteModel, *, provider_id: str, key_status: KeyStatus 

329) -> FrontierCatalogRow: 

330 """Convert a discovered cloud chat model to a FrontierCatalogRow. 

331 

332 ``ref`` is the canonical ``provider/name`` form so callers pass it 

333 straight to ``Config.chat_model`` without re-prefixing. 

334 """ 

335 return FrontierCatalogRow( 

336 name=rm.name, 

337 ref=format_remote_ref(rm.name, rm.provider), 

338 task=rm.task, 

339 provider=rm.provider, 

340 provider_id=provider_id, 

341 key_status=key_status, 

342 ) 

343 

344 

345# Column sort key extractors. Local-only because every column except 

346# Name reads a field FrontierCatalogRow doesn't carry, and the catalog 

347# screen sorts local and frontier rows independently before concat. 

348SORT_KEYS: dict[str, Callable[[LocalCatalogRow], Any]] = { 

349 "Name": lambda r: r.name.lower(), 

350 "Task": lambda r: r.task, 

351 "Backend": lambda r: r.backend.lower(), 

352 "Params": lambda r: _param_sort_value(r.params), 

353 "Size": lambda r: r.sort_size, 

354 "Quant": lambda r: r.quant, 

355 "Downloads": lambda r: r.sort_downloads, 

356} 

357 

358 

359def _param_sort_value(params: str) -> float: 

360 """Convert param label to sortable float (e.g. '8B' -> 8.0).""" 

361 match = re.search(r"(\d+\.?\d*)", params) 

362 return float(match.group(1)) if match else 0.0 

363 

364 

365def row_delete_id(row: CatalogRow) -> str | None: 

366 """Return the model_manager-compatible identifier for *row*. 

367 

368 Remote rows hand back the bare ``RemoteModel.name`` because the 

369 Ollama HTTP API keys models by bare name, while ``ref`` carries the 

370 canonical ``ollama/<name>`` chat_model form. 

371 """ 

372 if row.kind == CatalogRowKind.FRONTIER: 

373 return row.ref or None 

374 if row.remote_model is not None: 

375 return row.remote_model.name or None 

376 return row.ref or None 

377 

378 

379def matches_search(row: CatalogRow, search: str) -> bool: 

380 """Return True if the row matches the search text (hyphen/underscore-insensitive). 

381 

382 Local rows match against name/task/params/quant/backend; frontier 

383 rows match against name + provider so users can type "gemini" and 

384 see every Gemini model regardless of suffix. 

385 """ 

386 if not search: 

387 return True 

388 needle = _normalize_for_search(search) 

389 if row.kind == CatalogRowKind.FRONTIER: 

390 return any( 

391 needle in _normalize_for_search(field) 

392 for field in (row.name, row.provider, row.provider_id) 

393 ) 

394 return any( 

395 needle in _normalize_for_search(field) 

396 for field in (row.name, row.task, row.params, row.quant, row.backend) 

397 ) 

398 

399 

400def _normalize_for_search(value: str) -> str: 

401 return value.lower().replace("-", " ").replace("_", " ")