diff --git a/backend-flask/events.py b/backend-flask/events.py index ca09fff..627f83f 100644 --- a/backend-flask/events.py +++ b/backend-flask/events.py @@ -72,7 +72,7 @@ class CreateEventJson(BaseModel): ) ) @requires_authentication -@requires_team_membership +@requires_team_membership() def create_event(player_team: PlayerTeam, json: CreateEventJson, **_): event = Event() event.team_id = player_team.team_id @@ -107,7 +107,7 @@ def create_event(player_team: PlayerTeam, json: CreateEventJson, **_): @api_events.patch("//players") @requires_authentication -@requires_team_membership +@requires_team_membership() def set_event_players(player_team: PlayerTeam, event_id: int, **_): assert_team_authority(player_team, None) diff --git a/backend-flask/middleware.py b/backend-flask/middleware.py index 7bfefa2..17afbca 100644 --- a/backend-flask/middleware.py +++ b/backend-flask/middleware.py @@ -1,5 +1,7 @@ from functools import wraps +from typing import Optional from flask import abort, make_response, request +from sqlalchemy.sql.operators import json_path_getitem_op from app_db import db from models.auth_session import AuthSession from models.player import Player @@ -28,29 +30,47 @@ def requires_authentication(f): return f(*args, **kwargs) return decorator -def requires_team_membership(f): - @wraps(f) - def decorator(*args, **kwargs): - player: Player | None = kwargs["player"] - team_id: int = kwargs["team_id"] +def requires_team_membership( + path_param: Optional[str] = None, + json_param: Optional[str] = None, + query_param: Optional[str] = None +): + def wrapper(f): + @wraps(f) + def decorator(*args, **kwargs): + player: Player | None = kwargs["player"] - if not player: - abort(401) + team_id: int + if path_param: + team_id = kwargs[path_param] + elif json_param: + team_id = getattr(kwargs["json"], json_param) + elif query_param: + team_id = getattr(kwargs["query"], query_param) + else: + team_id = kwargs["team_id"] - player_team = db.session.query( - PlayerTeam - ).where( - PlayerTeam.player == player - ).where( - PlayerTeam.team_id == team_id - ).one_or_none() + if not player: + abort(401) - if not player_team: - abort(404, "Player is not a member of this team") + if not team_id: + abort(500) - kwargs["player_team"] = player_team - return f(*args, **kwargs) - return decorator + player_team = db.session.query( + PlayerTeam + ).where( + PlayerTeam.player == player + ).where( + PlayerTeam.team_id == team_id + ).one_or_none() + + if not player_team: + abort(404, "Player is not a member of this team") + + kwargs["player_team"] = player_team + return f(*args, **kwargs) + return decorator + return wrapper def assert_team_authority( player_team: PlayerTeam, diff --git a/backend-flask/team_integration.py b/backend-flask/team_integration.py index eb16f75..a323eb5 100644 --- a/backend-flask/team_integration.py +++ b/backend-flask/team_integration.py @@ -55,7 +55,7 @@ def get_integrations(player: Player, team_id: int, **_): operation_id="create_integration" ) @requires_authentication -@requires_team_membership +@requires_team_membership() def create_integration(player_team: PlayerTeam, integration_type: str, **_): assert_team_authority(player_team) @@ -81,7 +81,7 @@ def create_integration(player_team: PlayerTeam, integration_type: str, **_): operation_id="delete_integration" ) @requires_authentication -@requires_team_membership +@requires_team_membership() def delete_integration(player_team: PlayerTeam, integration_id: int, **_): assert_team_authority(player_team) @@ -109,7 +109,7 @@ def delete_integration(player_team: PlayerTeam, integration_id: int, **_): operation_id="update_integration" ) @requires_authentication -@requires_team_membership +@requires_team_membership() def update_integration( player_team: PlayerTeam, integration_id: int, diff --git a/backend-flask/team_invite.py b/backend-flask/team_invite.py index 786a981..1521ac3 100644 --- a/backend-flask/team_invite.py +++ b/backend-flask/team_invite.py @@ -22,7 +22,7 @@ api_team_invite = Blueprint("team_invite", __name__) operation_id="get_invites" ) @requires_authentication -@requires_team_membership +@requires_team_membership() def get_invites(team_id: int, **_): invites = db.session.query( TeamInvite @@ -48,7 +48,7 @@ def get_invites(team_id: int, **_): operation_id="create_invite" ) @requires_authentication -@requires_team_membership +@requires_team_membership() def create_invite(team_id: int, **_): team_id_shifted = int(team_id) << 48 random_value_shifted = int(randint(0, (1 << 16) - 1)) << 32