Coverage for src / lilbee / server / auth.py: 100%

84 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Session token auth middleware with decorator-based read-only marking.""" 

2 

3from __future__ import annotations 

4 

5import hmac 

6import json 

7import logging 

8import secrets 

9import sys 

10from collections.abc import Callable 

11from pathlib import Path 

12from typing import Any, TypeVar 

13 

14from litestar.exceptions import NotAuthorizedException 

15from litestar.types import ASGIApp, Receive, Scope, Send 

16 

17from lilbee.core.config import cfg 

18 

19log = logging.getLogger(__name__) 

20 

21_TOKEN_BYTES = 32 

22 

23F = TypeVar("F", bound=Callable[..., Any]) 

24 

25 

26# Set of route-handler functions that bypass auth. Populated at import time by 

27# the @read_only decorator; checked by AuthMiddleware via membership lookup. 

28# Module-level set is intentional: route handlers are defined once at import, 

29# the registry has no other lifecycle, and the alternative (mutating an 

30# attribute on the function object) lands every check on a # type: ignore 

31# because mypy cannot see the dynamic attribute on Callable. 

32_READ_ONLY_HANDLERS: set[Callable[..., Any]] = set() 

33 

34 

35def read_only(fn: F) -> F: 

36 """Mark a route handler as read-only (no auth required).""" 

37 _READ_ONLY_HANDLERS.add(fn) 

38 return fn 

39 

40 

41def is_read_only(fn: Callable[..., Any]) -> bool: 

42 """True iff *fn* was decorated with :func:`read_only`.""" 

43 return fn in _READ_ONLY_HANDLERS 

44 

45 

46def server_json_path() -> Path: 

47 """Return the path to the server session file.""" 

48 return cfg.data_dir / "server.json" 

49 

50 

51class SessionManager: 

52 """Manages the server session token lifecycle. 

53 Replaces the old module-level ``_session_token`` global so that auth 

54 state is explicit and injectable rather than hidden mutable state. 

55 """ 

56 

57 def __init__(self) -> None: 

58 self.token: str | None = None 

59 

60 def load_or_generate(self) -> str: 

61 """Return the persisted token if shape-valid; generate a new one otherwise.""" 

62 path = server_json_path() 

63 existing = self._read_persisted_token(path) 

64 if existing is not None: 

65 self.token = existing 

66 return existing 

67 self.token = secrets.token_urlsafe(_TOKEN_BYTES) 

68 path.parent.mkdir(parents=True, exist_ok=True) 

69 path.write_text(json.dumps({"token": self.token})) 

70 if sys.platform != "win32": 

71 path.chmod(0o600) 

72 return self.token 

73 

74 @staticmethod 

75 def _read_persisted_token(path: Path) -> str | None: 

76 """Return a previously-persisted token if shape-valid, else None.""" 

77 try: 

78 raw = path.read_text() 

79 except (FileNotFoundError, OSError): 

80 return None 

81 try: 

82 data = json.loads(raw) 

83 except json.JSONDecodeError: 

84 return None 

85 if not isinstance(data, dict): 

86 return None 

87 token = data.get("token") 

88 if not isinstance(token, str) or len(token) < _TOKEN_BYTES: 

89 return None 

90 return token 

91 

92 def cleanup(self) -> None: 

93 """Remove server.json on shutdown and clear the in-memory token.""" 

94 self.token = None 

95 path = server_json_path() 

96 path.unlink(missing_ok=True) 

97 

98 def validate(self, auth_header: str) -> bool: 

99 """Check whether *auth_header* carries a valid bearer token.""" 

100 if self.token is None: 

101 return True # auth disabled (tests) 

102 if not self.token: 

103 raise NotAuthorizedException("Server token not initialized") 

104 return hmac.compare_digest(auth_header, f"Bearer {self.token}") 

105 

106 

107# Singleton instance: used by AuthMiddleware and the app lifespan. 

108session_manager = SessionManager() 

109 

110 

111class AuthMiddleware: 

112 """Bearer token auth middleware for mutating endpoints.""" 

113 

114 def __init__(self, app: ASGIApp) -> None: 

115 self.app = app 

116 

117 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

118 if scope["type"] != "http": 

119 await self.app(scope, receive, send) 

120 return 

121 

122 method = scope.get("method", "") 

123 if method == "OPTIONS": 

124 await self.app(scope, receive, send) 

125 return 

126 

127 handler = scope.get("route_handler") 

128 if handler and is_read_only(handler.fn): 

129 await self.app(scope, receive, send) 

130 return 

131 

132 headers = dict(scope.get("headers", [])) 

133 auth_header = headers.get(b"authorization", b"").decode() 

134 if session_manager.validate(auth_header): 

135 await self.app(scope, receive, send) 

136 return 

137 raise NotAuthorizedException("Missing or invalid bearer token")