Coverage for src / lilbee / data / export.py: 100%

110 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Per-page text dataset: build/write from a store, and import one back.""" 

2 

3from __future__ import annotations 

4 

5import json 

6from dataclasses import dataclass 

7from enum import StrEnum 

8from pathlib import Path 

9from typing import TYPE_CHECKING, cast 

10 

11from lilbee.data.ingest.extract import chunk_and_embed_pages 

12from lilbee.data.store import PageTextRecord, SourceType 

13from lilbee.runtime.progress import DetailedProgressCallback, noop_callback 

14 

15if TYPE_CHECKING: 

16 from lilbee.data.store import Store 

17 

18 

19class DatasetFormat(StrEnum): 

20 """On-disk format for the per-page text dataset.""" 

21 

22 PARQUET = "parquet" 

23 JSONL = "jsonl" 

24 

25 

26@dataclass 

27class ImportResult: 

28 """Summary of an `import_dataset` run.""" 

29 

30 sources: list[str] 

31 pages: int 

32 chunks: int 

33 

34 

35def decode_format(value: str) -> DatasetFormat: 

36 """Decode an explicit format string, raising a user-facing ``ValueError``.""" 

37 try: 

38 return DatasetFormat(value) 

39 except ValueError: 

40 raise ValueError(f"Unsupported format: {value!r} (expected parquet or jsonl)") from None 

41 

42 

43def resolve_format(value: str, path: Path) -> DatasetFormat: 

44 """Pick a format from explicit *value*, else the *path* suffix. 

45 

46 Raises ``ValueError`` with a user-facing message when neither yields a 

47 known format. 

48 """ 

49 if value: 

50 return decode_format(value) 

51 suffix = path.suffix.lower().lstrip(".") 

52 try: 

53 return DatasetFormat(suffix) 

54 except ValueError: 

55 raise ValueError( 

56 f"Could not infer format from {path.name!r}; use a .parquet or .jsonl path" 

57 ) from None 

58 

59 

60def build_page_dataset(store: Store, source: str | None = None) -> list[PageTextRecord]: 

61 """Collect per-page text rows for every source (or just *source*). 

62 

63 Sources captured at ingest are returned verbatim from the page-text table. 

64 Sources without captured text (older indexes, code) are reconstructed from 

65 the chunks table; that reconstruction concatenates chunk text per page, so 

66 chunk overlap may repeat a little text across page boundaries. 

67 """ 

68 names = [source] if source is not None else sorted(s["filename"] for s in store.get_sources()) 

69 captured = store.page_text_sources() 

70 rows: list[PageTextRecord] = [] 

71 for name in names: 

72 if name in captured: 

73 rows.extend(store.get_page_texts(name)) 

74 else: 

75 rows.extend(_reconstruct_from_chunks(store, name)) 

76 rows.sort(key=lambda r: (r["source"], r["page"])) 

77 return rows 

78 

79 

80def _reconstruct_from_chunks(store: Store, source: str) -> list[PageTextRecord]: 

81 """Rebuild per-page rows for *source* by joining its chunks per page.""" 

82 by_page: dict[int, list[tuple[int, str]]] = {} 

83 content_type = "" 

84 for chunk in store.get_chunks_by_source(source): 

85 content_type = chunk.content_type or content_type 

86 by_page.setdefault(chunk.page_start, []).append((chunk.chunk_index, chunk.chunk)) 

87 rows: list[PageTextRecord] = [] 

88 for page in sorted(by_page): 

89 ordered = [text for _, text in sorted(by_page[page])] 

90 rows.append( 

91 PageTextRecord( 

92 source=source, page=page, text="\n".join(ordered), content_type=content_type 

93 ) 

94 ) 

95 return rows 

96 

97 

98def _serialize_parquet(rows: list[PageTextRecord]) -> bytes: 

99 import io 

100 

101 import pyarrow as pa 

102 import pyarrow.parquet as pq 

103 

104 table = pa.Table.from_pylist([dict(r) for r in rows]) 

105 buffer = io.BytesIO() 

106 pq.write_table(table, buffer) 

107 return buffer.getvalue() 

108 

109 

110def _serialize_jsonl(rows: list[PageTextRecord]) -> bytes: 

111 return "".join(json.dumps(dict(row)) + "\n" for row in rows).encode("utf-8") 

112 

113 

114_SERIALIZERS = {DatasetFormat.PARQUET: _serialize_parquet, DatasetFormat.JSONL: _serialize_jsonl} 

115 

116 

117def serialize_dataset(rows: list[PageTextRecord], fmt: DatasetFormat) -> bytes: 

118 """Encode *rows* to dataset bytes in the given format.""" 

119 return _SERIALIZERS[fmt](rows) 

120 

121 

122def write_dataset(rows: list[PageTextRecord], path: Path, fmt: DatasetFormat) -> None: 

123 """Write *rows* to *path* in the given format.""" 

124 path.write_bytes(serialize_dataset(rows, fmt)) 

125 

126 

127def _coerce_row(raw: dict) -> PageTextRecord: 

128 """Validate one raw dataset row into a `PageTextRecord`.""" 

129 try: 

130 return PageTextRecord( 

131 source=str(raw["source"]), 

132 page=int(raw["page"]), 

133 text=str(raw["text"]), 

134 content_type=str(raw.get("content_type", "")), 

135 ) 

136 except (KeyError, TypeError, ValueError): 

137 raise ValueError("Dataset row is missing required source/page/text fields") from None 

138 

139 

140def _deserialize_parquet(data: bytes) -> list[PageTextRecord]: 

141 import io 

142 

143 import pyarrow.parquet as pq 

144 

145 return [_coerce_row(row) for row in pq.read_table(io.BytesIO(data)).to_pylist()] 

146 

147 

148def _deserialize_jsonl(data: bytes) -> list[PageTextRecord]: 

149 rows: list[PageTextRecord] = [] 

150 for line in data.decode("utf-8").splitlines(): 

151 stripped = line.strip() 

152 if stripped: 

153 rows.append(_coerce_row(json.loads(stripped))) 

154 return rows 

155 

156 

157_DESERIALIZERS = { 

158 DatasetFormat.PARQUET: _deserialize_parquet, 

159 DatasetFormat.JSONL: _deserialize_jsonl, 

160} 

161 

162 

163def deserialize_dataset(data: bytes, fmt: DatasetFormat) -> list[PageTextRecord]: 

164 """Decode dataset bytes in the given format back into rows.""" 

165 return _DESERIALIZERS[fmt](data) 

166 

167 

168def load_page_dataset(path: Path, fmt: DatasetFormat) -> list[PageTextRecord]: 

169 """Read a per-page text dataset back from disk.""" 

170 if not path.exists(): 

171 raise ValueError(f"Dataset not found: {path}") 

172 return deserialize_dataset(path.read_bytes(), fmt) 

173 

174 

175async def import_dataset( 

176 store: Store, 

177 rows: list[PageTextRecord], 

178 *, 

179 on_progress: DetailedProgressCallback = noop_callback, 

180) -> ImportResult: 

181 """Re-chunk and re-embed *rows* under the current embedder. 

182 

183 Each source's pages are embedded and stored as detached ``IMPORTED`` 

184 chunks plus their page texts. Raises ``EmbeddingModelMismatchError`` (before 

185 any write) when the store was built by a different embedder. 

186 """ 

187 store.assert_embedding_compatible() 

188 by_source: dict[str, list[PageTextRecord]] = {} 

189 for row in rows: 

190 by_source.setdefault(row["source"], []).append(row) 

191 

192 imported: list[str] = [] 

193 total_pages = 0 

194 total_chunks = 0 

195 for name, source_rows in by_source.items(): 

196 source_rows.sort(key=lambda r: r["page"]) 

197 content_type = source_rows[0]["content_type"] or "text" 

198 page_texts = [(r["page"], r["text"]) for r in source_rows] 

199 chunks = await chunk_and_embed_pages(page_texts, name, content_type, on_progress) 

200 store.delete_by_source(name) 

201 store.add_chunks(cast(list[dict], chunks)) 

202 store.add_page_texts([dict(r) for r in source_rows]) 

203 store.upsert_source(name, "", len(chunks), source_type=SourceType.IMPORTED) 

204 imported.append(name) 

205 total_pages += len(source_rows) 

206 total_chunks += len(chunks) 

207 return ImportResult(sources=sorted(imported), pages=total_pages, chunks=total_chunks)