diff --git a/open_strix/app.py b/open_strix/app.py index 502b02b..11b5982 100644 --- a/open_strix/app.py +++ b/open_strix/app.py @@ -530,7 +530,12 @@ def should_process_discord_message( *, author_is_bot: bool, author_id: str | int | None, + channel_id: str | None = None, ) -> bool: + # Channel allowlist: if configured, only process messages from listed channels. + if self.config.discord_channel_allowlist and channel_id is not None: + if str(channel_id) not in self.config.discord_channel_allowlist: + return False if not author_is_bot: return True return self.should_respond_to_bot(author_id) diff --git a/open_strix/config.py b/open_strix/config.py index 220dbbe..0996206 100644 --- a/open_strix/config.py +++ b/open_strix/config.py @@ -29,6 +29,7 @@ discord_messages_in_prompt: 10 discord_token_env: DISCORD_TOKEN always_respond_bot_ids: [] +discord_channel_allowlist: [] api_port: 0 web_ui_port: 0 web_ui_host: 127.0.0.1 @@ -182,6 +183,7 @@ class AppConfig: discord_messages_in_prompt: int = 10 discord_token_env: str = "DISCORD_TOKEN" always_respond_bot_ids: set[str] = field(default_factory=set) + discord_channel_allowlist: set[str] = field(default_factory=set) session_log_retention_days: int = 30 api_port: int = 0 web_ui_port: int = 0 @@ -249,6 +251,7 @@ def load_config(layout: RepoLayout) -> AppConfig: discord_messages_in_prompt=int(loaded.get("discord_messages_in_prompt", 10)), discord_token_env=str(loaded.get("discord_token_env", "DISCORD_TOKEN")), always_respond_bot_ids=_normalize_id_list(loaded.get("always_respond_bot_ids")), + discord_channel_allowlist=_normalize_id_list(loaded.get("discord_channel_allowlist")), session_log_retention_days=int(loaded.get("session_log_retention_days", 30)), api_port=int(loaded.get("api_port", 0)), web_ui_port=int(loaded.get("web_ui_port", 0)), diff --git a/open_strix/discord.py b/open_strix/discord.py index 7de3611..53fa5de 100644 --- a/open_strix/discord.py +++ b/open_strix/discord.py @@ -177,9 +177,11 @@ async def on_ready(self) -> None: async def on_message(self, message: discord.Message) -> None: author_id = getattr(message.author, "id", None) + channel_id = str(getattr(message.channel, "id", "")) if not self._app.should_process_discord_message( author_is_bot=bool(getattr(message.author, "bot", False)), author_id=author_id, + channel_id=channel_id or None, ): return await self._app.handle_discord_message(message) diff --git a/tests/test_discord.py b/tests/test_discord.py index db4e170..4181218 100644 --- a/tests/test_discord.py +++ b/tests/test_discord.py @@ -87,6 +87,50 @@ def test_bot_allowlist_config_controls_message_processing( assert app.should_process_discord_message(author_is_bot=True, author_id="42") is True +def test_channel_allowlist_filters_messages( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _stub_agent_factory(monkeypatch) + (tmp_path / "config.yaml").write_text( + "discord_channel_allowlist:\n" + " - '111'\n" + " - '222'\n", + encoding="utf-8", + ) + app = app_mod.OpenStrixApp(tmp_path) + + # Allowed channel — should process + assert app.should_process_discord_message( + author_is_bot=False, author_id=None, channel_id="111" + ) is True + # Not in allowlist — should filter + assert app.should_process_discord_message( + author_is_bot=False, author_id=None, channel_id="999" + ) is False + # No channel_id — still processes (DMs, web UI, etc.) + assert app.should_process_discord_message( + author_is_bot=False, author_id=None, channel_id=None + ) is True + + +def test_empty_channel_allowlist_processes_all( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _stub_agent_factory(monkeypatch) + (tmp_path / "config.yaml").write_text( + "discord_channel_allowlist: []\n", + encoding="utf-8", + ) + app = app_mod.OpenStrixApp(tmp_path) + + # Empty allowlist — all channels should process + assert app.should_process_discord_message( + author_is_bot=False, author_id=None, channel_id="999" + ) is True + + def test_log_event_includes_stable_session_id_for_app_run( tmp_path: Path, monkeypatch: pytest.MonkeyPatch,