Coverage for src / lilbee / cli / tui / widgets / model_bar.py: 100%
331 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Model status bar: pill buttons for chat / embedding plus mode + scope."""
3from __future__ import annotations
5import contextlib
6import logging
7from pathlib import Path
8from typing import TYPE_CHECKING, ClassVar, Literal, NamedTuple
10if TYPE_CHECKING:
11 from lilbee.cli.tui.app import LilbeeApp
12 from lilbee.modelhub.registry import ModelRegistry
14from textual import events, work
15from textual.app import ComposeResult
16from textual.binding import Binding, BindingType
17from textual.containers import Horizontal
18from textual.widget import Widget
19from textual.widgets import Static
21from lilbee.app.services import get_services, reset_services
22from lilbee.catalog import clean_display_name, display_label_for_ref, extract_quant
23from lilbee.catalog.types import ModelTask
24from lilbee.cli.tui import messages as msg
25from lilbee.cli.tui.app import apply_active_model, apply_setting
26from lilbee.cli.tui.pill import pill
27from lilbee.cli.tui.thread_safe import call_from_thread
28from lilbee.core.config import cfg
29from lilbee.core.config.enums import ChatMode
30from lilbee.providers.model_ref import format_remote_ref, parse_model_ref
31from lilbee.providers.sdk_backend import PROVIDER_KEYS
32from lilbee.providers.worker.transport import WorkerRole
33from lilbee.retrieval.embedder import is_model_available
35log = logging.getLogger(__name__)
37_MMPROJ_MARKER = "mmproj"
39_CLOUD_WARNING_ID = "cloud-provider-warning"
40_CLOUD_WARNING_CLASS = "cloud-warning"
41_CLOUD_WARNING_VISIBLE_CLASS = "-visible"
43_CHAT_MODEL_BUTTON_ID = "chat-model-button"
44_EMBED_MODEL_BUTTON_ID = "embed-model-button"
46# Routing-name -> display-label map derived from PROVIDER_KEYS. Any new
47# entry added there lights up the warning without further changes here.
48_CLOUD_PROVIDER_LABELS: dict[str, str] = {name: label for name, _, _, label in PROVIDER_KEYS}
51def _cloud_provider_label(chat_model: str) -> str | None:
52 """Return the provider display label for cloud-routed models, else None."""
53 if not chat_model:
54 return None
55 ref = parse_model_ref(chat_model)
56 if not ref.is_api:
57 return None
58 return _CLOUD_PROVIDER_LABELS.get(ref.provider)
61class ModelOption(NamedTuple):
62 """A selectable model with display label and config ref."""
64 label: str # human-readable name for the dropdown
65 ref: str # canonical ref persisted to config
68def _is_mmproj(name: str) -> bool:
69 """Return True if a model name refers to an mmproj projection file."""
70 return _MMPROJ_MARKER in name.lower()
73def _classify_installed_models() -> tuple[list[ModelOption], list[ModelOption]]:
74 """Classify installed models into (chat, embedding) lists, dropping mmproj.
76 The chat-bar surfaces only chat + embedding pickers; vision and rerank
77 use ``classify_installed_models_full`` directly. Vision/rerank entries
78 are still discovered here so their refs are claimed in ``seen`` and
79 later buckets don't duplicate them.
80 """
81 buckets = classify_installed_models_full()
82 return (buckets[ModelTask.CHAT], buckets[ModelTask.EMBEDDING])
85def classify_installed_models_full() -> dict[ModelTask, list[ModelOption]]:
86 """Classify installed models into per-task lists, dropping mmproj entries."""
87 buckets: dict[ModelTask, list[ModelOption]] = {task: [] for task in ModelTask}
88 seen: set[str] = set()
90 _collect_native_models(buckets, seen)
91 _collect_remote_models(buckets, seen)
92 _collect_api_models(buckets, seen)
94 return {task: sorted(opts, key=lambda o: o.ref) for task, opts in buckets.items()}
97def _lookup_bucket(
98 buckets: dict[ModelTask, list[ModelOption]], task: str, ref: str
99) -> list[ModelOption] | None:
100 """Return the bucket for *task*, or None if it is not a known ModelTask."""
101 try:
102 key = ModelTask(task)
103 except ValueError:
104 log.debug("dropping %r with unknown task %r", ref, task)
105 return None
106 return buckets.get(key)
109def _native_label(hf_repo: str, gguf_filename: str, repo_count: int) -> str:
110 """Build the picker label, appending the quant suffix only on collision."""
111 base = clean_display_name(hf_repo)
112 if repo_count <= 1:
113 return base
114 quant = extract_quant(gguf_filename)
115 return f"{base} ({quant})" if quant else base
118def _has_vision_sidecar(registry: ModelRegistry, ref: str) -> bool:
119 """Return True if *ref* resolves to a model with an adjacent ``*mmproj*.gguf`` file.
121 Models like ``google/gemma-3-12b-it`` carry their vision capability in
122 a sibling ``mmproj`` GGUF; without checking the file system, the
123 ref's name alone gives no signal that the model is multimodal, so the
124 vision picker would silently miss it.
125 """
126 try:
127 path = registry.resolve(ref)
128 except (KeyError, ValueError):
129 return False
130 return any(path.parent.glob("*mmproj*.gguf"))
133def _collect_native_models(buckets: dict[ModelTask, list[ModelOption]], seen: set[str]) -> None:
134 """Add native registry models to buckets."""
135 try:
136 from lilbee.modelhub.registry import ModelRegistry
138 registry = ModelRegistry(cfg.models_dir)
139 manifests = registry.list_installed()
140 repo_counts: dict[str, int] = {}
141 for m in manifests:
142 repo_counts[m.hf_repo] = repo_counts.get(m.hf_repo, 0) + 1
144 from lilbee.modelhub.model_manager.discovery import reclassify_by_name
146 for manifest in manifests:
147 ref = manifest.ref
148 if _is_mmproj(manifest.gguf_filename) or ref in seen:
149 continue
150 task = reclassify_by_name(ref, manifest.task)
151 label = _native_label(
152 manifest.hf_repo, manifest.gguf_filename, repo_counts[manifest.hf_repo]
153 )
154 primary_bucket = _lookup_bucket(buckets, task, ref)
155 if primary_bucket is None:
156 continue
157 seen.add(ref)
158 primary_bucket.append(ModelOption(label=label, ref=ref))
159 # If the model has an mmproj sidecar it is also vision-capable.
160 # Surface it under the vision picker too without dropping its
161 # primary classification, so a chat model with vision (e.g.
162 # gemma-3 with mmproj) shows up in both pickers.
163 if task != ModelTask.VISION and _has_vision_sidecar(registry, ref):
164 buckets[ModelTask.VISION].append(ModelOption(label=label, ref=ref))
165 except Exception:
166 log.debug("Could not read native model registry", exc_info=True)
169def _collect_remote_models(buckets: dict[ModelTask, list[ModelOption]], seen: set[str]) -> None:
170 """Add remote (Ollama / OpenAI-compatible) models, prefixed for routing.
172 Skipped when the litellm extra is not installed -- surfacing a model
173 the SDK cannot route is a guaranteed runtime error.
174 """
175 from lilbee.providers.litellm_sdk import litellm_available
177 if not litellm_available():
178 return
179 try:
180 from lilbee.modelhub.model_manager import classify_remote_models
182 for model in classify_remote_models(cfg.remote_base_url):
183 # Skip backend rows with a blank model name so the picker
184 # doesn't render an empty " (Ollama)" row.
185 if not model.name.strip():
186 continue
187 ref = format_remote_ref(model.name, model.provider)
188 if ref in seen or _is_mmproj(model.name):
189 continue
190 bucket = _lookup_bucket(buckets, model.task, ref)
191 if bucket is None:
192 continue
193 seen.add(ref)
194 label = f"{model.name} ({model.provider})"
195 bucket.append(ModelOption(label=label, ref=ref))
196 except Exception:
197 log.debug("Could not classify remote models", exc_info=True)
200def _collect_api_models(buckets: dict[ModelTask, list[ModelOption]], seen: set[str]) -> None:
201 """Add frontier API chat models. Skipped without litellm (cannot route)."""
202 from lilbee.providers.litellm_sdk import litellm_available
204 if not litellm_available():
205 return
206 try:
207 from lilbee.modelhub.model_manager import discover_api_models
209 # API discovery returns only chat-capable refs; revisit if providers
210 # expose embedding/vision/rerank.
211 for display_name, models in discover_api_models().items():
212 for model in models:
213 qualified = format_remote_ref(model.name, model.provider)
214 if qualified in seen:
215 continue
216 seen.add(qualified)
217 label = f"{model.name} ({display_name})"
218 buckets[ModelTask.CHAT].append(ModelOption(label=label, ref=qualified))
219 except Exception:
220 log.debug("Could not discover API models", exc_info=True)
223def _options_fingerprint(opts: list[ModelOption], default: str) -> tuple[tuple[str, str], ...]:
224 """Hashable fingerprint of (options, active default) for cache hits."""
225 return ((default, default), *((o.label, o.ref) for o in opts))
228_CSS_FILE = Path(__file__).parent / "model_bar.tcss"
230_CHAT_MODE_TOGGLE_ID = "chat-mode-toggle"
231_CHAT_MODE_SEARCH_PILL_ID = "chat-mode-search"
232_CHAT_MODE_CHAT_PILL_ID = "chat-mode-chat"
233_CHAT_MODE_PILL_CLASS = "chat-mode-pill"
234_CHAT_MODE_DISABLED_CLASS = "-disabled"
235_CHAT_MODE_ACTIVE_CLASS = "-active"
238class ModelPickerButton(Static, can_focus=True):
239 """Pill button that opens a ModelPickerModal scoped to chat or embed."""
241 BINDINGS: ClassVar[list[BindingType]] = [
242 Binding("enter", "open_picker", "Pick model", show=False),
243 Binding("space", "open_picker", "Pick model", show=False),
244 ]
246 def __init__(self, *, scope: Literal["chat", "embed"], button_id: str) -> None:
247 super().__init__(id=button_id)
248 self._scope: Literal["chat", "embed"] = scope
249 self._options: list[ModelOption] = []
250 self.tooltip = (
251 msg.MODEL_PICKER_CHAT_TOOLTIP if scope == "chat" else msg.MODEL_PICKER_EMBED_TOOLTIP
252 )
254 def on_mount(self) -> None:
255 self._refresh()
257 def set_options(self, options: list[ModelOption]) -> None:
258 """Update the options pool. Repaints the label from cfg."""
259 self._options = options
260 if self.is_mounted:
261 self._refresh()
263 def _refresh(self) -> None:
264 ref = cfg.chat_model if self._scope == "chat" else cfg.embedding_model
265 label = display_label_for_ref(ref) or ref or msg.MODEL_VALUE_NONE
266 self.update(label)
268 def on_click(self, event: events.Click) -> None:
269 event.stop()
270 self.open_picker()
272 def action_open_picker(self) -> None:
273 self.open_picker()
275 def open_picker(self) -> None:
276 # Lazy import: model_picker imports ModelOption from this module.
277 from lilbee.cli.tui.screens.model_picker import ModelPickerModal
279 modal = ModelPickerModal(scope=self._scope, options=self._options)
280 self.app.push_screen(modal, self._on_picker_dismissed)
282 def _on_picker_dismissed(self, ref: str | None) -> None:
283 if not ref:
284 return
285 if self._scope == "chat":
286 if ref == cfg.chat_model:
287 return
288 apply_active_model(self.app, "chat_model", ref)
289 self._commit_after_change()
290 return
291 if ref == cfg.embedding_model:
292 return
293 # Embed-model swap invalidates a populated vector store. Confirm first.
294 store = get_services().store
295 if store.has_chunks():
296 from lilbee.cli.tui.widgets.confirm_dialog import ConfirmDialog
298 self.app.push_screen(
299 ConfirmDialog(
300 msg.EMBED_SWAP_CONFIRM_TITLE,
301 msg.EMBED_SWAP_CONFIRM_MESSAGE,
302 ),
303 lambda confirmed: self._on_embed_swap_confirmed(ref, confirmed),
304 )
305 return
306 self._apply_embed_change(ref)
308 def _on_embed_swap_confirmed(self, ref: str, confirmed: bool | None) -> None:
309 """Apply the embed swap or notify cancel; ``confirmed`` mirrors ConfirmDialog."""
310 if not confirmed:
311 self.app.notify(msg.EMBED_SWAP_CANCELLED)
312 return
313 self._apply_embed_change(ref)
315 def _apply_embed_change(self, ref: str) -> None:
316 """Persist the new embed ref, refresh the bar, and respawn the embed worker."""
317 get_services().store.initialize_meta_if_legacy()
318 apply_active_model(self.app, "embedding_model", ref)
319 self._commit_after_change()
321 def _commit_after_change(self) -> None:
322 """Repaint this button and fan ``_after_model_change`` for the scope."""
323 self._refresh()
324 bar = self.screen.query(ModelBar)
325 for b in bar:
326 b._after_model_change(self._scope)
329class ChatModePill(Static, can_focus=True):
330 """Single focusable mode pill; Enter / Space picks this pill's mode."""
332 BINDINGS: ClassVar[list[BindingType]] = [
333 Binding("enter", "select", "Pick mode", show=False),
334 Binding("space", "select", "Pick mode", show=False),
335 ]
337 def action_select(self) -> None:
338 toggle = next(
339 (n for n in self.ancestors_with_self if isinstance(n, ChatModeToggle)),
340 None,
341 )
342 if toggle is None:
343 return
344 target = (
345 ChatMode.SEARCH.value if self.id == _CHAT_MODE_SEARCH_PILL_ID else ChatMode.CHAT.value
346 )
347 toggle._set_mode(target)
350class ChatModeToggle(Widget, can_focus=False):
351 """Two-pill control toggling cfg.chat_mode between Search and Chat.
353 The toggle itself is not focusable; the inner pills are. Tab walks
354 Search then Chat, Enter / Space picks. The container keeps left /
355 right arrow handling so the legacy keyboard flow still works.
356 """
358 BINDINGS: ClassVar[list[BindingType]] = [
359 Binding("left", "select_search", "Search mode", show=False),
360 Binding("right", "select_chat", "Chat mode", show=False),
361 ]
363 def __init__(self) -> None:
364 super().__init__(id=_CHAT_MODE_TOGGLE_ID)
366 def compose(self) -> ComposeResult:
367 with Horizontal():
368 yield ChatModePill(
369 msg.CHAT_MODE_SEARCH_LABEL,
370 id=_CHAT_MODE_SEARCH_PILL_ID,
371 classes=_CHAT_MODE_PILL_CLASS,
372 )
373 yield ChatModePill(
374 msg.CHAT_MODE_CHAT_LABEL,
375 id=_CHAT_MODE_CHAT_PILL_ID,
376 classes=_CHAT_MODE_PILL_CLASS,
377 )
379 def on_mount(self) -> None:
380 self._refresh()
382 def refresh_state(self) -> None:
383 """Repaint label/state. Call after settings or embedding-model changes."""
384 if self.is_mounted:
385 self._refresh()
387 def _embedding_ready(self) -> bool:
388 return is_model_available(cfg.embedding_model, get_services().provider)
390 def _refresh(self) -> None:
391 ready = self._embedding_ready()
392 mode = cfg.chat_mode if ready else ChatMode.CHAT.value
393 active_search = mode == ChatMode.SEARCH.value
394 search_pill = self.query_one(f"#{_CHAT_MODE_SEARCH_PILL_ID}", ChatModePill)
395 chat_pill = self.query_one(f"#{_CHAT_MODE_CHAT_PILL_ID}", ChatModePill)
396 # Search half is disabled whenever embedding isn't ready; Chat is
397 # always reachable so it never carries the disabled class.
398 search_pill.set_class(active_search, _CHAT_MODE_ACTIVE_CLASS)
399 search_pill.set_class(not ready, _CHAT_MODE_DISABLED_CLASS)
400 chat_pill.set_class(not active_search, _CHAT_MODE_ACTIVE_CLASS)
401 chat_pill.set_class(False, _CHAT_MODE_DISABLED_CLASS)
402 # Parent carries the disabled class so external selectors can
403 # disable interaction on the whole toggle when search is gated.
404 self.set_class(not ready, _CHAT_MODE_DISABLED_CLASS)
405 self.tooltip = (
406 msg.CHAT_MODE_TOGGLE_DISABLED_TOOLTIP if not ready else msg.CHAT_MODE_TOGGLE_TOOLTIP
407 )
409 def _set_mode(self, target: str) -> bool:
410 """Apply *target* if it differs from the current mode and Search is allowed."""
411 if cfg.chat_mode == target:
412 return False
413 if target == ChatMode.SEARCH.value and not self._embedding_ready():
414 return False
415 apply_setting(self.app, "chat_mode", target)
416 self._refresh()
417 return True
419 def toggle(self) -> bool:
420 """Flip mode if embedding is ready. Returns True when the mode changed."""
421 target = (
422 ChatMode.CHAT.value if cfg.chat_mode == ChatMode.SEARCH.value else ChatMode.SEARCH.value
423 )
424 return self._set_mode(target)
426 def on_click(self, event: events.Click) -> None:
427 event.stop()
428 # Click on a specific pill picks that side; click on the container
429 # frame falls through to a toggle.
430 widget = event.widget
431 if widget is not None:
432 wid = widget.id
433 if wid == _CHAT_MODE_SEARCH_PILL_ID:
434 self._set_mode(ChatMode.SEARCH.value)
435 return
436 if wid == _CHAT_MODE_CHAT_PILL_ID:
437 self._set_mode(ChatMode.CHAT.value)
438 return
439 self.toggle()
441 def action_flip_mode(self) -> None:
442 self.toggle()
444 def action_select_search(self) -> None:
445 self._set_mode(ChatMode.SEARCH.value)
447 def action_select_chat(self) -> None:
448 self._set_mode(ChatMode.CHAT.value)
451class ModelBar(Widget, can_focus=False):
452 """Compact bar with picker buttons for active model assignments + mode toggle."""
454 app: LilbeeApp # type: ignore[assignment]
455 DEFAULT_CSS: ClassVar[str] = _CSS_FILE.read_text(encoding="utf-8")
457 def __init__(self, id: str | None = None) -> None:
458 super().__init__(id=id)
459 # _scan_models runs on every chat on_show but the install set rarely
460 # changes between visits; fingerprint to skip redundant set_options.
461 self._chat_options_cache: tuple[tuple[str, str], ...] = ()
462 self._embed_options_cache: tuple[tuple[str, str], ...] = ()
464 def compose(self) -> ComposeResult:
465 with Horizontal():
466 yield Static(pill("Chat", "$primary", "$text"), classes="model-bar-pill")
467 yield ModelPickerButton(scope="chat", button_id=_CHAT_MODEL_BUTTON_ID)
468 yield Static(pill("Embed", "$secondary", "$text"), classes="model-bar-pill")
469 yield ModelPickerButton(scope="embed", button_id=_EMBED_MODEL_BUTTON_ID)
470 yield ChatModeToggle()
471 yield Static("", id=_CLOUD_WARNING_ID, classes=_CLOUD_WARNING_CLASS)
473 def on_mount(self) -> None:
474 self._refresh_cloud_warning()
475 self._scan_models()
476 # External activation paths (Catalog screen, /model setting, settings UI)
477 # publish on this signal but don't reach this widget's _refresh otherwise.
478 # ``app: LilbeeApp`` is declared on the class; test hosts inherit
479 # LilbeeAppHost so the signal attribute always exists. No isinstance
480 # guard needed (AGENTS.md "no test-aware production branches").
481 self.app.settings_changed_signal.subscribe(self, self._on_settings_changed)
483 def _on_settings_changed(self, payload: tuple[str, object]) -> None:
484 key, _ = payload
485 if key == "chat_model":
486 with contextlib.suppress(Exception):
487 self.query_one(f"#{_CHAT_MODEL_BUTTON_ID}", ModelPickerButton)._refresh()
488 self._refresh_cloud_warning()
489 self._refresh_chat_mode_toggle()
490 elif key == "embedding_model":
491 with contextlib.suppress(Exception):
492 self.query_one(f"#{_EMBED_MODEL_BUTTON_ID}", ModelPickerButton)._refresh()
494 @work(thread=True)
495 def _scan_models(self) -> None:
496 """Scan installed models in background, then populate buttons."""
497 chat, embed = _classify_installed_models()
498 call_from_thread(self, self._populate, chat, embed)
500 def _populate(
501 self,
502 chat_models: list[ModelOption],
503 embed_models: list[ModelOption],
504 ) -> None:
505 chat_opts = list(chat_models) if chat_models else [ModelOption(msg.MODEL_VALUE_NONE, "")]
506 embed_opts = list(embed_models) if embed_models else [ModelOption(msg.MODEL_VALUE_NONE, "")]
507 chat_fingerprint = _options_fingerprint(chat_opts, cfg.chat_model)
508 if chat_fingerprint != self._chat_options_cache:
509 self.query_one(f"#{_CHAT_MODEL_BUTTON_ID}", ModelPickerButton).set_options(chat_opts)
510 self._chat_options_cache = chat_fingerprint
511 embed_fingerprint = _options_fingerprint(embed_opts, cfg.embedding_model)
512 if embed_fingerprint != self._embed_options_cache:
513 self.query_one(f"#{_EMBED_MODEL_BUTTON_ID}", ModelPickerButton).set_options(embed_opts)
514 self._embed_options_cache = embed_fingerprint
515 self._refresh_cloud_warning()
516 self._refresh_chat_mode_toggle()
518 def _refresh_cloud_warning(self) -> None:
519 """Show a warning if the active chat model routes to a cloud API."""
520 warning = self.query_one(f"#{_CLOUD_WARNING_ID}", Static)
521 label = _cloud_provider_label(cfg.chat_model)
522 if label is None:
523 warning.remove_class(_CLOUD_WARNING_VISIBLE_CLASS)
524 return
525 warning.update(msg.MODEL_BAR_CLOUD_PROVIDER_WARNING.format(provider=label))
526 warning.add_class(_CLOUD_WARNING_VISIBLE_CLASS)
528 def _refresh_chat_mode_toggle(self) -> None:
529 with contextlib.suppress(Exception):
530 self.query_one(ChatModeToggle).refresh_state()
532 def _after_model_change(self, scope: Literal["chat", "embed"]) -> None:
533 """Apply the side-effect of the role's model swap.
535 Chat-scope swaps route through :meth:`ChatScreen.apply_model_change`
536 so the in-flight stream cancels under the same UX that ``/model``
537 provides. Embed-scope swaps respawn only the embed worker via
538 :meth:`Services.reload_role`; the chat worker and any active
539 stream are untouched. Off-chat-screen chat swaps fall through to
540 a full ``reset_services`` because the chat-cancel path needs the
541 ChatScreen state machine.
542 """
543 if scope == "embed":
544 get_services().reload_role(WorkerRole.EMBED)
545 return
547 from lilbee.cli.tui.screens.chat import ChatScreen
549 screen = self.app.screen
550 if isinstance(screen, ChatScreen):
551 screen.apply_model_change()
552 else:
553 reset_services()
555 def refresh_models(self) -> None:
556 """Re-scan models (called after downloads complete)."""
557 self._scan_models()