Coverage for src / lilbee / runtime / ingest_lock.py: 100%
47 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Per-process ingest lock registry.
3A runtime concurrency primitive shared by the HTTP add-files handler, the
4TUI ingest task, and any other surface that wants to serialize concurrent
5ingest of the same source file. Lives at the runtime layer so callers in
6core/server/cli/tui can all use it without dragging in HTTP-layer code.
7"""
9from __future__ import annotations
11import asyncio
12from pathlib import Path
15class IngestLockRegistry:
16 """Per-source ingest locks with a serialized check-and-acquire step.
18 The registry lock serializes lock creation and the check-and-acquire
19 so concurrent ``/api/add`` calls cannot TOCTOU between
20 ``locked()`` and ``acquire()``. One instance is held by ``Services``
21 and discarded by ``reset_services()``.
22 """
24 def __init__(self) -> None:
25 self._locks: dict[str, asyncio.Lock] = {}
26 self._registry_lock: asyncio.Lock | None = None
28 def _get_registry_lock(self) -> asyncio.Lock:
29 if self._registry_lock is None:
30 self._registry_lock = asyncio.Lock()
31 return self._registry_lock
33 def reset(self) -> None:
34 """Test hook: clear per-source locks and the registry lock."""
35 self._locks.clear()
36 self._registry_lock = None
38 async def try_acquire(self, name: str) -> asyncio.Lock | None:
39 """Acquire the lock for ``name`` or return ``None`` if already held."""
40 async with self._get_registry_lock():
41 lock = self._locks.get(name)
42 if lock is None:
43 lock = asyncio.Lock()
44 self._locks[name] = lock
45 if lock.locked():
46 return None
47 await lock.acquire()
48 return lock
50 @staticmethod
51 def canonical_source_name(p_str: str) -> str:
52 """Match the basename ``copy_files`` writes under ``cfg.documents_dir``."""
53 return Path(p_str).name
55 async def acquire(self, paths: list[str]) -> tuple[list[tuple[str, asyncio.Lock]], list[str]]:
56 """Return ``(acquired, busy)`` partitioning of ``paths`` by lock state."""
57 acquired: list[tuple[str, asyncio.Lock]] = []
58 busy: list[str] = []
59 seen: set[str] = set()
60 for p_str in paths:
61 name = self.canonical_source_name(p_str)
62 if name in seen:
63 continue
64 seen.add(name)
65 lock = await self.try_acquire(name)
66 if lock is None:
67 busy.append(name)
68 else:
69 acquired.append((name, lock))
70 return acquired, busy
72 @staticmethod
73 def release(acquired: list[tuple[str, asyncio.Lock]]) -> None:
74 """Release every lock in ``acquired``. Safe to call multiple times."""
75 while acquired:
76 _, lock = acquired.pop()
77 if lock.locked():
78 lock.release()