Coverage for src / lilbee / providers / model_defaults.py: 100%
67 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Per-model default generation settings.
3Parses and caches generation parameters from key-value parameter text
4or GGUF file metadata so that model-specific defaults (temperature, num_ctx,
5etc.) are applied automatically when switching models.
6"""
8from __future__ import annotations
10import contextlib
11import logging
12from dataclasses import dataclass, fields
13from typing import Any
15log = logging.getLogger(__name__)
17# Parameter keys we recognise and their target types
18_KNOWN_PARAM_TYPES: dict[str, type] = {
19 "temperature": float,
20 "top_p": float,
21 "top_k": int,
22 "repeat_penalty": float,
23 "num_ctx": int,
24 "max_tokens": int,
25 "max_reasoning_chars": int,
26}
28# GGUF metadata keys mapped to ModelDefaults field names
29_GGUF_KEY_MAP: dict[str, str] = {
30 "general.temperature": "temperature",
31 "general.top_p": "top_p",
32 "general.top_k": "top_k",
33 "general.repeat_penalty": "repeat_penalty",
34 "general.max_reasoning_chars": "max_reasoning_chars",
35}
37# ``str.split(None, 1)`` returns at most this many tokens for valid ``key value`` lines.
38_KV_PARTS = 2
41@dataclass(frozen=True)
42class ModelDefaults:
43 """Frozen snapshot of a model's default generation parameters."""
45 temperature: float | None = None
46 top_p: float | None = None
47 top_k: int | None = None
48 repeat_penalty: float | None = None
49 num_ctx: int | None = None
50 max_tokens: int | None = None
51 max_reasoning_chars: int | None = None
54class _DefaultsCache:
55 """Encapsulates the per-model defaults cache (no module-level mutable global)."""
57 def __init__(self) -> None:
58 self._data: dict[str, ModelDefaults] = {}
60 def get(self, model_name: str) -> ModelDefaults | None:
61 return self._data.get(model_name)
63 def set(self, model_name: str, defaults: ModelDefaults) -> None:
64 self._data[model_name] = defaults
66 def clear(self) -> None:
67 self._data.clear()
70_defaults_cache = _DefaultsCache()
72# Public API: preserves existing call sites.
73get_defaults = _defaults_cache.get
74set_defaults = _defaults_cache.set
75clear_cache = _defaults_cache.clear
78def parse_kv_parameters(text: str) -> ModelDefaults:
79 """Parse multiline ``key value`` parameter format.
80 Example input::
82 temperature 0.7
83 top_p 0.9
84 num_ctx 4096
85 stop <|im_end|>
87 Unknown keys (like ``stop``) are silently skipped.
88 """
89 values: dict[str, Any] = {}
90 for line in text.splitlines():
91 line = line.strip()
92 if not line:
93 continue
94 parts = line.split(None, 1)
95 if len(parts) != _KV_PARTS:
96 continue
97 key, raw_value = parts
98 if key not in _KNOWN_PARAM_TYPES:
99 continue
100 try:
101 values[key] = _KNOWN_PARAM_TYPES[key](raw_value)
102 except (ValueError, TypeError):
103 log.debug("Skipping unparseable param %s=%r", key, raw_value)
104 return ModelDefaults(**values)
107def read_gguf_defaults(metadata: dict[str, str]) -> ModelDefaults:
108 """Extract generation defaults from a GGUF metadata dict.
109 Looks for keys like ``general.temperature``, ``context_length`` (via the
110 architecture-prefixed key already resolved by the caller into
111 ``context_length``).
112 """
113 values: dict[str, Any] = {}
114 for gguf_key, field_name in _GGUF_KEY_MAP.items():
115 if gguf_key in metadata:
116 try:
117 target_type = _field_type(field_name)
118 values[field_name] = target_type(metadata[gguf_key])
119 except (ValueError, TypeError):
120 log.debug("Skipping unparseable GGUF key %s=%r", gguf_key, metadata[gguf_key])
121 if "context_length" in metadata:
122 with contextlib.suppress(ValueError, TypeError):
123 ctx = int(metadata["context_length"])
124 if ctx > 0:
125 values["num_ctx"] = ctx
126 return ModelDefaults(**values)
129def _field_type(field_name: str) -> type:
130 """Return the base type for a ModelDefaults field."""
131 for f in fields(ModelDefaults):
132 if f.name == field_name:
133 return int if "int" in str(f.type) else float
134 return float # pragma: no cover