Coverage for src / lilbee / providers / model_ref.py: 100%
56 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Model reference parsing and option translation.
3Single source of truth for classifying model strings and translating
4generation options per provider type. This module must NOT import from
5lilbee.config or lilbee.models to avoid circular imports.
6"""
8from __future__ import annotations
10from dataclasses import dataclass
11from typing import Any
13from lilbee.providers.base import filter_options
14from lilbee.providers.local_servers import (
15 LOCAL_SERVER_KEYS,
16 local_server_for_key,
17 local_server_for_label,
18)
20_API_PROVIDERS = frozenset(
21 {
22 "openrouter",
23 "gemini",
24 "anthropic",
25 "openai",
26 "mistral",
27 "deepseek",
28 }
29)
31# All provider prefixes that route a ref away from the local registry:
32# API providers plus the local OpenAI-compatible servers (ollama, lm_studio).
33PROVIDER_PREFIXES: frozenset[str] = frozenset(_API_PROVIDERS | LOCAL_SERVER_KEYS)
36@dataclass(frozen=True)
37class ProviderModelRef:
38 """Parsed model reference with provider routing information."""
40 raw: str
41 provider: str # "local" or any value in PROVIDER_PREFIXES
42 name: str # provider-specific name with tag normalization applied
44 @property
45 def is_api(self) -> bool:
46 return self.provider in _API_PROVIDERS
48 @property
49 def is_local(self) -> bool:
50 return self.provider == "local"
52 @property
53 def is_remote(self) -> bool:
54 """True if this model routes through a remote SDK (any non-``local`` provider)."""
55 return self.provider != "local"
57 def for_openai_prefix(self) -> str:
58 """Name with its canonical ``provider/model`` prefix (``ollama/llama3.2:1b``)."""
59 spec = local_server_for_key(self.provider)
60 if spec is not None:
61 return spec.qualify(self.name)
62 if self.is_api:
63 return f"{self.provider}/{self.name}"
64 return self.name
66 def for_display(self) -> str:
67 """Human-readable name for UI."""
68 return self.raw
70 @property
71 def needs_api_base(self) -> bool:
72 """True if the SDK needs an explicit api_base (Ollama/local)."""
73 return not self.is_api
76def format_remote_ref(name: str, provider: str) -> str:
77 """Render a remote model as a canonical ``provider/name`` ref.
79 *provider* may be a routing key (``"ollama"``) or a backend display
80 name (``"LM Studio"``); local-server labels are normalised to the
81 routing key so the prefix survives. API providers fall through to
82 their lowercase key unchanged.
83 """
84 spec = local_server_for_label(provider)
85 key = spec.key if spec is not None else provider.lower()
86 return ProviderModelRef(raw=name, provider=key, name=name).for_openai_prefix()
89def parse_model_ref(raw: str) -> ProviderModelRef:
90 """Classify a model string by its prefix and return the routing ref.
92 Native HuggingFace refs are ``<org>/<repo>/<file>.gguf``. Remote
93 providers use prefixes from :data:`PROVIDER_PREFIXES` (the local
94 servers ``ollama/`` and ``lm_studio/`` plus every API provider).
95 """
96 if "/" not in raw:
97 known = ", ".join(f"{p}/" for p in sorted(PROVIDER_PREFIXES))
98 raise ValueError(
99 f"Model ref {raw!r} must be a HuggingFace ref "
100 f"('<org>/<repo>/<filename>.gguf') or carry a known provider prefix ({known})."
101 )
102 prefix, rest = raw.split("/", 1)
103 if prefix in _API_PROVIDERS:
104 return ProviderModelRef(raw=raw, provider=prefix, name=rest)
105 spec = local_server_for_key(prefix)
106 if spec is not None:
107 return ProviderModelRef(raw=raw, provider=spec.key, name=spec.normalize_name(rest))
108 return ProviderModelRef(raw=raw, provider="local", name=raw)
111def translate_options(options: dict[str, Any], ref: ProviderModelRef) -> dict[str, Any]:
112 """Translate generation options for the target provider."""
113 filtered = filter_options(options)
114 if ref.is_api:
115 # API providers use max_tokens, not num_predict
116 if "num_predict" in filtered:
117 filtered["max_tokens"] = filtered.pop("num_predict")
118 # num_ctx is a model-load param, not per-call
119 filtered.pop("num_ctx", None)
120 # top_k kept for local llama.cpp, stripped for API providers: litellm
121 # forwards it (into extra_body for OpenAI-compatible) without erroring,
122 # but hosted APIs ignore it, so dropping it keeps the wire request clean.
123 filtered.pop("top_k", None)
124 return filtered