Coverage for src / lilbee / retrieval / query / dedup.py: 100%
81 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Result filtering, sorting, deduplication, and source-diversity helpers."""
3from __future__ import annotations
5from lilbee.core.config import cfg
6from lilbee.data.store import CitationRecord, SearchChunk
7from lilbee.retrieval.query.formatting import format_source
9_DEFAULT_RELEVANCE_WEIGHT = 0.5
12def _relevance_weight(result: SearchChunk) -> float:
13 """Return a [0, 1] relevance weight for distance-aware selection.
15 Hybrid results (relevance_score set): use directly.
16 Vector results (distance set): invert cosine distance.
17 Neither: neutral default.
18 """
19 if result.relevance_score is not None:
20 return min(1.0, max(0.0, result.relevance_score))
21 if result.distance is not None:
22 return max(0.0, 1.0 - result.distance)
23 return _DEFAULT_RELEVANCE_WEIGHT
26def _greedy_cover(
27 chunk_tokens: list[set[str]],
28 question_terms: set[str],
29 term_weights: dict[str, float],
30 budget: int,
31 relevance_weights: list[float] | None = None,
32) -> list[int]:
33 """Greedy weighted set cover: pick chunks that add the most uncovered weight.
35 Standard (1 - 1/e) approximation for weighted set cover. Budget is
36 always filled, falling back to retrieval order once no chunk can
37 contribute any new weight. When *relevance_weights* is provided,
38 each chunk's IDF gain is scaled by its relevance so that far-away
39 chunks are penalised even when they share query terms.
40 """
41 selected: list[int] = []
42 covered: set[str] = set()
43 remaining = list(range(len(chunk_tokens)))
44 while remaining and len(selected) < budget:
45 best_pos = -1
46 best_gain = 0.0
47 for pos, idx in enumerate(remaining):
48 new_terms = (chunk_tokens[idx] & question_terms) - covered
49 gain = sum(term_weights[t] for t in new_terms)
50 if relevance_weights is not None:
51 gain *= relevance_weights[idx]
52 if gain > best_gain:
53 best_gain = gain
54 best_pos = pos
55 if best_pos < 0:
56 break
57 chosen = remaining.pop(best_pos)
58 selected.append(chosen)
59 covered |= chunk_tokens[chosen] & question_terms
61 for idx in remaining:
62 if len(selected) >= budget:
63 break
64 selected.append(idx)
65 return selected
68def filter_results(
69 results: list[SearchChunk],
70 max_distance: float,
71 min_relevance_score: float = 0.0,
72) -> list[SearchChunk]:
73 """Drop results above max_distance or below min_relevance_score.
75 Hybrid results (relevance_score set) are checked against min_relevance_score.
76 Vector results (distance set) are checked against max_distance.
77 Results with neither score pass through. When both scores are present,
78 relevance_score takes priority (hybrid results use RRF scoring, not
79 cosine distance). Pass max_distance=0 to disable distance filtering.
80 """
81 if max_distance <= 0 and min_relevance_score <= 0:
82 return results
83 filtered: list[SearchChunk] = []
84 for r in results:
85 # Hybrid results: check relevance_score (takes priority over distance)
86 if r.relevance_score is not None:
87 if min_relevance_score > 0 and r.relevance_score < min_relevance_score:
88 continue
89 elif r.distance is not None and max_distance > 0 and r.distance > max_distance:
90 continue
91 filtered.append(r)
92 return filtered
95def deduplicate_sources(
96 results: list[SearchChunk],
97 max_citations: int = 5,
98 citations_map: dict[str, list[CitationRecord]] | None = None,
99) -> list[str]:
100 """Merge results from same source into deduplicated citation lines."""
101 seen: set[str] = set()
102 citation_lines: list[str] = []
103 for r in results:
104 cits = (citations_map or {}).get(r.source)
105 line = format_source(r, citations=cits)
106 if line not in seen:
107 seen.add(line)
108 citation_lines.append(line)
109 if len(citation_lines) >= max_citations:
110 break
111 return citation_lines
114def _sort_key(r: SearchChunk) -> float:
115 """Sort key: lower = more relevant."""
116 if r.relevance_score is not None:
117 return -r.relevance_score
118 if r.distance is not None:
119 return r.distance
120 return float("inf")
123def sort_by_relevance(results: list[SearchChunk]) -> list[SearchChunk]:
124 """Sort search results by relevance (works for both hybrid and vector results)."""
125 return sorted(results, key=_sort_key)
128def diversify_sources(
129 results: list[SearchChunk], max_per_source: int | None = None
130) -> list[SearchChunk]:
131 """Cap results per source document to ensure diversity.
132 Source diversity filtering: Zhai 2008, "Statistical Language Models for
133 Information Retrieval" -- caps per-source representation to prevent
134 any single document from dominating results.
135 """
136 if max_per_source is None:
137 max_per_source = cfg.diversity_max_per_source
138 counts: dict[str, int] = {}
139 diverse: list[SearchChunk] = []
140 for r in results:
141 count = counts.get(r.source, 0)
142 if count < max_per_source:
143 diverse.append(r)
144 counts[r.source] = count + 1
145 return diverse
148def prepare_results(results: list[SearchChunk]) -> list[SearchChunk]:
149 """Sort by relevance and apply source diversity cap."""
150 return diversify_sources(sort_by_relevance(results))