Coverage for src / lilbee / cli / tui / screens / chat_helpers.py: 100%
105 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Module-level helpers used by ChatScreen: progress callbacks, file cleanup, stream close."""
3from __future__ import annotations
5import contextlib
6import logging
7import shutil
8import time
9from collections.abc import Callable
10from dataclasses import dataclass
11from typing import Any
13from lilbee.cli.tui import messages as msg
14from lilbee.cli.tui.widgets.task_bar_controller import ProgressReporter
15from lilbee.core.config import cfg
16from lilbee.providers.base import ClosableIterator
17from lilbee.runtime.progress import (
18 BatchProgressEvent,
19 BatchStatus,
20 DetailedProgressCallback,
21 EmbedEvent,
22 EventType,
23 ExtractEvent,
24 FileDoneEvent,
25 FileStartEvent,
26 ProgressEvent,
27 SyncDoneEvent,
28)
30log = logging.getLogger(__name__)
32_ADD_EMBED_THROTTLE_SECONDS = 0.15
33"""Throttle EMBED reporter updates to avoid TaskBar update storms.
35The embed worker fires one EmbedEvent per sub-batch, which on a fast
36laptop can be dozens per second. The Task Center only repaints at 10 Hz
37anyway, so we coalesce here at the same cadence.
38"""
41def close_stream(stream: Any) -> None:
42 """Close a streaming iterator if it satisfies the ClosableIterator protocol."""
43 if isinstance(stream, ClosableIterator):
44 with contextlib.suppress(Exception):
45 stream.close()
48def detail_for_batch_progress(data: BatchProgressEvent, in_flight: list[str]) -> str:
49 """Pick the user-facing detail label for a BATCH_PROGRESS tick.
51 Per-page rasterization (vision OCR) is the only producer that uses
52 BatchStatus.RASTERIZING; it emits an absolute path in data.file
53 which never matches the relative source name kept in in_flight, so
54 identity-based detection would never fire. Status-based dispatch is
55 the reliable discriminator between per-page and per-file ticks.
56 """
57 if data.status == BatchStatus.RASTERIZING:
58 return msg.ADD_PAGE_PROGRESS.format(
59 status=data.status.capitalize(), current=data.current, total=data.total
60 )
61 if in_flight:
62 return msg.ADD_SYNCING_FILE.format(file=in_flight[0])
63 return msg.ADD_FILE_DONE.format(file=data.file)
66_PREFERENCE_PREFIX = "pref:"
69@dataclass(frozen=True)
70class RememberOutcome:
71 """A /remember result: the toast message plus the notify severity to use."""
73 message: str
74 severity: str = "information"
77def remember_from_input(raw: str) -> RememberOutcome:
78 """Parse, gate, and store a ``/remember`` command; return the toast outcome.
80 Pure orchestration so the ``@work`` worker body stays a single call and the
81 parse/gate/store path is testable without a running TUI. A leading
82 ``pref:`` marks the text as an always-recalled preference; anything else is
83 stored as a fact.
84 """
85 from lilbee.app.memory import MEMORY_DISABLED_HINT, memory_enabled, remember
86 from lilbee.app.services import get_services
87 from lilbee.data.store import MemoryKind
89 if not memory_enabled():
90 return RememberOutcome(MEMORY_DISABLED_HINT, "warning")
92 text = raw.strip()
93 kind = MemoryKind.FACT
94 if text[: len(_PREFERENCE_PREFIX)].lower() == _PREFERENCE_PREFIX:
95 kind = MemoryKind.PREFERENCE
96 text = text[len(_PREFERENCE_PREFIX) :].strip()
97 if not text:
98 return RememberOutcome(msg.CMD_REMEMBER_USAGE, "warning")
100 if not get_services().embedder.embedding_available():
101 return RememberOutcome(msg.CMD_REMEMBER_NO_EMBED, "warning")
103 remember(text, kind=kind)
104 return RememberOutcome(msg.CMD_REMEMBER_SUCCESS.format(kind=kind.value))
107def remove_copied_files(names: list[str]) -> None:
108 """Delete files previously copied into documents/ by a /add invocation.
110 Called on cancel or failure of the add task so a cancelled file does not
111 re-appear on the next sync. Silently tolerates missing entries;
112 the user may have removed them concurrently, and the goal is just to
113 prevent accidental indexing.
114 """
115 for name in names:
116 target = cfg.documents_dir / name
117 try:
118 if target.is_dir():
119 shutil.rmtree(target, ignore_errors=True)
120 elif target.exists():
121 target.unlink()
122 except OSError:
123 log.debug("Could not remove copied file %s", target, exc_info=True)
126def build_add_progress_callback(reporter: ProgressReporter) -> DetailedProgressCallback:
127 """Build the on_progress callback used by /add.
129 Tracks files in flight in start order so the displayed filename pins
130 to the oldest unfinished file (the pipeline runs files concurrently;
131 without pinning the label flips around the queue). EXTRACT surfaces
132 "extracted N pages" once per file so a 44MB scanned PDF doesn't read
133 as a hang; EMBED ticks per chunk, throttled to a steady cadence.
134 """
135 in_flight: list[str] = []
136 last_embed_update = 0.0
138 def on_progress(event_type: EventType, data: ProgressEvent) -> None:
139 nonlocal last_embed_update
140 reporter.check_cancelled()
141 if event_type == EventType.FILE_START and isinstance(data, FileStartEvent):
142 in_flight.append(data.file)
143 reporter.update(0, msg.ADD_SYNCING_FILE.format(file=in_flight[0]), indeterminate=True)
144 elif event_type == EventType.FILE_DONE and isinstance(data, FileDoneEvent):
145 with contextlib.suppress(ValueError):
146 in_flight.remove(data.file)
147 elif event_type == EventType.BATCH_PROGRESS and isinstance(data, BatchProgressEvent):
148 pct = (data.current / data.total * 100.0) if data.total else 0.0
149 reporter.update(pct, detail_for_batch_progress(data, in_flight), indeterminate=False)
150 elif event_type == EventType.EXTRACT and isinstance(data, ExtractEvent):
151 reporter.update(
152 0,
153 msg.SYNC_FILE_PROGRESS.format(
154 current=data.page, total=data.total_pages, file=data.file
155 ),
156 indeterminate=True,
157 )
158 elif event_type == EventType.EMBED and isinstance(data, EmbedEvent):
159 now = time.monotonic()
160 if now - last_embed_update < _ADD_EMBED_THROTTLE_SECONDS:
161 return
162 last_embed_update = now
163 pct = int(data.chunk * 100 / data.total_chunks) if data.total_chunks else 0
164 reporter.update(pct, msg.SYNC_EMBEDDING.format(file=data.file), indeterminate=False)
166 return on_progress
169def build_sync_progress_callback(
170 reporter: ProgressReporter,
171) -> Callable[[EventType, ProgressEvent], None]:
172 """Return the on_progress shim used by ``_do_sync``.
174 EXTRACT mirrors the /add path: a 44MB scanned PDF needs a per-page
175 tick or the row reads as frozen.
176 """
177 last_embed_update = 0.0
179 def on_progress(event_type: EventType, data: ProgressEvent) -> None:
180 nonlocal last_embed_update
181 # Mirror /add: explicit cancel check on every event so a SYNC task
182 # cancelled mid-batch stops at the next progress tick instead of
183 # finishing the current file. update() also checks, but events
184 # without a reporter.update call (e.g. BATCH_PROGRESS in the
185 # ingest_batch path) would otherwise miss the cooperative checkpoint.
186 reporter.check_cancelled()
187 if event_type == EventType.FILE_START and isinstance(data, FileStartEvent):
188 pct = int((data.current_file - 1) * 100 / data.total_files)
189 status = msg.SYNC_FILE_PROGRESS.format(
190 current=data.current_file, total=data.total_files, file=data.file
191 )
192 reporter.update(pct, status, indeterminate=False)
193 elif event_type == EventType.FILE_DONE and isinstance(data, FileDoneEvent):
194 reporter.update(0, msg.SYNC_FILE_DONE.format(file=data.file), indeterminate=False)
195 elif event_type == EventType.EXTRACT and isinstance(data, ExtractEvent):
196 reporter.update(
197 0,
198 msg.SYNC_FILE_PROGRESS.format(
199 current=data.page, total=data.total_pages, file=data.file
200 ),
201 indeterminate=True,
202 )
203 elif event_type == EventType.EMBED and isinstance(data, EmbedEvent):
204 now = time.monotonic()
205 if now - last_embed_update < _ADD_EMBED_THROTTLE_SECONDS:
206 return
207 last_embed_update = now
208 pct = int(data.chunk * 100 / data.total_chunks) if data.total_chunks else 0
209 reporter.update(pct, msg.SYNC_EMBEDDING.format(file=data.file), indeterminate=False)
210 elif event_type == EventType.DONE and isinstance(data, SyncDoneEvent):
211 total = data.added + data.updated + data.removed
212 reporter.update(100, msg.SYNC_STATUS_DONE.format(count=total), indeterminate=False)
214 return on_progress