Coverage for src / lilbee / providers / model_defaults.py: 100%

67 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Per-model default generation settings. 

2 

3Parses and caches generation parameters from key-value parameter text 

4or GGUF file metadata so that model-specific defaults (temperature, num_ctx, 

5etc.) are applied automatically when switching models. 

6""" 

7 

8from __future__ import annotations 

9 

10import contextlib 

11import logging 

12from dataclasses import dataclass, fields 

13from typing import Any 

14 

15log = logging.getLogger(__name__) 

16 

17# Parameter keys we recognise and their target types 

18_KNOWN_PARAM_TYPES: dict[str, type] = { 

19 "temperature": float, 

20 "top_p": float, 

21 "top_k": int, 

22 "repeat_penalty": float, 

23 "num_ctx": int, 

24 "max_tokens": int, 

25 "max_reasoning_chars": int, 

26} 

27 

28# GGUF metadata keys mapped to ModelDefaults field names 

29_GGUF_KEY_MAP: dict[str, str] = { 

30 "general.temperature": "temperature", 

31 "general.top_p": "top_p", 

32 "general.top_k": "top_k", 

33 "general.repeat_penalty": "repeat_penalty", 

34 "general.max_reasoning_chars": "max_reasoning_chars", 

35} 

36 

37# ``str.split(None, 1)`` returns at most this many tokens for valid ``key value`` lines. 

38_KV_PARTS = 2 

39 

40 

41@dataclass(frozen=True) 

42class ModelDefaults: 

43 """Frozen snapshot of a model's default generation parameters.""" 

44 

45 temperature: float | None = None 

46 top_p: float | None = None 

47 top_k: int | None = None 

48 repeat_penalty: float | None = None 

49 num_ctx: int | None = None 

50 max_tokens: int | None = None 

51 max_reasoning_chars: int | None = None 

52 

53 

54class _DefaultsCache: 

55 """Encapsulates the per-model defaults cache (no module-level mutable global).""" 

56 

57 def __init__(self) -> None: 

58 self._data: dict[str, ModelDefaults] = {} 

59 

60 def get(self, model_name: str) -> ModelDefaults | None: 

61 return self._data.get(model_name) 

62 

63 def set(self, model_name: str, defaults: ModelDefaults) -> None: 

64 self._data[model_name] = defaults 

65 

66 def clear(self) -> None: 

67 self._data.clear() 

68 

69 

70_defaults_cache = _DefaultsCache() 

71 

72# Public API: preserves existing call sites. 

73get_defaults = _defaults_cache.get 

74set_defaults = _defaults_cache.set 

75clear_cache = _defaults_cache.clear 

76 

77 

78def parse_kv_parameters(text: str) -> ModelDefaults: 

79 """Parse multiline ``key value`` parameter format. 

80 Example input:: 

81 

82 temperature 0.7 

83 top_p 0.9 

84 num_ctx 4096 

85 stop <|im_end|> 

86 

87 Unknown keys (like ``stop``) are silently skipped. 

88 """ 

89 values: dict[str, Any] = {} 

90 for line in text.splitlines(): 

91 line = line.strip() 

92 if not line: 

93 continue 

94 parts = line.split(None, 1) 

95 if len(parts) != _KV_PARTS: 

96 continue 

97 key, raw_value = parts 

98 if key not in _KNOWN_PARAM_TYPES: 

99 continue 

100 try: 

101 values[key] = _KNOWN_PARAM_TYPES[key](raw_value) 

102 except (ValueError, TypeError): 

103 log.debug("Skipping unparseable param %s=%r", key, raw_value) 

104 return ModelDefaults(**values) 

105 

106 

107def read_gguf_defaults(metadata: dict[str, str]) -> ModelDefaults: 

108 """Extract generation defaults from a GGUF metadata dict. 

109 Looks for keys like ``general.temperature``, ``context_length`` (via the 

110 architecture-prefixed key already resolved by the caller into 

111 ``context_length``). 

112 """ 

113 values: dict[str, Any] = {} 

114 for gguf_key, field_name in _GGUF_KEY_MAP.items(): 

115 if gguf_key in metadata: 

116 try: 

117 target_type = _field_type(field_name) 

118 values[field_name] = target_type(metadata[gguf_key]) 

119 except (ValueError, TypeError): 

120 log.debug("Skipping unparseable GGUF key %s=%r", gguf_key, metadata[gguf_key]) 

121 if "context_length" in metadata: 

122 with contextlib.suppress(ValueError, TypeError): 

123 ctx = int(metadata["context_length"]) 

124 if ctx > 0: 

125 values["num_ctx"] = ctx 

126 return ModelDefaults(**values) 

127 

128 

129def _field_type(field_name: str) -> type: 

130 """Return the base type for a ModelDefaults field.""" 

131 for f in fields(ModelDefaults): 

132 if f.name == field_name: 

133 return int if "int" in str(f.type) else float 

134 return float # pragma: no cover