Coverage for src / lilbee / modelhub / role_validator.py: 100%

59 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Role-slot assignment validation for the four model config fields.""" 

2 

3import os 

4import sys 

5from typing import Any 

6 

7from lilbee.catalog import find_catalog_entry 

8from lilbee.catalog.refs import is_bare_hf_repo 

9from lilbee.catalog.types import ModelTask 

10from lilbee.core.config import cfg 

11from lilbee.modelhub.model_manager.discovery import reclassify_by_name 

12from lilbee.modelhub.registry import ModelRegistry 

13from lilbee.providers.model_ref import PROVIDER_PREFIXES 

14 

15# Test-only bypass. Both the env var and pytest must be present so a 

16# leaked env var cannot disable validation in production. 

17_SKIP_MODEL_TASK_VALIDATION_ENV = "LILBEE_SKIP_MODEL_TASK_VALIDATION" 

18 

19_MODEL_FIELD_TO_TASK: dict[str, str] = { 

20 "chat_model": "chat", 

21 "embedding_model": "embedding", 

22 "vision_model": "vision", 

23 "reranker_model": "rerank", 

24} 

25 

26# A native GGUF ref of the form ``<owner>/<repo>/<file>.gguf`` has at least 

27# two ``/`` separators; one-slash refs are bare repo IDs. 

28_NATIVE_GGUF_REF_MIN_SLASHES = 2 

29 

30 

31class TaskMismatchError(ValueError): 

32 """A role slot was assigned a model whose catalog task does not match. 

33 

34 Carries the structured fields so each surface (HTTP, CLI, TUI, MCP) 

35 can format its own user-facing message. The default ``str()`` form is 

36 surface-neutral so it is safe to surface unmodified. 

37 """ 

38 

39 def __init__(self, ref: str, entry_task: ModelTask, expected_task: ModelTask) -> None: 

40 self.ref = ref 

41 self.entry_task = entry_task 

42 self.expected_task = expected_task 

43 super().__init__(f"Model '{ref}' is a {entry_task} model, not {expected_task}.") 

44 

45 

46def _model_task_validation_bypassed() -> bool: 

47 if not os.environ.get(_SKIP_MODEL_TASK_VALIDATION_ENV): 

48 return False 

49 return sys.modules.get("pytest") is not None 

50 

51 

52def _resolve_installed_task(registry: ModelRegistry, ref: str) -> ModelTask | None: 

53 """Return the manifest's ``ModelTask`` for *ref*, name-reclassified, or ``None``.""" 

54 manifest = registry.get_manifest(ref) 

55 if manifest is None: 

56 return None 

57 return ModelTask(reclassify_by_name(ref, manifest.task)) 

58 

59 

60def _skips_catalog_check(ref: str, *, allow_bypass: bool) -> bool: 

61 """Whether *ref* skips the catalog check.""" 

62 if not ref or not ref.strip(): 

63 return True 

64 if allow_bypass and _model_task_validation_bypassed(): 

65 return True 

66 return ref.split("/", 1)[0] in PROVIDER_PREFIXES 

67 

68 

69def _canonical_featured_ref(ref: str, entry: Any, want: ModelTask) -> str: 

70 """Role-check a featured entry and pick the canonical ref to persist.""" 

71 if entry.task != want: 

72 raise TaskMismatchError(ref, ModelTask(entry.task), want) 

73 # Keep a full ``<repo>/<file>.gguf`` so resolve_model_path lands on 

74 # the exact installed quant; fall back to the catalog ref otherwise. 

75 if ref.endswith(".gguf") and ref.count("/") >= _NATIVE_GGUF_REF_MIN_SLASHES: 

76 return ref 

77 canonical: str = entry.ref 

78 return canonical 

79 

80 

81def _validate_installed_ref(ref: str, want: ModelTask) -> str: 

82 """Role-check a non-featured ref by consulting the installed registry. 

83 

84 A bare ``<org>/<repo>`` ref canonicalizes to its installed quant's full 

85 ref so the persisted value always names the exact GGUF file. 

86 """ 

87 registry = ModelRegistry(cfg.models_dir) 

88 if is_bare_hf_repo(ref): 

89 ref = registry.installed_ref_for_repo(ref) or ref 

90 installed_task = _resolve_installed_task(registry, ref) 

91 if installed_task is None: 

92 raise ValueError( 

93 f"Model '{ref}' is not installed. " 

94 "Install it with 'lilbee model pull <ref>' " 

95 "(or POST /api/models/pull) before assigning it to a role." 

96 ) 

97 if installed_task != want: 

98 raise TaskMismatchError(ref, installed_task, want) 

99 return ref 

100 

101 

102def validate_model_task_assignment(field_name: str, ref: str, *, allow_bypass: bool = True) -> str: 

103 """Check *ref* is assignable to *field_name*; return the canonical ref. 

104 

105 Accepts featured catalog refs and installed non-featured refs (any model 

106 the user has pulled). Raises ``TaskMismatchError`` on role mismatch and 

107 ``ValueError`` when the model is neither featured nor installed. 

108 """ 

109 if _skips_catalog_check(ref, allow_bypass=allow_bypass): 

110 return ref 

111 want = ModelTask(_MODEL_FIELD_TO_TASK[field_name]) 

112 entry: Any = find_catalog_entry(ref) 

113 if entry is not None: 

114 return _canonical_featured_ref(ref, entry, want) 

115 return _validate_installed_ref(ref, want)