Coverage for src / lilbee / catalog / query.py: 100%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Catalog filtering, sorting, lookup, and ad-hoc HF resolution.""" 

2 

3import functools 

4from typing import Any, NamedTuple 

5 

6from huggingface_hub.utils import HFValidationError, validate_repo_id 

7 

8from lilbee.app.services import get_services 

9from lilbee.catalog.featured import FEATURED_ALL 

10from lilbee.catalog.models import CatalogModel, CatalogResult 

11from lilbee.catalog.refs import format_native_gguf_ref, hf_repo_from_ref 

12from lilbee.catalog.types import ModelTask 

13 

14 

15def _search_blob(m: CatalogModel) -> str: 

16 """Lowercased join of searchable fields on a catalog row. 

17 

18 Null char joins the fields so a search term never straddles them. 

19 """ 

20 return f"{m.display_name}\0{m.hf_repo}\0{m.description}".lower() 

21 

22 

23_SIZE_RANGES: dict[str, tuple[float, float]] = { 

24 "small": (0.0, 3.0), 

25 "medium": (3.0, 10.0), 

26 "large": (10.0, float("inf")), 

27} 

28 

29# A native GGUF ref of the form ``<owner>/<repo>/<file>.gguf`` has at least 

30# two ``/`` separators; one-slash refs are bare repo IDs. 

31_NATIVE_GGUF_REF_MIN_SLASHES = 2 

32 

33 

34def get_catalog( 

35 task: ModelTask | None = None, 

36 *, 

37 search: str = "", 

38 size: str | None = None, 

39 installed: bool | None = None, 

40 featured: bool | None = None, 

41 sort: str = "featured", 

42 limit: int = 20, 

43 offset: int = 0, 

44 model_manager: Any = None, 

45) -> CatalogResult: 

46 """Get paginated, filtered catalog of models.""" 

47 # Featured models only on the first page 

48 all_models = list(FEATURED_ALL) if offset == 0 else [] 

49 hf_has_more = False 

50 

51 # Optionally fetch from HF API 

52 if not featured: 

53 hf_task, hf_library = _task_to_pipeline(task) 

54 hf_page = get_services().hf_client.fetch_models( 

55 pipeline_tag=hf_task, 

56 limit=limit, 

57 offset=offset, 

58 library=hf_library, 

59 search=search, 

60 ) 

61 hf_has_more = hf_page.has_more 

62 # Deduplicate: skip HF models whose repo matches a featured model 

63 featured_repos = {m.hf_repo for m in FEATURED_ALL} 

64 hf_models = [m for m in hf_page.models if m.hf_repo not in featured_repos] 

65 all_models.extend(hf_models) 

66 

67 # Filter by task 

68 if task: 

69 all_models = [m for m in all_models if m.task == task] 

70 

71 # Filter by search. Single join+lower per model per keystroke instead 

72 # of four separate lowers + substring checks; the no-match path 

73 # (the common case) runs four times fewer ``str.lower()`` calls. 

74 if search: 

75 search_lower = search.lower() 

76 all_models = [m for m in all_models if search_lower in _search_blob(m)] 

77 

78 # Filter by size 

79 if size and size in _SIZE_RANGES: 

80 lo, hi = _SIZE_RANGES[size] 

81 all_models = [m for m in all_models if lo <= m.size_gb < hi] 

82 

83 # A repo is "installed" if any of its quants has a manifest. 

84 if installed is not None and model_manager is not None: 

85 installed_repos = {hf_repo_from_ref(ref) for ref in _get_installed_models(model_manager)} 

86 if installed: 

87 all_models = [m for m in all_models if m.hf_repo in installed_repos] 

88 else: 

89 all_models = [m for m in all_models if m.hf_repo not in installed_repos] 

90 

91 # Filter by featured status 

92 if featured is not None: 

93 all_models = [m for m in all_models if m.featured == featured] 

94 

95 # Sort 

96 all_models = _sort_models(all_models, sort) 

97 

98 total = len(all_models) 

99 

100 # When HF API pagination is active (offset passed to API), skip local slicing 

101 # to avoid double-applying the offset. Only slice for featured-only requests. 

102 paginated = all_models[offset : offset + limit] if featured else all_models[:limit] 

103 

104 return CatalogResult( 

105 total=total, limit=limit, offset=offset, models=paginated, has_more=hf_has_more 

106 ) 

107 

108 

109def _task_to_pipeline(task: ModelTask | None) -> tuple[str, str | None]: 

110 """Map task name to HuggingFace pipeline tag and library filter.""" 

111 mapping: dict[ModelTask, tuple[str, str | None]] = { 

112 ModelTask.CHAT: ("text-generation", None), 

113 ModelTask.EMBEDDING: ("feature-extraction", "sentence-transformers"), 

114 ModelTask.VISION: ("image-text-to-text", None), 

115 ModelTask.RERANK: ("text-classification", None), 

116 } 

117 return mapping.get(task or ModelTask.CHAT, ("text-generation", None)) 

118 

119 

120_PIPELINE_TO_TASK: dict[str, ModelTask] = { 

121 "text-generation": ModelTask.CHAT, 

122 "feature-extraction": ModelTask.EMBEDDING, 

123 "sentence-similarity": ModelTask.EMBEDDING, 

124 "image-text-to-text": ModelTask.VISION, 

125 "image-to-text": ModelTask.VISION, 

126 "text-classification": ModelTask.RERANK, 

127 "text-ranking": ModelTask.RERANK, 

128} 

129 

130 

131def pipeline_to_task(pipeline_tag: str) -> ModelTask: 

132 """Map HuggingFace pipeline tag to internal task name.""" 

133 return _PIPELINE_TO_TASK.get(pipeline_tag, ModelTask.CHAT) 

134 

135 

136def _get_installed_models(model_manager: Any) -> set[str]: 

137 """Get set of installed model names from model_manager.""" 

138 try: 

139 return set(model_manager.list_installed()) 

140 except Exception: 

141 return set() 

142 

143 

144_SORT_KEYS: dict[str, tuple] = { 

145 "downloads": (lambda m: m.downloads, True), 

146 "name": (lambda m: m.display_name.lower(), False), 

147 "size_asc": (lambda m: m.size_gb, False), 

148 "size_desc": (lambda m: m.size_gb, True), 

149 "featured": (lambda m: (not m.featured, -m.downloads), False), 

150} 

151 

152 

153def _sort_models(models: list[CatalogModel], sort: str) -> list[CatalogModel]: 

154 """Sort models according to the specified sort order.""" 

155 key_fn, reverse = _SORT_KEYS.get(sort, _SORT_KEYS["featured"]) 

156 return sorted(models, key=key_fn, reverse=reverse) 

157 

158 

159class CatalogIndex(NamedTuple): 

160 """Case-insensitive lookup indexes for find_catalog_entry.""" 

161 

162 by_hf_repo: dict[str, CatalogModel] 

163 by_full_ref: dict[str, CatalogModel] # repo + concrete filename 

164 

165 

166@functools.cache 

167def _build_catalog_index() -> CatalogIndex: 

168 """Build case-insensitive lookup indexes for find_catalog_entry.""" 

169 by_hf_repo: dict[str, CatalogModel] = {} 

170 by_full_ref: dict[str, CatalogModel] = {} 

171 for m in FEATURED_ALL: 

172 by_hf_repo.setdefault(m.hf_repo.lower(), m) 

173 if "*" not in m.gguf_filename: 

174 by_full_ref[format_native_gguf_ref(m.hf_repo, m.gguf_filename).lower()] = m 

175 return CatalogIndex(by_hf_repo, by_full_ref) 

176 

177 

178def find_catalog_entry(query: str) -> CatalogModel | None: 

179 """Find a featured model by hf_repo or by ``hf_repo/filename`` ref. 

180 

181 Tries the query as-is, then strips a trailing ``/<filename>.gguf``, 

182 then strips a leading non-HF provider prefix (``ollama/``, etc.). 

183 Case-insensitive; returns ``None`` on miss. 

184 """ 

185 if not query: 

186 return None 

187 idx = _build_catalog_index() 

188 q = query.lower() 

189 candidates = [q] 

190 # Strip the filename for ``<repo>/<filename>.gguf`` queries so the 

191 # bare-repo index catches featured entries whose gguf_filename is a 

192 # glob (most are). 

193 if q.endswith(".gguf") and q.count("/") >= _NATIVE_GGUF_REF_MIN_SLASHES: 

194 candidates.append(q.rsplit("/", 1)[0]) 

195 if "/" in q: 

196 prefix, rest = q.split("/", 1) 

197 hf_owners = {r.split("/", 1)[0] for r in idx.by_hf_repo if "/" in r} 

198 if prefix not in hf_owners: 

199 candidates.append(rest) 

200 for c in candidates: 

201 hit = idx.by_full_ref.get(c) or idx.by_hf_repo.get(c) 

202 if hit is not None: 

203 return hit 

204 return None 

205 

206 

207def is_rerank_ref(model_ref: str) -> bool: 

208 """Return True iff *model_ref* resolves to a rerank catalog entry.""" 

209 if not model_ref: 

210 return False 

211 entry = find_catalog_entry(model_ref) 

212 return entry is not None and entry.task == ModelTask.RERANK 

213 

214 

215def _is_hf_repo_id(value: str) -> bool: 

216 """True if *value* is a well-formed ``owner/name`` HuggingFace repo id.""" 

217 if "/" not in value: 

218 return False 

219 try: 

220 validate_repo_id(value) 

221 except HFValidationError: 

222 return False 

223 return True 

224 

225 

226def build_adhoc_entry(hf_repo: str, *, task: ModelTask = ModelTask.CHAT) -> CatalogModel: 

227 """Minimal CatalogModel for a non-featured HuggingFace GGUF repo.""" 

228 return CatalogModel( 

229 hf_repo=hf_repo, 

230 gguf_filename="*.gguf", 

231 size_gb=0.0, 

232 min_ram_gb=2.0, 

233 description="", 

234 featured=False, 

235 downloads=0, 

236 task=task, 

237 ) 

238 

239 

240def resolve_pull_target(model: str) -> CatalogModel | None: 

241 """Resolve *model* to a pullable entry: featured first, then ad-hoc HF.""" 

242 featured = find_catalog_entry(model) 

243 if featured is not None: 

244 return featured 

245 return build_adhoc_entry(model) if _is_hf_repo_id(model) else None