Coverage for src / lilbee / cli / tui / screens / catalog_utils.py: 100%

141 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Catalog data types, row builders, and formatting helpers. 

2 

3The catalog renders two distinct row shapes side by side: locally 

4installed / installable GGUFs (``LocalCatalogRow``) and cloud chat 

5models accessed through a provider's API (``FrontierCatalogRow``). 

6They share enough surface area that grouping and search reuse the 

7same helpers, but they carry different metadata and pull from 

8different sources, so they're separate types under a sealed 

9``CatalogRow`` union rather than a single optional-fields dataclass. 

10""" 

11 

12from __future__ import annotations 

13 

14import re 

15from collections.abc import Callable 

16from dataclasses import dataclass, field 

17from enum import StrEnum 

18from typing import Any, Literal 

19 

20from lilbee.catalog import PARAM_COUNT_RE, CatalogModel, ModelFamily, ModelVariant, extract_quant 

21from lilbee.catalog.types import KeyStatus, ModelCompat, ModelTask 

22from lilbee.modelhub.model_manager import RemoteModel 

23from lilbee.providers.model_ref import format_remote_ref 

24from lilbee.runtime.hardware import FitChip 

25 

26 

27class CatalogRowKind(StrEnum): 

28 """Discriminator for the sealed CatalogRow union.""" 

29 

30 LOCAL = "local" 

31 FRONTIER = "frontier" 

32 

33 

34# Tab IDs for the 6-tab catalog shell. Discover is the curated landing, 

35# the four task tabs each render a single per-task grid, Library is the 

36# personal-encyclopedia view of installed local + activated cloud APIs. 

37TAB_DISCOVER = "discover" 

38TAB_CHAT = "chat" 

39TAB_EMBED = "embed" 

40TAB_VISION = "vision" 

41TAB_RERANK = "rerank" 

42TAB_LIBRARY = "library" 

43 

44# Order matters: numbered shortcuts 1-6 follow this sequence. 

45ALL_TAB_IDS: tuple[str, ...] = ( 

46 TAB_DISCOVER, 

47 TAB_CHAT, 

48 TAB_EMBED, 

49 TAB_VISION, 

50 TAB_RERANK, 

51 TAB_LIBRARY, 

52) 

53TASK_TAB_IDS: tuple[str, ...] = (TAB_CHAT, TAB_EMBED, TAB_VISION, TAB_RERANK) 

54 

55# Maps ModelTask -> the per-task tab id that renders its rows. Featured 

56# items still appear pinned at the top of their task tab; cross-task 

57# discovery happens on the Discover landing. 

58TASK_TO_TAB_ID: dict[ModelTask, str] = { 

59 ModelTask.CHAT: TAB_CHAT, 

60 ModelTask.EMBEDDING: TAB_EMBED, 

61 ModelTask.VISION: TAB_VISION, 

62 ModelTask.RERANK: TAB_RERANK, 

63} 

64TAB_ID_TO_TASK: dict[str, ModelTask] = {v: k for k, v in TASK_TO_TAB_ID.items()} 

65 

66 

67class SourceMode(StrEnum): 

68 """Per-task tab filter for which row source backs the visible cards. 

69 

70 LOCAL hides frontier rows (the default; mirrors the legacy mega-grid 

71 behavior). CLOUD shows only frontier rows for the active task. BOTH 

72 unions them so users can compare a local Llama with a cloud Llama 

73 side by side. Cycled via the ``c`` keybinding. 

74 """ 

75 

76 LOCAL = "local" 

77 CLOUD = "cloud" 

78 BOTH = "both" 

79 

80 

81_SOURCE_MODE_CYCLE: tuple[SourceMode, ...] = (SourceMode.LOCAL, SourceMode.CLOUD, SourceMode.BOTH) 

82 

83 

84def next_source_mode(current: SourceMode) -> SourceMode: 

85 """Return the next SourceMode in the LOCAL -> CLOUD -> BOTH -> LOCAL cycle.""" 

86 idx = _SOURCE_MODE_CYCLE.index(current) 

87 return _SOURCE_MODE_CYCLE[(idx + 1) % len(_SOURCE_MODE_CYCLE)] 

88 

89 

90def task_to_tab_id(task: ModelTask | str) -> str: 

91 """Return the per-task tab id for a ModelTask or its string value. 

92 

93 Accepts the string form because catalog rows carry ``task`` as a 

94 raw string (matching how HF API and the row builders return it), 

95 while the routing tables are keyed on the enum. 

96 """ 

97 if isinstance(task, ModelTask): 

98 return TASK_TO_TAB_ID[task] 

99 try: 

100 return TASK_TO_TAB_ID[ModelTask(task)] 

101 except (KeyError, ValueError) as exc: 

102 raise KeyError(f"unknown task: {task!r}") from exc 

103 

104 

105# SI thresholds for short download counts ("12.3M" / "456K") and binary 

106# thresholds for sizes ("4.2 GB" / "768 MB"). 

107_DOWNLOADS_PER_M = 1_000_000 

108_DOWNLOADS_PER_K = 1_000 

109_MB_PER_GB = 1024 

110 

111 

112@dataclass(frozen=True) 

113class SizeVariant: 

114 """One size/quant variant of a model family for the family-as-card strip. 

115 

116 ``label`` renders inline on the card (e.g. "8B Q4_K_M"). ``ref`` is 

117 the canonical pull target for this specific variant. ``fit`` is the 

118 fit chip computed against the host's available memory; ``None`` 

119 when the hardware probe has not yet run. 

120 """ 

121 

122 label: str 

123 quant: str 

124 size_gb: float 

125 ref: str 

126 fit: FitChip | None = None 

127 

128 

129@dataclass 

130class LocalCatalogRow: 

131 """A row in the catalog backed by a local GGUF (installable or installed). 

132 

133 ``name`` is the human-readable display label (e.g. "Qwen3 0.6B"). 

134 ``ref`` is the canonical identifier used for config persistence: 

135 ``hf_repo`` for catalog rows, ``hf_repo/filename`` for installed 

136 native models, and the provider's ref shape for remote/API rows. 

137 ``size_variants`` carries every quant for a family-aggregated row, 

138 so the card can render an inline chip strip and the detail drawer 

139 can list all sizes. ``fit`` is the chip for the row's primary 

140 variant. 

141 """ 

142 

143 name: str 

144 task: str 

145 params: str 

146 size: str 

147 quant: str 

148 downloads: str 

149 featured: bool 

150 installed: bool 

151 sort_downloads: int 

152 sort_size: float 

153 ref: str = "" 

154 backend: str = "" 

155 variant: ModelVariant | None = None 

156 family: ModelFamily | None = None 

157 catalog_model: CatalogModel | None = None 

158 remote_model: RemoteModel | None = None 

159 size_variants: list[SizeVariant] = field(default_factory=list) 

160 fit: FitChip | None = None 

161 compat: ModelCompat = ModelCompat.UNKNOWN 

162 kind: Literal[CatalogRowKind.LOCAL] = CatalogRowKind.LOCAL 

163 

164 

165@dataclass 

166class FrontierCatalogRow: 

167 """A row in the catalog backed by a cloud provider's chat API. 

168 

169 Frontier rows skip the local-model fields (size on disk, quant, 

170 GGUF filename) because they don't apply: the model lives on the 

171 provider's infrastructure. 

172 """ 

173 

174 name: str 

175 ref: str 

176 task: str 

177 provider: str # Display label, e.g. "Gemini" / "OpenAI" / "Anthropic". 

178 provider_id: str # Canonical id used for the API key field, e.g. "gemini". 

179 key_status: KeyStatus 

180 kind: Literal[CatalogRowKind.FRONTIER] = CatalogRowKind.FRONTIER 

181 

182 

183# Sealed union discriminated on .kind. Pattern-match (or compare) on row.kind 

184# to dispatch instead of isinstance, so adding a new row type is one place. 

185CatalogRow = LocalCatalogRow | FrontierCatalogRow 

186 

187 

188def parse_param_label(name: str) -> str: 

189 """Extract parameter count label from model name (e.g. '8B', '0.6B').""" 

190 from lilbee.catalog import PARAM_COUNT_RE 

191 

192 match = PARAM_COUNT_RE.search(name) 

193 return match.group(1).upper() if match else "--" 

194 

195 

196def _format_downloads(n: int) -> str: 

197 if n >= _DOWNLOADS_PER_M: 

198 return f"{n / _DOWNLOADS_PER_M:.1f}M" 

199 if n >= _DOWNLOADS_PER_K: 

200 return f"{n / _DOWNLOADS_PER_K:.0f}K" 

201 return str(n) 

202 

203 

204def _format_size_mb(size_mb: int) -> str: 

205 """Format size in MB to a human-readable string.""" 

206 if size_mb == 0: 

207 return "--" 

208 if size_mb >= _MB_PER_GB: 

209 return f"{size_mb / _MB_PER_GB:.1f} GB" 

210 return f"{size_mb} MB" 

211 

212 

213def format_size_gb(size_gb: float) -> str: 

214 """Format size in GB to a human-readable string.""" 

215 if size_gb <= 0: 

216 return "--" 

217 return f"{size_gb:.1f} GB" 

218 

219 

220def _is_param_count(label: str) -> bool: 

221 """True when label looks like a parameter count (e.g. '8B', '0.6B').""" 

222 return bool(PARAM_COUNT_RE.fullmatch(label)) 

223 

224 

225def family_to_size_variants(family: ModelFamily) -> list[SizeVariant]: 

226 """Build the size-chip strip for a featured ModelFamily. 

227 

228 Variants are returned in increasing size order so the chip strip 

229 reads compact-to-large left-to-right. ``fit`` is left ``None``; 

230 the catalog screen fills it in once the hardware probe has run. 

231 """ 

232 variants = sorted(family.variants, key=lambda v: v.size_mb) 

233 return [ 

234 SizeVariant( 

235 label=_size_variant_label(v), 

236 quant=v.quant or "--", 

237 size_gb=v.size_mb / 1024, 

238 ref=v.hf_repo, 

239 fit=None, 

240 ) 

241 for v in variants 

242 ] 

243 

244 

245def _size_variant_label(v: ModelVariant) -> str: 

246 """Render a compact label for a ModelVariant chip (e.g. '8B Q4_K_M').""" 

247 pieces = [p for p in (v.param_count, v.quant) if p] 

248 return " ".join(pieces) if pieces else "--" 

249 

250 

251def variant_to_row(v: ModelVariant, f: ModelFamily, installed: bool) -> LocalCatalogRow: 

252 """Convert a ModelVariant + family to a LocalCatalogRow.""" 

253 # Avoid duplicating the param count when the family name already ends with it. 

254 if v.param_count and not f.name.endswith(v.param_count): 

255 label = f"{f.name} {v.param_count}" 

256 else: 

257 label = f.name 

258 params = v.param_count if _is_param_count(v.param_count) else "--" 

259 return LocalCatalogRow( 

260 name=label, 

261 task=f.task, 

262 params=params, 

263 size=_format_size_mb(v.size_mb), 

264 quant=v.quant or "--", 

265 downloads="--", 

266 featured=True, 

267 installed=installed, 

268 sort_downloads=0, 

269 sort_size=v.size_mb / 1024, 

270 ref=v.hf_repo, 

271 backend="native", 

272 variant=v, 

273 family=f, 

274 ) 

275 

276 

277def catalog_to_row(m: CatalogModel, installed: bool) -> LocalCatalogRow: 

278 """Convert a CatalogModel to a LocalCatalogRow.""" 

279 quant = extract_quant(m.gguf_filename) 

280 return LocalCatalogRow( 

281 name=m.display_name, 

282 task=m.task, 

283 params=parse_param_label(m.display_name), 

284 size=format_size_gb(m.size_gb), 

285 quant=quant or "--", 

286 downloads=_format_downloads(m.downloads) if m.downloads > 0 else "--", 

287 featured=m.featured, 

288 installed=installed, 

289 sort_downloads=m.downloads, 

290 sort_size=m.size_gb, 

291 ref=m.ref, 

292 backend="native", 

293 catalog_model=m, 

294 compat=m.compat, 

295 ) 

296 

297 

298def remote_to_row(rm: RemoteModel) -> LocalCatalogRow: 

299 """Convert a RemoteModel to a LocalCatalogRow. 

300 

301 ``ref`` is the canonical ``provider/name`` form so it round-trips 

302 through ``Config.chat_model``'s validator without a per-call-site 

303 fixup. 

304 """ 

305 return LocalCatalogRow( 

306 name=rm.name, 

307 task=rm.task, 

308 params=rm.parameter_size or "--", 

309 size="--", 

310 quant="--", 

311 downloads="--", 

312 featured=False, 

313 installed=True, 

314 sort_downloads=0, 

315 sort_size=0.0, 

316 ref=format_remote_ref(rm.name, rm.provider), 

317 backend=rm.provider.lower(), 

318 remote_model=rm, 

319 ) 

320 

321 

322def frontier_row_from_remote( 

323 rm: RemoteModel, *, provider_id: str, key_status: KeyStatus 

324) -> FrontierCatalogRow: 

325 """Convert a discovered cloud chat model to a FrontierCatalogRow. 

326 

327 ``ref`` is the canonical ``provider/name`` form so callers pass it 

328 straight to ``Config.chat_model`` without re-prefixing. 

329 """ 

330 return FrontierCatalogRow( 

331 name=rm.name, 

332 ref=format_remote_ref(rm.name, rm.provider), 

333 task=rm.task, 

334 provider=rm.provider, 

335 provider_id=provider_id, 

336 key_status=key_status, 

337 ) 

338 

339 

340# Column sort key extractors. Local-only because every column except 

341# Name reads a field FrontierCatalogRow doesn't carry, and the catalog 

342# screen sorts local and frontier rows independently before concat. 

343SORT_KEYS: dict[str, Callable[[LocalCatalogRow], Any]] = { 

344 "Name": lambda r: r.name.lower(), 

345 "Task": lambda r: r.task, 

346 "Backend": lambda r: r.backend.lower(), 

347 "Params": lambda r: _param_sort_value(r.params), 

348 "Size": lambda r: r.sort_size, 

349 "Quant": lambda r: r.quant, 

350 "Downloads": lambda r: r.sort_downloads, 

351} 

352 

353 

354def _param_sort_value(params: str) -> float: 

355 """Convert param label to sortable float (e.g. '8B' -> 8.0).""" 

356 match = re.search(r"(\d+\.?\d*)", params) 

357 return float(match.group(1)) if match else 0.0 

358 

359 

360def row_delete_id(row: CatalogRow) -> str | None: 

361 """Return the model_manager-compatible identifier for *row*. 

362 

363 Remote rows hand back the bare ``RemoteModel.name`` because the 

364 Ollama HTTP API keys models by bare name, while ``ref`` carries the 

365 canonical ``ollama/<name>`` chat_model form. 

366 """ 

367 if row.kind == CatalogRowKind.FRONTIER: 

368 return row.ref or None 

369 if row.remote_model is not None: 

370 return row.remote_model.name or None 

371 return row.ref or None 

372 

373 

374def matches_search(row: CatalogRow, search: str) -> bool: 

375 """Return True if the row matches the search text (hyphen/underscore-insensitive). 

376 

377 Local rows match against name/task/params/quant/backend; frontier 

378 rows match against name + provider so users can type "gemini" and 

379 see every Gemini model regardless of suffix. 

380 """ 

381 if not search: 

382 return True 

383 needle = _normalize_for_search(search) 

384 if row.kind == CatalogRowKind.FRONTIER: 

385 return any( 

386 needle in _normalize_for_search(field) 

387 for field in (row.name, row.provider, row.provider_id) 

388 ) 

389 return any( 

390 needle in _normalize_for_search(field) 

391 for field in (row.name, row.task, row.params, row.quant, row.backend) 

392 ) 

393 

394 

395def _normalize_for_search(value: str) -> str: 

396 return value.lower().replace("-", " ").replace("_", " ")