Coverage for src / lilbee / retrieval / reasoning.py: 100%
130 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Reasoning token filter and cap-aware chat orchestrator.
3Reasoning models (Qwen3, DeepSeek-R1) wrap their thinking process in
4``<think>...</think>`` tags. This module provides:
6- ``filter_reasoning``: a stateful streaming filter that classifies
7 tokens as reasoning vs response and signals when reasoning exceeds a
8 caller-supplied cap.
9- ``stream_chat_with_cap``: the high-level orchestrator. Wraps a
10 provider call with the filter; when the cap fires, re-issues the
11 chat with a "stop thinking, answer directly" nudge. All chat surfaces
12 (HTTP/SSE, CLI, TUI) consume this so cap behavior is uniform.
13- ``effective_reasoning_cap``: resolves the cap from the global config
14 with per-model ``ModelDefaults`` overrides.
15"""
17from __future__ import annotations
19import contextlib
20import re
21from collections.abc import Callable, Generator, Iterator
22from dataclasses import dataclass
23from typing import TYPE_CHECKING, Any
25from lilbee.core.config import cfg
26from lilbee.providers.base import ClosableIterator
28if TYPE_CHECKING:
29 from lilbee.providers.base import LLMProvider
31_OPEN_TAG = "<think>"
32_CLOSE_TAG = "</think>"
33_THINK_BLOCK_RE = re.compile(r"<think>[\s\S]*?</think>\s*|<think>[\s\S]*$")
34_PROGRESS_TICK_CHARS = 256
35"""Coarseness of the progress callback: fire when reasoning grows by at least this many chars."""
37CAP_CONTINUATION_PROMPT = (
38 "Stop thinking now. Give your final answer directly, without any further <think> blocks."
39)
40"""The user-message nudge appended on the continuation call after the cap fires."""
42CAP_NOTICE_TEMPLATE = "\n[reasoning capped at {chars} chars, asking for a direct answer]\n"
43"""User-visible marker emitted between the truncated reasoning and the continuation answer."""
46@dataclass
47class StreamToken:
48 """A classified token from the stream."""
50 content: str
51 is_reasoning: bool
54@dataclass
55class CapNotice:
56 """Emitted once when the reasoning cap fires, before the continuation stream."""
58 cap_chars: int
61@dataclass
62class _TagParser:
63 """Stateful parser that tracks whether we're inside a thinking block."""
65 show: bool
66 buf: str = ""
67 in_thinking: bool = False
68 reasoning_chars: int = 0
70 def feed(self, token: str) -> list[StreamToken]:
71 """Feed a token and return any complete StreamTokens."""
72 self.buf += token
73 result: list[StreamToken] = []
74 while self.buf:
75 emitted = self._process_thinking() if self.in_thinking else self._process_normal()
76 if emitted is None:
77 break
78 if emitted.content:
79 result.append(emitted)
80 return result
82 def flush(self) -> StreamToken | None:
83 """Flush remaining buffer at end of stream."""
84 if not self.buf:
85 return None
86 if self.in_thinking:
87 self.reasoning_chars += len(self.buf)
88 return StreamToken(content=self.buf, is_reasoning=True) if self.show else None
89 return StreamToken(content=self.buf, is_reasoning=False)
91 def _process_thinking(self) -> StreamToken | None:
92 close_idx = self.buf.find(_CLOSE_TAG)
93 if close_idx == -1:
94 if _could_be_partial(_CLOSE_TAG, self.buf):
95 return None
96 content = self.buf
97 self.reasoning_chars += len(content)
98 self.buf = ""
99 return (
100 StreamToken(content=content, is_reasoning=True)
101 if self.show
102 else StreamToken(content="", is_reasoning=True)
103 )
104 thinking_content = self.buf[:close_idx]
105 self.reasoning_chars += len(thinking_content)
106 self.buf = self.buf[close_idx + len(_CLOSE_TAG) :]
107 self.in_thinking = False
108 if thinking_content and self.show:
109 return StreamToken(content=thinking_content, is_reasoning=True)
110 return StreamToken(content="", is_reasoning=True)
112 def _process_normal(self) -> StreamToken | None:
113 open_idx = self.buf.find(_OPEN_TAG)
114 if open_idx == -1:
115 if _could_be_partial(_OPEN_TAG, self.buf):
116 return None
117 content = self.buf
118 self.buf = ""
119 return StreamToken(content=content, is_reasoning=False)
120 before = self.buf[:open_idx]
121 self.buf = self.buf[open_idx + len(_OPEN_TAG) :]
122 self.in_thinking = True
123 return StreamToken(content=before, is_reasoning=False)
126def filter_reasoning(
127 tokens: Iterator[str],
128 *,
129 show: bool,
130 cap_chars: int,
131 on_cap: Callable[[], None] | None = None,
132 on_progress: Callable[[int], None] | None = None,
133) -> Iterator[StreamToken]:
134 """Classify ``<think>...</think>`` tokens and stop when reasoning exceeds the cap.
136 *cap_chars* bounds reasoning content. When exceeded, ``on_cap`` is
137 fired (no payload), the upstream iterator is closed, and iteration
138 stops. The caller decides what to do next via the higher-level
139 ``stream_chat_with_cap`` orchestrator. *on_progress* is fired with
140 the running reasoning-chars count each time it grows by at least 256
141 characters. A non-positive *cap_chars* disables the cap.
142 """
143 parser = _TagParser(show=show)
144 last_progress_tick = 0
145 try:
146 for token in tokens:
147 for st in parser.feed(token):
148 if st.content:
149 yield st
150 if (
151 on_progress is not None
152 and parser.reasoning_chars >= last_progress_tick + _PROGRESS_TICK_CHARS
153 ):
154 last_progress_tick = parser.reasoning_chars
155 on_progress(parser.reasoning_chars)
156 if cap_chars > 0 and parser.reasoning_chars > cap_chars:
157 if on_cap is not None:
158 on_cap()
159 return
160 final = parser.flush()
161 if final and final.content:
162 yield final
163 if on_progress is not None and parser.reasoning_chars > last_progress_tick:
164 on_progress(parser.reasoning_chars)
165 finally:
166 _close_iterator(tokens)
169def effective_reasoning_cap() -> int:
170 """Return the active reasoning cap; 0 means unlimited.
172 A per-model ``ModelDefaults.max_reasoning_chars`` value (including
173 ``0`` for "this model is allowed to think forever") beats the global
174 ``cfg.max_reasoning_chars`` setting. Only ``None`` falls through to
175 the global, so a per-model 0 means the user explicitly opted that
176 model out of the cap.
177 """
178 defaults = cfg.model_defaults
179 override = defaults.max_reasoning_chars if defaults is not None else None
180 return override if isinstance(override, int) and override >= 0 else cfg.max_reasoning_chars
183def stream_chat_with_cap(
184 provider: LLMProvider,
185 messages: list[dict[str, Any]],
186 *,
187 options: dict[str, Any] | None,
188 model: str,
189 show_reasoning: bool,
190 cap_chars: int,
191) -> Generator[StreamToken | CapNotice, None, None]:
192 """Stream chat tokens; on cap-fire, re-issue with a stop-thinking nudge.
194 Yields ``StreamToken`` events for both reasoning and response tokens
195 in the first pass. If reasoning exceeds *cap_chars*, the upstream
196 iterator is closed, a single ``CapNotice`` is yielded, and the
197 continuation stream starts (same messages plus a user message asking
198 the model to answer directly). Continuation tokens stream as
199 ``StreamToken(is_reasoning=False)``.
200 """
201 cap_fired = False
203 def _on_cap() -> None:
204 nonlocal cap_fired
205 cap_fired = True
207 first_stream = provider.chat(messages, stream=True, options=options or None, model=model)
208 yield from filter_reasoning(
209 first_stream,
210 show=show_reasoning,
211 cap_chars=cap_chars,
212 on_cap=_on_cap,
213 )
214 if not cap_fired:
215 return
216 yield CapNotice(cap_chars=cap_chars)
217 nudged = [*messages, {"role": "user", "content": CAP_CONTINUATION_PROMPT}]
218 second_stream = provider.chat(nudged, stream=True, options=options or None, model=model)
219 try:
220 for chunk in second_stream:
221 if chunk:
222 yield StreamToken(content=chunk, is_reasoning=False)
223 finally:
224 _close_iterator(second_stream)
227def cap_events_as_stream_tokens(
228 events: Iterator[StreamToken | CapNotice],
229) -> Iterator[StreamToken]:
230 """Render ``CapNotice`` events as user-visible reasoning ``StreamToken``s.
232 Library and CLI surfaces speak ``StreamToken`` only. This helper lets
233 them consume the orchestrator's union output without a per-call
234 isinstance dance for the cap notice.
235 """
236 for event in events:
237 if isinstance(event, CapNotice):
238 yield StreamToken(
239 content=CAP_NOTICE_TEMPLATE.format(chars=event.cap_chars),
240 is_reasoning=True,
241 )
242 elif event.content:
243 yield event
246def _close_iterator(tokens: Iterator[str]) -> None:
247 """Close *tokens* if it satisfies the ClosableIterator protocol."""
248 if isinstance(tokens, ClosableIterator):
249 with contextlib.suppress(Exception):
250 tokens.close()
253def strip_reasoning(text: str) -> str:
254 """Remove ``<think>...</think>`` blocks from a complete (non-streaming) string."""
255 return _THINK_BLOCK_RE.sub("", text)
258def _could_be_partial(tag: str, buf: str) -> bool:
259 """Check if the end of buf could be the start of the given tag."""
260 return any(buf.endswith(tag[:length]) for length in range(1, len(tag)))