Coverage for src / lilbee / providers / model_cache.py: 100%

82 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Llama-cpp loader-mode constants and dynamic-context / GPU-memory helpers.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6import platform 

7from enum import StrEnum 

8from pathlib import Path 

9 

10log = logging.getLogger(__name__) 

11 

12 

13class LoaderMode(StrEnum): 

14 """Which task to configure llama.cpp for at load time.""" 

15 

16 CHAT = "chat" 

17 EMBED = "embed" 

18 RERANK = "rerank" 

19 

20 

21# Fallback KV cache estimate when GGUF metadata can't be read. 

22# 2048 bytes/token undershoots real KV size for modern models (Gemma3-4B is 

23# ~640 KB/token f16) but is fine as a coarse pre-load eviction signal. 

24_KV_BYTES_PER_CTX_TOKEN = 2048 

25 

26# Metal/CUDA buffer overhead as fraction of model weight memory 

27_BUFFER_OVERHEAD_FRACTION = 0.10 

28 

29# Default context length for estimation when metadata unavailable 

30_DEFAULT_CTX_LEN = 2048 

31 

32# Floor for the dynamic n_ctx computation (smaller is unusable for chat) 

33_DYNAMIC_CTX_FLOOR = 512 

34 

35# Round dynamic n_ctx down to a multiple of this (clean batch sizes) 

36_DYNAMIC_CTX_QUANTUM = 256 

37 

38# KV cache element size for f16 (bytes). Quantized KV reduces this. 

39_KV_ELEM_BYTES_F16 = 2 

40 

41 

42def kv_bytes_per_token(meta: dict[str, str] | None, kv_elem_bytes: int = _KV_ELEM_BYTES_F16) -> int: 

43 """Estimate per-token KV cache size in bytes from GGUF metadata. 

44 

45 Formula: 2 (K + V) * n_layers * n_kv_heads * head_dim * elem_bytes. 

46 Falls back to ``_KV_BYTES_PER_CTX_TOKEN`` when metadata is missing. 

47 """ 

48 if not meta: 

49 return _KV_BYTES_PER_CTX_TOKEN 

50 try: 

51 n_layers = int(meta["block_count"]) 

52 head_count_kv = int(meta.get("head_count_kv") or meta["head_count"]) 

53 if "key_length" in meta and "value_length" in meta: 

54 kv_dim = int(meta["key_length"]) + int(meta["value_length"]) 

55 else: 

56 embed = int(meta["embedding_length"]) 

57 head_count = int(meta.get("head_count") or head_count_kv) 

58 head_dim = embed // head_count 

59 kv_dim = 2 * head_dim 

60 except (KeyError, ValueError, ZeroDivisionError): 

61 return _KV_BYTES_PER_CTX_TOKEN 

62 return n_layers * head_count_kv * kv_dim * kv_elem_bytes 

63 

64 

65def estimate_model_memory( 

66 model_path: Path, 

67 n_ctx: int = _DEFAULT_CTX_LEN, 

68 kv_bytes_per_tok: int = _KV_BYTES_PER_CTX_TOKEN, 

69) -> int: 

70 """Estimate memory consumption for a GGUF model. 

71 Approximation: file_size (weights) + KV cache + 10% buffer overhead. 

72 """ 

73 file_bytes = model_path.stat().st_size if model_path.exists() else 0 

74 kv_bytes = n_ctx * kv_bytes_per_tok 

75 overhead = int(file_bytes * _BUFFER_OVERHEAD_FRACTION) 

76 return file_bytes + kv_bytes + overhead 

77 

78 

79def compute_dynamic_ctx( 

80 *, 

81 model_bytes: int, 

82 available_bytes: int, 

83 training_ctx: int, 

84 kv_bytes_per_tok: int, 

85 ceiling: int, 

86 target: int | None = None, 

87 floor: int = _DYNAMIC_CTX_FLOOR, 

88 quantum: int = _DYNAMIC_CTX_QUANTUM, 

89) -> int: 

90 """Pick the n_ctx that best fits target, ceiling, and host RAM. 

91 

92 Selection rule, in order: 

93 

94 1. ``upper = min(training_ctx, ceiling)`` is the hard upper bound; the 

95 model cannot exceed its training window and the caller may cap below it. 

96 2. If ``target`` is provided, prefer it (clamped to ``[floor, upper]``) 

97 so a 40K-context model still loads at 8K when chat doesn't need more, 

98 rather than maximising n_ctx just because RAM allows it. 

99 3. ``raw_ctx = budget // kv_bytes_per_tok`` is the largest n_ctx the host 

100 can physically back. The result is clamped to ``raw_ctx`` so we never 

101 over-allocate on memory-constrained boxes. 

102 4. Result is quantized down to ``quantum`` and floored at ``floor``. 

103 """ 

104 upper = min(training_ctx, ceiling) 

105 if kv_bytes_per_tok <= 0: 

106 if target is not None: 

107 return max(floor, min(target, upper)) 

108 return upper 

109 overhead = int(model_bytes * _BUFFER_OVERHEAD_FRACTION) 

110 budget = available_bytes - model_bytes - overhead 

111 if budget <= 0: 

112 return floor 

113 raw_ctx = budget // kv_bytes_per_tok 

114 # Aim for target when set, but never above what host RAM or model training_ctx permit. 

115 desired = min(target, raw_ctx, upper) if target is not None else min(raw_ctx, upper) 

116 bounded = max(floor, desired) 

117 quantized = (bounded // quantum) * quantum 

118 return max(floor, quantized) 

119 

120 

121def get_available_memory(fraction: float) -> int: 

122 """Return usable GPU/unified memory in bytes, scaled by *fraction*. 

123 - macOS (Apple Silicon): unified memory via psutil 

124 - Linux with NVIDIA GPU: pynvml -> nvidia-smi -> psutil fallback 

125 - Other: psutil system memory 

126 """ 

127 import psutil 

128 

129 system = platform.system() 

130 

131 if system == "Darwin": 

132 total = psutil.virtual_memory().total 

133 return int(total * fraction) 

134 

135 if system in ("Linux", "Windows"): 

136 nvidia_mem = _try_nvidia_memory() 

137 if nvidia_mem is not None: 

138 return int(nvidia_mem * fraction) 

139 

140 total = psutil.virtual_memory().total 

141 return int(total * fraction) 

142 

143 

144def _try_nvidia_memory() -> int | None: 

145 """Try to get NVIDIA GPU total memory via pynvml, then nvidia-smi.""" 

146 try: 

147 import pynvml # type: ignore[import-untyped] 

148 

149 pynvml.nvmlInit() 

150 handle = pynvml.nvmlDeviceGetHandleByIndex(0) 

151 info = pynvml.nvmlDeviceGetMemoryInfo(handle) 

152 pynvml.nvmlShutdown() 

153 return int(info.total) 

154 except Exception: # noqa: S110 -- optional GPU detect; absence is expected on non-NVIDIA hosts 

155 pass 

156 

157 try: 

158 import subprocess 

159 

160 # nvidia-smi ships with the NVIDIA driver and is always on PATH when 

161 # present; fully-qualifying it would break on every install layout. 

162 result = subprocess.run( 

163 ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"], # noqa: S607 

164 capture_output=True, 

165 text=True, 

166 timeout=5, 

167 ) 

168 if result.returncode == 0: 

169 mib = int(result.stdout.strip().split("\n")[0]) 

170 return mib * 1024 * 1024 

171 except Exception: # noqa: S110 -- optional GPU detect; same rationale as above 

172 pass 

173 

174 return None