diff --git a/forum/api/threads.py b/forum/api/threads.py index b5b036c5..f04bcc80 100644 --- a/forum/api/threads.py +++ b/forum/api/threads.py @@ -362,6 +362,7 @@ def get_user_threads( user_id: Optional[str] = None, group_id: Optional[int] = None, group_ids: Optional[int] = None, + context: Optional[str] = None, **kwargs: Any, ) -> dict[str, Any]: """ @@ -385,6 +386,7 @@ def get_user_threads( "user_id": user_id, "group_id": group_id, "group_ids": group_ids, + "context": context, } params = {k: v for k, v in params.items() if v is not None} backend.validate_params(params) diff --git a/forum/backends/mongodb/api.py b/forum/backends/mongodb/api.py index b279ac8e..d74c281c 100644 --- a/forum/backends/mongodb/api.py +++ b/forum/backends/mongodb/api.py @@ -992,6 +992,7 @@ def get_threads( int(params.get("per_page", 100)), commentable_ids=params.get("commentable_ids", []), is_moderator=params.get("is_moderator", False), + context=params.get("context", "course"), ) context: dict[str, Any] = { "count_flagged": count_flagged, diff --git a/forum/backends/mysql/api.py b/forum/backends/mysql/api.py index b115119b..fea6340e 100644 --- a/forum/backends/mysql/api.py +++ b/forum/backends/mysql/api.py @@ -1158,6 +1158,7 @@ def get_threads( params.get("sort_key", ""), int(params.get("page", 1)), int(params.get("per_page", 100)), + context=params.get("context", "course"), commentable_ids=params.get("commentable_ids", []), is_moderator=params.get("is_moderator", False), ) diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 00000000..929f206a --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "forum", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/tests/test_views/test_threads.py b/tests/test_views/test_threads.py index ca28d864..40eeece4 100644 --- a/tests/test_views/test_threads.py +++ b/tests/test_views/test_threads.py @@ -447,6 +447,41 @@ def test_unresponded_filter(api_client: APIClient, patched_get_backend: Any) -> assert len(thread) == 1 +def test_get_user_threads_context(api_client: APIClient, patched_get_backend: Any) -> None: + """Test get_user_threads filters threads by context.""" + backend = patched_get_backend + user_id, course_thread_id = setup_models(backend=backend) + standalone_thread_id = backend.create_thread( + { + "title": "Standalone Thread", + "body": "Standalone Thread", + "course_id": "course1", + "commentable_id": "CommentThread", + "author_id": user_id, + "author_username": "user1", + "abuse_flaggers": [], + "historical_abuse_flaggers": [], + "context": "standalone", + } + ) + + # Default (course) context: only the course thread is returned + response = api_client.get_json("/api/v2/threads", {"course_id": "course1"}) + assert response.status_code == 200 + ids = [t["id"] for t in response.json()["collection"]] + assert course_thread_id in ids + assert standalone_thread_id not in ids + + # Explicit standalone context: only the standalone thread is returned + response = api_client.get_json( + "/api/v2/threads", {"course_id": "course1", "context": "standalone"} + ) + assert response.status_code == 200 + ids = [t["id"] for t in response.json()["collection"]] + assert standalone_thread_id in ids + assert course_thread_id not in ids + + def test_filter_by_post_type(api_client: APIClient, patched_get_backend: Any) -> None: """Test filter threads by thread_type through get thread API.""" backend = patched_get_backend