Coverage for src / lilbee / data / export.py: 100%
110 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Per-page text dataset: build/write from a store, and import one back."""
3from __future__ import annotations
5import json
6from dataclasses import dataclass
7from enum import StrEnum
8from pathlib import Path
9from typing import TYPE_CHECKING, cast
11from lilbee.data.ingest.extract import chunk_and_embed_pages
12from lilbee.data.store import PageTextRecord, SourceType
13from lilbee.runtime.progress import DetailedProgressCallback, noop_callback
15if TYPE_CHECKING:
16 from lilbee.data.store import Store
19class DatasetFormat(StrEnum):
20 """On-disk format for the per-page text dataset."""
22 PARQUET = "parquet"
23 JSONL = "jsonl"
26@dataclass
27class ImportResult:
28 """Summary of an `import_dataset` run."""
30 sources: list[str]
31 pages: int
32 chunks: int
35def decode_format(value: str) -> DatasetFormat:
36 """Decode an explicit format string, raising a user-facing ``ValueError``."""
37 try:
38 return DatasetFormat(value)
39 except ValueError:
40 raise ValueError(f"Unsupported format: {value!r} (expected parquet or jsonl)") from None
43def resolve_format(value: str, path: Path) -> DatasetFormat:
44 """Pick a format from explicit *value*, else the *path* suffix.
46 Raises ``ValueError`` with a user-facing message when neither yields a
47 known format.
48 """
49 if value:
50 return decode_format(value)
51 suffix = path.suffix.lower().lstrip(".")
52 try:
53 return DatasetFormat(suffix)
54 except ValueError:
55 raise ValueError(
56 f"Could not infer format from {path.name!r}; use a .parquet or .jsonl path"
57 ) from None
60def build_page_dataset(store: Store, source: str | None = None) -> list[PageTextRecord]:
61 """Collect per-page text rows for every source (or just *source*).
63 Sources captured at ingest are returned verbatim from the page-text table.
64 Sources without captured text (older indexes, code) are reconstructed from
65 the chunks table; that reconstruction concatenates chunk text per page, so
66 chunk overlap may repeat a little text across page boundaries.
67 """
68 names = [source] if source is not None else sorted(s["filename"] for s in store.get_sources())
69 captured = store.page_text_sources()
70 rows: list[PageTextRecord] = []
71 for name in names:
72 if name in captured:
73 rows.extend(store.get_page_texts(name))
74 else:
75 rows.extend(_reconstruct_from_chunks(store, name))
76 rows.sort(key=lambda r: (r["source"], r["page"]))
77 return rows
80def _reconstruct_from_chunks(store: Store, source: str) -> list[PageTextRecord]:
81 """Rebuild per-page rows for *source* by joining its chunks per page."""
82 by_page: dict[int, list[tuple[int, str]]] = {}
83 content_type = ""
84 for chunk in store.get_chunks_by_source(source):
85 content_type = chunk.content_type or content_type
86 by_page.setdefault(chunk.page_start, []).append((chunk.chunk_index, chunk.chunk))
87 rows: list[PageTextRecord] = []
88 for page in sorted(by_page):
89 ordered = [text for _, text in sorted(by_page[page])]
90 rows.append(
91 PageTextRecord(
92 source=source, page=page, text="\n".join(ordered), content_type=content_type
93 )
94 )
95 return rows
98def _serialize_parquet(rows: list[PageTextRecord]) -> bytes:
99 import io
101 import pyarrow as pa
102 import pyarrow.parquet as pq
104 table = pa.Table.from_pylist([dict(r) for r in rows])
105 buffer = io.BytesIO()
106 pq.write_table(table, buffer)
107 return buffer.getvalue()
110def _serialize_jsonl(rows: list[PageTextRecord]) -> bytes:
111 return "".join(json.dumps(dict(row)) + "\n" for row in rows).encode("utf-8")
114_SERIALIZERS = {DatasetFormat.PARQUET: _serialize_parquet, DatasetFormat.JSONL: _serialize_jsonl}
117def serialize_dataset(rows: list[PageTextRecord], fmt: DatasetFormat) -> bytes:
118 """Encode *rows* to dataset bytes in the given format."""
119 return _SERIALIZERS[fmt](rows)
122def write_dataset(rows: list[PageTextRecord], path: Path, fmt: DatasetFormat) -> None:
123 """Write *rows* to *path* in the given format."""
124 path.write_bytes(serialize_dataset(rows, fmt))
127def _coerce_row(raw: dict) -> PageTextRecord:
128 """Validate one raw dataset row into a `PageTextRecord`."""
129 try:
130 return PageTextRecord(
131 source=str(raw["source"]),
132 page=int(raw["page"]),
133 text=str(raw["text"]),
134 content_type=str(raw.get("content_type", "")),
135 )
136 except (KeyError, TypeError, ValueError):
137 raise ValueError("Dataset row is missing required source/page/text fields") from None
140def _deserialize_parquet(data: bytes) -> list[PageTextRecord]:
141 import io
143 import pyarrow.parquet as pq
145 return [_coerce_row(row) for row in pq.read_table(io.BytesIO(data)).to_pylist()]
148def _deserialize_jsonl(data: bytes) -> list[PageTextRecord]:
149 rows: list[PageTextRecord] = []
150 for line in data.decode("utf-8").splitlines():
151 stripped = line.strip()
152 if stripped:
153 rows.append(_coerce_row(json.loads(stripped)))
154 return rows
157_DESERIALIZERS = {
158 DatasetFormat.PARQUET: _deserialize_parquet,
159 DatasetFormat.JSONL: _deserialize_jsonl,
160}
163def deserialize_dataset(data: bytes, fmt: DatasetFormat) -> list[PageTextRecord]:
164 """Decode dataset bytes in the given format back into rows."""
165 return _DESERIALIZERS[fmt](data)
168def load_page_dataset(path: Path, fmt: DatasetFormat) -> list[PageTextRecord]:
169 """Read a per-page text dataset back from disk."""
170 if not path.exists():
171 raise ValueError(f"Dataset not found: {path}")
172 return deserialize_dataset(path.read_bytes(), fmt)
175async def import_dataset(
176 store: Store,
177 rows: list[PageTextRecord],
178 *,
179 on_progress: DetailedProgressCallback = noop_callback,
180) -> ImportResult:
181 """Re-chunk and re-embed *rows* under the current embedder.
183 Each source's pages are embedded and stored as detached ``IMPORTED``
184 chunks plus their page texts. Raises ``EmbeddingModelMismatchError`` (before
185 any write) when the store was built by a different embedder.
186 """
187 store.assert_embedding_compatible()
188 by_source: dict[str, list[PageTextRecord]] = {}
189 for row in rows:
190 by_source.setdefault(row["source"], []).append(row)
192 imported: list[str] = []
193 total_pages = 0
194 total_chunks = 0
195 for name, source_rows in by_source.items():
196 source_rows.sort(key=lambda r: r["page"])
197 content_type = source_rows[0]["content_type"] or "text"
198 page_texts = [(r["page"], r["text"]) for r in source_rows]
199 chunks = await chunk_and_embed_pages(page_texts, name, content_type, on_progress)
200 store.delete_by_source(name)
201 store.add_chunks(cast(list[dict], chunks))
202 store.add_page_texts([dict(r) for r in source_rows])
203 store.upsert_source(name, "", len(chunks), source_type=SourceType.IMPORTED)
204 imported.append(name)
205 total_pages += len(source_rows)
206 total_chunks += len(chunks)
207 return ImportResult(sources=sorted(imported), pages=total_pages, chunks=total_chunks)