diff --git a/adrf/permissions.py b/adrf/permissions.py new file mode 100644 index 0000000..41523f3 --- /dev/null +++ b/adrf/permissions.py @@ -0,0 +1,231 @@ +import asyncio + +from asgiref.sync import sync_to_async +from rest_framework import permissions + + +def try_convert_operator(operator_instance): + """ + Helper function which attempts to convert a given permissions operator (i.e., AND, OR, NOT) and any sub-operators + to their async equivalent. This addresses issues with mixed operator types, where a sync operator (e.g., AND) + does not await the result of an async sub-operator. If the given parameter is NOT an operator, then this function + returns the original argument unchanged. + """ + if not is_perm_operator(operator_instance): + return operator_instance + if is_async_perm_operator(operator_instance): + # Avoid mixed async/sync operators within the operands of the given async operator + if isinstance(operator_instance, (AAND, AOR)): + operator_instance.op1 = try_convert_operator(operator_instance.op1) + operator_instance.op2 = try_convert_operator(operator_instance.op2) + else: + operator_instance.op1 = try_convert_operator(operator_instance.op1) + return operator_instance + # Convert sync operator types to async + if isinstance(operator_instance, permissions.AND): + operator_class = AAND + operands = [operator_instance.op1, operator_instance.op2] + elif isinstance(operator_instance, permissions.OR): + operator_class = AOR + operands = [operator_instance.op1, operator_instance.op2] + else: + operator_class = ANOT + operands = [operator_instance.op1] + operands = [try_convert_operator(operand) for operand in operands] + return operator_class(*operands) + + +def is_perm_operator(operator_instance): + """ + Helper function which checks whether the given parameter is a permissions operator (i.e., AND, OR, NOT). + """ + return isinstance( + operator_instance, (permissions.AND, permissions.OR, permissions.NOT) + ) + + +def is_async_perm_operator(operator_instance): + """ + Helper function which checks whether the given parameter is an async permissions operator (i.e., AAND, AOR, ANOT). + """ + return isinstance(operator_instance, (AAND, AOR, ANOT)) + + +class AsyncOperandHolderMixin: + """ + Async version of rest framework's operand holder mixin. This uses the async versions of permissions operators, + rather than the sync equivalents. + """ + + def __and__(self, other): + return AsyncOperandHolder(AAND, self, other) + + def __or__(self, other): + return AsyncOperandHolder(AOR, self, other) + + def __rand__(self, other): + return AsyncOperandHolder(AAND, other, self) + + def __ror__(self, other): + return AsyncOperandHolder(AOR, other, self) + + def __invert__(self): + return AsyncSingleOperandHolder(ANOT, self) + + +class AsyncLogicOperatorMixin: + """ + Mixin containing common methods for permissions logic operators with two operands. + """ + + def get_async_has_perm(self): + async_has_perm_a = ( + self.op1.has_permission + if asyncio.iscoroutinefunction(self.op1.has_permission) + else sync_to_async(self.op1.has_permission) + ) + async_has_perm_b = ( + self.op2.has_permission + if asyncio.iscoroutinefunction(self.op2.has_permission) + else sync_to_async(self.op2.has_permission) + ) + return async_has_perm_a, async_has_perm_b + + def get_async_has_obj_perm(self): + async_obj_perm_a = ( + self.op1.has_object_permission + if asyncio.iscoroutinefunction(self.op1.has_object_permission) + else sync_to_async(self.op1.has_object_permission) + ) + async_obj_perm_b = ( + self.op2.has_object_permission + if asyncio.iscoroutinefunction(self.op2.has_object_permission) + else sync_to_async(self.op2.has_object_permission) + ) + return async_obj_perm_a, async_obj_perm_b + + +class AsyncSingleLogicOperatorMixin: + """ + Mixin containing common methods for permissions logic operators with one operand. + """ + + def get_async_has_perm(self): + return ( + self.op1.has_permission + if asyncio.iscoroutinefunction(self.op1.has_permission) + else sync_to_async(self.op1.has_permission) + ) + + def get_async_has_obj_perm(self): + return ( + self.op1.has_object_permission + if asyncio.iscoroutinefunction(self.op1.has_object_permission) + else sync_to_async(self.op1.has_object_permission) + ) + + +class AsyncSingleOperandHolder( + AsyncOperandHolderMixin, permissions.SingleOperandHolder +): + """ + Extension to the rest framework single operand holder which uses async operators. + """ + + pass + + +class AsyncOperandHolder(AsyncOperandHolderMixin, permissions.OperandHolder): + """ + Extension to the rest framework operand holder which uses async operators. + """ + + pass + + +class AAND(AsyncLogicOperatorMixin, permissions.AND): + """ + Asynchronous logical AND operator for permissions checks, based on the synchronous equivalent defined by rest + framework. + """ + + async def has_permission(self, request, view): + async_has_perm_a, async_has_perm_b = self.get_async_has_perm() + return await async_has_perm_a(request, view) and await async_has_perm_b( + request, view + ) + + async def has_object_permission(self, request, view, obj): + async_obj_perm_a, async_obj_perm_b = self.get_async_has_obj_perm() + return await async_obj_perm_a(request, view, obj) and await async_obj_perm_b( + request, view, obj + ) + + +class AOR(AsyncLogicOperatorMixin, permissions.OR): + """ + Asynchronous logical OR operator for permissions checks, based on the synchronous equivalent defined by rest + framework. + """ + + async def has_permission(self, request, view): + async_has_perm_a, async_has_perm_b = self.get_async_has_perm() + return await async_has_perm_a(request, view) or await async_has_perm_b( + request, view + ) + + async def has_object_permission(self, request, view, obj): + async_has_perm_a, async_has_perm_b = self.get_async_has_perm() + async_obj_perm_a, async_obj_perm_b = self.get_async_has_obj_perm() + return ( + await async_has_perm_a(request, view) + and await async_obj_perm_a(request, view, obj) + ) or ( + await async_has_perm_b(request, view) + and await async_obj_perm_b(request, view, obj) + ) + + +class ANOT(AsyncSingleLogicOperatorMixin, permissions.NOT): + """ + Asynchronous logical NOT operator for permissions checks, based on the synchronous equivalent defined by rest + framework. + """ + + async def has_permission(self, request, view): + async_has_perm = self.get_async_has_perm() + return not await async_has_perm(request, view) + + async def has_object_permission(self, request, view, obj): + async_obj_perm = self.get_async_has_obj_perm() + return not await async_obj_perm(request, view, obj) + + +class AsyncBasePermissionMetaClass( + AsyncOperandHolderMixin, permissions.BasePermissionMetaclass +): + """ + Extension to the rest framework base permission metaclass which uses async operators. + """ + + pass + + +class AsyncBasePermission( + permissions.BasePermission, metaclass=AsyncBasePermissionMetaClass +): + """ + Asynchronous base permission which can be combined with other permissions using logical operators. + """ + + async def has_permission(self, request, view): + """ + Return `True` if permission is granted, `False` otherwise. + """ + return True + + async def has_object_permission(self, request, view, obj): + """ + Return `True` if permission is granted, `False` otherwise. + """ + return True diff --git a/adrf/views.py b/adrf/views.py index 27fe8af..579d403 100755 --- a/adrf/views.py +++ b/adrf/views.py @@ -7,6 +7,7 @@ from rest_framework.throttling import BaseThrottle from rest_framework.views import APIView as DRFAPIView +from adrf.permissions import try_convert_operator from adrf.requests import AsyncRequest @@ -99,6 +100,10 @@ def initialize_request(self, request, *args, **kwargs): parser_context=parser_context, ) + def get_permissions(self): + permissions = super().get_permissions() + return [try_convert_operator(permission) for permission in permissions] + def check_permissions(self, request: Request) -> None: permissions = self.get_permissions() diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 51ebfe7..566e741 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -1,14 +1,17 @@ +import unittest.mock + from django.http import HttpResponse from django.test import TestCase, override_settings from rest_framework.permissions import BasePermission from rest_framework.test import APIRequestFactory +from adrf.permissions import AsyncBasePermission from adrf.views import APIView factory = APIRequestFactory() -class AsyncPermission(BasePermission): +class AsyncPermission(AsyncBasePermission): async def has_permission(self, request, view): path = request.path_info.lstrip("/") @@ -21,6 +24,14 @@ async def has_object_permission(self, request, view, obj): return True +class AsyncRejectPermission(AsyncBasePermission): + async def has_permission(self, request, view): + return False + + async def has_object_permission(self, request, view, obj): + return False + + class SyncPermission(BasePermission): def has_permission(self, request, view): path = request.path_info.lstrip("/") @@ -73,3 +84,133 @@ async def test_sync_permission_reject(self): response = await MockView.as_view(permission_classes=(SyncPermission,))(request) self.assertEqual(response.status_code, 403) + + +class TestAsyncPermissionLogicOperators(TestCase): + @unittest.mock.patch.object(AsyncPermission, "has_permission", return_value=True) + @unittest.mock.patch.object( + AsyncRejectPermission, "has_permission", return_value=False + ) + async def test_pure_async_logical_and_permission( + self, mock_has_perm_a, mock_has_perm_b + ): + request = factory.get("/view/async/allow/") + combined_permission = AsyncPermission & AsyncRejectPermission + + response = await MockView.as_view(permission_classes=(combined_permission,))( + request + ) + + mock_has_perm_a.assert_awaited() + mock_has_perm_b.assert_awaited() + self.assertEqual(response.status_code, 403) + + @unittest.mock.patch.object(AsyncPermission, "has_permission", return_value=True) + @unittest.mock.patch.object( + AsyncRejectPermission, "has_permission", return_value=False + ) + async def test_pure_async_logical_or_permission( + self, mock_has_perm_a, mock_has_perm_b + ): + request = factory.get("/view/async/allow/") + combined_permission = AsyncRejectPermission | AsyncPermission + + response = await MockView.as_view(permission_classes=(combined_permission,))( + request + ) + + mock_has_perm_a.assert_awaited() + mock_has_perm_b.assert_awaited() + self.assertEqual(response.status_code, 200) + + @unittest.mock.patch.object( + AsyncRejectPermission, "has_permission", return_value=False + ) + async def test_pure_async_logical_neg_permission(self, mock_has_perm): + request = factory.get("/view/async/allow/") + negated_permission = ~AsyncRejectPermission + + response = await MockView.as_view(permission_classes=(negated_permission,))( + request + ) + + mock_has_perm.assert_awaited() + self.assertEqual(response.status_code, 200) + + @unittest.mock.patch.object(SyncPermission, "has_permission", return_value=True) + @unittest.mock.patch.object(AsyncPermission, "has_permission", return_value=True) + async def test_mixed_logical_and_permission( + self, mock_has_perm_async, mock_has_perm_sync + ): + request = factory.get("/view/async/allow/") + combined_permission = SyncPermission & AsyncPermission + + response = await MockView.as_view(permission_classes=(combined_permission,))( + request + ) + + mock_has_perm_async.assert_awaited() + mock_has_perm_sync.assert_called() + self.assertEqual(response.status_code, 200) + + @unittest.mock.patch.object(SyncPermission, "has_permission", return_value=True) + @unittest.mock.patch.object( + AsyncRejectPermission, "has_permission", return_value=False + ) + async def test_mixed_logical_or_permission( + self, mock_has_perm_async, mock_has_perm_sync + ): + request = factory.get("/view/async/allow/") + combined_permission = AsyncRejectPermission | SyncPermission + + response = await MockView.as_view(permission_classes=(combined_permission,))( + request + ) + + mock_has_perm_async.assert_awaited() + mock_has_perm_sync.assert_called() + self.assertEqual(response.status_code, 200) + + @unittest.mock.patch.object(SyncPermission, "has_permission", return_value=False) + @unittest.mock.patch.object(AsyncPermission, "has_permission", return_value=True) + @unittest.mock.patch.object( + AsyncRejectPermission, "has_permission", return_value=False + ) + async def test_async_first_complex_mixed_permission( + self, mock_async_reject, mock_async_accept, mock_sync_reject + ): + request = factory.get("/view/async/allow/") + combined_permission = AsyncPermission & ( + SyncPermission | ~AsyncRejectPermission + ) + + response = await MockView.as_view(permission_classes=(combined_permission,))( + request + ) + + mock_async_reject.assert_awaited() + mock_async_accept.assert_awaited() + mock_sync_reject.assert_called() + self.assertEqual(response.status_code, 200) + + @unittest.mock.patch.object(SyncPermission, "has_permission", return_value=True) + @unittest.mock.patch.object(AsyncPermission, "has_permission", return_value=False) + @unittest.mock.patch.object( + AsyncRejectPermission, "has_permission", return_value=False + ) + async def test_sync_first_complex_mixed_permission( + self, mock_async_reject, mock_async_accept, mock_sync_reject + ): + request = factory.get("/view/async/allow/") + combined_permission = SyncPermission & ( + AsyncPermission | ~AsyncRejectPermission + ) + + response = await MockView.as_view(permission_classes=(combined_permission,))( + request + ) + + mock_async_reject.assert_awaited() + mock_async_accept.assert_awaited() + mock_sync_reject.assert_called() + self.assertEqual(response.status_code, 200)