Coverage for src / lilbee / modelhub / model_manager / discovery.py: 100%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Remote model discovery and task classification.""" 

2 

3import logging 

4import os 

5 

6import httpx 

7 

8from lilbee.app.services import get_services 

9from lilbee.catalog.types import ModelTask 

10from lilbee.core.config.model import cfg 

11from lilbee.modelhub.model_manager.types import RemoteModel 

12from lilbee.providers.sdk_backend import ( 

13 PROVIDER_KEYS, 

14 detect_backend_name, 

15) 

16 

17log = logging.getLogger(__name__) 

18 

19_EMBEDDING_FAMILIES = frozenset({"bert", "nomic-bert", "e5", "bge"}) 

20_VISION_NAME_PATTERNS = frozenset({"llava", "vision", "moondream", "ocr", "minicpm-v"}) 

21# Reranker detection runs BEFORE embedding detection so ``bge-reranker-*`` 

22# (family "bge" but clearly a reranker) does not get misclassified as 

23# EMBEDDING. ``cross-encoder`` covers the SBERT-style naming convention 

24# for cross-encoder rerankers. 

25_RERANKER_NAME_PATTERNS = frozenset({"reranker", "rerank", "cross-encoder"}) 

26 

27_CLASSIFY_DEFAULT_TIMEOUT_S = 5.0 

28 

29 

30def _classify_remote_task(name: str, family: str) -> ModelTask: 

31 """Classify a remote model as chat, embedding, vision, or rerank. 

32 

33 Reranker detection runs first so ``bge-reranker-base`` (family 

34 ``bge``) does not get dragged into the embedding bucket by the 

35 family check. After reranker: embedding by family tag, vision by 

36 name pattern, else chat. 

37 """ 

38 name_lower = name.lower() 

39 if any(rp in name_lower for rp in _RERANKER_NAME_PATTERNS): 

40 return ModelTask.RERANK 

41 family_lower = family.lower() 

42 if any(ef in family_lower for ef in _EMBEDDING_FAMILIES): 

43 return ModelTask.EMBEDDING 

44 if any(vp in name_lower for vp in _VISION_NAME_PATTERNS): 

45 return ModelTask.VISION 

46 return ModelTask.CHAT 

47 

48 

49def reclassify_by_name(ref: str, declared_task: str) -> str: 

50 """Override declared_task to RERANK / VISION when ref names a known role. 

51 

52 Defends against pre-fix manifests that stored ``task="chat"`` for 

53 models whose ref obviously identifies them as rerankers (e.g. 

54 ``bge-reranker-*``) or vision loaders. The model bar uses this so a 

55 historical mis-tag does not surface a reranker in the chat picker. 

56 """ 

57 name_lower = ref.lower() 

58 if any(rp in name_lower for rp in _RERANKER_NAME_PATTERNS): 

59 return ModelTask.RERANK 

60 if any(vp in name_lower for vp in _VISION_NAME_PATTERNS): 

61 return ModelTask.VISION 

62 return declared_task 

63 

64 

65def classify_remote_models( 

66 base_url: str = "http://localhost:11434", 

67 *, 

68 timeout: float = _CLASSIFY_DEFAULT_TIMEOUT_S, 

69) -> list[RemoteModel]: 

70 """Discover and classify all models from the SDK backend by task. 

71 

72 Uses /api/tags family metadata for embedding detection and name 

73 patterns for reranker and vision detection. Returns an empty list 

74 on any error (including timeout) so callers in read-only code paths 

75 can stay responsive when the backend is down. 

76 """ 

77 try: 

78 resp = httpx.get(f"{base_url}/api/tags", timeout=timeout) 

79 resp.raise_for_status() 

80 raw_models = resp.json().get("models", []) 

81 except Exception: 

82 log.debug("Failed to classify remote models", exc_info=True) 

83 return [] 

84 

85 provider = detect_backend_name(base_url) 

86 result: list[RemoteModel] = [] 

87 for model in raw_models: 

88 name = model.get("name", "") 

89 details = model.get("details", {}) 

90 family = details.get("family", "") 

91 param_size = details.get("parameter_size", "") 

92 task = _classify_remote_task(name, family) 

93 result.append( 

94 RemoteModel( 

95 name=name, task=task, family=family, parameter_size=param_size, provider=provider 

96 ) 

97 ) 

98 return result 

99 

100 

101def _has_provider_key(cfg_field: str, env_var: str) -> bool: 

102 """Return True if a usable API key exists via env var or lilbee config.""" 

103 if os.environ.get(env_var): 

104 return True 

105 return bool(getattr(cfg, cfg_field, "")) 

106 

107 

108def discover_api_models() -> dict[str, list[RemoteModel]]: 

109 """Return frontier chat models grouped by provider. 

110 

111 Returns whatever the active provider's backend exposes for each 

112 configured API key, no curation. Short-circuits before touching 

113 the SDK when no keys are present. 

114 """ 

115 active = [ 

116 (prov, cfg_f, env, label) 

117 for prov, cfg_f, env, label in PROVIDER_KEYS 

118 if _has_provider_key(cfg_f, env) 

119 ] 

120 if not active: 

121 return {} 

122 

123 provider = get_services().provider 

124 

125 result: dict[str, list[RemoteModel]] = {} 

126 for prov, _cfg_field, _env_var, display_name in active: 

127 chat_models = [ 

128 RemoteModel( 

129 name=model_name, 

130 task=ModelTask.CHAT, 

131 family="", 

132 parameter_size="", 

133 provider=display_name, 

134 ) 

135 for model_name in provider.list_chat_models(prov) 

136 ] 

137 if chat_models: 

138 result[display_name] = chat_models 

139 return result 

140 

141 

142def detect_remote_embedding_models(base_url: str = "http://localhost:11434") -> list[str]: 

143 """Return names of models classified as embedding from the SDK backend.""" 

144 return [m.name for m in classify_remote_models(base_url) if m.task == ModelTask.EMBEDDING]