Coverage for src / lilbee / retrieval / query / dedup.py: 100%

81 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Result filtering, sorting, deduplication, and source-diversity helpers.""" 

2 

3from __future__ import annotations 

4 

5from lilbee.core.config import cfg 

6from lilbee.data.store import CitationRecord, SearchChunk 

7from lilbee.retrieval.query.formatting import format_source 

8 

9_DEFAULT_RELEVANCE_WEIGHT = 0.5 

10 

11 

12def _relevance_weight(result: SearchChunk) -> float: 

13 """Return a [0, 1] relevance weight for distance-aware selection. 

14 

15 Hybrid results (relevance_score set): use directly. 

16 Vector results (distance set): invert cosine distance. 

17 Neither: neutral default. 

18 """ 

19 if result.relevance_score is not None: 

20 return min(1.0, max(0.0, result.relevance_score)) 

21 if result.distance is not None: 

22 return max(0.0, 1.0 - result.distance) 

23 return _DEFAULT_RELEVANCE_WEIGHT 

24 

25 

26def _greedy_cover( 

27 chunk_tokens: list[set[str]], 

28 question_terms: set[str], 

29 term_weights: dict[str, float], 

30 budget: int, 

31 relevance_weights: list[float] | None = None, 

32) -> list[int]: 

33 """Greedy weighted set cover: pick chunks that add the most uncovered weight. 

34 

35 Standard (1 - 1/e) approximation for weighted set cover. Budget is 

36 always filled, falling back to retrieval order once no chunk can 

37 contribute any new weight. When *relevance_weights* is provided, 

38 each chunk's IDF gain is scaled by its relevance so that far-away 

39 chunks are penalised even when they share query terms. 

40 """ 

41 selected: list[int] = [] 

42 covered: set[str] = set() 

43 remaining = list(range(len(chunk_tokens))) 

44 while remaining and len(selected) < budget: 

45 best_pos = -1 

46 best_gain = 0.0 

47 for pos, idx in enumerate(remaining): 

48 new_terms = (chunk_tokens[idx] & question_terms) - covered 

49 gain = sum(term_weights[t] for t in new_terms) 

50 if relevance_weights is not None: 

51 gain *= relevance_weights[idx] 

52 if gain > best_gain: 

53 best_gain = gain 

54 best_pos = pos 

55 if best_pos < 0: 

56 break 

57 chosen = remaining.pop(best_pos) 

58 selected.append(chosen) 

59 covered |= chunk_tokens[chosen] & question_terms 

60 

61 for idx in remaining: 

62 if len(selected) >= budget: 

63 break 

64 selected.append(idx) 

65 return selected 

66 

67 

68def filter_results( 

69 results: list[SearchChunk], 

70 max_distance: float, 

71 min_relevance_score: float = 0.0, 

72) -> list[SearchChunk]: 

73 """Drop results above max_distance or below min_relevance_score. 

74 

75 Hybrid results (relevance_score set) are checked against min_relevance_score. 

76 Vector results (distance set) are checked against max_distance. 

77 Results with neither score pass through. When both scores are present, 

78 relevance_score takes priority (hybrid results use RRF scoring, not 

79 cosine distance). Pass max_distance=0 to disable distance filtering. 

80 """ 

81 if max_distance <= 0 and min_relevance_score <= 0: 

82 return results 

83 filtered: list[SearchChunk] = [] 

84 for r in results: 

85 # Hybrid results: check relevance_score (takes priority over distance) 

86 if r.relevance_score is not None: 

87 if min_relevance_score > 0 and r.relevance_score < min_relevance_score: 

88 continue 

89 elif r.distance is not None and max_distance > 0 and r.distance > max_distance: 

90 continue 

91 filtered.append(r) 

92 return filtered 

93 

94 

95def deduplicate_sources( 

96 results: list[SearchChunk], 

97 max_citations: int = 5, 

98 citations_map: dict[str, list[CitationRecord]] | None = None, 

99) -> list[str]: 

100 """Merge results from same source into deduplicated citation lines.""" 

101 seen: set[str] = set() 

102 citation_lines: list[str] = [] 

103 for r in results: 

104 cits = (citations_map or {}).get(r.source) 

105 line = format_source(r, citations=cits) 

106 if line not in seen: 

107 seen.add(line) 

108 citation_lines.append(line) 

109 if len(citation_lines) >= max_citations: 

110 break 

111 return citation_lines 

112 

113 

114def _sort_key(r: SearchChunk) -> float: 

115 """Sort key: lower = more relevant.""" 

116 if r.relevance_score is not None: 

117 return -r.relevance_score 

118 if r.distance is not None: 

119 return r.distance 

120 return float("inf") 

121 

122 

123def sort_by_relevance(results: list[SearchChunk]) -> list[SearchChunk]: 

124 """Sort search results by relevance (works for both hybrid and vector results).""" 

125 return sorted(results, key=_sort_key) 

126 

127 

128def diversify_sources( 

129 results: list[SearchChunk], max_per_source: int | None = None 

130) -> list[SearchChunk]: 

131 """Cap results per source document to ensure diversity. 

132 Source diversity filtering: Zhai 2008, "Statistical Language Models for 

133 Information Retrieval" -- caps per-source representation to prevent 

134 any single document from dominating results. 

135 """ 

136 if max_per_source is None: 

137 max_per_source = cfg.diversity_max_per_source 

138 counts: dict[str, int] = {} 

139 diverse: list[SearchChunk] = [] 

140 for r in results: 

141 count = counts.get(r.source, 0) 

142 if count < max_per_source: 

143 diverse.append(r) 

144 counts[r.source] = count + 1 

145 return diverse 

146 

147 

148def prepare_results(results: list[SearchChunk]) -> list[SearchChunk]: 

149 """Sort by relevance and apply source diversity cap.""" 

150 return diversify_sources(sort_by_relevance(results))