Coverage for src / lilbee / server / auth.py: 100%
84 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Session token auth middleware with decorator-based read-only marking."""
3from __future__ import annotations
5import hmac
6import json
7import logging
8import secrets
9import sys
10from collections.abc import Callable
11from pathlib import Path
12from typing import Any, TypeVar
14from litestar.exceptions import NotAuthorizedException
15from litestar.types import ASGIApp, Receive, Scope, Send
17from lilbee.core.config import cfg
19log = logging.getLogger(__name__)
21_TOKEN_BYTES = 32
23F = TypeVar("F", bound=Callable[..., Any])
26# Set of route-handler functions that bypass auth. Populated at import time by
27# the @read_only decorator; checked by AuthMiddleware via membership lookup.
28# Module-level set is intentional: route handlers are defined once at import,
29# the registry has no other lifecycle, and the alternative (mutating an
30# attribute on the function object) lands every check on a # type: ignore
31# because mypy cannot see the dynamic attribute on Callable.
32_READ_ONLY_HANDLERS: set[Callable[..., Any]] = set()
35def read_only(fn: F) -> F:
36 """Mark a route handler as read-only (no auth required)."""
37 _READ_ONLY_HANDLERS.add(fn)
38 return fn
41def is_read_only(fn: Callable[..., Any]) -> bool:
42 """True iff *fn* was decorated with :func:`read_only`."""
43 return fn in _READ_ONLY_HANDLERS
46def server_json_path() -> Path:
47 """Return the path to the server session file."""
48 return cfg.data_dir / "server.json"
51class SessionManager:
52 """Manages the server session token lifecycle.
53 Replaces the old module-level ``_session_token`` global so that auth
54 state is explicit and injectable rather than hidden mutable state.
55 """
57 def __init__(self) -> None:
58 self.token: str | None = None
60 def load_or_generate(self) -> str:
61 """Return the persisted token if shape-valid; generate a new one otherwise."""
62 path = server_json_path()
63 existing = self._read_persisted_token(path)
64 if existing is not None:
65 self.token = existing
66 return existing
67 self.token = secrets.token_urlsafe(_TOKEN_BYTES)
68 path.parent.mkdir(parents=True, exist_ok=True)
69 path.write_text(json.dumps({"token": self.token}))
70 if sys.platform != "win32":
71 path.chmod(0o600)
72 return self.token
74 @staticmethod
75 def _read_persisted_token(path: Path) -> str | None:
76 """Return a previously-persisted token if shape-valid, else None."""
77 try:
78 raw = path.read_text()
79 except (FileNotFoundError, OSError):
80 return None
81 try:
82 data = json.loads(raw)
83 except json.JSONDecodeError:
84 return None
85 if not isinstance(data, dict):
86 return None
87 token = data.get("token")
88 if not isinstance(token, str) or len(token) < _TOKEN_BYTES:
89 return None
90 return token
92 def cleanup(self) -> None:
93 """Remove server.json on shutdown and clear the in-memory token."""
94 self.token = None
95 path = server_json_path()
96 path.unlink(missing_ok=True)
98 def validate(self, auth_header: str) -> bool:
99 """Check whether *auth_header* carries a valid bearer token."""
100 if self.token is None:
101 return True # auth disabled (tests)
102 if not self.token:
103 raise NotAuthorizedException("Server token not initialized")
104 return hmac.compare_digest(auth_header, f"Bearer {self.token}")
107# Singleton instance: used by AuthMiddleware and the app lifespan.
108session_manager = SessionManager()
111class AuthMiddleware:
112 """Bearer token auth middleware for mutating endpoints."""
114 def __init__(self, app: ASGIApp) -> None:
115 self.app = app
117 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
118 if scope["type"] != "http":
119 await self.app(scope, receive, send)
120 return
122 method = scope.get("method", "")
123 if method == "OPTIONS":
124 await self.app(scope, receive, send)
125 return
127 handler = scope.get("route_handler")
128 if handler and is_read_only(handler.fn):
129 await self.app(scope, receive, send)
130 return
132 headers = dict(scope.get("headers", []))
133 auth_header = headers.get(b"authorization", b"").decode()
134 if session_manager.validate(auth_header):
135 await self.app(scope, receive, send)
136 return
137 raise NotAuthorizedException("Missing or invalid bearer token")