Coverage for src / lilbee / retrieval / reranker.py: 100%
68 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Cross-encoder reranking for search results.
3Optional precision pass that scores each (query, chunk) pair through the
4active provider's ``rerank`` method. Only active when
5``cfg.reranker_model`` is set.
7Core technique: Nogueira & Cho 2019, "Passage Re-ranking with BERT"
8(https://arxiv.org/abs/1901.04085).
10Position-aware blending: derived from learning-to-rank literature
11(Burges et al. 2005). Top positions trust hybrid fusion more, lower
12positions trust the reranker more.
13"""
15from __future__ import annotations
17import logging
18from typing import NamedTuple
20from lilbee.core.config import Config
21from lilbee.data.store import SearchChunk
23log = logging.getLogger(__name__)
26class ScoredChunk(NamedTuple):
27 """A search chunk paired with its blended score."""
29 score: float
30 chunk: SearchChunk
33_TOP_POSITION_CUTOFF = 3
34_MID_POSITION_CUTOFF = 10
36_BLEND_SCHEDULE = {
37 "top": (0.70, 0.30),
38 "mid": (0.50, 0.50),
39 "bottom": (0.30, 0.70),
40}
43def _normalize_scores(scores: list[float]) -> list[float]:
44 """Min-max normalize raw cross-encoder scores to [0, 1]."""
45 min_score = min(scores)
46 max_score = max(scores)
47 score_range = max_score - min_score
48 if score_range > 0:
49 return [(s - min_score) / score_range for s in scores]
50 return [0.5] * len(scores)
53def _blend_scores(to_rerank: list[SearchChunk], norm_scores: list[float]) -> list[ScoredChunk]:
54 """Blend fusion scores with reranker scores using position-aware weights.
56 Each chunk is copied with ``rerank_score`` set to its blended score;
57 the input chunks are left untouched.
58 """
59 blended: list[ScoredChunk] = []
60 for i, (chunk, rerank_score) in enumerate(zip(to_rerank, norm_scores, strict=True)):
61 fusion_score = chunk.relevance_score or (1.0 - (chunk.distance or 0.5))
62 fusion_norm = max(0.0, min(1.0, fusion_score))
64 if i < _TOP_POSITION_CUTOFF:
65 fw, rw = _BLEND_SCHEDULE["top"]
66 elif i < _MID_POSITION_CUTOFF:
67 fw, rw = _BLEND_SCHEDULE["mid"]
68 else:
69 fw, rw = _BLEND_SCHEDULE["bottom"]
71 final_score = fw * fusion_norm + rw * rerank_score
72 scored = chunk.model_copy(update={"rerank_score": final_score})
73 blended.append(ScoredChunk(final_score, scored))
74 return blended
77def _pin_original_top(
78 blended: list[ScoredChunk],
79 skip_threshold: float,
80) -> list[ScoredChunk]:
81 """Pin the original top result if its relevance exceeds the skip threshold."""
82 original_top = blended[0].chunk
83 top_score = original_top.relevance_score or 0
84 blended_sorted = sorted(blended, key=lambda x: x.score, reverse=True)
85 if top_score >= skip_threshold and blended_sorted[0].chunk is not original_top:
86 blended_sorted = [ScoredChunk(999.0, original_top)] + [
87 ScoredChunk(s, c) for s, c in blended_sorted if c is not original_top
88 ]
89 return blended_sorted
92class Reranker:
93 """Cross-encoder reranker with position-aware blending.
95 Delegates scoring to the active provider's ``rerank``; handles result
96 blending and the BM25-protection pin (Nogueira & Cho 2019,
97 https://arxiv.org/abs/1901.04085).
98 """
100 def __init__(self, config: Config) -> None:
101 self._config = config
103 def rerank(
104 self,
105 query: str,
106 results: list[SearchChunk],
107 candidates: int | None = None,
108 ) -> list[SearchChunk]:
109 """Rerank search results through the provider's ``rerank`` method."""
110 if not self._config.reranker_model:
111 return results
112 if candidates is None:
113 candidates = self._config.rerank_candidates
114 to_rerank = results[:candidates]
115 remainder = results[candidates:]
117 if not to_rerank:
118 return results
120 scores = _score_candidates(query, to_rerank)
121 if scores is None:
122 return results
124 norm_scores = _normalize_scores(scores)
125 blended = _blend_scores(to_rerank, norm_scores)
126 blended_sorted = _pin_original_top(blended, self._config.expansion_skip_threshold)
128 reranked = [chunk for _, chunk in blended_sorted]
129 return reranked + remainder
132def _score_candidates(query: str, to_rerank: list[SearchChunk]) -> list[float] | None:
133 """Call the active provider's rerank; return None on error after logging."""
134 # circular: services -> reranker via Searcher; deferred so test-time
135 # monkeypatching of ``lilbee.services.get_services`` stays effective.
136 from lilbee.app.services import get_services
138 try:
139 provider = get_services().provider
140 return provider.rerank(query, [c.chunk for c in to_rerank])
141 except Exception as exc:
142 log.warning("Reranker failed; skipping rerank pass: %s", exc, exc_info=True)
143 return None