Coverage for src / lilbee / providers / model_ref.py: 100%

56 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Model reference parsing and option translation. 

2 

3Single source of truth for classifying model strings and translating 

4generation options per provider type. This module must NOT import from 

5lilbee.config or lilbee.models to avoid circular imports. 

6""" 

7 

8from __future__ import annotations 

9 

10from dataclasses import dataclass 

11from typing import Any 

12 

13from lilbee.providers.base import filter_options 

14from lilbee.providers.local_servers import ( 

15 LOCAL_SERVER_KEYS, 

16 local_server_for_key, 

17 local_server_for_label, 

18) 

19 

20_API_PROVIDERS = frozenset( 

21 { 

22 "openrouter", 

23 "gemini", 

24 "anthropic", 

25 "openai", 

26 "mistral", 

27 "deepseek", 

28 } 

29) 

30 

31# All provider prefixes that route a ref away from the local registry: 

32# API providers plus the local OpenAI-compatible servers (ollama, lm_studio). 

33PROVIDER_PREFIXES: frozenset[str] = frozenset(_API_PROVIDERS | LOCAL_SERVER_KEYS) 

34 

35 

36@dataclass(frozen=True) 

37class ProviderModelRef: 

38 """Parsed model reference with provider routing information.""" 

39 

40 raw: str 

41 provider: str # "local" or any value in PROVIDER_PREFIXES 

42 name: str # provider-specific name with tag normalization applied 

43 

44 @property 

45 def is_api(self) -> bool: 

46 return self.provider in _API_PROVIDERS 

47 

48 @property 

49 def is_local(self) -> bool: 

50 return self.provider == "local" 

51 

52 @property 

53 def is_remote(self) -> bool: 

54 """True if this model routes through a remote SDK (any non-``local`` provider).""" 

55 return self.provider != "local" 

56 

57 def for_openai_prefix(self) -> str: 

58 """Name with its canonical ``provider/model`` prefix (``ollama/llama3.2:1b``).""" 

59 spec = local_server_for_key(self.provider) 

60 if spec is not None: 

61 return spec.qualify(self.name) 

62 if self.is_api: 

63 return f"{self.provider}/{self.name}" 

64 return self.name 

65 

66 def for_display(self) -> str: 

67 """Human-readable name for UI.""" 

68 return self.raw 

69 

70 @property 

71 def needs_api_base(self) -> bool: 

72 """True if the SDK needs an explicit api_base (Ollama/local).""" 

73 return not self.is_api 

74 

75 

76def format_remote_ref(name: str, provider: str) -> str: 

77 """Render a remote model as a canonical ``provider/name`` ref. 

78 

79 *provider* may be a routing key (``"ollama"``) or a backend display 

80 name (``"LM Studio"``); local-server labels are normalised to the 

81 routing key so the prefix survives. API providers fall through to 

82 their lowercase key unchanged. 

83 """ 

84 spec = local_server_for_label(provider) 

85 key = spec.key if spec is not None else provider.lower() 

86 return ProviderModelRef(raw=name, provider=key, name=name).for_openai_prefix() 

87 

88 

89def parse_model_ref(raw: str) -> ProviderModelRef: 

90 """Classify a model string by its prefix and return the routing ref. 

91 

92 Native HuggingFace refs are ``<org>/<repo>/<file>.gguf``. Remote 

93 providers use prefixes from :data:`PROVIDER_PREFIXES` (the local 

94 servers ``ollama/`` and ``lm_studio/`` plus every API provider). 

95 """ 

96 if "/" not in raw: 

97 known = ", ".join(f"{p}/" for p in sorted(PROVIDER_PREFIXES)) 

98 raise ValueError( 

99 f"Model ref {raw!r} must be a HuggingFace ref " 

100 f"('<org>/<repo>/<filename>.gguf') or carry a known provider prefix ({known})." 

101 ) 

102 prefix, rest = raw.split("/", 1) 

103 if prefix in _API_PROVIDERS: 

104 return ProviderModelRef(raw=raw, provider=prefix, name=rest) 

105 spec = local_server_for_key(prefix) 

106 if spec is not None: 

107 return ProviderModelRef(raw=raw, provider=spec.key, name=spec.normalize_name(rest)) 

108 return ProviderModelRef(raw=raw, provider="local", name=raw) 

109 

110 

111def translate_options(options: dict[str, Any], ref: ProviderModelRef) -> dict[str, Any]: 

112 """Translate generation options for the target provider.""" 

113 filtered = filter_options(options) 

114 if ref.is_api: 

115 # API providers use max_tokens, not num_predict 

116 if "num_predict" in filtered: 

117 filtered["max_tokens"] = filtered.pop("num_predict") 

118 # num_ctx is a model-load param, not per-call 

119 filtered.pop("num_ctx", None) 

120 # top_k kept for local llama.cpp, stripped for API providers: litellm 

121 # forwards it (into extra_body for OpenAI-compatible) without erroring, 

122 # but hosted APIs ignore it, so dropping it keeps the wire request clean. 

123 filtered.pop("top_k", None) 

124 return filtered