diff --git a/src/bot/middleware/auth.py b/src/bot/middleware/auth.py index 7bba27af..cfef9548 100644 --- a/src/bot/middleware/auth.py +++ b/src/bot/middleware/auth.py @@ -61,8 +61,25 @@ async def auth_middleware(handler: Callable, event: Any, data: Dict[str, Any]) - "Attempting authentication for user", user_id=user_id, username=username ) + # If the message is an /auth command, pass the token as + # credentials so that TokenAuthProvider can verify it. + credentials: Dict[str, Any] = {} + msg_text = getattr(event.effective_message, "text", "") or "" + if msg_text.startswith("/auth "): + parts = msg_text.split(maxsplit=1) + if len(parts) == 2: + candidate = parts[1].strip() + # Only treat as a raw token if it doesn't look like a + # sub-command (generate, revoke, status, add, remove). + if candidate and not candidate.startswith( + ("generate", "revoke", "status", "add", "remove") + ): + credentials = {"token": candidate} + # Try to authenticate (providers will check whitelist and tokens) - authentication_successful = await auth_manager.authenticate_user(user_id) + authentication_successful = await auth_manager.authenticate_user( + user_id, credentials + ) # Log authentication attempt if audit_logger: @@ -82,8 +99,15 @@ async def auth_middleware(handler: Callable, event: Any, data: Dict[str, Any]) - auth_provider=session.auth_provider if session else None, ) - # Welcome message for new session - if event.effective_message: + # Welcome message for new session — but skip it when the + # message is an /auth command: the command handler will send + # its own, operation-specific response (e.g. "Token revoked + # for user 123"). Otherwise the admin sees a confusing mix + # of "Welcome!" + the actual result. + is_auth_command = msg_text.startswith("/auth") and ( + msg_text == "/auth" or msg_text.startswith("/auth ") + ) + if event.effective_message and not is_auth_command: await event.effective_message.reply_text( f"🔓 Welcome! You are now authenticated.\n" f"Session started at {datetime.now(UTC).strftime('%H:%M:%S UTC')}" diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index 6d9719f0..e17e7051 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -331,6 +331,8 @@ def _register_agentic_handlers(self, app: Application) -> None: ] if self.settings.enable_project_threads: handlers.append(("sync_threads", command.sync_threads)) + if self.settings.enable_token_auth or self.settings.allowed_users: + handlers.append(("auth", self.agentic_auth)) # Derive known commands dynamically — avoids drift when new commands are added self._known_commands: frozenset[str] = frozenset(cmd for cmd, _ in handlers) @@ -420,6 +422,8 @@ def _register_classic_handlers(self, app: Application) -> None: ] if self.settings.enable_project_threads: handlers.append(("sync_threads", command.sync_threads)) + if self.settings.enable_token_auth or self.settings.allowed_users: + handlers.append(("auth", self.agentic_auth)) for cmd, handler in handlers: app.add_handler(CommandHandler(cmd, self._inject_deps(handler))) @@ -464,6 +468,8 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] ] if self.settings.enable_project_threads: commands.append(BotCommand("sync_threads", "Sync project topics")) + if self.settings.enable_token_auth or self.settings.allowed_users: + commands.append(BotCommand("auth", "Authentication management")) return commands else: commands = [ @@ -484,6 +490,8 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] ] if self.settings.enable_project_threads: commands.append(BotCommand("sync_threads", "Sync project topics")) + if self.settings.enable_token_auth or self.settings.allowed_users: + commands.append(BotCommand("auth", "Authentication management")) return commands # --- Agentic handlers --- @@ -1670,6 +1678,301 @@ async def agentic_repo( reply_markup=reply_markup, ) + # --- Token auth command --- + + def _is_admin(self, user_id: int) -> bool: + """Return True if *user_id* is in the whitelist (admin).""" + return bool( + self.settings.allowed_users and user_id in self.settings.allowed_users + ) + + async def agentic_auth( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Handle /auth — authentication management. + + /auth — authenticate with a token + /auth status — show own authentication info + /auth add — (admin) add user to persistent allowlist + /auth remove — (admin) remove user from allowlist + /auth generate — (admin) generate token for a user + /auth revoke — (admin) revoke a user's token + """ + user_id = update.effective_user.id + args = update.message.text.split()[1:] if update.message.text else [] + token_provider = context.bot_data.get("token_auth_provider") + storage = context.bot_data.get("storage") + + # /auth or /auth status — always available + if not args or (len(args) == 1 and args[0] == "status"): + auth_manager = context.bot_data.get("auth_manager") + if auth_manager and auth_manager.is_authenticated(user_id): + session = auth_manager.get_session(user_id) + provider = session.auth_provider if session else "unknown" + await update.message.reply_text( + f"\U0001f513 Authenticated via {escape_html(provider)}", + parse_mode="HTML", + ) + else: + await update.message.reply_text( + "\U0001f512 Not authenticated.\n" + "Use /auth <token> to log in.", + parse_mode="HTML", + ) + return + + sub = args[0].lower() + + # /auth add — admin adds user to DB allowlist + if sub == "add": + if not self._is_admin(user_id): + await update.message.reply_text("\U0001f512 Admin access required.") + return + if len(args) < 2: + await update.message.reply_text( + "Usage: /auth add <user_id>", + parse_mode="HTML", + ) + return + try: + target_id = int(args[1]) + except ValueError: + await update.message.reply_text("Invalid user ID.") + return + if not storage: + await update.message.reply_text( + "Storage unavailable \u2014 cannot update allowlist." + ) + return + + # Ensure user row exists, check current state. + user = await storage.get_or_create_user(target_id) + if user.is_allowed: + await update.message.reply_text( + f"\u2139\ufe0f User {target_id} is already " + f"in the allowlist.", + parse_mode="HTML", + ) + return + + await storage.users.set_user_allowed(target_id, True) + await update.message.reply_text( + f"\u2705 User {target_id} added to the allowlist. " + f"They can now write to the bot.", + parse_mode="HTML", + ) + + audit_logger = context.bot_data.get("audit_logger") + if audit_logger: + await audit_logger.log_command( + user_id=user_id, + command="auth_add", + args=[str(target_id)], + success=True, + ) + return + + # /auth remove — admin removes user from DB allowlist + if sub == "remove": + if not self._is_admin(user_id): + await update.message.reply_text("\U0001f512 Admin access required.") + return + if len(args) < 2: + await update.message.reply_text( + "Usage: /auth remove <user_id>", + parse_mode="HTML", + ) + return + try: + target_id = int(args[1]) + except ValueError: + await update.message.reply_text("Invalid user ID.") + return + if not storage: + await update.message.reply_text( + "Storage unavailable \u2014 cannot update allowlist." + ) + return + + user = await storage.users.get_user(target_id) + if not user or not user.is_allowed: + await update.message.reply_text( + f"\u2139\ufe0f User {target_id} is not in " + f"the allowlist \u2014 nothing to remove.", + parse_mode="HTML", + ) + return + + await storage.users.set_user_allowed(target_id, False) + # End their session so the change takes effect immediately + auth_manager = context.bot_data.get("auth_manager") + if auth_manager: + auth_manager.end_session(target_id) + + await update.message.reply_text( + f"\u2705 User {target_id} removed from the allowlist.", + parse_mode="HTML", + ) + + audit_logger = context.bot_data.get("audit_logger") + if audit_logger: + await audit_logger.log_command( + user_id=user_id, + command="auth_remove", + args=[str(target_id)], + success=True, + ) + return + + # All remaining sub-commands require token auth to be enabled. + if not token_provider: + await update.message.reply_text( + "Token authentication is not enabled. " + "Set ENABLE_TOKEN_AUTH=true to use token commands.", + parse_mode="HTML", + ) + return + + # /auth generate + if sub == "generate": + if not self._is_admin(user_id): + await update.message.reply_text("\U0001f512 Admin access required.") + return + if len(args) < 2: + await update.message.reply_text( + "Usage: /auth generate <user_id>", + parse_mode="HTML", + ) + return + try: + target_id = int(args[1]) + except ValueError: + await update.message.reply_text("Invalid user ID.") + return + + token = await token_provider.generate_token(target_id) + # Send the token privately — it's sensitive. + await update.message.reply_text( + f"\U0001f511 Token generated for user " + f"{target_id}:\n\n" + f"
{escape_html(token)}
\n\n" + f"The user should send:\n" + f"/auth {escape_html(token)}\n\n" + f"Expires in {token_provider.token_lifetime.days} days.", + parse_mode="HTML", + ) + + audit_logger = context.bot_data.get("audit_logger") + if audit_logger: + await audit_logger.log_command( + user_id=user_id, + command="auth_generate", + args=[str(target_id)], + success=True, + ) + return + + # /auth revoke + if sub == "revoke": + if not self._is_admin(user_id): + await update.message.reply_text("\U0001f512 Admin access required.") + return + if len(args) < 2: + await update.message.reply_text( + "Usage: /auth revoke <user_id>", + parse_mode="HTML", + ) + return + try: + target_id = int(args[1]) + except ValueError: + await update.message.reply_text("Invalid user ID.") + return + + # Check whether the target actually has an active token + # before claiming we revoked anything. + stored = await token_provider.storage.get_user_token(target_id) + if stored is None: + await update.message.reply_text( + f"\u2139\ufe0f User {target_id} has no active " + f"token \u2014 nothing to revoke.", + parse_mode="HTML", + ) + audit_logger = context.bot_data.get("audit_logger") + if audit_logger: + await audit_logger.log_command( + user_id=user_id, + command="auth_revoke", + args=[str(target_id), "no_active_token"], + success=False, + ) + return + + await token_provider.revoke_token(target_id) + + # Also end the user's session if active + auth_manager = context.bot_data.get("auth_manager") + if auth_manager: + auth_manager.end_session(target_id) + + await update.message.reply_text( + f"\u2705 Token revoked for user {target_id}.", + parse_mode="HTML", + ) + + audit_logger = context.bot_data.get("audit_logger") + if audit_logger: + await audit_logger.log_command( + user_id=user_id, + command="auth_revoke", + args=[str(target_id)], + success=True, + ) + return + + # /auth — reached when `something` isn't a known + # subcommand. Either the user successfully authenticated via + # token (middleware let them through), or the admin typed a + # typo / unknown subcommand. Disambiguate via session provider. + auth_manager = context.bot_data.get("auth_manager") + session = auth_manager.get_session(user_id) if auth_manager else None + provider_name = session.auth_provider if session else None + + if provider_name == "TokenAuthProvider": + # External user presented a valid token — middleware already + # created the session. Confirm it clearly. + await update.message.reply_text( + "\U0001f513 Authenticated via token. Welcome!", + ) + elif provider_name == "WhitelistAuthProvider": + # Admin typed something that isn't a known subcommand. + # Show a help hint instead of a misleading success message. + lines = [ + "\u2139\ufe0f You're already authenticated as admin \u2014 " + "no token needed for your account.\n", + "Did you mean one of:", + "/auth add <user_id>", + "/auth remove <user_id>", + ] + if token_provider: + lines.extend( + [ + "/auth generate <user_id>", + "/auth revoke <user_id>", + ] + ) + lines.append("/auth status") + await update.message.reply_text( + "\n".join(lines), + parse_mode="HTML", + ) + else: + # Should not reach here — middleware would have rejected + # an unauthenticated user. Fall back to a safe message. + await update.message.reply_text( + "\U0001f512 Invalid or expired token.", + ) + async def _handle_stop_callback( self, update: Update, context: ContextTypes.DEFAULT_TYPE ) -> None: diff --git a/src/main.py b/src/main.py index 02660733..7e11bd20 100644 --- a/src/main.py +++ b/src/main.py @@ -29,7 +29,8 @@ from src.security.audit import AuditLogger, InMemoryAuditStorage from src.security.auth import ( AuthenticationManager, - InMemoryTokenStorage, + DatabaseAllowlistAuthProvider, + SqliteTokenStorage, TokenAuthProvider, WhitelistAuthProvider, ) @@ -105,15 +106,20 @@ async def create_application(config: Settings) -> Dict[str, Any]: # Create security components providers = [] + token_provider: Optional[TokenAuthProvider] = None - # Add whitelist provider if users are configured + # Add whitelist provider if users are configured. + # Also add the DB-backed dynamic allowlist so admins can add users + # at runtime via /auth allow . if config.allowed_users: providers.append(WhitelistAuthProvider(config.allowed_users)) + providers.append(DatabaseAllowlistAuthProvider(storage.users)) # Add token provider if enabled if config.enable_token_auth: - token_storage = InMemoryTokenStorage() # TODO: Use database storage - providers.append(TokenAuthProvider(config.auth_token_secret, token_storage)) + token_storage = SqliteTokenStorage(storage.tokens) + token_provider = TokenAuthProvider(config.auth_secret_str, token_storage) + providers.append(token_provider) # Fall back to allowing all users in development mode if not providers and config.development_mode: @@ -173,6 +179,7 @@ async def create_application(config: Settings) -> Dict[str, Any]: # Create bot with all dependencies dependencies = { "auth_manager": auth_manager, + "token_auth_provider": token_provider, "security_validator": security_validator, "rate_limiter": rate_limiter, "audit_logger": audit_logger, diff --git a/src/security/auth.py b/src/security/auth.py index d8d216c9..cb01bb50 100644 --- a/src/security/auth.py +++ b/src/security/auth.py @@ -90,6 +90,40 @@ async def get_user_info(self, user_id: int) -> Optional[Dict[str, Any]]: return None +class DatabaseAllowlistAuthProvider(AuthProvider): + """Dynamic allowlist backed by the ``users.is_allowed`` column. + + Works alongside :class:`WhitelistAuthProvider` — the env-based list is + for permanent admins, this provider is for dynamic entries admins + add via ``/auth allow ``. + """ + + def __init__(self, user_repo: Any) -> None: + self._repo = user_repo + logger.info("Database allowlist auth provider initialized") + + async def authenticate(self, user_id: int, credentials: Dict[str, Any]) -> bool: + """Return True if the user's ``is_allowed`` flag is set.""" + user = await self._repo.get_user(user_id) + is_allowed = bool(user and user.is_allowed) + logger.info( + "Database allowlist authentication attempt", + user_id=user_id, + success=is_allowed, + ) + return is_allowed + + async def get_user_info(self, user_id: int) -> Optional[Dict[str, Any]]: + user = await self._repo.get_user(user_id) + if user and user.is_allowed: + return { + "user_id": user_id, + "auth_type": "db_allowlist", + "permissions": ["basic"], + } + return None + + class TokenStorage(ABC): """Abstract token storage interface.""" @@ -107,6 +141,15 @@ async def get_user_token(self, user_id: int) -> Optional[Dict[str, Any]]: async def revoke_token(self, user_id: int) -> None: """Revoke token for user.""" + @abstractmethod + async def touch_token(self, user_id: int, new_expires_at: datetime) -> None: + """Slide ``expires_at`` and bump ``last_used`` for the active token. + + Called on every successful authentication to implement rolling + expiration: the token stays alive as long as the user keeps using + the bot. No-op if the user has no active token. + """ + class InMemoryTokenStorage(TokenStorage): """In-memory token storage for development/testing.""" @@ -138,6 +181,46 @@ async def revoke_token(self, user_id: int) -> None: """Remove token from memory.""" self._tokens.pop(user_id, None) + async def touch_token(self, user_id: int, new_expires_at: datetime) -> None: + """Slide expiration in-memory.""" + token_data = self._tokens.get(user_id) + if token_data is not None: + token_data["expires_at"] = new_expires_at + token_data["last_used"] = datetime.now(UTC) + + +class SqliteTokenStorage(TokenStorage): + """SQLite-backed token storage using the existing ``user_tokens`` table. + + Wraps a :class:`TokenRepository` (from ``src.storage.repositories``) + so that it satisfies the :class:`TokenStorage` interface expected by + :class:`TokenAuthProvider`. + """ + + def __init__(self, token_repo: Any) -> None: + self._repo = token_repo + + async def store_token( + self, user_id: int, token_hash: str, expires_at: datetime + ) -> None: + await self._repo.store_token(user_id, token_hash, expires_at) + + async def get_user_token(self, user_id: int) -> Optional[Dict[str, Any]]: + model = await self._repo.get_active_token(user_id) + if model is None: + return None + return { + "hash": model.token_hash, + "expires_at": model.expires_at, + "created_at": model.created_at, + } + + async def revoke_token(self, user_id: int) -> None: + await self._repo.revoke_token(user_id) + + async def touch_token(self, user_id: int, new_expires_at: datetime) -> None: + await self._repo.extend_token(user_id, new_expires_at) + class TokenAuthProvider(AuthProvider): """Token-based authentication.""" @@ -154,14 +237,25 @@ def __init__( logger.info("Token auth provider initialized") async def authenticate(self, user_id: int, credentials: Dict[str, Any]) -> bool: - """Authenticate using token.""" - token = credentials.get("token") - if not token: - logger.warning( - "Token authentication failed: no token provided", user_id=user_id - ) - return False + """Authenticate using token (with rolling expiration). + + Two acceptance paths: + + 1. **Raw token provided**: verify SHA256 hash against the stored + hash for *user_id*. This is the initial handshake after + ``/auth generate`` → ``/auth ``. + 2. **No raw token, but user_id has an active stored token**: + accept based on user_id alone. Telegram guarantees the + sender's identity, and the presence of a non-expired token + in the DB means an admin previously authorised this user. + This removes the need to re-submit the token after every + 24-hour session timeout. + + On every success the token's ``expires_at`` is slid forward by + ``token_lifetime`` so active users never hit the 30-day wall. + Admins retain full control via ``/auth revoke``. + """ stored_token = await self.storage.get_user_token(user_id) if not stored_token: logger.warning( @@ -169,9 +263,39 @@ async def authenticate(self, user_id: int, credentials: Dict[str, Any]) -> bool: ) return False - is_valid = self._verify_token(token, stored_token["hash"]) - logger.info("Token authentication attempt", user_id=user_id, success=is_valid) - return is_valid + raw_token = credentials.get("token") + if raw_token: + if not self._verify_token(raw_token, stored_token["hash"]): + logger.info( + "Token authentication attempt", + user_id=user_id, + success=False, + reason="hash_mismatch", + ) + return False + auth_path = "raw_token" + else: + # No raw token — rely on the active stored token. + auth_path = "active_stored_token" + + # Slide expiration forward on every successful auth. + new_expires_at = datetime.now(UTC) + self.token_lifetime + try: + await self.storage.touch_token(user_id, new_expires_at) + except Exception as e: + logger.warning( + "Failed to slide token expiration (auth still succeeded)", + user_id=user_id, + error=str(e), + ) + + logger.info( + "Token authentication successful", + user_id=user_id, + path=auth_path, + new_expires_at=new_expires_at.isoformat(), + ) + return True async def generate_token(self, user_id: int) -> str: """Generate new authentication token.""" diff --git a/src/storage/facade.py b/src/storage/facade.py index 268a55fa..f46f77e2 100644 --- a/src/storage/facade.py +++ b/src/storage/facade.py @@ -24,6 +24,7 @@ MessageRepository, ProjectThreadRepository, SessionRepository, + TokenRepository, ToolUsageRepository, UserRepository, ) @@ -44,6 +45,7 @@ def __init__(self, database_url: str): self.tools = ToolUsageRepository(self.db_manager) self.audit = AuditLogRepository(self.db_manager) self.costs = CostTrackingRepository(self.db_manager) + self.tokens = TokenRepository(self.db_manager) self.analytics = AnalyticsRepository(self.db_manager) async def initialize(self): diff --git a/src/storage/repositories.py b/src/storage/repositories.py index 02492b8e..37ed7ca8 100644 --- a/src/storage/repositories.py +++ b/src/storage/repositories.py @@ -21,6 +21,7 @@ SessionModel, ToolUsageModel, UserModel, + UserTokenModel, ) logger = structlog.get_logger() @@ -829,3 +830,100 @@ async def get_system_stats(self) -> Dict[str, any]: "tool_stats": tool_stats, "daily_activity": daily_activity, } + + +class TokenRepository: + """User token data access.""" + + def __init__(self, db_manager: DatabaseManager): + self.db = db_manager + + async def store_token( + self, user_id: int, token_hash: str, expires_at: datetime + ) -> None: + """Store (or replace) the active token for a user. + + Auto-creates a minimal ``users`` row if the target user has never + interacted with the bot — the ``user_tokens.user_id`` foreign key + requires it, and admins need to be able to generate tokens for + brand-new users before they ever send a message. + """ + now = datetime.now(UTC) + async with self.db.get_connection() as conn: + # Ensure the user exists so the FK constraint is satisfied. + await conn.execute( + """ + INSERT OR IGNORE INTO users + (user_id, first_seen, last_active, is_allowed) + VALUES (?, ?, ?, ?) + """, + (user_id, now, now, False), + ) + # Deactivate any existing tokens for this user first + await conn.execute( + "UPDATE user_tokens SET is_active = 0 WHERE user_id = ?", + (user_id,), + ) + await conn.execute( + """ + INSERT INTO user_tokens + (user_id, token_hash, expires_at, is_active) + VALUES (?, ?, ?, 1) + """, + (user_id, token_hash, expires_at), + ) + await conn.commit() + + async def get_active_token(self, user_id: int) -> Optional[UserTokenModel]: + """Get the active, non-expired token for a user.""" + async with self.db.get_connection() as conn: + cursor = await conn.execute( + """ + SELECT * FROM user_tokens + WHERE user_id = ? AND is_active = 1 + AND (expires_at IS NULL OR expires_at > ?) + ORDER BY created_at DESC LIMIT 1 + """, + (user_id, datetime.now(UTC)), + ) + row = await cursor.fetchone() + return UserTokenModel.from_row(row) if row else None + + async def revoke_token(self, user_id: int) -> None: + """Deactivate all tokens for a user.""" + async with self.db.get_connection() as conn: + await conn.execute( + "UPDATE user_tokens SET is_active = 0 WHERE user_id = ?", + (user_id,), + ) + await conn.commit() + + async def update_last_used(self, user_id: int) -> None: + """Update the last_used timestamp for the active token.""" + async with self.db.get_connection() as conn: + await conn.execute( + """ + UPDATE user_tokens SET last_used = ? + WHERE user_id = ? AND is_active = 1 + """, + (datetime.now(UTC), user_id), + ) + await conn.commit() + + async def extend_token(self, user_id: int, new_expires_at: datetime) -> None: + """Slide ``expires_at`` forward and bump ``last_used`` atomically. + + Used for rolling-expiration auth: every successful authentication + pushes the expiry forward so active users never have to re-paste + their token. Admins can still force a cut-off with /auth revoke. + """ + async with self.db.get_connection() as conn: + await conn.execute( + """ + UPDATE user_tokens + SET expires_at = ?, last_used = ? + WHERE user_id = ? AND is_active = 1 + """, + (new_expires_at, datetime.now(UTC), user_id), + ) + await conn.commit() diff --git a/tests/unit/test_bot/test_auth_command.py b/tests/unit/test_bot/test_auth_command.py new file mode 100644 index 00000000..399c1c1a --- /dev/null +++ b/tests/unit/test_bot/test_auth_command.py @@ -0,0 +1,462 @@ +"""Tests for the /auth command handler.""" + +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from src.bot.orchestrator import MessageOrchestrator +from src.config import create_test_config +from src.security.auth import InMemoryTokenStorage, TokenAuthProvider + + +@pytest.fixture +def tmp_dir(): + with tempfile.TemporaryDirectory() as d: + yield Path(d) + + +@pytest.fixture +def token_provider(): + storage = InMemoryTokenStorage() + return TokenAuthProvider("test_secret", storage) + + +@pytest.fixture +def settings(tmp_dir): + return create_test_config( + approved_directory=str(tmp_dir), + agentic_mode=True, + enable_token_auth=True, + auth_token_secret="test_secret", + ) + + +@pytest.fixture +def deps(token_provider): + return { + "claude_integration": MagicMock(), + "storage": MagicMock(), + "security_validator": MagicMock(), + "rate_limiter": MagicMock(), + "audit_logger": None, + "token_auth_provider": token_provider, + "auth_manager": MagicMock(), + } + + +def _make_update(user_id: int = 123, text: str = "/auth") -> MagicMock: + update = MagicMock() + update.effective_user.id = user_id + update.message.text = text + update.message.reply_text = AsyncMock() + return update + + +def _make_context(deps: dict, settings: object) -> MagicMock: + ctx = MagicMock() + ctx.user_data = {} + ctx.bot_data = {**deps, "settings": settings} + return ctx + + +# --------------------------------------------------------------------------- +# /auth status +# --------------------------------------------------------------------------- + + +class TestAuthStatus: + async def test_status_not_authenticated(self, settings, deps): + orch = MessageOrchestrator(settings, deps) + update = _make_update(text="/auth status") + deps["auth_manager"].is_authenticated.return_value = False + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "Not authenticated" in text + + async def test_status_authenticated(self, settings, deps): + orch = MessageOrchestrator(settings, deps) + update = _make_update(text="/auth status") + deps["auth_manager"].is_authenticated.return_value = True + session = MagicMock() + session.auth_provider = "token" + deps["auth_manager"].get_session.return_value = session + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "Authenticated" in text + + async def test_bare_auth_shows_status(self, settings, deps): + """'/auth' with no args shows status.""" + orch = MessageOrchestrator(settings, deps) + update = _make_update(text="/auth") + deps["auth_manager"].is_authenticated.return_value = False + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "Not authenticated" in text + + +# --------------------------------------------------------------------------- +# /auth generate +# --------------------------------------------------------------------------- + + +class TestAuthGenerate: + async def test_generate_as_admin(self, settings, deps): + orch = MessageOrchestrator(settings, deps) + # The test config has allowed_users=[12345] by default + update = _make_update( + user_id=settings.allowed_users[0], + text="/auth generate 999", + ) + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "Token generated" in text + assert "999" in text + + async def test_generate_rejected_for_non_admin(self, settings, deps): + orch = MessageOrchestrator(settings, deps) + update = _make_update(user_id=77777, text="/auth generate 999") + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "Admin" in text + + async def test_generate_missing_user_id(self, settings, deps): + orch = MessageOrchestrator(settings, deps) + update = _make_update( + user_id=settings.allowed_users[0], + text="/auth generate", + ) + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "Usage" in text + + +# --------------------------------------------------------------------------- +# /auth revoke +# --------------------------------------------------------------------------- + + +class TestAuthRevoke: + async def test_revoke_as_admin_with_active_token( + self, settings, deps, token_provider + ): + """Revoke when the target actually has an active token.""" + await token_provider.generate_token(999) + + orch = MessageOrchestrator(settings, deps) + update = _make_update( + user_id=settings.allowed_users[0], + text="/auth revoke 999", + ) + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "revoked" in text.lower() + assert "no active token" not in text.lower() + + # Verify token is actually revoked + assert await token_provider.authenticate(999, {"token": "any"}) is False + + async def test_revoke_user_without_token(self, settings, deps, token_provider): + """Revoke for a user who has no token should report it clearly.""" + orch = MessageOrchestrator(settings, deps) + update = _make_update( + user_id=settings.allowed_users[0], + text="/auth revoke 12345", # Never had a token + ) + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "no active token" in text.lower() + # Must NOT claim it was revoked + assert "revoked" not in text.lower() + + async def test_revoke_rejected_for_non_admin(self, settings, deps): + orch = MessageOrchestrator(settings, deps) + update = _make_update(user_id=77777, text="/auth revoke 999") + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "Admin" in text + + +# --------------------------------------------------------------------------- +# /auth (passthrough — middleware handles actual auth) +# --------------------------------------------------------------------------- + + +class TestAuthToken: + async def test_token_user_gets_token_specific_message(self, settings, deps): + """External user authenticated via token sees a token-specific message.""" + orch = MessageOrchestrator(settings, deps) + update = _make_update(user_id=99999, text="/auth some_random_token_value") + + session = MagicMock() + session.auth_provider = "TokenAuthProvider" + deps["auth_manager"].get_session.return_value = session + + ctx = _make_context(deps, settings) + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "token" in text.lower() + assert "Welcome" in text or "Authenticated" in text + + async def test_admin_garbage_shows_help_not_success(self, settings, deps): + """Admin typing an unknown subcommand sees help, not a misleading success.""" + orch = MessageOrchestrator(settings, deps) + update = _make_update( + user_id=settings.allowed_users[0], + text="/auth xyz 123", # "xyz" is not a known subcommand + ) + + session = MagicMock() + session.auth_provider = "WhitelistAuthProvider" + deps["auth_manager"].get_session.return_value = session + + ctx = _make_context(deps, settings) + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + # Should NOT claim success + assert "Authenticated successfully" not in text + # Should hint at known subcommands — primary ones (add/remove) always listed + assert "add" in text + assert "remove" in text + # Token commands listed because token_provider is enabled in this test + assert "generate" in text + assert "revoke" in text + + +# --------------------------------------------------------------------------- +# Token auth disabled +# --------------------------------------------------------------------------- + + +class TestAuthDisabled: + async def test_no_provider_for_generate(self, settings, deps): + """Generate should fail clearly when token auth is disabled.""" + deps["token_auth_provider"] = None + orch = MessageOrchestrator(settings, deps) + update = _make_update( + user_id=settings.allowed_users[0], text="/auth generate 123" + ) + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "not enabled" in text.lower() + + async def test_status_still_works_without_token_auth(self, settings, deps): + """Status must work even if token auth is disabled.""" + deps["token_auth_provider"] = None + deps["auth_manager"].is_authenticated.return_value = False + orch = MessageOrchestrator(settings, deps) + update = _make_update(text="/auth status") + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + # Should be status message, not "not enabled" + assert "not enabled" not in text.lower() + assert "Not authenticated" in text + + +# --------------------------------------------------------------------------- +# /auth add / /auth remove +# --------------------------------------------------------------------------- + + +class TestAuthAdd: + async def test_add_as_admin(self, settings, deps): + storage = MagicMock() + user = MagicMock() + user.is_allowed = False + storage.get_or_create_user = AsyncMock(return_value=user) + storage.users.set_user_allowed = AsyncMock() + deps["storage"] = storage + + orch = MessageOrchestrator(settings, deps) + update = _make_update(user_id=settings.allowed_users[0], text="/auth add 555") + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + storage.get_or_create_user.assert_awaited_once_with(555) + storage.users.set_user_allowed.assert_awaited_once_with(555, True) + text = update.message.reply_text.call_args[0][0] + assert "added" in text.lower() or "allow" in text.lower() + + async def test_add_already_allowed(self, settings, deps): + """Allowing an already-allowed user should say so, not double-add.""" + storage = MagicMock() + user = MagicMock() + user.is_allowed = True + storage.get_or_create_user = AsyncMock(return_value=user) + storage.users.set_user_allowed = AsyncMock() + deps["storage"] = storage + + orch = MessageOrchestrator(settings, deps) + update = _make_update(user_id=settings.allowed_users[0], text="/auth add 555") + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + storage.users.set_user_allowed.assert_not_called() + text = update.message.reply_text.call_args[0][0] + assert "already" in text.lower() + + async def test_add_rejected_for_non_admin(self, settings, deps): + orch = MessageOrchestrator(settings, deps) + update = _make_update(user_id=77777, text="/auth add 555") + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "Admin" in text + + async def test_add_missing_user_id(self, settings, deps): + orch = MessageOrchestrator(settings, deps) + update = _make_update(user_id=settings.allowed_users[0], text="/auth add") + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "Usage" in text + + async def test_add_works_without_token_auth(self, settings, deps): + """allow/deny must work even when token auth is disabled.""" + deps["token_auth_provider"] = None + storage = MagicMock() + user = MagicMock() + user.is_allowed = False + storage.get_or_create_user = AsyncMock(return_value=user) + storage.users.set_user_allowed = AsyncMock() + deps["storage"] = storage + + orch = MessageOrchestrator(settings, deps) + update = _make_update(user_id=settings.allowed_users[0], text="/auth add 555") + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + storage.users.set_user_allowed.assert_awaited_once_with(555, True) + + +class TestAuthRemove: + async def test_remove_as_admin(self, settings, deps): + storage = MagicMock() + user = MagicMock() + user.is_allowed = True + storage.users.get_user = AsyncMock(return_value=user) + storage.users.set_user_allowed = AsyncMock() + deps["storage"] = storage + + orch = MessageOrchestrator(settings, deps) + update = _make_update( + user_id=settings.allowed_users[0], text="/auth remove 555" + ) + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + storage.users.set_user_allowed.assert_awaited_once_with(555, False) + text = update.message.reply_text.call_args[0][0] + assert "removed" in text.lower() + + async def test_remove_user_not_in_allowlist(self, settings, deps): + storage = MagicMock() + user = MagicMock() + user.is_allowed = False + storage.users.get_user = AsyncMock(return_value=user) + storage.users.set_user_allowed = AsyncMock() + deps["storage"] = storage + + orch = MessageOrchestrator(settings, deps) + update = _make_update( + user_id=settings.allowed_users[0], text="/auth remove 555" + ) + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + storage.users.set_user_allowed.assert_not_called() + text = update.message.reply_text.call_args[0][0] + assert "not in" in text.lower() or "nothing" in text.lower() + + async def test_remove_unknown_user(self, settings, deps): + """Deny for a user that doesn't exist in DB at all.""" + storage = MagicMock() + storage.users.get_user = AsyncMock(return_value=None) + storage.users.set_user_allowed = AsyncMock() + deps["storage"] = storage + + orch = MessageOrchestrator(settings, deps) + update = _make_update( + user_id=settings.allowed_users[0], text="/auth remove 999" + ) + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + storage.users.set_user_allowed.assert_not_called() + text = update.message.reply_text.call_args[0][0] + assert "not in" in text.lower() or "nothing" in text.lower() + + async def test_remove_ends_session(self, settings, deps): + """Denying a user must end their active session.""" + storage = MagicMock() + user = MagicMock() + user.is_allowed = True + storage.users.get_user = AsyncMock(return_value=user) + storage.users.set_user_allowed = AsyncMock() + deps["storage"] = storage + + orch = MessageOrchestrator(settings, deps) + update = _make_update( + user_id=settings.allowed_users[0], text="/auth remove 555" + ) + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + deps["auth_manager"].end_session.assert_called_once_with(555) + + async def test_remove_rejected_for_non_admin(self, settings, deps): + orch = MessageOrchestrator(settings, deps) + update = _make_update(user_id=77777, text="/auth remove 555") + ctx = _make_context(deps, settings) + + await orch.agentic_auth(update, ctx) + + text = update.message.reply_text.call_args[0][0] + assert "Admin" in text diff --git a/tests/unit/test_bot/test_auth_middleware_welcome.py b/tests/unit/test_bot/test_auth_middleware_welcome.py new file mode 100644 index 00000000..511b6b8d --- /dev/null +++ b/tests/unit/test_bot/test_auth_middleware_welcome.py @@ -0,0 +1,112 @@ +"""Regression tests for auth middleware's welcome message behavior. + +When a user runs an /auth command while not authenticated, the middleware +authenticates them (possibly via whitelist or token). Historically it then +sent a "Welcome! You are now authenticated." message, which collides with +the command handler's own response (e.g. "Token revoked for user 123"). + +The fix: suppress the welcome message for /auth commands so the handler's +operation-specific response stands alone. +""" + +from unittest.mock import AsyncMock, MagicMock + +from src.bot.middleware.auth import auth_middleware + + +def _make_event(text: str, user_id: int = 123) -> MagicMock: + event = MagicMock() + event.effective_user = MagicMock() + event.effective_user.id = user_id + event.effective_user.username = "tester" + event.effective_message = MagicMock() + event.effective_message.text = text + event.effective_message.reply_text = AsyncMock() + return event + + +def _make_data(authenticated: bool = False) -> dict: + auth_manager = MagicMock() + auth_manager.is_authenticated.return_value = authenticated + auth_manager.authenticate_user = AsyncMock(return_value=True) + session = MagicMock() + session.auth_provider = "WhitelistAuthProvider" + auth_manager.get_session.return_value = session + return {"auth_manager": auth_manager, "audit_logger": None} + + +async def _noop_handler(event, data): + return None + + +class TestWelcomeSuppression: + async def test_auth_command_suppresses_welcome(self): + """/auth ... should NOT trigger the 'Welcome!' message.""" + event = _make_event("/auth revoke 123") + data = _make_data(authenticated=False) + + await auth_middleware(_noop_handler, event, data) + + # reply_text should NOT have been called with "Welcome!" + welcome_calls = [ + c + for c in event.effective_message.reply_text.call_args_list + if "Welcome" in str(c) + ] + assert welcome_calls == [] + + async def test_auth_bare_command_suppresses_welcome(self): + """Bare /auth should also suppress welcome.""" + event = _make_event("/auth") + data = _make_data(authenticated=False) + + await auth_middleware(_noop_handler, event, data) + + welcome_calls = [ + c + for c in event.effective_message.reply_text.call_args_list + if "Welcome" in str(c) + ] + assert welcome_calls == [] + + async def test_auth_token_subcommand_suppresses_welcome(self): + """/auth should suppress welcome — handler will respond.""" + event = _make_event("/auth some_token_here") + data = _make_data(authenticated=False) + + await auth_middleware(_noop_handler, event, data) + + welcome_calls = [ + c + for c in event.effective_message.reply_text.call_args_list + if "Welcome" in str(c) + ] + assert welcome_calls == [] + + async def test_regular_message_still_shows_welcome(self): + """Non-/auth messages still get the 'Welcome!' on first auth.""" + event = _make_event("hello") + data = _make_data(authenticated=False) + + await auth_middleware(_noop_handler, event, data) + + welcome_calls = [ + c + for c in event.effective_message.reply_text.call_args_list + if "Welcome" in str(c) + ] + assert len(welcome_calls) == 1 + + async def test_auth_like_prefix_still_shows_welcome(self): + """A message like '/authorize' (not an /auth command) still gets welcome.""" + event = _make_event("/authorize me") + data = _make_data(authenticated=False) + + await auth_middleware(_noop_handler, event, data) + + welcome_calls = [ + c + for c in event.effective_message.reply_text.call_args_list + if "Welcome" in str(c) + ] + assert len(welcome_calls) == 1 diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index ce5e419e..2dcc1af6 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -20,12 +20,20 @@ def tmp_dir(): @pytest.fixture def agentic_settings(tmp_dir): - return create_test_config(approved_directory=str(tmp_dir), agentic_mode=True) + return create_test_config( + approved_directory=str(tmp_dir), + agentic_mode=True, + enable_token_auth=False, + ) @pytest.fixture def classic_settings(tmp_dir): - return create_test_config(approved_directory=str(tmp_dir), agentic_mode=False) + return create_test_config( + approved_directory=str(tmp_dir), + agentic_mode=False, + enable_token_auth=False, + ) @pytest.fixture @@ -82,8 +90,8 @@ def deps(): } -def test_agentic_registers_6_commands(agentic_settings, deps): - """Agentic mode registers start, new, status, verbose, repo, restart commands.""" +def test_agentic_registers_commands(agentic_settings, deps): + """Agentic mode registers the expected command handlers.""" orchestrator = MessageOrchestrator(agentic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -100,17 +108,19 @@ def test_agentic_registers_6_commands(agentic_settings, deps): ] commands = [h[0][0].commands for h in cmd_handlers] - assert len(cmd_handlers) == 6 assert frozenset({"start"}) in commands assert frozenset({"new"}) in commands assert frozenset({"status"}) in commands assert frozenset({"verbose"}) in commands assert frozenset({"repo"}) in commands assert frozenset({"restart"}) in commands + # /auth is registered because test config has allowed_users + assert frozenset({"auth"}) in commands + assert len(cmd_handlers) == 7 -def test_classic_registers_14_commands(classic_settings, deps): - """Classic mode registers all 14 commands.""" +def test_classic_registers_commands(classic_settings, deps): + """Classic mode registers the expected command handlers.""" orchestrator = MessageOrchestrator(classic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -125,7 +135,8 @@ def test_classic_registers_14_commands(classic_settings, deps): if isinstance(call[0][0], CommandHandler) ] - assert len(cmd_handlers) == 14 + # 14 classic commands + /auth (test config has allowed_users) + assert len(cmd_handlers) == 15 def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): @@ -156,21 +167,27 @@ def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): async def test_agentic_bot_commands(agentic_settings, deps): - """Agentic mode returns 6 bot commands.""" + """Agentic mode returns bot commands including /auth (allowed_users set).""" orchestrator = MessageOrchestrator(agentic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 6 cmd_names = [c.command for c in commands] - assert cmd_names == ["start", "new", "status", "verbose", "repo", "restart"] + assert "start" in cmd_names + assert "new" in cmd_names + assert "status" in cmd_names + assert "verbose" in cmd_names + assert "repo" in cmd_names + assert "restart" in cmd_names + assert "auth" in cmd_names + assert len(commands) == 7 async def test_classic_bot_commands(classic_settings, deps): - """Classic mode returns 14 bot commands.""" + """Classic mode returns bot commands including /auth (allowed_users set).""" orchestrator = MessageOrchestrator(classic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 14 + assert len(commands) == 15 cmd_names = [c.command for c in commands] assert "start" in cmd_names assert "help" in cmd_names diff --git a/tests/unit/test_security/test_db_allowlist.py b/tests/unit/test_security/test_db_allowlist.py new file mode 100644 index 00000000..2bf99154 --- /dev/null +++ b/tests/unit/test_security/test_db_allowlist.py @@ -0,0 +1,121 @@ +"""Tests for DatabaseAllowlistAuthProvider.""" + +import tempfile +from datetime import UTC, datetime + +import pytest + +from src.security.auth import ( + AuthenticationManager, + DatabaseAllowlistAuthProvider, + WhitelistAuthProvider, +) +from src.storage.database import DatabaseManager +from src.storage.models import UserModel +from src.storage.repositories import UserRepository + + +@pytest.fixture +async def db_manager(): + with tempfile.NamedTemporaryFile(suffix=".db") as f: + mgr = DatabaseManager(f"sqlite:///{f.name}") + await mgr.initialize() + yield mgr + await mgr.close() + + +@pytest.fixture +async def user_repo(db_manager): + return UserRepository(db_manager) + + +@pytest.fixture +async def provider(user_repo): + return DatabaseAllowlistAuthProvider(user_repo) + + +async def _create_user( + user_repo: UserRepository, user_id: int, is_allowed: bool +) -> None: + now = datetime.now(UTC) + await user_repo.create_user( + UserModel( + user_id=user_id, + first_seen=now, + last_active=now, + is_allowed=is_allowed, + ) + ) + + +class TestDatabaseAllowlistAuthProvider: + async def test_allowed_user_authenticates(self, provider, user_repo): + await _create_user(user_repo, 42, is_allowed=True) + assert await provider.authenticate(42, {}) is True + + async def test_not_allowed_user_rejected(self, provider, user_repo): + await _create_user(user_repo, 42, is_allowed=False) + assert await provider.authenticate(42, {}) is False + + async def test_unknown_user_rejected(self, provider): + assert await provider.authenticate(999, {}) is False + + async def test_get_user_info_for_allowed(self, provider, user_repo): + await _create_user(user_repo, 42, is_allowed=True) + info = await provider.get_user_info(42) + assert info is not None + assert info["user_id"] == 42 + assert info["auth_type"] == "db_allowlist" + + async def test_get_user_info_for_not_allowed(self, provider, user_repo): + await _create_user(user_repo, 42, is_allowed=False) + assert await provider.get_user_info(42) is None + + async def test_credentials_ignored(self, provider, user_repo): + """Allowlist doesn't care about credentials — just checks the flag.""" + await _create_user(user_repo, 42, is_allowed=True) + # Any credentials should work (or no credentials) + assert await provider.authenticate(42, {"token": "garbage"}) is True + assert await provider.authenticate(42, {}) is True + + +class TestAllowlistIntegration: + async def test_whitelist_takes_priority(self, user_repo): + """When both providers accept, the first one in the chain wins.""" + # User is in both env whitelist AND DB allowlist + await _create_user(user_repo, 1, is_allowed=True) + + mgr = AuthenticationManager( + [ + WhitelistAuthProvider([1]), + DatabaseAllowlistAuthProvider(user_repo), + ] + ) + + assert await mgr.authenticate_user(1, {}) is True + session = mgr.get_session(1) + assert session.auth_provider == "WhitelistAuthProvider" + + async def test_db_allowlist_kicks_in_when_whitelist_fails(self, user_repo): + """User not in env whitelist but in DB allowlist — should be authed via DB.""" + await _create_user(user_repo, 99, is_allowed=True) + + mgr = AuthenticationManager( + [ + WhitelistAuthProvider([1]), # 99 not in env whitelist + DatabaseAllowlistAuthProvider(user_repo), + ] + ) + + assert await mgr.authenticate_user(99, {}) is True + session = mgr.get_session(99) + assert session.auth_provider == "DatabaseAllowlistAuthProvider" + + async def test_both_fail_for_unknown_user(self, user_repo): + mgr = AuthenticationManager( + [ + WhitelistAuthProvider([1]), + DatabaseAllowlistAuthProvider(user_repo), + ] + ) + assert await mgr.authenticate_user(999, {}) is False diff --git a/tests/unit/test_security/test_sqlite_token_storage.py b/tests/unit/test_security/test_sqlite_token_storage.py new file mode 100644 index 00000000..154e0057 --- /dev/null +++ b/tests/unit/test_security/test_sqlite_token_storage.py @@ -0,0 +1,251 @@ +"""Tests for SqliteTokenStorage and TokenRepository.""" + +import tempfile +from datetime import UTC, datetime, timedelta + +import pytest + +from src.security.auth import SqliteTokenStorage, TokenAuthProvider +from src.storage.database import DatabaseManager +from src.storage.repositories import TokenRepository + + +@pytest.fixture +async def db_manager(): + """Create a fresh database with schema. + + Users are *not* pre-created — the fix for the FK constraint issue + is that ``store_token`` auto-creates a user stub when needed, so + tests should verify that path works. + """ + with tempfile.NamedTemporaryFile(suffix=".db") as f: + mgr = DatabaseManager(f"sqlite:///{f.name}") + await mgr.initialize() + yield mgr + await mgr.close() + + +@pytest.fixture +async def repo(db_manager): + return TokenRepository(db_manager) + + +@pytest.fixture +async def storage(repo): + return SqliteTokenStorage(repo) + + +# --------------------------------------------------------------------------- +# TokenRepository +# --------------------------------------------------------------------------- + + +class TestTokenRepository: + async def test_store_and_get(self, repo): + expires = datetime.now(UTC) + timedelta(days=30) + await repo.store_token(42, "hash_abc", expires) + + model = await repo.get_active_token(42) + assert model is not None + assert model.user_id == 42 + assert model.token_hash == "hash_abc" + assert model.is_active + + async def test_get_missing_user(self, repo): + assert await repo.get_active_token(999) is None + + async def test_revoke(self, repo): + expires = datetime.now(UTC) + timedelta(days=30) + await repo.store_token(42, "hash_abc", expires) + await repo.revoke_token(42) + + assert await repo.get_active_token(42) is None + + async def test_expired_token_not_returned(self, repo): + past = datetime.now(UTC) - timedelta(seconds=1) + await repo.store_token(42, "hash_old", past) + + assert await repo.get_active_token(42) is None + + async def test_store_replaces_old_token(self, repo): + expires = datetime.now(UTC) + timedelta(days=30) + await repo.store_token(42, "hash_v1", expires) + await repo.store_token(42, "hash_v2", expires) + + model = await repo.get_active_token(42) + assert model is not None + assert model.token_hash == "hash_v2" + + async def test_update_last_used(self, repo): + expires = datetime.now(UTC) + timedelta(days=30) + await repo.store_token(42, "hash_abc", expires) + await repo.update_last_used(42) + + model = await repo.get_active_token(42) + assert model is not None + assert model.last_used is not None + + +# --------------------------------------------------------------------------- +# SqliteTokenStorage (adapter layer) +# --------------------------------------------------------------------------- + + +class TestSqliteTokenStorage: + async def test_store_and_get(self, storage): + expires = datetime.now(UTC) + timedelta(days=7) + await storage.store_token(1, "hash_x", expires) + + data = await storage.get_user_token(1) + assert data is not None + assert data["hash"] == "hash_x" + assert "expires_at" in data + assert "created_at" in data + + async def test_get_missing(self, storage): + assert await storage.get_user_token(999) is None + + async def test_revoke(self, storage): + expires = datetime.now(UTC) + timedelta(days=7) + await storage.store_token(1, "hash_x", expires) + await storage.revoke_token(1) + + assert await storage.get_user_token(1) is None + + +# --------------------------------------------------------------------------- +# End-to-end: TokenAuthProvider + SqliteTokenStorage +# --------------------------------------------------------------------------- + + +class TestTokenAuthE2E: + async def test_generate_and_authenticate(self, storage): + provider = TokenAuthProvider("secret", storage) + token = await provider.generate_token(100) + + assert await provider.authenticate(100, {"token": token}) is True + assert await provider.authenticate(100, {"token": "wrong"}) is False + + async def test_generate_for_new_user_without_users_row(self, db_manager, storage): + """Regression: admin can generate a token for a user who has never + interacted with the bot (i.e. has no ``users`` row yet). + """ + provider = TokenAuthProvider("secret", storage) + # User 7777 is NOT in the users table. + async with db_manager.get_connection() as conn: + cursor = await conn.execute( + "SELECT 1 FROM users WHERE user_id = ?", (7777,) + ) + assert await cursor.fetchone() is None + + # Should not raise a FK violation — the fix auto-creates a stub row. + token = await provider.generate_token(7777) + assert token + + # And authentication should work end-to-end. + assert await provider.authenticate(7777, {"token": token}) is True + + # User stub should now exist. + async with db_manager.get_connection() as conn: + cursor = await conn.execute( + "SELECT 1 FROM users WHERE user_id = ?", (7777,) + ) + assert await cursor.fetchone() is not None + + async def test_regenerate_same_user(self, storage): + """Regression: admin can re-generate a token for the same user.""" + provider = TokenAuthProvider("secret", storage) + token1 = await provider.generate_token(42) + token2 = await provider.generate_token(42) + + # Old token rejected, new token accepted. + assert token1 != token2 + assert await provider.authenticate(42, {"token": token1}) is False + assert await provider.authenticate(42, {"token": token2}) is True + + async def test_revoke_prevents_auth(self, storage): + provider = TokenAuthProvider("secret", storage) + token = await provider.generate_token(100) + + await provider.revoke_token(100) + + assert await provider.authenticate(100, {"token": token}) is False + + async def test_empty_credentials_accepted_when_active_token_exists(self, storage): + """After first auth, user_id alone suffices while token is active.""" + provider = TokenAuthProvider("secret", storage) + await provider.generate_token(100) + + # No credentials at all — should still succeed because user_id has + # an active stored token. + assert await provider.authenticate(100, {}) is True + + async def test_empty_credentials_rejected_when_no_stored_token(self, storage): + """User without a stored token can't slip through with empty creds.""" + provider = TokenAuthProvider("secret", storage) + assert await provider.authenticate(999, {}) is False + + async def test_wrong_raw_token_still_fails(self, storage): + """Empty-creds path must NOT be reachable by sending a wrong token.""" + provider = TokenAuthProvider("secret", storage) + await provider.generate_token(100) + + # Wrong token → hard fail, do NOT silently fall back to "active + # stored token" path. + assert await provider.authenticate(100, {"token": "wrong"}) is False + + async def test_successful_auth_slides_expiration(self, storage, db_manager): + """Every successful authenticate() pushes expires_at forward.""" + from datetime import timedelta + + provider = TokenAuthProvider( + "secret", storage, token_lifetime=timedelta(days=30) + ) + token = await provider.generate_token(100) + + # Manually move expires_at close to expiration + soon = datetime.now(UTC) + timedelta(days=1) + async with db_manager.get_connection() as conn: + await conn.execute( + "UPDATE user_tokens SET expires_at = ? WHERE user_id = ?", + (soon, 100), + ) + await conn.commit() + + # Authenticate — should slide expires_at back to ~30 days out + assert await provider.authenticate(100, {"token": token}) is True + + async with db_manager.get_connection() as conn: + cursor = await conn.execute( + "SELECT expires_at FROM user_tokens WHERE user_id = ? " + "AND is_active = 1", + (100,), + ) + row = await cursor.fetchone() + + new_expires = row[0] + # Should be ~30 days away, definitely more than the 1-day window + assert new_expires - datetime.now(UTC) > timedelta(days=29) + + async def test_touch_updates_last_used(self, storage, db_manager): + """touch_token bumps last_used.""" + provider = TokenAuthProvider("secret", storage) + await provider.generate_token(100) + + # last_used should be None initially + async with db_manager.get_connection() as conn: + cursor = await conn.execute( + "SELECT last_used FROM user_tokens WHERE user_id = ?", (100,) + ) + assert (await cursor.fetchone())[0] is None + + # Successful auth should set it + assert await provider.authenticate(100, {}) is True + + async with db_manager.get_connection() as conn: + cursor = await conn.execute( + "SELECT last_used FROM user_tokens WHERE user_id = ? " + "AND is_active = 1", + (100,), + ) + assert (await cursor.fetchone())[0] is not None diff --git a/tests/unit/test_security/test_token_auth_flow.py b/tests/unit/test_security/test_token_auth_flow.py new file mode 100644 index 00000000..0f226c86 --- /dev/null +++ b/tests/unit/test_security/test_token_auth_flow.py @@ -0,0 +1,156 @@ +"""End-to-end integration tests for token auth flow. + +Simulates the exact sequence a user would experience: + 1. Admin (whitelisted) runs ``/auth generate ``. + 2. Target user (NOT whitelisted, NOT in users table) sends + ``/auth `` — goes through ``auth_middleware``. + 3. Target user sends a follow-up message — should be recognised as + authenticated without re-presenting the token. + 4. Admin re-runs ``/auth generate `` — should succeed + (regression against an earlier FK-constraint bug). +""" + +import tempfile +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from src.bot.middleware.auth import auth_middleware +from src.security.auth import ( + AuthenticationManager, + SqliteTokenStorage, + TokenAuthProvider, + WhitelistAuthProvider, +) +from src.storage.database import DatabaseManager +from src.storage.repositories import TokenRepository + + +@pytest.fixture +async def db_manager(): + with tempfile.NamedTemporaryFile(suffix=".db") as f: + mgr = DatabaseManager(f"sqlite:///{f.name}") + await mgr.initialize() + # Only the admin exists up front — target user is brand-new. + async with mgr.get_connection() as conn: + await conn.execute( + "INSERT OR IGNORE INTO users (user_id) VALUES (?)", (405,) + ) + await conn.commit() + yield mgr + await mgr.close() + + +@pytest.fixture +async def auth_setup(db_manager): + """Build a realistic auth_manager with both providers.""" + repo = TokenRepository(db_manager) + storage = SqliteTokenStorage(repo) + token_provider = TokenAuthProvider("test_secret", storage) + auth_manager = AuthenticationManager([WhitelistAuthProvider([405]), token_provider]) + return auth_manager, token_provider + + +def _make_event(user_id: int, text: str) -> MagicMock: + event = MagicMock() + event.effective_user.id = user_id + event.effective_user.username = "test" + event.effective_user.is_bot = False + event.effective_message.text = text + event.effective_message.reply_text = AsyncMock() + return event + + +# --------------------------------------------------------------------------- +# End-to-end flow +# --------------------------------------------------------------------------- + + +class TestTokenAuthEndToEnd: + async def test_full_flow_new_user(self, auth_setup): + """The scenario the user reported: brand-new target, full round trip.""" + auth_manager, token_provider = auth_setup + + # 1. Admin generates a token for user 123 (who has never touched the bot). + token = await token_provider.generate_token(123) + assert token + + # 2. User 123 sends ``/auth `` through the middleware. + data = {"auth_manager": auth_manager, "audit_logger": None} + handler_called = [False] + + async def handler(event, data): + handler_called[0] = True + + auth_event = _make_event(123, f"/auth {token}") + await auth_middleware(handler, auth_event, data) + + assert handler_called[0], "Middleware should have let /auth through" + assert auth_manager.is_authenticated(123), "Session must be created" + + # 3. User 123 sends a plain follow-up message. It must pass auth. + handler_called[0] = False + followup = _make_event(123, "hello world") + await auth_middleware(handler, followup, data) + + assert handler_called[ + 0 + ], "Authenticated user should be allowed through on subsequent messages" + assert auth_manager.is_authenticated(123) + + # 4. Admin regenerates the token for the same user — no error. + token2 = await token_provider.generate_token(123) + assert token2 != token + # Old token no longer authenticates. + assert await token_provider.authenticate(123, {"token": token}) is False + # New token does. + assert await token_provider.authenticate(123, {"token": token2}) is True + + async def test_wrong_token_rejects(self, auth_setup): + auth_manager, token_provider = auth_setup + await token_provider.generate_token(123) + + data = {"auth_manager": auth_manager, "audit_logger": None} + handler_called = [False] + + async def handler(event, data): + handler_called[0] = True + + event = _make_event(123, "/auth totally-wrong-token") + await auth_middleware(handler, event, data) + + assert not handler_called[0] + assert not auth_manager.is_authenticated(123) + + async def test_subcommand_not_treated_as_token(self, auth_setup): + """``/auth generate 999`` from admin must not be passed as a token.""" + auth_manager, token_provider = auth_setup + + data = {"auth_manager": auth_manager, "audit_logger": None} + handler_called = [False] + + async def handler(event, data): + handler_called[0] = True + + # Admin is whitelisted, so middleware should let them through. + event = _make_event(405, "/auth generate 999") + await auth_middleware(handler, event, data) + + assert handler_called[0] + assert auth_manager.is_authenticated(405) + + async def test_bare_auth_no_token_rejects_new_user(self, auth_setup): + """``/auth`` alone (no token) from a new user must not authenticate.""" + auth_manager, _ = auth_setup + + data = {"auth_manager": auth_manager, "audit_logger": None} + handler_called = [False] + + async def handler(event, data): + handler_called[0] = True + + event = _make_event(999, "/auth") + await auth_middleware(handler, event, data) + + assert not handler_called[0] + assert not auth_manager.is_authenticated(999)