Coverage for src / lilbee / cli / commands / search_chat.py: 100%

179 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Search, ask, chat, and topics commands.""" 

2 

3from __future__ import annotations 

4 

5import sys 

6from pathlib import Path 

7from typing import NoReturn 

8 

9import typer 

10from rich.table import Table 

11 

12from lilbee.app.search import clean_result 

13from lilbee.app.services import get_services 

14from lilbee.cli import theme 

15from lilbee.cli.app import ( 

16 apply_overrides, 

17 console, 

18 data_dir_option, 

19 global_option, 

20 model_option, 

21 num_ctx_option, 

22 repeat_penalty_option, 

23 seed_option, 

24 temperature_option, 

25 top_k_sampling_option, 

26 top_p_option, 

27) 

28from lilbee.cli.commands._shared import CHUNK_PREVIEW_LEN 

29from lilbee.cli.helpers import ( 

30 auto_sync, 

31 json_output, 

32) 

33from lilbee.core.config import cfg 

34from lilbee.data.store import EmbeddingModelMismatchError, SearchScope, scope_to_chunk_type 

35from lilbee.providers.base import ProviderError 

36 

37# How many top concepts to show inline before truncating with a ``+N more`` tail. 

38_TOPIC_PREVIEW_LIMIT = 5 

39 

40_EMBED_MISMATCH_ADOPT_HINT = ( 

41 "Run `lilbee use-embedder {model}` to search this index with its embedder." 

42) 

43_EMBED_MISMATCH_REBUILD_HINT = ( 

44 "This index needs a {dim}-dim embedder; run `lilbee rebuild` to re-embed it " 

45 "under your current model." 

46) 

47 

48 

49def _exit_embedding_mismatch(exc: EmbeddingModelMismatchError) -> NoReturn: 

50 """Print a surface-appropriate mismatch error and exit non-zero. 

51 

52 Headless: never switches embedder silently. Names the index's embedder and, 

53 when adoptable (same dim), the one command that makes it searchable. 

54 """ 

55 hint = ( 

56 _EMBED_MISMATCH_ADOPT_HINT.format(model=exc.persisted_model) 

57 if exc.dims_match 

58 else _EMBED_MISMATCH_REBUILD_HINT.format(dim=exc.persisted_dim) 

59 ) 

60 if cfg.json_mode: 

61 json_output({"error": str(exc), "hint": hint, "persisted_model": exc.persisted_model}) 

62 raise SystemExit(1) 

63 console.print(f"[{theme.ERROR}]Error:[/{theme.ERROR}] {exc}") 

64 console.print(hint) 

65 raise SystemExit(1) 

66 

67 

68_scope_option = typer.Option( 

69 SearchScope.BOTH, 

70 "--scope", 

71 "-s", 

72 help="Restrict the pool to raw chunks, wiki pages, or both (default).", 

73 case_sensitive=False, 

74) 

75 

76 

77def search( 

78 query: str = typer.Argument(..., help="Search query"), 

79 top_k: int = typer.Option(None, "--top-k", "-k", help="Number of results"), 

80 scope: SearchScope = _scope_option, 

81 data_dir: Path | None = data_dir_option, 

82 use_global: bool = global_option, 

83) -> None: 

84 """Search the knowledge base for relevant chunks.""" 

85 apply_overrides(data_dir=data_dir, use_global=use_global) 

86 

87 if not query or not query.strip(): 

88 if cfg.json_mode: 

89 json_output({"error": "query must not be empty"}) 

90 raise SystemExit(1) 

91 console.print(f"[{theme.ERROR}]Error:[/{theme.ERROR}] query must not be empty") 

92 raise SystemExit(1) 

93 

94 try: 

95 results = get_services().searcher.search( 

96 query, 

97 top_k=top_k or cfg.top_k, 

98 chunk_type=scope_to_chunk_type(scope), 

99 ) 

100 except EmbeddingModelMismatchError as exc: 

101 _exit_embedding_mismatch(exc) 

102 except Exception as exc: 

103 if cfg.json_mode: 

104 json_output({"error": str(exc)}) 

105 raise SystemExit(1) from None 

106 console.print(f"[{theme.ERROR}]Error:[/{theme.ERROR}] {exc}") 

107 raise SystemExit(1) from None 

108 cleaned = [clean_result(r) for r in results] 

109 

110 if cfg.json_mode: 

111 json_output({"command": "search", "query": query, "results": cleaned}) 

112 return 

113 

114 if not cleaned: 

115 console.print("No results found.") 

116 return 

117 

118 has_relevance = any("relevance_score" in r for r in cleaned) 

119 table = Table(title="Search Results") 

120 table.add_column("Source", style=theme.ACCENT) 

121 table.add_column("Chunk", max_width=80) 

122 score_label = "Score" if has_relevance else "Distance" 

123 table.add_column(score_label, justify="right", style=theme.MUTED) 

124 

125 for r in cleaned: 

126 chunk_text = r["chunk"] 

127 preview = chunk_text[:CHUNK_PREVIEW_LEN] 

128 if len(chunk_text) > CHUNK_PREVIEW_LEN: 

129 preview += "..." 

130 score = r.get("relevance_score") or r.get("distance") or 0 

131 table.add_row(r["source"], preview, f"{score:.4f}") 

132 console.print(table) 

133 

134 

135def ask( 

136 question: str = typer.Argument(..., help="Question to ask"), 

137 scope: SearchScope = _scope_option, 

138 data_dir: Path | None = data_dir_option, 

139 model: str | None = model_option, 

140 use_global: bool = global_option, 

141 temperature: float | None = temperature_option, 

142 top_p: float | None = top_p_option, 

143 top_k_sampling: int | None = top_k_sampling_option, 

144 repeat_penalty: float | None = repeat_penalty_option, 

145 num_ctx: int | None = num_ctx_option, 

146 seed: int | None = seed_option, 

147) -> None: 

148 """Ask a one-shot question (auto-syncs first).""" 

149 apply_overrides( 

150 data_dir=data_dir, 

151 model=model, 

152 use_global=use_global, 

153 temperature=temperature, 

154 top_p=top_p, 

155 top_k_sampling=top_k_sampling, 

156 repeat_penalty=repeat_penalty, 

157 num_ctx=num_ctx, 

158 seed=seed, 

159 ) 

160 

161 try: 

162 from lilbee.app.settings import apply_settings_update 

163 from lilbee.modelhub.models import ensure_chat_model 

164 

165 pulled = ensure_chat_model() 

166 if pulled is not None: 

167 apply_settings_update({"chat_model": pulled}) 

168 get_services().embedder.validate_model() 

169 if cfg.json_mode: 

170 from rich.console import Console as _QuietConsole 

171 

172 auto_sync(_QuietConsole(quiet=True)) 

173 else: 

174 auto_sync(console) 

175 

176 chunk_type = scope_to_chunk_type(scope) 

177 

178 if cfg.json_mode: 

179 result = get_services().searcher.ask_raw(question, chunk_type=chunk_type) 

180 json_output( 

181 { 

182 "command": "ask", 

183 "question": question, 

184 "answer": result.answer, 

185 "sources": [clean_result(s) for s in result.sources], 

186 } 

187 ) 

188 return 

189 

190 for token in get_services().searcher.ask_stream(question, chunk_type=chunk_type): 

191 console.print(token.content, end="") 

192 console.print() 

193 except EmbeddingModelMismatchError as exc: 

194 _exit_embedding_mismatch(exc) 

195 except (RuntimeError, ProviderError) as exc: 

196 if cfg.json_mode: 

197 json_output({"error": str(exc)}) 

198 raise SystemExit(1) from None 

199 console.print(f"[{theme.ERROR}]Error:[/{theme.ERROR}] {exc}") 

200 raise SystemExit(1) from None 

201 

202 

203def use_embedder( 

204 ref: str = typer.Argument( 

205 ..., help="Embedding model ref to adopt (copy it from a downloaded index's error)." 

206 ), 

207 data_dir: Path | None = data_dir_option, 

208 use_global: bool = global_option, 

209) -> None: 

210 """Switch to embedder REF, downloading it if needed, without rebuilding the index.""" 

211 apply_overrides(data_dir=data_dir, use_global=use_global) 

212 

213 from lilbee.app.models import adopt_embedder 

214 from lilbee.catalog.compat import UnsupportedArchError 

215 

216 try: 

217 result = adopt_embedder(ref) 

218 except (RuntimeError, ValueError, OSError, UnsupportedArchError) as exc: 

219 if cfg.json_mode: 

220 json_output({"error": str(exc)}) 

221 raise SystemExit(1) from None 

222 console.print(f"[{theme.ERROR}]Error:[/{theme.ERROR}] {exc}") 

223 raise SystemExit(1) from None 

224 

225 if cfg.json_mode: 

226 json_output( 

227 {"command": "use-embedder", "model": result.model, "status": result.status.value} 

228 ) 

229 return 

230 console.print(f"Now embedding with [{theme.ACCENT}]{result.model}[/{theme.ACCENT}].") 

231 

232 

233def chat( 

234 data_dir: Path | None = data_dir_option, 

235 model: str | None = model_option, 

236 use_global: bool = global_option, 

237 temperature: float | None = temperature_option, 

238 top_p: float | None = top_p_option, 

239 top_k_sampling: int | None = top_k_sampling_option, 

240 repeat_penalty: float | None = repeat_penalty_option, 

241 num_ctx: int | None = num_ctx_option, 

242 seed: int | None = seed_option, 

243) -> None: 

244 """Interactive chat loop. Press S in the TUI to sync pending documents.""" 

245 apply_overrides( 

246 data_dir=data_dir, 

247 model=model, 

248 use_global=use_global, 

249 temperature=temperature, 

250 top_p=top_p, 

251 top_k_sampling=top_k_sampling, 

252 repeat_penalty=repeat_penalty, 

253 num_ctx=num_ctx, 

254 seed=seed, 

255 ) 

256 

257 if cfg.json_mode: 

258 json_output({"error": "Chat requires a terminal, not --json"}) 

259 raise SystemExit(1) 

260 if not sys.stdin.isatty() or not sys.stdout.isatty(): 

261 console.print(f"[{theme.ERROR}]Error:[/{theme.ERROR}] Chat requires a terminal.") 

262 raise SystemExit(1) 

263 from lilbee.cli.tui import run_tui 

264 

265 run_tui() 

266 

267 

268def topics( 

269 query: str = typer.Argument(None, help="Optional query to find related concepts."), 

270 top_k: int = typer.Option(10, "--top-k", "-k", help="Number of results."), 

271 data_dir: Path | None = data_dir_option, 

272 use_global: bool = global_option, 

273) -> None: 

274 """Show top concept communities or concepts related to a query.""" 

275 apply_overrides(data_dir=data_dir, use_global=use_global) 

276 

277 from lilbee.retrieval.concepts import concepts_available 

278 

279 if not concepts_available(): 

280 msg = "Concept graph requires: pip install 'lilbee[graph]'" 

281 if cfg.json_mode: 

282 json_output({"error": msg}) 

283 raise SystemExit(1) 

284 console.print(f"[{theme.ERROR}]{msg}[/{theme.ERROR}]") 

285 raise SystemExit(1) 

286 

287 if not cfg.concept_graph: 

288 if cfg.json_mode: 

289 json_output({"error": "Concept graph is disabled (LILBEE_CONCEPT_GRAPH=false)"}) 

290 raise SystemExit(1) 

291 console.print( 

292 f"[{theme.ERROR}]Concept graph is disabled.[/{theme.ERROR}] " 

293 "Enable with LILBEE_CONCEPT_GRAPH=true" 

294 ) 

295 raise SystemExit(1) 

296 

297 if not get_services().concepts.get_graph(): 

298 if cfg.json_mode: 

299 json_output({"error": "Concept graph not available"}) 

300 raise SystemExit(1) 

301 console.print(f"[{theme.ERROR}]Concept graph not available.[/{theme.ERROR}]") 

302 raise SystemExit(1) 

303 

304 if query: 

305 _topics_for_query(query) 

306 else: 

307 _topics_overview(top_k) 

308 

309 

310def _topics_for_query(query: str) -> None: 

311 """Show concepts related to a query.""" 

312 cg = get_services().concepts 

313 concepts = cg.extract_concepts(query) 

314 related = cg.expand_query(query) 

315 all_concepts = concepts + [r for r in related if r not in concepts] 

316 

317 if cfg.json_mode: 

318 json_output({"command": "topics", "query": query, "concepts": all_concepts}) 

319 return 

320 if not all_concepts: 

321 console.print("No concepts found for this query.") 

322 return 

323 console.print(f"Concepts related to [{theme.ACCENT}]{query}[/{theme.ACCENT}]:") 

324 for c in all_concepts: 

325 console.print(f" {c}") 

326 

327 

328def _topics_overview(top_k: int) -> None: 

329 """Show top concept communities.""" 

330 from dataclasses import asdict 

331 

332 communities = get_services().concepts.top_communities(k=top_k) 

333 if cfg.json_mode: 

334 json_output({"command": "topics", "communities": [asdict(c) for c in communities]}) 

335 return 

336 if not communities: 

337 console.print("No concept communities found. Try syncing some documents first.") 

338 return 

339 table = Table(title="Concept Communities") 

340 table.add_column("Cluster", justify="right", style=theme.MUTED) 

341 table.add_column("Size", justify="right") 

342 table.add_column("Top Concepts", style=theme.ACCENT) 

343 for comm in communities: 

344 preview = ", ".join(comm.concepts[:_TOPIC_PREVIEW_LIMIT]) 

345 if len(comm.concepts) > _TOPIC_PREVIEW_LIMIT: 

346 preview += f" (+{len(comm.concepts) - _TOPIC_PREVIEW_LIMIT} more)" 

347 table.add_row(str(comm.cluster_id), str(comm.size), preview) 

348 console.print(table)