Coverage for src / lilbee / retrieval / reasoning.py: 100%

130 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Reasoning token filter and cap-aware chat orchestrator. 

2 

3Reasoning models (Qwen3, DeepSeek-R1) wrap their thinking process in 

4``<think>...</think>`` tags. This module provides: 

5 

6- ``filter_reasoning``: a stateful streaming filter that classifies 

7 tokens as reasoning vs response and signals when reasoning exceeds a 

8 caller-supplied cap. 

9- ``stream_chat_with_cap``: the high-level orchestrator. Wraps a 

10 provider call with the filter; when the cap fires, re-issues the 

11 chat with a "stop thinking, answer directly" nudge. All chat surfaces 

12 (HTTP/SSE, CLI, TUI) consume this so cap behavior is uniform. 

13- ``effective_reasoning_cap``: resolves the cap from the global config 

14 with per-model ``ModelDefaults`` overrides. 

15""" 

16 

17from __future__ import annotations 

18 

19import contextlib 

20import re 

21from collections.abc import Callable, Generator, Iterator 

22from dataclasses import dataclass 

23from typing import TYPE_CHECKING, Any 

24 

25from lilbee.core.config import cfg 

26from lilbee.providers.base import ClosableIterator 

27 

28if TYPE_CHECKING: 

29 from lilbee.providers.base import LLMProvider 

30 

31_OPEN_TAG = "<think>" 

32_CLOSE_TAG = "</think>" 

33_THINK_BLOCK_RE = re.compile(r"<think>[\s\S]*?</think>\s*|<think>[\s\S]*$") 

34_PROGRESS_TICK_CHARS = 256 

35"""Coarseness of the progress callback: fire when reasoning grows by at least this many chars.""" 

36 

37CAP_CONTINUATION_PROMPT = ( 

38 "Stop thinking now. Give your final answer directly, without any further <think> blocks." 

39) 

40"""The user-message nudge appended on the continuation call after the cap fires.""" 

41 

42CAP_NOTICE_TEMPLATE = "\n[reasoning capped at {chars} chars, asking for a direct answer]\n" 

43"""User-visible marker emitted between the truncated reasoning and the continuation answer.""" 

44 

45 

46@dataclass 

47class StreamToken: 

48 """A classified token from the stream.""" 

49 

50 content: str 

51 is_reasoning: bool 

52 

53 

54@dataclass 

55class CapNotice: 

56 """Emitted once when the reasoning cap fires, before the continuation stream.""" 

57 

58 cap_chars: int 

59 

60 

61@dataclass 

62class _TagParser: 

63 """Stateful parser that tracks whether we're inside a thinking block.""" 

64 

65 show: bool 

66 buf: str = "" 

67 in_thinking: bool = False 

68 reasoning_chars: int = 0 

69 

70 def feed(self, token: str) -> list[StreamToken]: 

71 """Feed a token and return any complete StreamTokens.""" 

72 self.buf += token 

73 result: list[StreamToken] = [] 

74 while self.buf: 

75 emitted = self._process_thinking() if self.in_thinking else self._process_normal() 

76 if emitted is None: 

77 break 

78 if emitted.content: 

79 result.append(emitted) 

80 return result 

81 

82 def flush(self) -> StreamToken | None: 

83 """Flush remaining buffer at end of stream.""" 

84 if not self.buf: 

85 return None 

86 if self.in_thinking: 

87 self.reasoning_chars += len(self.buf) 

88 return StreamToken(content=self.buf, is_reasoning=True) if self.show else None 

89 return StreamToken(content=self.buf, is_reasoning=False) 

90 

91 def _process_thinking(self) -> StreamToken | None: 

92 close_idx = self.buf.find(_CLOSE_TAG) 

93 if close_idx == -1: 

94 if _could_be_partial(_CLOSE_TAG, self.buf): 

95 return None 

96 content = self.buf 

97 self.reasoning_chars += len(content) 

98 self.buf = "" 

99 return ( 

100 StreamToken(content=content, is_reasoning=True) 

101 if self.show 

102 else StreamToken(content="", is_reasoning=True) 

103 ) 

104 thinking_content = self.buf[:close_idx] 

105 self.reasoning_chars += len(thinking_content) 

106 self.buf = self.buf[close_idx + len(_CLOSE_TAG) :] 

107 self.in_thinking = False 

108 if thinking_content and self.show: 

109 return StreamToken(content=thinking_content, is_reasoning=True) 

110 return StreamToken(content="", is_reasoning=True) 

111 

112 def _process_normal(self) -> StreamToken | None: 

113 open_idx = self.buf.find(_OPEN_TAG) 

114 if open_idx == -1: 

115 if _could_be_partial(_OPEN_TAG, self.buf): 

116 return None 

117 content = self.buf 

118 self.buf = "" 

119 return StreamToken(content=content, is_reasoning=False) 

120 before = self.buf[:open_idx] 

121 self.buf = self.buf[open_idx + len(_OPEN_TAG) :] 

122 self.in_thinking = True 

123 return StreamToken(content=before, is_reasoning=False) 

124 

125 

126def filter_reasoning( 

127 tokens: Iterator[str], 

128 *, 

129 show: bool, 

130 cap_chars: int, 

131 on_cap: Callable[[], None] | None = None, 

132 on_progress: Callable[[int], None] | None = None, 

133) -> Iterator[StreamToken]: 

134 """Classify ``<think>...</think>`` tokens and stop when reasoning exceeds the cap. 

135 

136 *cap_chars* bounds reasoning content. When exceeded, ``on_cap`` is 

137 fired (no payload), the upstream iterator is closed, and iteration 

138 stops. The caller decides what to do next via the higher-level 

139 ``stream_chat_with_cap`` orchestrator. *on_progress* is fired with 

140 the running reasoning-chars count each time it grows by at least 256 

141 characters. A non-positive *cap_chars* disables the cap. 

142 """ 

143 parser = _TagParser(show=show) 

144 last_progress_tick = 0 

145 try: 

146 for token in tokens: 

147 for st in parser.feed(token): 

148 if st.content: 

149 yield st 

150 if ( 

151 on_progress is not None 

152 and parser.reasoning_chars >= last_progress_tick + _PROGRESS_TICK_CHARS 

153 ): 

154 last_progress_tick = parser.reasoning_chars 

155 on_progress(parser.reasoning_chars) 

156 if cap_chars > 0 and parser.reasoning_chars > cap_chars: 

157 if on_cap is not None: 

158 on_cap() 

159 return 

160 final = parser.flush() 

161 if final and final.content: 

162 yield final 

163 if on_progress is not None and parser.reasoning_chars > last_progress_tick: 

164 on_progress(parser.reasoning_chars) 

165 finally: 

166 _close_iterator(tokens) 

167 

168 

169def effective_reasoning_cap() -> int: 

170 """Return the active reasoning cap; 0 means unlimited. 

171 

172 A per-model ``ModelDefaults.max_reasoning_chars`` value (including 

173 ``0`` for "this model is allowed to think forever") beats the global 

174 ``cfg.max_reasoning_chars`` setting. Only ``None`` falls through to 

175 the global, so a per-model 0 means the user explicitly opted that 

176 model out of the cap. 

177 """ 

178 defaults = cfg.model_defaults 

179 override = defaults.max_reasoning_chars if defaults is not None else None 

180 return override if isinstance(override, int) and override >= 0 else cfg.max_reasoning_chars 

181 

182 

183def stream_chat_with_cap( 

184 provider: LLMProvider, 

185 messages: list[dict[str, Any]], 

186 *, 

187 options: dict[str, Any] | None, 

188 model: str, 

189 show_reasoning: bool, 

190 cap_chars: int, 

191) -> Generator[StreamToken | CapNotice, None, None]: 

192 """Stream chat tokens; on cap-fire, re-issue with a stop-thinking nudge. 

193 

194 Yields ``StreamToken`` events for both reasoning and response tokens 

195 in the first pass. If reasoning exceeds *cap_chars*, the upstream 

196 iterator is closed, a single ``CapNotice`` is yielded, and the 

197 continuation stream starts (same messages plus a user message asking 

198 the model to answer directly). Continuation tokens stream as 

199 ``StreamToken(is_reasoning=False)``. 

200 """ 

201 cap_fired = False 

202 

203 def _on_cap() -> None: 

204 nonlocal cap_fired 

205 cap_fired = True 

206 

207 first_stream = provider.chat(messages, stream=True, options=options or None, model=model) 

208 yield from filter_reasoning( 

209 first_stream, 

210 show=show_reasoning, 

211 cap_chars=cap_chars, 

212 on_cap=_on_cap, 

213 ) 

214 if not cap_fired: 

215 return 

216 yield CapNotice(cap_chars=cap_chars) 

217 nudged = [*messages, {"role": "user", "content": CAP_CONTINUATION_PROMPT}] 

218 second_stream = provider.chat(nudged, stream=True, options=options or None, model=model) 

219 try: 

220 for chunk in second_stream: 

221 if chunk: 

222 yield StreamToken(content=chunk, is_reasoning=False) 

223 finally: 

224 _close_iterator(second_stream) 

225 

226 

227def cap_events_as_stream_tokens( 

228 events: Iterator[StreamToken | CapNotice], 

229) -> Iterator[StreamToken]: 

230 """Render ``CapNotice`` events as user-visible reasoning ``StreamToken``s. 

231 

232 Library and CLI surfaces speak ``StreamToken`` only. This helper lets 

233 them consume the orchestrator's union output without a per-call 

234 isinstance dance for the cap notice. 

235 """ 

236 for event in events: 

237 if isinstance(event, CapNotice): 

238 yield StreamToken( 

239 content=CAP_NOTICE_TEMPLATE.format(chars=event.cap_chars), 

240 is_reasoning=True, 

241 ) 

242 elif event.content: 

243 yield event 

244 

245 

246def _close_iterator(tokens: Iterator[str]) -> None: 

247 """Close *tokens* if it satisfies the ClosableIterator protocol.""" 

248 if isinstance(tokens, ClosableIterator): 

249 with contextlib.suppress(Exception): 

250 tokens.close() 

251 

252 

253def strip_reasoning(text: str) -> str: 

254 """Remove ``<think>...</think>`` blocks from a complete (non-streaming) string.""" 

255 return _THINK_BLOCK_RE.sub("", text) 

256 

257 

258def _could_be_partial(tag: str, buf: str) -> bool: 

259 """Check if the end of buf could be the start of the given tag.""" 

260 return any(buf.endswith(tag[:length]) for length in range(1, len(tag)))