Coverage for src / lilbee / providers / llama_cpp / batching.py: 100%
65 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Batched llama-cpp embed and rerank helpers used inside worker subprocesses."""
3from __future__ import annotations
5import logging
6from collections.abc import Iterator
7from typing import Any
9from lilbee.providers.base import ProviderError
11log = logging.getLogger(__name__)
13_RERANK_PAIR_SEPARATOR = "</s></s>"
15EMBED_N_SEQ_MAX = 64
16"""Max parallel sequences per ``create_embedding`` call.
18llama-cpp-python's inner ``Llama.embed`` flushes its batch on token
19budget but not on sequence count, so caller-side batches above the
20context's ``n_seq_max`` trip a C-level assertion. Workaround for
21upstream issue #2051 / PR #2058 (still open as of May 2026).
22"""
25def _truncate_to_budget(llm: Any, text: str, token_cap: int) -> str:
26 """Tokenize *text*, keep the first ``token_cap`` tokens, detokenize back.
28 Token-aware truncation is needed because the chunker's 4-chars-per-token
29 heuristic underestimates dense input (medical codes, JSON, source code).
30 Bytes-level slicing would split mid-token and confuse the embedder.
31 """
32 tokens = llm.tokenize(text.encode("utf-8"))
33 if len(tokens) <= token_cap:
34 return text
35 truncated: bytes = llm.detokenize(tokens[:token_cap])
36 return truncated.decode("utf-8", errors="replace")
39def _split_into_sub_batches(llm: Any, items: list[str]) -> Iterator[list[str]]:
40 """Yield sub-batches respecting both the token budget and ``EMBED_N_SEQ_MAX``.
42 A single item longer than ``llm.n_batch`` tokens is truncated to that
43 budget with a warning. Without the truncation, llama-cpp's ``llama_decode``
44 returns -1 and the whole call fails, which used to surface as
45 "Embedding worker reported an error: RuntimeError: llama_decode
46 returned -1" on every file in a token-dense corpus.
47 """
48 token_cap = max(1, int(llm.n_batch))
49 sub_batch: list[str] = []
50 sub_tokens = 0
51 for raw_item in items:
52 raw_tokens = max(1, len(llm.tokenize(raw_item.encode("utf-8"))))
53 if raw_tokens > token_cap:
54 log.warning(
55 "Truncating oversize input: %d tokens > cap %d (chars/token heuristic too loose)",
56 raw_tokens,
57 token_cap,
58 )
59 item = _truncate_to_budget(llm, raw_item, token_cap)
60 token_count = token_cap
61 else:
62 item = raw_item
63 token_count = raw_tokens
64 if sub_batch and (
65 sub_tokens + token_count > token_cap or len(sub_batch) >= EMBED_N_SEQ_MAX
66 ):
67 yield sub_batch
68 sub_batch = []
69 sub_tokens = 0
70 sub_batch.append(item)
71 sub_tokens += token_count
72 if sub_batch:
73 yield sub_batch
76def embed_batch(llm: Any, texts: list[str]) -> list[list[float]]:
77 """Embed *texts* in as few llama-cpp calls as the model's batch budget allows.
79 One ``llm.create_embedding(input=sub_batch)`` per sub-batch instead of
80 one per text. Vectors come back in input order. Caller must run inside
81 a worker subprocess where ``redirect_stdio_to_devnull()`` ran at
82 startup so fd 2 is already redirected.
83 """
84 if not texts:
85 return []
86 vectors: list[list[float]] = []
87 for sub_batch in _split_into_sub_batches(llm, texts):
88 vectors.extend(_embed_one_call(llm, sub_batch))
89 return vectors
92def compute_rerank_scores(llm: Any, query: str, candidates: list[str]) -> list[float]:
93 """Score *candidates* against *query* via llama.cpp reranker embeddings.
95 ``pooling_type=LLAMA_POOLING_TYPE_RANK`` requires the pair pre-joined
96 as ``query</s></s>candidate``; passing them as two inputs makes
97 ``llama_decode`` fail with ``-1``. Pairs share ``embed_batch``'s
98 sub-batching so one rerank call decodes many candidates per graph.
99 """
100 if not candidates:
101 return []
102 pairs = [f"{query}{_RERANK_PAIR_SEPARATOR}{candidate}" for candidate in candidates]
103 scores: list[float] = []
104 for sub_batch in _split_into_sub_batches(llm, pairs):
105 scores.extend(_rerank_one_call(llm, sub_batch))
106 return scores
109def _embed_one_call(llm: Any, sub_batch: list[str]) -> list[list[float]]:
110 response = llm.create_embedding(input=sub_batch)
111 data = response.get("data") or []
112 return [item["embedding"] for item in data]
115def _rerank_one_call(llm: Any, sub_batch: list[str]) -> list[float]:
116 response = llm.create_embedding(input=sub_batch)
117 data = response.get("data") or []
118 if len(data) != len(sub_batch):
119 raise ProviderError(
120 f"Reranker returned {len(data)} entries for {len(sub_batch)} pairs; "
121 "llama-cpp-python may have changed its response format",
122 provider="llama-cpp",
123 )
124 return [_extract_rerank_score(item) for item in data]
127def _extract_rerank_score(item: dict[str, Any]) -> float:
128 """Pull a single relevance score from one pooling_type=RANK response item."""
129 embedding = item.get("embedding")
130 if isinstance(embedding, list) and embedding and isinstance(embedding[0], (int, float)):
131 return float(embedding[0])
132 raise ProviderError(
133 "Reranker returned unexpected score shape "
134 f"(got {type(embedding).__name__}: {embedding!r}); "
135 "llama-cpp-python may have changed its response format",
136 provider="llama-cpp",
137 )