Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from __future__ import annotations

from pathlib import Path
from typing import Any

import voluptuous as vol
from aiofiles.ospath import exists

from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT
Expand All @@ -20,7 +20,7 @@
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) # pylint: disable=invalid-name


async def _validate_service_data(data: dict[str, Any]) -> None:
async def _validate_service_data(hass: HomeAssistant, data: dict[str, Any]) -> None:
has_password: bool = bool(data.get(CONF_PASSWORD))
has_key_file: bool = bool(data.get(CONF_KEY_FILE))

Expand All @@ -41,7 +41,7 @@ async def _validate_service_data(data: dict[str, Any]) -> None:
translation_key="command_or_input",
)

if has_key_file and not await exists(data[CONF_KEY_FILE]):
if has_key_file and not await hass.async_add_executor_job(Path(data[CONF_KEY_FILE]).exists):
raise ServiceValidationError(
"Could not find key file.",
translation_domain=DOMAIN,
Expand Down Expand Up @@ -80,7 +80,7 @@ async def async_setup(hass: HomeAssistant, _config: ConfigType) -> bool:
hass.data.setdefault(DOMAIN, {})

async def async_execute(service_call: ServiceCall) -> ServiceResponse:
await _validate_service_data(service_call.data)
await _validate_service_data(hass, service_call.data)
# ssh_command is a single-instance integration (enforced by single_instance_allowed
# in the config flow), so there is at most one coordinator in hass.data[DOMAIN].
coordinator = next(iter(hass.data.get(DOMAIN, {}).values()), None)
Expand Down
9 changes: 3 additions & 6 deletions coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from pathlib import Path
from typing import Any

from aiofiles import open as aioopen
from aiofiles.ospath import exists
from asyncssh import HostKeyNotVerifiable, PermissionDenied, connect, read_known_hosts

from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT
Expand Down Expand Up @@ -64,9 +62,8 @@ async def async_execute(self, data: dict[str, Any]) -> dict[str, Any]:
timeout = data.get(CONF_TIMEOUT, CONST_DEFAULT_TIMEOUT)

if input_data:
if await exists(input_data):
async with aioopen(input_data, 'r') as sf:
input_data = await sf.read()
if await self.hass.async_add_executor_job(Path(input_data).exists):
input_data = await self.hass.async_add_executor_job(Path(input_data).read_text)

conn_kwargs = {
CONF_HOST: host,
Expand Down Expand Up @@ -131,6 +128,6 @@ async def _resolve_known_hosts(self, check_known_hosts: bool, known_hosts: str |
return None
if not known_hosts:
known_hosts = str(Path("~", ".ssh", "known_hosts").expanduser())
if await exists(known_hosts):
if await self.hass.async_add_executor_job(Path(known_hosts).exists):
return await self.hass.async_add_executor_job(read_known_hosts, known_hosts)
return known_hosts
2 changes: 1 addition & 1 deletion manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"iot_class": "calculated",
"issue_tracker": "https://github.com/gensyn/ssh_command/issues",
"quality_scale": "bronze",
"requirements": ["asyncssh==2.22.0", "aiofiles==25.1.0"],
"requirements": ["asyncssh==2.22.0"],
"ssdp": [],
"version": "0.0.0",
"zeroconf": []
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
aiofiles==25.1.0
asyncssh==2.22.0
49 changes: 24 additions & 25 deletions tests/integration_tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ async def test_execute_returns_stdout(self, hass: HomeAssistant) -> None:
mock_conn = _make_mock_conn(stdout="hello\n", stderr="", exit_status=0)
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)):
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
result = await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
Expand All @@ -245,7 +245,7 @@ async def test_execute_returns_stderr(self, hass: HomeAssistant) -> None:
mock_conn = _make_mock_conn(stdout="", stderr="some error", exit_status=1)
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)):
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
result = await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
Expand All @@ -265,7 +265,7 @@ async def test_execute_with_password_auth(self, hass: HomeAssistant) -> None:
data = {**SERVICE_DATA_BASE, "password": "mysecret"}
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
Expand All @@ -291,15 +291,14 @@ async def test_execute_with_key_file_auth(self, hass: HomeAssistant) -> None:
}
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("custom_components.ssh_command.coordinator.exists", return_value=True):
with patch("custom_components.ssh_command.exists", return_value=True):
await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
data,
blocking=True,
return_response=True,
)
with patch("pathlib.Path.exists", return_value=True):
await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
data,
blocking=True,
return_response=True,
)

call_kwargs = mock_connect.call_args[1]
assert call_kwargs["client_keys"] == "/home/user/.ssh/id_rsa"
Expand All @@ -312,7 +311,7 @@ async def test_execute_with_inline_input(self, hass: HomeAssistant) -> None:
data = {**SERVICE_DATA_BASE, "input": "inline input"}
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)):
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
Expand All @@ -337,7 +336,7 @@ async def test_execute_with_input_file(self, hass: HomeAssistant) -> None:
data = {**SERVICE_DATA_BASE, "command": "cat", "input": tf_path}
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)):
with patch("custom_components.ssh_command.coordinator.exists", return_value=True):
with patch("pathlib.Path.exists", return_value=True):
await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
Expand All @@ -359,7 +358,7 @@ async def test_execute_with_custom_timeout(self, hass: HomeAssistant) -> None:
data = {**SERVICE_DATA_BASE, "timeout": 60}
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)):
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
Expand Down Expand Up @@ -387,7 +386,7 @@ async def test_check_known_hosts_false_passes_none(self, hass: HomeAssistant) ->
mock_conn = _make_mock_conn()
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
Expand All @@ -412,7 +411,7 @@ async def test_check_known_hosts_true_with_custom_file(self, hass: HomeAssistant
}
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("custom_components.ssh_command.coordinator.exists", return_value=True):
with patch("pathlib.Path.exists", return_value=True):
with patch("custom_components.ssh_command.coordinator.read_known_hosts",
return_value=mock_known_hosts) as mock_rkh:
await hass.services.async_call(
Expand Down Expand Up @@ -440,7 +439,7 @@ async def test_check_known_hosts_true_with_missing_file(self, hass: HomeAssistan
}
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
Expand All @@ -463,7 +462,7 @@ async def test_check_known_hosts_true_uses_default_path_when_missing(
data = {**SERVICE_DATA_BASE, "check_known_hosts": True}
with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
await hass.services.async_call(
DOMAIN,
SERVICE_EXECUTE,
Expand Down Expand Up @@ -521,7 +520,7 @@ async def test_key_file_not_found_raises(self, hass: HomeAssistant) -> None:
entry = _make_entry()
await _setup_entry(hass, entry)

with patch("custom_components.ssh_command.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
with pytest.raises(ServiceValidationError) as exc_info:
await hass.services.async_call(
DOMAIN,
Expand Down Expand Up @@ -592,7 +591,7 @@ async def test_host_key_not_verifiable(self, hass: HomeAssistant) -> None:

with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnectRaises(HostKeyNotVerifiable("test"))):
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
with pytest.raises(ServiceValidationError) as exc_info:
await hass.services.async_call(
DOMAIN,
Expand All @@ -610,7 +609,7 @@ async def test_permission_denied(self, hass: HomeAssistant) -> None:

with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnectRaises(PermissionDenied("auth failed"))):
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
with pytest.raises(ServiceValidationError) as exc_info:
await hass.services.async_call(
DOMAIN,
Expand All @@ -628,7 +627,7 @@ async def test_timeout(self, hass: HomeAssistant) -> None:

with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnectRaises(TimeoutError())):
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
with pytest.raises(ServiceValidationError) as exc_info:
await hass.services.async_call(
DOMAIN,
Expand All @@ -647,7 +646,7 @@ async def test_host_not_reachable(self, hass: HomeAssistant) -> None:

with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnectRaises(err)):
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
with pytest.raises(ServiceValidationError) as exc_info:
await hass.services.async_call(
DOMAIN,
Expand All @@ -666,7 +665,7 @@ async def test_other_oserror_is_reraised(self, hass: HomeAssistant) -> None:

with patch("custom_components.ssh_command.coordinator.connect",
return_value=_MockConnectRaises(err)):
with patch("custom_components.ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
with pytest.raises(OSError):
await hass.services.async_call(
DOMAIN,
Expand Down
41 changes: 16 additions & 25 deletions tests/unit_tests/test_async_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ async def test_success(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)):
with patch("ssh_command.coordinator.exists", return_value=False):
result = await self.handler(service_call)
result = await self.handler(service_call)

self.assertEqual(result[CONF_OUTPUT], "hello\n")
self.assertEqual(result[CONF_ERROR], "")
Expand All @@ -96,29 +95,26 @@ async def test_host_key_not_verifiable(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(HostKeyNotVerifiable("test"))):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)

self.assertEqual(ctx.exception.translation_key, "host_key_not_verifiable")

async def test_permission_denied(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(PermissionDenied("auth failed"))):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)

self.assertEqual(ctx.exception.translation_key, "login_failed")

async def test_timeout(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(TimeoutError())):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)

self.assertEqual(ctx.exception.translation_key, "connection_timed_out")

Expand All @@ -127,9 +123,8 @@ async def test_name_resolution_failure(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)

self.assertEqual(ctx.exception.translation_key, "host_not_reachable")

Expand All @@ -138,9 +133,8 @@ async def test_other_oserror_is_reraised(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(OSError):
await self.handler(service_call)
with self.assertRaises(OSError):
await self.handler(service_call)

async def test_input_from_file(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tf:
Expand All @@ -153,8 +147,7 @@ async def test_input_from_file(self):
service_call = self._make_service_call(data)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)):
with patch("ssh_command.coordinator.exists", return_value=True):
await self.handler(service_call)
await self.handler(service_call)

call_kwargs = mock_conn.run.call_args[1]
self.assertEqual(call_kwargs["input"], "file content\n")
Expand All @@ -167,8 +160,7 @@ async def test_input_string_not_file(self):
service_call = self._make_service_call(data)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)):
with patch("ssh_command.coordinator.exists", return_value=False):
await self.handler(service_call)
await self.handler(service_call)

call_kwargs = mock_conn.run.call_args[1]
self.assertEqual(call_kwargs["input"], "inline input")
Expand All @@ -178,8 +170,7 @@ async def test_check_known_hosts_false(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("ssh_command.coordinator.exists", return_value=False):
await self.handler(service_call)
await self.handler(service_call)

call_kwargs = mock_connect.call_args[1]
self.assertIsNone(call_kwargs["known_hosts"])
Expand All @@ -191,7 +182,7 @@ async def test_known_hosts_file_exists(self):
service_call = self._make_service_call(data)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("ssh_command.coordinator.exists", return_value=True):
with patch("pathlib.Path.exists", return_value=True):
with patch("ssh_command.coordinator.read_known_hosts", return_value=mock_known_hosts) as mock_rkh:
await self.handler(service_call)

Expand All @@ -205,7 +196,7 @@ async def test_check_known_hosts_default_path_missing(self):
service_call = self._make_service_call(data)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
await self.handler(service_call)

call_kwargs = mock_connect.call_args[1]
Expand Down
Loading
Loading