Coverage for src / lilbee / modelhub / models.py: 100%

162 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""RAM detection, model selection, interactive picker, and auto-install for chat models.""" 

2 

3import functools 

4import logging 

5import os 

6import shutil 

7import sys 

8from dataclasses import dataclass 

9from pathlib import Path 

10 

11from rich.console import Console 

12from rich.progress import BarColumn, DownloadColumn, Progress, SpinnerColumn, TextColumn 

13from rich.table import Table 

14 

15from lilbee.catalog.types import ModelTask 

16from lilbee.core import settings 

17from lilbee.core.config.model import cfg 

18from lilbee.modelhub.registry import ModelRegistry 

19 

20log = logging.getLogger(__name__) 

21 

22FEATURED_STAR = "★" 

23 

24# Extra headroom required beyond model size (GB) 

25_DISK_HEADROOM_GB = 2 

26 

27MODELS_BROWSE_URL = "https://huggingface.co/models?library=gguf&sort=trending" 

28 

29 

30@dataclass(frozen=True) 

31class ModelInfo: 

32 """A curated chat model with metadata for the picker UI.""" 

33 

34 ref: str # canonical HF ref (e.g. "Qwen/Qwen3-0.6B-GGUF") 

35 display_name: str # UI label (e.g. "Qwen3 0.6B") 

36 size_gb: float 

37 min_ram_gb: float 

38 description: str 

39 

40 

41def _catalog_from_featured(featured: tuple) -> tuple[ModelInfo, ...]: 

42 """Build a ModelInfo tuple from ``lilbee.catalog``'s CatalogModel entries.""" 

43 return tuple( 

44 ModelInfo(m.ref, m.display_name, m.size_gb, m.min_ram_gb, m.description) for m in featured 

45 ) 

46 

47 

48@functools.cache 

49def _get_model_catalog() -> tuple[ModelInfo, ...]: 

50 from lilbee.catalog import FEATURED_CHAT 

51 

52 return _catalog_from_featured(FEATURED_CHAT) 

53 

54 

55def __getattr__(name: str) -> tuple[ModelInfo, ...]: 

56 if name == "MODEL_CATALOG": 

57 return _get_model_catalog() 

58 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") 

59 

60 

61def get_system_ram_gb() -> float: 

62 """Return total system RAM in GB. Falls back to 8.0 if detection fails.""" 

63 try: 

64 if sys.platform == "win32": 

65 import ctypes 

66 

67 class _MEMORYSTATUSEX(ctypes.Structure): 

68 _fields_ = [ 

69 ("dwLength", ctypes.c_ulong), 

70 ("dwMemoryLoad", ctypes.c_ulong), 

71 ("ullTotalPhys", ctypes.c_ulonglong), 

72 ("ullAvailPhys", ctypes.c_ulonglong), 

73 ("ullTotalPageFile", ctypes.c_ulonglong), 

74 ("ullAvailPageFile", ctypes.c_ulonglong), 

75 ("ullTotalVirtual", ctypes.c_ulonglong), 

76 ("ullAvailVirtual", ctypes.c_ulonglong), 

77 ("ullAvailExtendedVirtual", ctypes.c_ulonglong), 

78 ] 

79 

80 stat = _MEMORYSTATUSEX() 

81 stat.dwLength = ctypes.sizeof(stat) 

82 ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(stat)) # type: ignore[attr-defined] 

83 return stat.ullTotalPhys / (1024**3) 

84 pages = os.sysconf("SC_PHYS_PAGES") 

85 page_size = os.sysconf("SC_PAGE_SIZE") 

86 return (pages * page_size) / (1024**3) 

87 except (OSError, AttributeError, ValueError): 

88 log.debug("RAM detection failed, falling back to 8.0 GB") 

89 return 8.0 

90 

91 

92def get_free_disk_gb(path: Path) -> float: 

93 """Return free disk space in GB for the filesystem containing *path*.""" 

94 check_path = path if path.exists() else path.parent 

95 while not check_path.exists(): 

96 check_path = check_path.parent 

97 usage = shutil.disk_usage(check_path) 

98 return usage.free / (1024**3) 

99 

100 

101def pick_default_model(ram_gb: float) -> ModelInfo: 

102 """Choose the largest catalog model that fits in *ram_gb*.""" 

103 best = _get_model_catalog()[0] 

104 for model in _get_model_catalog(): 

105 if model.min_ram_gb <= ram_gb: 

106 best = model 

107 return best 

108 

109 

110def _model_download_size_gb(model: str) -> float: 

111 """Estimated download size in GiB for an HF model ref.""" 

112 catalog_sizes = {m.ref: m.size_gb for m in _get_model_catalog()} 

113 fallback = 5.0 # reasonable default for unknown models 

114 return catalog_sizes.get(model, fallback) 

115 

116 

117def display_model_picker( 

118 ram_gb: float, free_disk_gb: float, *, console: Console | None = None 

119) -> ModelInfo: 

120 """Show a Rich table of catalog models and return the recommended model.""" 

121 console = console or Console(stderr=True) 

122 recommended = pick_default_model(ram_gb) 

123 

124 table = Table(title="Available Models", show_lines=False) 

125 table.add_column("#", justify="right", style="bold") 

126 table.add_column("Model", style="cyan") 

127 table.add_column("Size", justify="right") 

128 table.add_column("Description") 

129 

130 for idx, model in enumerate(_get_model_catalog(), 1): 

131 num_str = str(idx) 

132 label = model.display_name 

133 size_str = f"{model.size_gb:.1f} GB" 

134 desc = model.description 

135 

136 is_recommended = model == recommended 

137 disk_too_small = free_disk_gb < model.size_gb + _DISK_HEADROOM_GB 

138 

139 if is_recommended: 

140 label = f"[bold]{label} ★[/bold]" 

141 desc = f"[bold]{desc}[/bold]" 

142 num_str = f"[bold]{num_str}[/bold]" 

143 

144 if disk_too_small: 

145 size_str = f"[red]{model.size_gb:.1f} GB[/red]" 

146 

147 table.add_row(num_str, label, size_str, desc) 

148 

149 console.print() 

150 console.print("[bold]No chat model found.[/bold] Pick one to download:\n") 

151 console.print(table) 

152 console.print(f"\n System: {ram_gb:.0f} GB RAM, {free_disk_gb:.1f} GB free disk") 

153 console.print(f" {FEATURED_STAR} = recommended for your system") 

154 console.print(f" Browse more models at {MODELS_BROWSE_URL}\n") 

155 

156 return recommended 

157 

158 

159def prompt_model_choice(ram_gb: float) -> ModelInfo: 

160 """Prompt the user to pick a model by number. Returns the chosen ModelInfo.""" 

161 free_disk_gb = get_free_disk_gb(cfg.data_dir) 

162 recommended = display_model_picker(ram_gb, free_disk_gb) 

163 default_idx = list(_get_model_catalog()).index(recommended) + 1 

164 

165 while True: 

166 try: 

167 raw = input(f"Choice [{default_idx}]: ").strip() 

168 except (EOFError, KeyboardInterrupt): 

169 return recommended 

170 

171 if not raw: 

172 return recommended 

173 

174 try: 

175 choice = int(raw) 

176 except ValueError: 

177 sys.stderr.write(f"Enter a number 1-{len(_get_model_catalog())}.\n") 

178 continue 

179 

180 if 1 <= choice <= len(_get_model_catalog()): 

181 return _get_model_catalog()[choice - 1] 

182 

183 sys.stderr.write(f"Enter a number 1-{len(_get_model_catalog())}.\n") 

184 

185 

186def validate_disk_and_pull( 

187 model_info: ModelInfo, free_gb: float, *, console: Console | None = None 

188) -> None: 

189 """Check disk space, pull the model, and persist the choice.""" 

190 required_gb = model_info.size_gb + _DISK_HEADROOM_GB 

191 if free_gb < required_gb: 

192 raise RuntimeError( 

193 f"Not enough disk space to download '{model_info.display_name}': " 

194 f"need {required_gb:.1f} GB, have {free_gb:.1f} GB free. " 

195 f"Free up space or choose a smaller model." 

196 ) 

197 

198 pull_with_progress(model_info.ref, console=console) 

199 cfg.chat_model = model_info.ref 

200 settings.set_value(cfg.data_root, "chat_model", model_info.ref) 

201 

202 

203def pull_with_progress(model: str, *, console: Console | None = None) -> None: 

204 """Pull a model via model_manager, showing a Rich progress bar.""" 

205 from lilbee.app.services import get_services 

206 from lilbee.catalog.types import ModelSource 

207 

208 if console is None: 

209 console = Console(file=sys.__stderr__ or sys.stderr) 

210 manager = get_services().model_manager 

211 with Progress( 

212 SpinnerColumn(), 

213 TextColumn("{task.description}"), 

214 BarColumn(), 

215 DownloadColumn(), 

216 TextColumn("{task.percentage:>3.0f}%"), 

217 transient=True, 

218 console=console, 

219 ) as progress: 

220 desc = f"Downloading model '{model}'..." 

221 ptask = progress.add_task(desc, total=None) 

222 

223 def _on_bytes(downloaded: int, total: int) -> None: 

224 if total > 0: 

225 progress.update(ptask, total=total, completed=downloaded) 

226 

227 manager.pull(model, ModelSource.NATIVE, on_bytes=_on_bytes) 

228 console.print(f"Model '{model}' ready.") 

229 

230 

231def ensure_chat_model() -> None: 

232 """If no chat models are installed, pick and pull one. 

233 Interactive (TTY): show catalog picker with descriptions and sizes. 

234 Non-interactive (CI/pipes): auto-pick recommended model silently. 

235 Persists the chosen model in config.toml so it becomes the default. 

236 """ 

237 from lilbee.app.services import get_services 

238 

239 manager = get_services().model_manager 

240 try: 

241 installed = manager.list_installed() 

242 except RuntimeError as exc: 

243 raise RuntimeError(f"Cannot list models: {exc}") from exc 

244 

245 # Filter out the configured embedding model so we only check for chat 

246 # candidates. The embedding ref points at one specific manifest; we 

247 # match it exactly rather than by family stem. 

248 embed_ref = cfg.embedding_model 

249 chat_models = [m for m in installed if m != embed_ref] 

250 if chat_models: 

251 return 

252 

253 ram_gb = get_system_ram_gb() 

254 free_gb = get_free_disk_gb(cfg.data_dir) 

255 

256 if sys.stdin.isatty(): 

257 model_info = prompt_model_choice(ram_gb) 

258 else: 

259 model_info = pick_default_model(ram_gb) 

260 sys.stderr.write( 

261 f"No chat model found. Auto-installing '{model_info.display_name}' " 

262 f"(detected {ram_gb:.0f} GB RAM)...\n" 

263 ) 

264 

265 validate_disk_and_pull(model_info, free_gb) 

266 

267 

268def list_installed_models() -> list[str]: 

269 """Return installed chat-task model names. 

270 

271 Sources both the native registry (manifest ``task`` field) and the 

272 SDK backend catalog (classified by name/family). Non-chat roles 

273 (embedding, vision, rerank) are excluded so TUI pickers don't offer 

274 refs that fail pydantic task validation at assignment time. 

275 """ 

276 # circular: modelhub.model_manager.discovery imports modelhub.models at top 

277 from lilbee.modelhub.model_manager import classify_remote_models 

278 from lilbee.modelhub.model_manager.discovery import reclassify_by_name 

279 

280 try: 

281 names: list[str] = [] 

282 registry = ModelRegistry(cfg.models_dir) 

283 for manifest in registry.list_installed(): 

284 if reclassify_by_name(manifest.ref, manifest.task) == ModelTask.CHAT: 

285 names.append(manifest.ref) 

286 for remote in classify_remote_models(cfg.remote_base_url): 

287 if remote.task == ModelTask.CHAT: 

288 names.append(remote.name) 

289 return sorted(set(names)) 

290 except Exception: 

291 log.debug("Failed to list installed models", exc_info=True) 

292 return []