Coverage for src / lilbee / providers / model_ref.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Model reference parsing and option translation. 

2 

3Single source of truth for classifying model strings and translating 

4generation options per provider type. This module must NOT import from 

5lilbee.config or lilbee.models to avoid circular imports. 

6""" 

7 

8from __future__ import annotations 

9 

10from dataclasses import dataclass 

11from typing import Any 

12 

13from lilbee.providers.base import filter_options 

14 

15_API_PROVIDERS = frozenset( 

16 { 

17 "openrouter", 

18 "gemini", 

19 "anthropic", 

20 "openai", 

21 "mistral", 

22 "deepseek", 

23 } 

24) 

25 

26# All provider prefixes that route a ref away from the local registry. 

27# Includes API providers and ollama (which keeps its own name:tag shape). 

28PROVIDER_PREFIXES: frozenset[str] = frozenset(_API_PROVIDERS | {"ollama"}) 

29 

30OLLAMA_PREFIX = "ollama/" 

31 

32 

33@dataclass(frozen=True) 

34class ProviderModelRef: 

35 """Parsed model reference with provider routing information.""" 

36 

37 raw: str 

38 provider: str # "local", "ollama", or any value in PROVIDER_PREFIXES 

39 name: str # provider-specific name with tag normalization applied 

40 

41 @property 

42 def is_api(self) -> bool: 

43 return self.provider in _API_PROVIDERS 

44 

45 @property 

46 def is_local(self) -> bool: 

47 return self.provider == "local" 

48 

49 @property 

50 def is_remote(self) -> bool: 

51 """True if this model must route through a remote SDK (API or Ollama). 

52 

53 Remote means "not a locally-loaded GGUF". Both Ollama (HTTP 

54 localhost server) and hosted API providers share the same 

55 dispatch path; they go through whichever SDK backend is wired 

56 up. 

57 """ 

58 return self.provider != "local" 

59 

60 def for_openai_prefix(self) -> str: 

61 """Name formatted with canonical ``provider/model`` prefix. 

62 

63 The prefix convention is the same one used by OpenAI-compatible 

64 SDKs: ``openai/gpt-4o``, ``ollama/llama3.2:1b``, etc. Every 

65 dispatching SDK accepts this shape. 

66 """ 

67 if self.provider == "ollama": 

68 return f"{OLLAMA_PREFIX}{self.name}" 

69 if self.is_api: 

70 return f"{self.provider}/{self.name}" 

71 return self.name 

72 

73 def for_display(self) -> str: 

74 """Human-readable name for UI.""" 

75 return self.raw 

76 

77 @property 

78 def needs_api_base(self) -> bool: 

79 """True if the SDK needs an explicit api_base (Ollama/local).""" 

80 return not self.is_api 

81 

82 

83def format_remote_ref(name: str, provider: str) -> str: 

84 """Render a remote model as a canonical ``provider/name`` ref.""" 

85 return ProviderModelRef(raw=name, provider=provider.lower(), name=name).for_openai_prefix() 

86 

87 

88def parse_model_ref(raw: str) -> ProviderModelRef: 

89 """Classify a model string by its prefix and return the routing ref. 

90 

91 Native HuggingFace refs are ``<org>/<repo>/<file>.gguf``. Remote 

92 providers use prefixes from :data:`PROVIDER_PREFIXES` (``ollama/`` 

93 plus every API provider listed there). 

94 """ 

95 if "/" not in raw: 

96 known = ", ".join(f"{p}/" for p in sorted(PROVIDER_PREFIXES)) 

97 raise ValueError( 

98 f"Model ref {raw!r} must be a HuggingFace ref " 

99 f"('<org>/<repo>/<filename>.gguf') or carry a known provider prefix ({known})." 

100 ) 

101 prefix, rest = raw.split("/", 1) 

102 if prefix in _API_PROVIDERS: 

103 return ProviderModelRef(raw=raw, provider=prefix, name=rest) 

104 if prefix == "ollama": 

105 name = rest if ":" in rest else f"{rest}:latest" 

106 return ProviderModelRef(raw=raw, provider="ollama", name=name) 

107 return ProviderModelRef(raw=raw, provider="local", name=raw) 

108 

109 

110def translate_options(options: dict[str, Any], ref: ProviderModelRef) -> dict[str, Any]: 

111 """Translate generation options for the target provider.""" 

112 filtered = filter_options(options) 

113 if ref.is_api: 

114 # API providers use max_tokens, not num_predict 

115 if "num_predict" in filtered: 

116 filtered["max_tokens"] = filtered.pop("num_predict") 

117 # num_ctx is a model-load param, not per-call 

118 filtered.pop("num_ctx", None) 

119 # top_k not supported by most API providers 

120 filtered.pop("top_k", None) 

121 return filtered