diff --git a/api_schemas/user_schemas.py b/api_schemas/user_schemas.py index 04b8d3d9..810c484d 100644 --- a/api_schemas/user_schemas.py +++ b/api_schemas/user_schemas.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Annotated, Literal from pydantic import StringConstraints -from fastapi_users_pelicanq import schemas as fastapi_users_schemas +from fastapi_users import schemas as fastapi_users_schemas from api_schemas.post_schemas import PostRead from helpers.constants import MAX_FIRST_NAME_LEN, MAX_LAST_NAME_LEN from api_schemas.base_schema import BaseSchema diff --git a/db_models/user_model.py b/db_models/user_model.py index c21c836f..99a94686 100644 --- a/db_models/user_model.py +++ b/db_models/user_model.py @@ -1,5 +1,5 @@ from typing import TYPE_CHECKING, Callable, Optional -from fastapi_users_pelicanq.db import SQLAlchemyBaseUserTable +from fastapi_users.db import SQLAlchemyBaseUserTable from sqlalchemy import String, JSON from sqlalchemy.orm import Mapped, relationship, mapped_column from db_models.candidate_model import Candidate_DB diff --git a/requirements.txt b/requirements.txt index 77413c9d..3fffdb44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,8 +15,8 @@ dnspython==2.8.0 email-validator==2.1.0.post1 fakeredis==2.31.0 fastapi==0.115.0 -fastapi-users-db-sqlalchemy-pelicanq==6.0.6 -fastapi-users-pelicanq==13.0.4 +fastapi-users-db-sqlalchemy==7.0.0 +fastapi-users==15.0.5 google-api-core==2.25.1 google-api-python-client==2.177.0 google-auth==2.40.3 @@ -45,19 +45,19 @@ pluggy==1.6.0 proto-plus==1.26.1 protobuf==6.33.5 psycopg==3.1.17 -pwdlib==0.2.0 +pwdlib==0.3.0 pyasn1==0.6.3 pyasn1_modules==0.4.2 pycparser==2.21 pydantic==2.9.2 pydantic-extra-types==2.5.0 pydantic_core==2.23.4 -PyJWT==2.8.0 +PyJWT==2.12.1 pyparsing==3.2.3 pytest==7.4.4 python-dateutil==2.9.0.post0 python-dotenv==1.2.1 -python-multipart==0.0.9 +python-multipart==0.0.22 pytz==2025.2 PyYAML==6.0.1 redis==7.4.0 diff --git a/routes/auth_router.py b/routes/auth_router.py index 5ea7e0de..501d73ef 100644 --- a/routes/auth_router.py +++ b/routes/auth_router.py @@ -2,7 +2,7 @@ from api_schemas.user_schemas import UserCreate, UserRead from helpers.rate_limit import rate_limit from fastapi import APIRouter -from fastapi_users_pelicanq.schemas import BaseUserUpdate +from fastapi_users.schemas import BaseUserUpdate from user.custom_auth_router import get_auth_router, get_update_account_router from user.user_stuff import USERS, auth_backend, refresh_backend diff --git a/routes/permission_router.py b/routes/permission_router.py index e380a4ff..9bc1bb50 100644 --- a/routes/permission_router.py +++ b/routes/permission_router.py @@ -5,6 +5,7 @@ from database import DB_dependency, get_db from db_models.permission_model import Permission_DB from db_models.post_model import Post_DB +from db_models.user_model import User_DB from api_schemas.permission_schemas import ( PermissionCreate, PermissionRead, @@ -31,6 +32,21 @@ def get_all_permissions(db: Annotated[Session, Depends(get_db)]): return res +@permission_router.get("/me", response_model=list[tuple[str, str]] | None) +def get_my_permissions(member: Annotated[User_DB | None, Permission.check_member()]): + if not member: + return None + seen: set[tuple[str, str]] = set() + result: list[tuple[str, str]] = [] + for post in member.posts: + for perm in post.post_permissions: + key: tuple[str, str] = (perm.permission.action, perm.permission.target) + if key not in seen: + seen.add(key) + result.append(key) + return result + + # Create a new permission which later can be assigned to posts @permission_router.post("/", response_model=PermissionRead, dependencies=[Permission.require("manage", "Permission")]) def create_permission(perm_data: PermissionCreate, db: Annotated[Session, Depends(get_db)]): diff --git a/routes/user_router.py b/routes/user_router.py index 4bfd8900..1e0f9324 100644 --- a/routes/user_router.py +++ b/routes/user_router.py @@ -15,7 +15,7 @@ UpdateUserPosts, ) from user.user_stuff import USERS -from fastapi_users_pelicanq.manager import BaseUserManager +from fastapi_users.manager import BaseUserManager from helpers.image_checker import validate_image from helpers.rate_limit import rate_limit from helpers.types import ALLOWED_EXT, ALLOWED_IMG_SIZES, ALLOWED_IMG_TYPES, ASSETS_BASE_PATH diff --git a/tests/basic_fixtures.py b/tests/basic_fixtures.py index 21691249..2ec0eb04 100644 --- a/tests/basic_fixtures.py +++ b/tests/basic_fixtures.py @@ -39,9 +39,8 @@ def registered_users(client, user1_data, user2_data): @pytest.fixture -def admin_post(db_session): - """Create and return an admin post.""" - +def admin_council(db_session): + """Create and return a council for the admin user.""" council = Council_DB( name_sv="AdminCouncilSV", description_sv="Svensk beskrivning för admins", @@ -50,12 +49,18 @@ def admin_post(db_session): ) db_session.add(council) db_session.commit() + return council + + +@pytest.fixture +def admin_post(db_session, admin_council): + """Create and return an admin post.""" post = Post_DB( name_sv="AdminPostSV", name_en="AdminPost", description_en="AdminDescriptionEn", description_sv="AdminDescriptionSv", - council_id=council.id, + council_id=admin_council.id, elected_user_recommended_limit=1, elected_user_max_limit=2, elected_at_semester="HT", @@ -107,6 +112,87 @@ def admin_post(db_session): return post +@pytest.fixture +def super_user_post(db_session, admin_council): + """Create and return a post which has only the "super" - "User" permission.""" + + post = Post_DB( + name_sv="SuperUserPostSV", + name_en="SuperUserPostEN", + description_en="SuperUserDescriptionEn", + description_sv="SuperUserDescriptionSv", + council_id=admin_council.id, + elected_user_recommended_limit=1, + elected_user_max_limit=2, + elected_at_semester="HT", + elected_by="Guild", + ) + db_session.add(post) + db_session.commit() + + permissions = [ + Permission_DB(action="super", target="User"), + ] + post.permissions.extend(permissions) + db_session.commit() + + return post + + +@pytest.fixture +def manage_user_post(db_session, admin_council): + """Create and return a post which has only the "manage" - "User" permission.""" + + post = Post_DB( + name_sv="ManageUserPostSV", + name_en="ManageUserPostEN", + description_en="ManageUserDescriptionEn", + description_sv="ManageUserDescriptionSv", + council_id=admin_council.id, + elected_user_recommended_limit=1, + elected_user_max_limit=2, + elected_at_semester="HT", + elected_by="Guild", + ) + db_session.add(post) + db_session.commit() + + permissions = [ + Permission_DB(action="manage", target="User"), + ] + post.permissions.extend(permissions) + db_session.commit() + + return post + + +@pytest.fixture +def view_user_post(db_session, admin_council): + """Create and return a post which has only the "view" - "User" permission.""" + + post = Post_DB( + name_sv="ViewUserPostSV", + name_en="ViewUserPostEN", + description_en="ViewUserDescriptionEn", + description_sv="ViewUserDescriptionSv", + council_id=admin_council.id, + elected_user_recommended_limit=1, + elected_user_max_limit=2, + elected_at_semester="HT", + elected_by="Guild", + ) + db_session.add(post) + db_session.commit() + + permissions = [ + Permission_DB(action="view", target="User"), + ] + post.permissions.extend(permissions) + db_session.commit() + + return post + + @pytest.fixture def admin_user(client, db_session, admin_post): """Create and return a full admin user with the admin post and permissions.""" diff --git a/tests/test_permissions.py b/tests/test_permissions.py new file mode 100644 index 00000000..decd8bd7 --- /dev/null +++ b/tests/test_permissions.py @@ -0,0 +1,212 @@ +# type: ignore +import pytest +import ast +from pathlib import Path +from fastapi import status +from .basic_factories import auth_headers + + +# Bot written function to make sure that when we run tests using permissions below, +# the actual required permissions in the code are what we expect, to guard against silent permission drift. +# If this gives you trouble for whatever reason, just comment it out along with +# the parts of tests containing 'required_permission' found in the tests +def _extract_required_permission(router_file: Path, function_name: str): + """Return (action, target) from Permission.require(action, target) in route dependencies.""" + + module = ast.parse(router_file.read_text(encoding="utf-8")) + + for node in module.body: + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if node.name != function_name: + continue + + for decorator in node.decorator_list: + if not isinstance(decorator, ast.Call): + continue + + for keyword in decorator.keywords: + if keyword.arg != "dependencies" or not isinstance(keyword.value, ast.List): + continue + + for dependency in keyword.value.elts: + if not isinstance(dependency, ast.Call): + continue + if not isinstance(dependency.func, ast.Attribute): + continue + if dependency.func.attr != "require": + continue + if not isinstance(dependency.func.value, ast.Name): + continue + if dependency.func.value.id != "Permission": + continue + if len(dependency.args) != 2: + continue + if not all(isinstance(arg, ast.Constant) and isinstance(arg.value, str) for arg in dependency.args): + continue + + return dependency.args[0].value, dependency.args[1].value + + return None + + raise AssertionError(f"Function {function_name} not found in {router_file}") + + +def test_get_my_permissions(client, admin_token): + """Test that the /permissions/me route returns some correct permissions for the admin user""" + + # Check that the admin user has some of the expected permissions + response = client.get("/permissions/me", headers=auth_headers(admin_token)) + assert response.status_code == status.HTTP_200_OK + permissions = response.json() + assert ["manage", "User"] in permissions + assert ["manage", "Permission"] in permissions + assert ["manage", "BlahBlah"] not in permissions + + +def test_deny_super_for_manage(client, manage_user_post, membered_user, db_session, admin_user): + """Test that a user with only manage permissions cannot access super permissions (like deleting a user)""" + + # Give appropriate permissions + membered_user.posts.append(manage_user_post) + db_session.commit() + + # Get token for the user with manage permissions + resp = client.post("/auth/login", data={"username": "member@example.com", "password": "Password123"}) + user_token = resp.json()["access_token"] + + # Check that the admin delete user route requires super permissions, not just manage permissions + # if this test fails due to permission drift, point it at some other route that requires super permissions + router_file = Path(__file__).resolve().parents[1] / "routes" / "user_router.py" + required_permission = _extract_required_permission(router_file, "admin_delete_user") + assert required_permission == ("super", "User") + + # Check that the user has manage permissions but not super permissions + response = client.get("/permissions/me", headers=auth_headers(user_token)) + assert response.status_code == status.HTTP_200_OK + permissions = response.json() + assert ["manage", "User"] in permissions + assert ["super", "User"] not in permissions + + # Try to delete another user (which requires super permissions) + response = client.delete( + "/users/admin/" + str(admin_user.id), + headers=auth_headers(user_token), + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_allow_manage_for_super(client, super_user_post, membered_user, db_session, admin_user): + """Test that a user with super permissions can access manage routes""" + + # Give appropriate permissions + membered_user.posts.append(super_user_post) + db_session.commit() + + # Get token for the user with super permissions + resp = client.post("/auth/login", data={"username": "member@example.com", "password": "Password123"}) + user_token = resp.json()["access_token"] + + # Check that the admin manage user route requires manage permissions + # if this test fails due to permission drift, point it at some other route that requires manage permissions + router_file = Path(__file__).resolve().parents[1] / "routes" / "user_router.py" + required_permission = _extract_required_permission(router_file, "admin_update_user") + assert required_permission == ("manage", "User") + + # Check that the user has super permissions but not manage permissions + response = client.get("/permissions/me", headers=auth_headers(user_token)) + assert response.status_code == status.HTTP_200_OK + permissions = response.json() + assert ["super", "User"] in permissions + assert ["manage", "User"] not in permissions + + # The user should still be able to access the manage route + response = client.patch( + f"/users/admin/update/{str(admin_user.id)}", + json={"first_name": "NewName"}, + headers=auth_headers(user_token), + ) + assert response.status_code == status.HTTP_200_OK + + +def test_allow_view_for_manage(client, manage_user_post, membered_user, db_session, admin_user): + """Test that a user with manage permissions can access view routes""" + + # Give appropriate permissions + membered_user.posts.append(manage_user_post) + db_session.commit() + + # Get token for the user with manage permissions + resp = client.post("/auth/login", data={"username": "member@example.com", "password": "Password123"}) + user_token = resp.json()["access_token"] + + # Check that the view user route requires view permissions but not manage permissions + # if this test fails due to permission drift, point it at some other route that requires view permissions + router_file = Path(__file__).resolve().parents[1] / "routes" / "user_router.py" + required_permission = _extract_required_permission(router_file, "get_user") + assert required_permission == ("view", "User") + + # Check that the user has manage permissions but not view permissions + response = client.get("/permissions/me", headers=auth_headers(user_token)) + assert response.status_code == status.HTTP_200_OK + permissions = response.json() + assert ["manage", "User"] in permissions + assert ["view", "User"] not in permissions + + # The user should still be able to access the view route + response = client.get( + "/users/admin/" + str(admin_user.id), + headers=auth_headers(user_token), + ) + assert response.status_code == status.HTTP_200_OK + + +def test_deny_manage_for_view(client, view_user_post, membered_user, db_session, admin_user): + """Test that a user with only view permissions cannot access manage routes""" + + # Give appropriate permissions + membered_user.posts.append(view_user_post) + db_session.commit() + + # Get token for the user with view permissions + resp = client.post("/auth/login", data={"username": "member@example.com", "password": "Password123"}) + user_token = resp.json()["access_token"] + + # Check that the admin manage user route requires manage permissions, not just view permissions + router_file = Path(__file__).resolve().parents[1] / "routes" / "user_router.py" + required_permission = _extract_required_permission(router_file, "admin_update_user") + assert required_permission == ("manage", "User") + + # Check that the user has view permissions but not manage permissions + response = client.get("/permissions/me", headers=auth_headers(user_token)) + assert response.status_code == status.HTTP_200_OK + permissions = response.json() + assert ["view", "User"] in permissions + assert ["manage", "User"] not in permissions + + # Try to access the manage route + response = client.patch( + f"/users/admin/update/{str(admin_user.id)}", + json={"first_name": "NewName"}, + headers=auth_headers(user_token), + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_member_without_permissions_is_forbidden(client, member_token): + """Test that a plain member without post permissions cannot access protected permission routes.""" + + response = client.get("/permissions/", headers=auth_headers(member_token)) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_get_my_permissions_requires_member_status(client, non_member_token): + """Test that /permissions/me returns null for verified non-members and anonymous users.""" + + response = client.get("/permissions/me", headers=auth_headers(non_member_token)) + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + response = client.get("/permissions/me") + assert response.status_code == status.HTTP_200_OK + assert response.json() is None diff --git a/user/custom_auth_router.py b/user/custom_auth_router.py index 59805520..d1b7d7ec 100644 --- a/user/custom_auth_router.py +++ b/user/custom_auth_router.py @@ -3,11 +3,11 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, status from fastapi.security import OAuth2PasswordRequestForm -from fastapi_users_pelicanq import models, schemas, exceptions -from fastapi_users_pelicanq.authentication import AuthenticationBackend, Authenticator, Strategy -from fastapi_users_pelicanq.manager import BaseUserManager, UserManagerDependency -from fastapi_users_pelicanq.openapi import OpenAPIResponseType -from fastapi_users_pelicanq.router.common import ErrorCode, ErrorModel +from fastapi_users import models, schemas, exceptions +from fastapi_users.authentication import AuthenticationBackend, Authenticator, Strategy +from fastapi_users.manager import BaseUserManager, UserManagerDependency +from fastapi_users.openapi import OpenAPIResponseType +from fastapi_users.router.common import ErrorCode, ErrorModel from helpers.rate_limit import rate_limit from pydantic import EmailStr from user.refresh_auth_backend import RefreshAuthenticationBackend diff --git a/user/permission.py b/user/permission.py index a4e59fae..8c8a1930 100644 --- a/user/permission.py +++ b/user/permission.py @@ -1,13 +1,9 @@ -from typing import cast from fastapi import Depends, HTTPException, status -from fastapi_users_pelicanq import jwt from db_models.permission_model import PERMISSION_TYPE, PERMISSION_TARGET from db_models.user_model import User_DB -from user.token_strategy import AccessTokenData, CustomTokenStrategy, get_jwt_secret from user.user_stuff import ( current_user, current_verified_user, - current_verified_user_token, ) @@ -66,30 +62,16 @@ def dependency(user: User_DB | None = Depends(current_verified_user)): def require(cls, action: PERMISSION_TYPE, target: PERMISSION_TARGET): # Use this dependency on routes which require specific permissions def dependency( - user_and_token: tuple[User_DB | None, str | None] = Depends(current_verified_user_token), - jwt_secret: str = Depends(get_jwt_secret), + user: User_DB | None = Depends(current_verified_user), ): - user, token = user_and_token - if user is None or token is None: + if user is None: # We can raise here unlike in "check" because this is supposed to be an absolute requirement raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") - permissions: list[str] = [] for post in user.posts: for perm in post.permissions: - permissions.append(f"{perm.action}:{perm.target}") - - decoded_token = cast(AccessTokenData, jwt.decode_jwt(token, jwt_secret, audience=["fastapi-users:auth"])) - - # see if user has a permission matching the required permission - for perm in decoded_token["permissions"]: - try: - claim_action, claim_target = CustomTokenStrategy.decode_permission(perm) - except: - raise HTTPException(status.HTTP_403_FORBIDDEN) - - verified = cls.verify_permission(claim_action, claim_target, action, target) - if verified: - return user + verified = cls.verify_permission(perm.action, perm.target, action, target) + if verified: + return user raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) @@ -109,7 +91,12 @@ def verify_permission( if claim_target != required_target and claim_target != "all": return False - if claim_action == required_action or claim_action == "manage": + if ( + claim_action == required_action + or (claim_action == "manage" and required_action == "view") + or claim_action == "super" + ): + # Super can do everything, manage can do view and manage, view can only do view return True return False @@ -118,32 +105,18 @@ def verify_permission( def check(cls, action: PERMISSION_TYPE, target: PERMISSION_TARGET): # Use this dependency on routes which work differently if the user has specific permissions def dependency( - user_and_token: tuple[User_DB | None, str | None] = Depends(current_verified_user_token), - jwt_secret: str = Depends(get_jwt_secret), + user: User_DB | None = Depends(current_verified_user), ): - user, token = user_and_token - if user is None or token is None: + if user is None: # If we raise an exception here, it would cause the whole calling function to get an exception # which defeats the point of this check being optional return False - permissions: list[str] = [] for post in user.posts: for perm in post.permissions: - permissions.append(f"{perm.action}:{perm.target}") - - decoded_token = cast(AccessTokenData, jwt.decode_jwt(token, jwt_secret, audience=["fastapi-users:auth"])) - - # see if user has a permission matching the required permission - for perm in decoded_token["permissions"]: - try: - claim_action, claim_target = CustomTokenStrategy.decode_permission(perm) - except: - return False - - verified = cls.verify_permission(claim_action, claim_target, action, target) - if verified: - return True + verified = cls.verify_permission(perm.action, perm.target, action, target) + if verified: + return True return False diff --git a/user/refresh_auth_backend.py b/user/refresh_auth_backend.py index 94df0f60..2e967b00 100644 --- a/user/refresh_auth_backend.py +++ b/user/refresh_auth_backend.py @@ -2,11 +2,11 @@ from fastapi import Response -from fastapi_users_pelicanq import models -from fastapi_users_pelicanq.authentication.strategy import ( +from fastapi_users import models +from fastapi_users.authentication.strategy import ( StrategyDestroyNotSupportedError, ) -from fastapi_users_pelicanq.authentication.backend import AuthenticationBackend +from fastapi_users.authentication.backend import AuthenticationBackend from user.token_strategy import RefreshStrategy diff --git a/user/token_strategy.py b/user/token_strategy.py index 0647568e..5adde427 100644 --- a/user/token_strategy.py +++ b/user/token_strategy.py @@ -1,13 +1,13 @@ import os import secrets from typing import Optional, TypedDict, Generic -from fastapi_users_pelicanq import BaseUserManager -from fastapi_users_pelicanq.authentication import JWTStrategy +from fastapi_users import BaseUserManager +from fastapi_users.authentication import JWTStrategy from db_models.user_model import User_DB import redis.asyncio from fastapi import Depends -from fastapi_users_pelicanq.authentication import RedisStrategy, Strategy -from fastapi_users_pelicanq import models +from fastapi_users.authentication import RedisStrategy, Strategy +from fastapi_users import models from database import get_redis @@ -49,29 +49,6 @@ async def get_jwt_secret() -> str: class AccessTokenData(TypedDict): sub: str aud: list[str] - permissions: list[str] # this is our own field we add for permission system - - -class CustomTokenStrategy(JWTStrategy[User_DB, int]): - # on login we add our own permissions data into the JWT token - async def get_user_permissions(self, user: User_DB) -> list[str]: - # lets add all permissions form the user's post - all_perms: list[str] = [] - for post in user.posts: - for perm in post.permissions: - all_perms.append(self.encode_permission(perm.action, perm.target)) - return all_perms - - @classmethod - def decode_permission(cls, permission: str) -> tuple[str, str]: - decoded = permission.split(":") - action = decoded[0] - target = decoded[1] - return action, target - - @classmethod - def encode_permission(cls, action: str, target: str) -> str: - return f"{action}:{target}" class RefreshStrategy(Strategy[models.UP, models.ID]): @@ -128,7 +105,7 @@ async def write_token(self, user: models.UP) -> str: def get_jwt_strategy( secret: str = Depends(get_jwt_secret), ) -> JWTStrategy[User_DB, int]: - strat = CustomTokenStrategy(secret=secret, lifetime_seconds=JWT_TOKEN_LIFETIME_SECONDS) + strat = JWTStrategy[User_DB, int](secret=secret, lifetime_seconds=JWT_TOKEN_LIFETIME_SECONDS) return strat diff --git a/user/user_manager.py b/user/user_manager.py index 58e9dc8f..4af53547 100644 --- a/user/user_manager.py +++ b/user/user_manager.py @@ -2,8 +2,8 @@ import re from typing import Any, Dict, Optional, Type, Union from fastapi import Request -from fastapi_users_pelicanq import BaseUserManager, IntegerIDMixin, InvalidPasswordException -from fastapi_users_pelicanq import schemas +from fastapi_users import BaseUserManager, IntegerIDMixin, InvalidPasswordException +from fastapi_users import schemas from api_schemas.user_schemas import UserCreate from db_models.user_model import User_DB diff --git a/user/user_stuff.py b/user/user_stuff.py index 6a0213a8..3f054a10 100644 --- a/user/user_stuff.py +++ b/user/user_stuff.py @@ -1,10 +1,11 @@ import os -from typing import Any -from fastapi_users_pelicanq.authentication import AuthenticationBackend, BearerTransport, CookieTransport -from fastapi_users_pelicanq.db import SQLAlchemyUserDatabase +from typing import Any, cast +from fastapi_users.authentication import AuthenticationBackend, BearerTransport, CookieTransport +from fastapi_users.db import SQLAlchemyUserDatabase +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from fastapi import Depends -from fastapi_users_pelicanq import FastAPIUsers +from fastapi_users import FastAPIUsers from database import get_db from db_models.user_model import User_DB from user.refresh_auth_backend import RefreshAuthenticationBackend @@ -12,6 +13,29 @@ from user.user_manager import UserManager +# This is such a horrible hack, but the alternative is essentially a full refactor of the code +class _AsyncSessionProxy: + """Expose an async-session-like interface on top of a sync Session.""" + + def __init__(self, session: Session): + self._session = session + + def add(self, instance: Any) -> None: + self._session.add(instance) + + async def execute(self, statement: Any, *args: Any, **kwargs: Any) -> Any: + return self._session.execute(statement, *args, **kwargs) + + async def commit(self) -> None: + self._session.commit() + + async def refresh(self, instance: Any) -> None: + self._session.refresh(instance) + + async def delete(self, instance: Any) -> None: + self._session.delete(instance) + + # Access token is sent in the Authorization header as a Bearer token. bearer_transport = BearerTransport(tokenUrl="auth/login") @@ -64,7 +88,8 @@ async def get_user_db(session: Session = Depends(get_db)): - yield SQLAlchemyUserDatabase[User_DB, int](session, User_DB) + async_session = cast(AsyncSession, _AsyncSessionProxy(session)) + yield SQLAlchemyUserDatabase[User_DB, int](async_session, User_DB) async def get_user_manager(user_db: SQLAlchemyUserDatabase[User_DB, int] = Depends(get_user_db)): @@ -91,7 +116,3 @@ def get_enabled_backends() -> list[AuthenticationBackend[User_DB, int]]: current_user: Any = USERS.current_user(get_enabled_backends=get_enabled_backends, optional=True) current_verified_user: Any = USERS.current_user(verified=True, get_enabled_backends=get_enabled_backends, optional=True) - -current_verified_user_token: Any = USERS.current_user_token( - verified=True, get_enabled_backends=get_enabled_backends, optional=True -)