Coverage for src / lilbee / providers / model_ref.py: 100%
53 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Model reference parsing and option translation.
3Single source of truth for classifying model strings and translating
4generation options per provider type. This module must NOT import from
5lilbee.config or lilbee.models to avoid circular imports.
6"""
8from __future__ import annotations
10from dataclasses import dataclass
11from typing import Any
13from lilbee.providers.base import filter_options
15_API_PROVIDERS = frozenset(
16 {
17 "openrouter",
18 "gemini",
19 "anthropic",
20 "openai",
21 "mistral",
22 "deepseek",
23 }
24)
26# All provider prefixes that route a ref away from the local registry.
27# Includes API providers and ollama (which keeps its own name:tag shape).
28PROVIDER_PREFIXES: frozenset[str] = frozenset(_API_PROVIDERS | {"ollama"})
30OLLAMA_PREFIX = "ollama/"
33@dataclass(frozen=True)
34class ProviderModelRef:
35 """Parsed model reference with provider routing information."""
37 raw: str
38 provider: str # "local", "ollama", or any value in PROVIDER_PREFIXES
39 name: str # provider-specific name with tag normalization applied
41 @property
42 def is_api(self) -> bool:
43 return self.provider in _API_PROVIDERS
45 @property
46 def is_local(self) -> bool:
47 return self.provider == "local"
49 @property
50 def is_remote(self) -> bool:
51 """True if this model must route through a remote SDK (API or Ollama).
53 Remote means "not a locally-loaded GGUF". Both Ollama (HTTP
54 localhost server) and hosted API providers share the same
55 dispatch path; they go through whichever SDK backend is wired
56 up.
57 """
58 return self.provider != "local"
60 def for_openai_prefix(self) -> str:
61 """Name formatted with canonical ``provider/model`` prefix.
63 The prefix convention is the same one used by OpenAI-compatible
64 SDKs: ``openai/gpt-4o``, ``ollama/llama3.2:1b``, etc. Every
65 dispatching SDK accepts this shape.
66 """
67 if self.provider == "ollama":
68 return f"{OLLAMA_PREFIX}{self.name}"
69 if self.is_api:
70 return f"{self.provider}/{self.name}"
71 return self.name
73 def for_display(self) -> str:
74 """Human-readable name for UI."""
75 return self.raw
77 @property
78 def needs_api_base(self) -> bool:
79 """True if the SDK needs an explicit api_base (Ollama/local)."""
80 return not self.is_api
83def format_remote_ref(name: str, provider: str) -> str:
84 """Render a remote model as a canonical ``provider/name`` ref."""
85 return ProviderModelRef(raw=name, provider=provider.lower(), name=name).for_openai_prefix()
88def parse_model_ref(raw: str) -> ProviderModelRef:
89 """Classify a model string by its prefix and return the routing ref.
91 Native HuggingFace refs are ``<org>/<repo>/<file>.gguf``. Remote
92 providers use prefixes from :data:`PROVIDER_PREFIXES` (``ollama/``
93 plus every API provider listed there).
94 """
95 if "/" not in raw:
96 known = ", ".join(f"{p}/" for p in sorted(PROVIDER_PREFIXES))
97 raise ValueError(
98 f"Model ref {raw!r} must be a HuggingFace ref "
99 f"('<org>/<repo>/<filename>.gguf') or carry a known provider prefix ({known})."
100 )
101 prefix, rest = raw.split("/", 1)
102 if prefix in _API_PROVIDERS:
103 return ProviderModelRef(raw=raw, provider=prefix, name=rest)
104 if prefix == "ollama":
105 name = rest if ":" in rest else f"{rest}:latest"
106 return ProviderModelRef(raw=raw, provider="ollama", name=name)
107 return ProviderModelRef(raw=raw, provider="local", name=raw)
110def translate_options(options: dict[str, Any], ref: ProviderModelRef) -> dict[str, Any]:
111 """Translate generation options for the target provider."""
112 filtered = filter_options(options)
113 if ref.is_api:
114 # API providers use max_tokens, not num_predict
115 if "num_predict" in filtered:
116 filtered["max_tokens"] = filtered.pop("num_predict")
117 # num_ctx is a model-load param, not per-call
118 filtered.pop("num_ctx", None)
119 # top_k not supported by most API providers
120 filtered.pop("top_k", None)
121 return filtered