Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/bot/middleware/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,25 @@ async def auth_middleware(handler: Callable, event: Any, data: Dict[str, Any]) -
"Attempting authentication for user", user_id=user_id, username=username
)

# If the message is an /auth <token> command, pass the token as
# credentials so that TokenAuthProvider can verify it.
credentials: Dict[str, Any] = {}
msg_text = getattr(event.effective_message, "text", "") or ""
if msg_text.startswith("/auth "):
parts = msg_text.split(maxsplit=1)
if len(parts) == 2:
candidate = parts[1].strip()
# Only treat as a raw token if it doesn't look like a
# sub-command (generate, revoke, status, add, remove).
if candidate and not candidate.startswith(
("generate", "revoke", "status", "add", "remove")
):
credentials = {"token": candidate}

# Try to authenticate (providers will check whitelist and tokens)
authentication_successful = await auth_manager.authenticate_user(user_id)
authentication_successful = await auth_manager.authenticate_user(
user_id, credentials
)

# Log authentication attempt
if audit_logger:
Expand All @@ -82,8 +99,15 @@ async def auth_middleware(handler: Callable, event: Any, data: Dict[str, Any]) -
auth_provider=session.auth_provider if session else None,
)

# Welcome message for new session
if event.effective_message:
# Welcome message for new session — but skip it when the
# message is an /auth command: the command handler will send
# its own, operation-specific response (e.g. "Token revoked
# for user 123"). Otherwise the admin sees a confusing mix
# of "Welcome!" + the actual result.
is_auth_command = msg_text.startswith("/auth") and (
msg_text == "/auth" or msg_text.startswith("/auth ")
)
if event.effective_message and not is_auth_command:
await event.effective_message.reply_text(
f"🔓 Welcome! You are now authenticated.\n"
f"Session started at {datetime.now(UTC).strftime('%H:%M:%S UTC')}"
Expand Down
303 changes: 303 additions & 0 deletions src/bot/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ def _register_agentic_handlers(self, app: Application) -> None:
]
if self.settings.enable_project_threads:
handlers.append(("sync_threads", command.sync_threads))
if self.settings.enable_token_auth or self.settings.allowed_users:
handlers.append(("auth", self.agentic_auth))

# Derive known commands dynamically — avoids drift when new commands are added
self._known_commands: frozenset[str] = frozenset(cmd for cmd, _ in handlers)
Expand Down Expand Up @@ -420,6 +422,8 @@ def _register_classic_handlers(self, app: Application) -> None:
]
if self.settings.enable_project_threads:
handlers.append(("sync_threads", command.sync_threads))
if self.settings.enable_token_auth or self.settings.allowed_users:
handlers.append(("auth", self.agentic_auth))

for cmd, handler in handlers:
app.add_handler(CommandHandler(cmd, self._inject_deps(handler)))
Expand Down Expand Up @@ -464,6 +468,8 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg]
]
if self.settings.enable_project_threads:
commands.append(BotCommand("sync_threads", "Sync project topics"))
if self.settings.enable_token_auth or self.settings.allowed_users:
commands.append(BotCommand("auth", "Authentication management"))
return commands
else:
commands = [
Expand All @@ -484,6 +490,8 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg]
]
if self.settings.enable_project_threads:
commands.append(BotCommand("sync_threads", "Sync project topics"))
if self.settings.enable_token_auth or self.settings.allowed_users:
commands.append(BotCommand("auth", "Authentication management"))
return commands

# --- Agentic handlers ---
Expand Down Expand Up @@ -1670,6 +1678,301 @@ async def agentic_repo(
reply_markup=reply_markup,
)

# --- Token auth command ---

def _is_admin(self, user_id: int) -> bool:
"""Return True if *user_id* is in the whitelist (admin)."""
return bool(
self.settings.allowed_users and user_id in self.settings.allowed_users
)

async def agentic_auth(
self, update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
"""Handle /auth — authentication management.

/auth <token> — authenticate with a token
/auth status — show own authentication info
/auth add <user_id> — (admin) add user to persistent allowlist
/auth remove <user_id> — (admin) remove user from allowlist
/auth generate <user_id> — (admin) generate token for a user
/auth revoke <user_id> — (admin) revoke a user's token
"""
user_id = update.effective_user.id
args = update.message.text.split()[1:] if update.message.text else []
token_provider = context.bot_data.get("token_auth_provider")
storage = context.bot_data.get("storage")

# /auth or /auth status — always available
if not args or (len(args) == 1 and args[0] == "status"):
auth_manager = context.bot_data.get("auth_manager")
if auth_manager and auth_manager.is_authenticated(user_id):
session = auth_manager.get_session(user_id)
provider = session.auth_provider if session else "unknown"
await update.message.reply_text(
f"\U0001f513 Authenticated via <b>{escape_html(provider)}</b>",
parse_mode="HTML",
)
else:
await update.message.reply_text(
"\U0001f512 Not authenticated.\n"
"Use <code>/auth &lt;token&gt;</code> to log in.",
parse_mode="HTML",
)
return

sub = args[0].lower()

# /auth add <user_id> — admin adds user to DB allowlist
if sub == "add":
if not self._is_admin(user_id):
await update.message.reply_text("\U0001f512 Admin access required.")
return
if len(args) < 2:
await update.message.reply_text(
"Usage: <code>/auth add &lt;user_id&gt;</code>",
parse_mode="HTML",
)
return
try:
target_id = int(args[1])
except ValueError:
await update.message.reply_text("Invalid user ID.")
return
if not storage:
await update.message.reply_text(
"Storage unavailable \u2014 cannot update allowlist."
)
return

# Ensure user row exists, check current state.
user = await storage.get_or_create_user(target_id)
if user.is_allowed:
await update.message.reply_text(
f"\u2139\ufe0f User <code>{target_id}</code> is already "
f"in the allowlist.",
parse_mode="HTML",
)
return

await storage.users.set_user_allowed(target_id, True)
await update.message.reply_text(
f"\u2705 User <code>{target_id}</code> added to the allowlist. "
f"They can now write to the bot.",
parse_mode="HTML",
)

audit_logger = context.bot_data.get("audit_logger")
if audit_logger:
await audit_logger.log_command(
user_id=user_id,
command="auth_add",
args=[str(target_id)],
success=True,
)
return

# /auth remove <user_id> — admin removes user from DB allowlist
if sub == "remove":
if not self._is_admin(user_id):
await update.message.reply_text("\U0001f512 Admin access required.")
return
if len(args) < 2:
await update.message.reply_text(
"Usage: <code>/auth remove &lt;user_id&gt;</code>",
parse_mode="HTML",
)
return
try:
target_id = int(args[1])
except ValueError:
await update.message.reply_text("Invalid user ID.")
return
if not storage:
await update.message.reply_text(
"Storage unavailable \u2014 cannot update allowlist."
)
return

user = await storage.users.get_user(target_id)
if not user or not user.is_allowed:
await update.message.reply_text(
f"\u2139\ufe0f User <code>{target_id}</code> is not in "
f"the allowlist \u2014 nothing to remove.",
parse_mode="HTML",
)
return

await storage.users.set_user_allowed(target_id, False)
# End their session so the change takes effect immediately
auth_manager = context.bot_data.get("auth_manager")
if auth_manager:
auth_manager.end_session(target_id)

await update.message.reply_text(
f"\u2705 User <code>{target_id}</code> removed from the allowlist.",
parse_mode="HTML",
)

audit_logger = context.bot_data.get("audit_logger")
if audit_logger:
await audit_logger.log_command(
user_id=user_id,
command="auth_remove",
args=[str(target_id)],
success=True,
)
return

# All remaining sub-commands require token auth to be enabled.
if not token_provider:
await update.message.reply_text(
"Token authentication is not enabled. "
"Set <code>ENABLE_TOKEN_AUTH=true</code> to use token commands.",
parse_mode="HTML",
)
return

# /auth generate <user_id>
if sub == "generate":
if not self._is_admin(user_id):
await update.message.reply_text("\U0001f512 Admin access required.")
return
if len(args) < 2:
await update.message.reply_text(
"Usage: <code>/auth generate &lt;user_id&gt;</code>",
parse_mode="HTML",
)
return
try:
target_id = int(args[1])
except ValueError:
await update.message.reply_text("Invalid user ID.")
return

token = await token_provider.generate_token(target_id)
# Send the token privately — it's sensitive.
await update.message.reply_text(
f"\U0001f511 Token generated for user "
f"<code>{target_id}</code>:\n\n"
f"<pre>{escape_html(token)}</pre>\n\n"
f"The user should send:\n"
f"<code>/auth {escape_html(token)}</code>\n\n"
f"Expires in {token_provider.token_lifetime.days} days.",
parse_mode="HTML",
)

audit_logger = context.bot_data.get("audit_logger")
if audit_logger:
await audit_logger.log_command(
user_id=user_id,
command="auth_generate",
args=[str(target_id)],
success=True,
)
return

# /auth revoke <user_id>
if sub == "revoke":
if not self._is_admin(user_id):
await update.message.reply_text("\U0001f512 Admin access required.")
return
if len(args) < 2:
await update.message.reply_text(
"Usage: <code>/auth revoke &lt;user_id&gt;</code>",
parse_mode="HTML",
)
return
try:
target_id = int(args[1])
except ValueError:
await update.message.reply_text("Invalid user ID.")
return

# Check whether the target actually has an active token
# before claiming we revoked anything.
stored = await token_provider.storage.get_user_token(target_id)
if stored is None:
await update.message.reply_text(
f"\u2139\ufe0f User <code>{target_id}</code> has no active "
f"token \u2014 nothing to revoke.",
parse_mode="HTML",
)
audit_logger = context.bot_data.get("audit_logger")
if audit_logger:
await audit_logger.log_command(
user_id=user_id,
command="auth_revoke",
args=[str(target_id), "no_active_token"],
success=False,
)
return

await token_provider.revoke_token(target_id)

# Also end the user's session if active
auth_manager = context.bot_data.get("auth_manager")
if auth_manager:
auth_manager.end_session(target_id)

await update.message.reply_text(
f"\u2705 Token revoked for user <code>{target_id}</code>.",
parse_mode="HTML",
)

audit_logger = context.bot_data.get("audit_logger")
if audit_logger:
await audit_logger.log_command(
user_id=user_id,
command="auth_revoke",
args=[str(target_id)],
success=True,
)
return

# /auth <something> — reached when `something` isn't a known
# subcommand. Either the user successfully authenticated via
# token (middleware let them through), or the admin typed a
# typo / unknown subcommand. Disambiguate via session provider.
auth_manager = context.bot_data.get("auth_manager")
session = auth_manager.get_session(user_id) if auth_manager else None
provider_name = session.auth_provider if session else None

if provider_name == "TokenAuthProvider":
# External user presented a valid token — middleware already
# created the session. Confirm it clearly.
await update.message.reply_text(
"\U0001f513 Authenticated via token. Welcome!",
)
elif provider_name == "WhitelistAuthProvider":
# Admin typed something that isn't a known subcommand.
# Show a help hint instead of a misleading success message.
lines = [
"\u2139\ufe0f You're already authenticated as admin \u2014 "
"no token needed for your account.\n",
"Did you mean one of:",
"<code>/auth add &lt;user_id&gt;</code>",
"<code>/auth remove &lt;user_id&gt;</code>",
]
if token_provider:
lines.extend(
[
"<code>/auth generate &lt;user_id&gt;</code>",
"<code>/auth revoke &lt;user_id&gt;</code>",
]
)
lines.append("<code>/auth status</code>")
await update.message.reply_text(
"\n".join(lines),
parse_mode="HTML",
)
else:
# Should not reach here — middleware would have rejected
# an unauthenticated user. Fall back to a safe message.
await update.message.reply_text(
"\U0001f512 Invalid or expired token.",
)

async def _handle_stop_callback(
self, update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
Expand Down
Loading