Coverage for src / lilbee / modelhub / models.py: 100%

168 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""RAM detection, model selection, interactive picker, and auto-install for chat models.""" 

2 

3import functools 

4import logging 

5import os 

6import shutil 

7import sys 

8from dataclasses import dataclass 

9from pathlib import Path 

10 

11from rich.console import Console 

12from rich.progress import BarColumn, DownloadColumn, Progress, SpinnerColumn, TextColumn 

13from rich.table import Table 

14 

15from lilbee.catalog.types import ModelTask 

16from lilbee.core.config.model import cfg 

17from lilbee.modelhub.registry import ModelRegistry 

18 

19log = logging.getLogger(__name__) 

20 

21FEATURED_STAR = "★" 

22 

23# Extra headroom required beyond model size (GB) 

24_DISK_HEADROOM_GB = 2 

25 

26MODELS_BROWSE_URL = "https://huggingface.co/models?library=gguf&sort=trending" 

27 

28 

29@dataclass(frozen=True) 

30class ModelInfo: 

31 """A curated chat model with metadata for the picker UI.""" 

32 

33 ref: str # canonical HF ref (e.g. "Qwen/Qwen3-0.6B-GGUF") 

34 display_name: str # UI label (e.g. "Qwen3 0.6B") 

35 size_gb: float 

36 min_ram_gb: float 

37 description: str 

38 

39 

40def _catalog_from_featured(featured: tuple) -> tuple[ModelInfo, ...]: 

41 """Build a ModelInfo tuple from ``lilbee.catalog``'s CatalogModel entries.""" 

42 return tuple( 

43 ModelInfo(m.ref, m.display_name, m.size_gb, m.min_ram_gb, m.description) for m in featured 

44 ) 

45 

46 

47@functools.cache 

48def _get_model_catalog() -> tuple[ModelInfo, ...]: 

49 from lilbee.catalog import FEATURED_CHAT 

50 

51 return _catalog_from_featured(FEATURED_CHAT) 

52 

53 

54def __getattr__(name: str) -> tuple[ModelInfo, ...]: 

55 if name == "MODEL_CATALOG": 

56 return _get_model_catalog() 

57 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") 

58 

59 

60def get_system_ram_gb() -> float: 

61 """Return total system RAM in GB. Falls back to 8.0 if detection fails.""" 

62 try: 

63 if sys.platform == "win32": 

64 import ctypes 

65 

66 class _MEMORYSTATUSEX(ctypes.Structure): 

67 _fields_ = [ 

68 ("dwLength", ctypes.c_ulong), 

69 ("dwMemoryLoad", ctypes.c_ulong), 

70 ("ullTotalPhys", ctypes.c_ulonglong), 

71 ("ullAvailPhys", ctypes.c_ulonglong), 

72 ("ullTotalPageFile", ctypes.c_ulonglong), 

73 ("ullAvailPageFile", ctypes.c_ulonglong), 

74 ("ullTotalVirtual", ctypes.c_ulonglong), 

75 ("ullAvailVirtual", ctypes.c_ulonglong), 

76 ("ullAvailExtendedVirtual", ctypes.c_ulonglong), 

77 ] 

78 

79 stat = _MEMORYSTATUSEX() 

80 stat.dwLength = ctypes.sizeof(stat) 

81 ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(stat)) # type: ignore[attr-defined] 

82 return stat.ullTotalPhys / (1024**3) 

83 pages = os.sysconf("SC_PHYS_PAGES") 

84 page_size = os.sysconf("SC_PAGE_SIZE") 

85 return (pages * page_size) / (1024**3) 

86 except (OSError, AttributeError, ValueError): 

87 log.debug("RAM detection failed, falling back to 8.0 GB") 

88 return 8.0 

89 

90 

91def get_free_disk_gb(path: Path) -> float: 

92 """Return free disk space in GB for the filesystem containing *path*.""" 

93 check_path = path if path.exists() else path.parent 

94 while not check_path.exists(): 

95 check_path = check_path.parent 

96 usage = shutil.disk_usage(check_path) 

97 return usage.free / (1024**3) 

98 

99 

100def pick_default_model(ram_gb: float) -> ModelInfo: 

101 """Choose the largest catalog model that fits in *ram_gb*.""" 

102 best = _get_model_catalog()[0] 

103 for model in _get_model_catalog(): 

104 if model.min_ram_gb <= ram_gb: 

105 best = model 

106 return best 

107 

108 

109def _model_download_size_gb(model: str) -> float: 

110 """Estimated download size in GiB for an HF model ref.""" 

111 catalog_sizes = {m.ref: m.size_gb for m in _get_model_catalog()} 

112 fallback = 5.0 # reasonable default for unknown models 

113 return catalog_sizes.get(model, fallback) 

114 

115 

116def display_model_picker( 

117 ram_gb: float, free_disk_gb: float, *, console: Console | None = None 

118) -> ModelInfo: 

119 """Show a Rich table of catalog models and return the recommended model.""" 

120 console = console or Console(stderr=True) 

121 recommended = pick_default_model(ram_gb) 

122 

123 table = Table(title="Available Models", show_lines=False) 

124 table.add_column("#", justify="right", style="bold") 

125 table.add_column("Model", style="cyan") 

126 table.add_column("Size", justify="right") 

127 table.add_column("Description") 

128 

129 for idx, model in enumerate(_get_model_catalog(), 1): 

130 num_str = str(idx) 

131 label = model.display_name 

132 size_str = f"{model.size_gb:.1f} GB" 

133 desc = model.description 

134 

135 is_recommended = model == recommended 

136 disk_too_small = free_disk_gb < model.size_gb + _DISK_HEADROOM_GB 

137 

138 if is_recommended: 

139 label = f"[bold]{label} ★[/bold]" 

140 desc = f"[bold]{desc}[/bold]" 

141 num_str = f"[bold]{num_str}[/bold]" 

142 

143 if disk_too_small: 

144 size_str = f"[red]{model.size_gb:.1f} GB[/red]" 

145 

146 table.add_row(num_str, label, size_str, desc) 

147 

148 console.print() 

149 console.print("[bold]No chat model found.[/bold] Pick one to download:\n") 

150 console.print(table) 

151 console.print(f"\n System: {ram_gb:.0f} GB RAM, {free_disk_gb:.1f} GB free disk") 

152 console.print(f" {FEATURED_STAR} = recommended for your system") 

153 console.print(f" Browse more models at {MODELS_BROWSE_URL}\n") 

154 

155 return recommended 

156 

157 

158def prompt_model_choice(ram_gb: float) -> ModelInfo: 

159 """Prompt the user to pick a model by number. Returns the chosen ModelInfo.""" 

160 free_disk_gb = get_free_disk_gb(cfg.data_dir) 

161 recommended = display_model_picker(ram_gb, free_disk_gb) 

162 default_idx = list(_get_model_catalog()).index(recommended) + 1 

163 

164 while True: 

165 try: 

166 raw = input(f"Choice [{default_idx}]: ").strip() 

167 except (EOFError, KeyboardInterrupt): 

168 return recommended 

169 

170 if not raw: 

171 return recommended 

172 

173 try: 

174 choice = int(raw) 

175 except ValueError: 

176 sys.stderr.write(f"Enter a number 1-{len(_get_model_catalog())}.\n") 

177 continue 

178 

179 if 1 <= choice <= len(_get_model_catalog()): 

180 return _get_model_catalog()[choice - 1] 

181 

182 sys.stderr.write(f"Enter a number 1-{len(_get_model_catalog())}.\n") 

183 

184 

185def validate_disk_and_pull( 

186 model_info: ModelInfo, free_gb: float, *, console: Console | None = None 

187) -> str: 

188 """Check disk space and pull the model. Returns the pulled ref; persist via the caller.""" 

189 required_gb = model_info.size_gb + _DISK_HEADROOM_GB 

190 if free_gb < required_gb: 

191 raise RuntimeError( 

192 f"Not enough disk space to download '{model_info.display_name}': " 

193 f"need {required_gb:.1f} GB, have {free_gb:.1f} GB free. " 

194 f"Free up space or choose a smaller model." 

195 ) 

196 

197 pull_with_progress(model_info.ref, console=console) 

198 return model_info.ref 

199 

200 

201def pull_with_progress(model: str, *, console: Console | None = None) -> None: 

202 """Pull a model via model_manager, showing a Rich progress bar.""" 

203 from lilbee.app.services import get_services 

204 from lilbee.catalog.types import ModelSource 

205 

206 if console is None: 

207 console = Console(file=sys.__stderr__ or sys.stderr) 

208 manager = get_services().model_manager 

209 with Progress( 

210 SpinnerColumn(), 

211 TextColumn("{task.description}"), 

212 BarColumn(), 

213 DownloadColumn(), 

214 TextColumn("{task.percentage:>3.0f}%"), 

215 transient=True, 

216 console=console, 

217 ) as progress: 

218 desc = f"Downloading model '{model}'..." 

219 ptask = progress.add_task(desc, total=None) 

220 

221 def _on_bytes(downloaded: int, total: int) -> None: 

222 if total > 0: 

223 progress.update(ptask, total=total, completed=downloaded) 

224 

225 manager.pull(model, ModelSource.NATIVE, on_bytes=_on_bytes) 

226 console.print(f"Model '{model}' ready.") 

227 

228 

229def ensure_chat_model() -> str | None: 

230 """If no chat models are installed, pick and pull one. Returns the pulled ref or None. 

231 

232 Interactive (TTY): show catalog picker with descriptions and sizes. 

233 Non-interactive (CI/pipes): auto-pick recommended model silently. 

234 The caller is responsible for persisting the returned ref via the 

235 settings boundary; this function only handles the pull side. 

236 """ 

237 from lilbee.app.services import get_services 

238 

239 manager = get_services().model_manager 

240 try: 

241 installed = manager.list_installed() 

242 except RuntimeError as exc: 

243 raise RuntimeError(f"Cannot list models: {exc}") from exc 

244 

245 # Filter out the configured embedding model so we only check for chat 

246 # candidates. The embedding ref points at one specific manifest; we 

247 # match it exactly rather than by family stem. 

248 embed_ref = cfg.embedding_model 

249 chat_models = [m for m in installed if m != embed_ref] 

250 if chat_models: 

251 return None 

252 

253 ram_gb = get_system_ram_gb() 

254 free_gb = get_free_disk_gb(cfg.data_dir) 

255 

256 if sys.stdin.isatty(): 

257 model_info = prompt_model_choice(ram_gb) 

258 else: 

259 model_info = pick_default_model(ram_gb) 

260 sys.stderr.write( 

261 f"No chat model found. Auto-installing '{model_info.display_name}' " 

262 f"(detected {ram_gb:.0f} GB RAM)...\n" 

263 ) 

264 

265 return validate_disk_and_pull(model_info, free_gb) 

266 

267 

268def list_installed_models() -> list[str]: 

269 """Return installed chat-task model names. 

270 

271 Sources both the native registry (manifest ``task`` field) and the 

272 SDK backend catalog (classified by name/family). Non-chat roles 

273 (embedding, vision, rerank) are excluded so TUI pickers don't offer 

274 refs that fail pydantic task validation at assignment time. 

275 """ 

276 # circular: modelhub.model_manager.discovery imports modelhub.models at top 

277 from lilbee.modelhub.model_manager import classify_all_remote_models 

278 from lilbee.modelhub.model_manager.discovery import reclassify_by_name 

279 

280 try: 

281 names: list[str] = [] 

282 registry = ModelRegistry(cfg.models_dir) 

283 for manifest in registry.list_installed(): 

284 if reclassify_by_name(manifest.ref, manifest.task) == ModelTask.CHAT: 

285 names.append(manifest.ref) 

286 for remote in classify_all_remote_models(): 

287 if remote.task == ModelTask.CHAT: 

288 names.append(remote.name) 

289 return sorted(set(names)) 

290 except Exception: 

291 log.debug("Failed to list installed models", exc_info=True) 

292 return []