Coverage for src / lilbee / modelhub / model_manager / core.py: 100%

162 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""ModelManager: native and SDK-backed model lifecycle operations.""" 

2 

3import json 

4import logging 

5import time 

6from collections.abc import Callable 

7from http import HTTPStatus 

8from pathlib import Path 

9 

10import httpx 

11 

12from lilbee.catalog.types import ModelSource 

13from lilbee.core.config import DEFAULT_HTTP_TIMEOUT 

14from lilbee.core.security import validate_path_within 

15from lilbee.modelhub.model_manager.types import ModelNotFoundError 

16from lilbee.modelhub.registry import ModelRegistry 

17 

18log = logging.getLogger(__name__) 

19 

20_INSTALLED_CACHE_TTL_SECONDS = 30.0 

21 

22 

23class ModelManager: 

24 """Manages model lifecycle with distinct sources.""" 

25 

26 def __init__(self, models_dir: Path, remote_base_url: str = "http://localhost:11434") -> None: 

27 self._models_dir = models_dir 

28 self._remote_base_url = remote_base_url.rstrip("/") 

29 self._registry = ModelRegistry(self._models_dir) 

30 # Memoize list_installed results to avoid walking the registry 

31 # filesystem and hitting the backend HTTP endpoint on every call. 

32 # The catalog filter path fires this per request. Time-based TTL 

33 # plus explicit invalidation on pull/remove keeps freshness. 

34 self._installed_cache: dict[ModelSource | None, tuple[float, list[str]]] = {} 

35 # Identity cache: refs + hf_repos of installed natives. The catalog 

36 # screen reads this to mark rows as installed without re-walking 

37 # the registry on every screen mount (~150-300 ms saved). 

38 self._native_identities_cache: tuple[float, frozenset[str]] | None = None 

39 

40 def list_installed(self, source: ModelSource | None = None) -> list[str]: 

41 """List installed model names. ``source=None`` lists all sources. 

42 

43 Memoized with a ``_INSTALLED_CACHE_TTL_SECONDS`` TTL and 

44 invalidated eagerly by ``pull``/``remove``. 

45 """ 

46 now = time.monotonic() 

47 cached = self._installed_cache.get(source) 

48 if cached is not None: 

49 cached_at, cached_result = cached 

50 if now - cached_at < _INSTALLED_CACHE_TTL_SECONDS: 

51 return cached_result 

52 

53 if source is None: 

54 native = set(self._list_native()) 

55 remote = set(self._list_remote()) 

56 result = sorted(native | remote) 

57 elif source is ModelSource.NATIVE: 

58 result = self._list_native() 

59 else: 

60 result = self._list_remote() 

61 

62 self._installed_cache[source] = (now, result) 

63 return result 

64 

65 def list_native_identities(self) -> frozenset[str]: 

66 """Return refs + hf_repos of installed native models. 

67 

68 Same TTL as ``list_installed``. The catalog screen reads this to 

69 mark catalog rows as installed without re-walking the registry 

70 on every screen mount. 

71 """ 

72 now = time.monotonic() 

73 if self._native_identities_cache is not None: 

74 cached_at, cached_result = self._native_identities_cache 

75 if now - cached_at < _INSTALLED_CACHE_TTL_SECONDS: 

76 return cached_result 

77 identities: set[str] = set() 

78 try: 

79 for m in self._registry.list_installed(): 

80 identities.add(m.ref) 

81 identities.add(m.hf_repo) 

82 except Exception: 

83 log.debug("ModelRegistry.list_installed failed", exc_info=True) 

84 result = frozenset(identities) 

85 self._native_identities_cache = (now, result) 

86 return result 

87 

88 def _invalidate_installed_cache(self) -> None: 

89 """Drop all cached list_installed results.""" 

90 self._installed_cache.clear() 

91 self._native_identities_cache = None 

92 

93 def _list_native(self) -> list[str]: 

94 """List native models from the registry only.""" 

95 return sorted(m.ref for m in self._registry.list_installed()) 

96 

97 def _list_remote(self) -> list[str]: 

98 """List models from the SDK backend via its HTTP API.""" 

99 url = f"{self._remote_base_url}/api/tags" 

100 try: 

101 resp = httpx.get(url, timeout=DEFAULT_HTTP_TIMEOUT) 

102 resp.raise_for_status() 

103 data = resp.json() 

104 return [m["name"] for m in data.get("models", [])] 

105 except httpx.HTTPStatusError as exc: 

106 log.warning("SDK backend HTTP error listing models: %s", exc) 

107 return [] 

108 except (httpx.ConnectError, httpx.TimeoutException) as exc: 

109 log.debug("SDK backend not reachable: %s", exc) 

110 return [] 

111 

112 def is_installed(self, model: str, source: ModelSource | None = None) -> bool: 

113 """Check if model exists in specified source.""" 

114 if source is None: 

115 return self._is_native(model) or self._is_remote(model) 

116 if source is ModelSource.NATIVE: 

117 return self._is_native(model) 

118 return self._is_remote(model) 

119 

120 def _is_native(self, model: str) -> bool: 

121 if self._registry.is_installed(model): 

122 return True 

123 try: 

124 validate_path_within(self._models_dir / model, self._models_dir) 

125 except ValueError: 

126 return False 

127 return (self._models_dir / model).is_file() 

128 

129 def _is_remote(self, model: str) -> bool: 

130 return model in self.list_installed(ModelSource.REMOTE) 

131 

132 def get_source(self, model: str) -> ModelSource | None: 

133 """Find which source a model lives in. Native takes precedence.""" 

134 if self._is_native(model): 

135 return ModelSource.NATIVE 

136 if self._is_remote(model): 

137 return ModelSource.REMOTE 

138 return None 

139 

140 def pull( 

141 self, 

142 model: str, 

143 source: ModelSource, 

144 *, 

145 on_progress: Callable[[dict], None] | None = None, 

146 on_bytes: Callable[[int, int], None] | None = None, 

147 ) -> Path | None: 

148 """Pull/download model to specified source. 

149 

150 Returns the Path for native downloads, None for backend-managed pulls. 

151 

152 *on_progress* receives dict events from the SDK backend. 

153 *on_bytes* receives (downloaded_bytes, total_bytes) from native 

154 HuggingFace downloads. The two sources report progress in different 

155 shapes, so callers pass whichever matches the chosen source. 

156 """ 

157 try: 

158 if source is ModelSource.NATIVE: 

159 return self._pull_native(model, on_bytes=on_bytes) 

160 self._pull_remote(model, on_progress=on_progress) 

161 return None 

162 finally: 

163 self._invalidate_installed_cache() 

164 

165 def _pull_native( 

166 self, 

167 model: str, 

168 *, 

169 on_bytes: Callable[[int, int], None] | None = None, 

170 ) -> Path: 

171 """Download a featured or ad-hoc HuggingFace model to the native GGUF directory.""" 

172 # heavy: lilbee.catalog (>50ms; huggingface_hub fanout) 

173 from lilbee.catalog import download_model, resolve_pull_target 

174 from lilbee.modelhub.registry import register_downloaded_model 

175 

176 entry = resolve_pull_target(model) 

177 if entry is None: 

178 raise ModelNotFoundError( 

179 f"Model '{model}' not recognized. " 

180 "Pass a HuggingFace repo id (owner/name) or a featured model name." 

181 ) 

182 path = download_model(entry, on_progress=on_bytes, on_complete=register_downloaded_model) 

183 log.info("Downloaded %s to %s", model, path) 

184 return path 

185 

186 def _pull_remote( 

187 self, model: str, *, on_progress: Callable[[dict], None] | None = None 

188 ) -> None: 

189 """Pull model via the SDK backend's HTTP API with streaming progress.""" 

190 url = f"{self._remote_base_url}/api/pull" 

191 try: 

192 with ( 

193 # Model pulls stream progress over minutes; an overall 

194 # timeout would cut the download. Connect/write timeouts 

195 # still apply via httpx defaults when timeout=None. 

196 httpx.Client(timeout=None) as client, # noqa: S113 

197 client.stream("POST", url, json={"name": model, "stream": True}) as resp, 

198 ): 

199 resp.raise_for_status() 

200 for line in resp.iter_lines(): 

201 if not line: 

202 continue 

203 data = json.loads(line) 

204 if "error" in data: 

205 raise RuntimeError(f"Failed to pull '{model}': {data['error']}") 

206 if on_progress is not None: 

207 on_progress(data) 

208 except httpx.ConnectError as exc: 

209 raise RuntimeError( 

210 f"Cannot connect to SDK backend: {exc}. Is the server running?" 

211 ) from exc 

212 log.info("Pulled %s via SDK backend", model) 

213 

214 def remove(self, model: str, source: ModelSource | None = None) -> bool: 

215 """Remove installed model. Returns True if removed.""" 

216 try: 

217 if source is None: 

218 native_removed = self._remove_native(model) 

219 backend_removed = self._remove_remote(model) 

220 return native_removed or backend_removed 

221 if source is ModelSource.NATIVE: 

222 return self._remove_native(model) 

223 return self._remove_remote(model) 

224 finally: 

225 self._invalidate_installed_cache() 

226 

227 def _remove_native(self, model: str) -> bool: 

228 if self._registry.remove(model): 

229 log.info("Removed native model %s from registry", model) 

230 return True 

231 try: 

232 path = validate_path_within(self._models_dir / model, self._models_dir) 

233 except ValueError: 

234 log.warning("Path traversal blocked: %s escapes %s", model, self._models_dir) 

235 return False 

236 if path.is_file(): 

237 path.unlink() 

238 log.info("Removed native model %s", model) 

239 return True 

240 return False 

241 

242 def _remove_remote(self, model: str) -> bool: 

243 url = f"{self._remote_base_url}/api/delete" 

244 try: 

245 resp = httpx.request( 

246 "DELETE", 

247 url, 

248 content=json.dumps({"model": model}).encode(), 

249 headers={"Content-Type": "application/json"}, 

250 timeout=DEFAULT_HTTP_TIMEOUT, 

251 ) 

252 if resp.status_code == HTTPStatus.OK: 

253 log.info("Removed backend model %s", model) 

254 return True 

255 if resp.status_code == HTTPStatus.NOT_FOUND: 

256 return False 

257 log.warning("Unexpected status %d removing %s", resp.status_code, model) 

258 return False 

259 except httpx.ConnectError as exc: 

260 raise RuntimeError( 

261 f"Cannot connect to SDK backend: {exc}. Is the server running?" 

262 ) from exc