Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions open_strix/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,12 @@ def should_process_discord_message(
*,
author_is_bot: bool,
author_id: str | int | None,
channel_id: str | None = None,
) -> bool:
# Channel allowlist: if configured, only process messages from listed channels.
if self.config.discord_channel_allowlist and channel_id is not None:
if str(channel_id) not in self.config.discord_channel_allowlist:
return False
if not author_is_bot:
return True
return self.should_respond_to_bot(author_id)
Expand Down
3 changes: 3 additions & 0 deletions open_strix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
discord_messages_in_prompt: 10
discord_token_env: DISCORD_TOKEN
always_respond_bot_ids: []
discord_channel_allowlist: []
api_port: 0
web_ui_port: 0
web_ui_host: 127.0.0.1
Expand Down Expand Up @@ -182,6 +183,7 @@ class AppConfig:
discord_messages_in_prompt: int = 10
discord_token_env: str = "DISCORD_TOKEN"
always_respond_bot_ids: set[str] = field(default_factory=set)
discord_channel_allowlist: set[str] = field(default_factory=set)
session_log_retention_days: int = 30
api_port: int = 0
web_ui_port: int = 0
Expand Down Expand Up @@ -249,6 +251,7 @@ def load_config(layout: RepoLayout) -> AppConfig:
discord_messages_in_prompt=int(loaded.get("discord_messages_in_prompt", 10)),
discord_token_env=str(loaded.get("discord_token_env", "DISCORD_TOKEN")),
always_respond_bot_ids=_normalize_id_list(loaded.get("always_respond_bot_ids")),
discord_channel_allowlist=_normalize_id_list(loaded.get("discord_channel_allowlist")),
session_log_retention_days=int(loaded.get("session_log_retention_days", 30)),
api_port=int(loaded.get("api_port", 0)),
web_ui_port=int(loaded.get("web_ui_port", 0)),
Expand Down
2 changes: 2 additions & 0 deletions open_strix/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@ async def on_ready(self) -> None:

async def on_message(self, message: discord.Message) -> None:
author_id = getattr(message.author, "id", None)
channel_id = str(getattr(message.channel, "id", ""))
if not self._app.should_process_discord_message(
author_is_bot=bool(getattr(message.author, "bot", False)),
author_id=author_id,
channel_id=channel_id or None,
):
return
await self._app.handle_discord_message(message)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,50 @@ def test_bot_allowlist_config_controls_message_processing(
assert app.should_process_discord_message(author_is_bot=True, author_id="42") is True


def test_channel_allowlist_filters_messages(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
_stub_agent_factory(monkeypatch)
(tmp_path / "config.yaml").write_text(
"discord_channel_allowlist:\n"
" - '111'\n"
" - '222'\n",
encoding="utf-8",
)
app = app_mod.OpenStrixApp(tmp_path)

# Allowed channel — should process
assert app.should_process_discord_message(
author_is_bot=False, author_id=None, channel_id="111"
) is True
# Not in allowlist — should filter
assert app.should_process_discord_message(
author_is_bot=False, author_id=None, channel_id="999"
) is False
# No channel_id — still processes (DMs, web UI, etc.)
assert app.should_process_discord_message(
author_is_bot=False, author_id=None, channel_id=None
) is True


def test_empty_channel_allowlist_processes_all(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
_stub_agent_factory(monkeypatch)
(tmp_path / "config.yaml").write_text(
"discord_channel_allowlist: []\n",
encoding="utf-8",
)
app = app_mod.OpenStrixApp(tmp_path)

# Empty allowlist — all channels should process
assert app.should_process_discord_message(
author_is_bot=False, author_id=None, channel_id="999"
) is True


def test_log_event_includes_stable_session_id_for_app_run(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
Expand Down