diff --git a/__init__.py b/__init__.py index 07dbc7f..61d9e4a 100644 --- a/__init__.py +++ b/__init__.py @@ -2,22 +2,20 @@ from __future__ import annotations -from pathlib import Path from typing import Any import voluptuous as vol -from aiofiles import open as aioopen from aiofiles.ospath import exists -from asyncssh import HostKeyNotVerifiable, PermissionDenied, connect, read_known_hosts from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT, CONF_ERROR +from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT from homeassistant.core import HomeAssistant, ServiceCall, SupportsResponse, ServiceResponse from homeassistant.exceptions import ServiceValidationError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import ConfigType from .const import DOMAIN, SERVICE_EXECUTE, CONF_KEY_FILE, CONF_INPUT, CONST_DEFAULT_TIMEOUT, \ - CONF_CHECK_KNOWN_HOSTS, CONF_KNOWN_HOSTS, CONF_CLIENT_KEYS, CONF_CHECK, CONF_OUTPUT, CONF_EXIT_STATUS + CONF_CHECK_KNOWN_HOSTS, CONF_KNOWN_HOSTS +from .coordinator import SshCommandCoordinator CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) # pylint: disable=invalid-name @@ -77,90 +75,22 @@ async def _validate_service_data(data: dict[str, Any]) -> None: ) -async def _resolve_known_hosts(hass: HomeAssistant, check_known_hosts: bool, known_hosts: str | None) -> str | None: - """Resolve the known_hosts value for the SSH connection.""" - if not check_known_hosts: - return None - if not known_hosts: - known_hosts = str(Path('~', '.ssh', CONF_KNOWN_HOSTS).expanduser()) - if await exists(known_hosts): - # open the known hosts file asynchronously, otherwise Home Assistant will complain about blocking I/O - return await hass.async_add_executor_job(read_known_hosts, known_hosts) - return known_hosts - - async def async_setup(hass: HomeAssistant, _config: ConfigType) -> bool: """Set up the SSH Command integration.""" + hass.data.setdefault(DOMAIN, {}) + async def async_execute(service_call: ServiceCall) -> ServiceResponse: await _validate_service_data(service_call.data) - host = service_call.data.get(CONF_HOST) - username = service_call.data.get(CONF_USERNAME) - password = service_call.data.get(CONF_PASSWORD) - key_file = service_call.data.get(CONF_KEY_FILE) - command = service_call.data.get(CONF_COMMAND) - input_data = service_call.data.get(CONF_INPUT) - check_known_hosts = service_call.data.get(CONF_CHECK_KNOWN_HOSTS, True) - known_hosts = service_call.data.get(CONF_KNOWN_HOSTS) - timeout = service_call.data.get(CONF_TIMEOUT, CONST_DEFAULT_TIMEOUT) - - if input_data: - if await exists(input_data): - # input is a file path, read it and send content as input - async with aioopen(input_data, 'r') as sf: - input_data = await sf.read() - - conn_kwargs = { - CONF_HOST: host, - CONF_USERNAME: username, - CONF_PASSWORD: password, - CONF_CLIENT_KEYS: key_file, - CONF_KNOWN_HOSTS: await _resolve_known_hosts(hass, check_known_hosts, known_hosts), - } - - run_kwargs = { - CONF_COMMAND: command, - CONF_CHECK: False, - CONF_TIMEOUT: timeout, - } - - if input_data: - run_kwargs[CONF_INPUT] = input_data - - try: - async with connect(**conn_kwargs) as conn: - result = await conn.run(**run_kwargs) - except HostKeyNotVerifiable as exc: - raise ServiceValidationError( - "The host key could not be verified.", - translation_domain=DOMAIN, - translation_key="host_key_not_verifiable", - ) from exc - except PermissionDenied as exc: + # ssh_command is a single-instance integration (enforced by single_instance_allowed + # in the config flow), so there is at most one coordinator in hass.data[DOMAIN]. + coordinator = next(iter(hass.data.get(DOMAIN, {}).values()), None) + if coordinator is None: raise ServiceValidationError( - "SSH login failed.", + "SSH Command integration is not set up.", translation_domain=DOMAIN, - translation_key="login_failed", - ) from exc - except TimeoutError as exc: - raise ServiceValidationError( - "Connection timed out.", - translation_domain=DOMAIN, - translation_key="connection_timed_out", - ) from exc - except OSError as e: - if e.strerror == 'Temporary failure in name resolution': - raise ServiceValidationError( - "Host is not reachable.", - translation_domain=DOMAIN, - translation_key="host_not_reachable", - ) from e - raise - - return { - CONF_OUTPUT: result.stdout, - CONF_ERROR: result.stderr, - CONF_EXIT_STATUS: result.exit_status, - } + translation_key="integration_not_set_up", + ) + return await coordinator.async_execute(service_call.data) hass.services.async_register( DOMAIN, @@ -173,11 +103,14 @@ async def async_execute(service_call: ServiceCall) -> ServiceResponse: return True -async def async_setup_entry(_hass: HomeAssistant, _entry: ConfigEntry) -> bool: - """Set up SSH Command from a config entry. Nothing to do here.""" +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up SSH Command from a config entry.""" + coordinator = SshCommandCoordinator(hass) + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator return True -async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry) -> bool: - """Unload a config entry. Nothing to do here.""" +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload a config entry.""" + hass.data.get(DOMAIN, {}).pop(entry.entry_id, None) return True diff --git a/coordinator.py b/coordinator.py new file mode 100644 index 0000000..3c64bb5 --- /dev/null +++ b/coordinator.py @@ -0,0 +1,131 @@ +"""Coordinator for the SSH Command integration. + +The SshCommandCoordinator is the single owner of all SSH I/O for the +integration. It encapsulates connection management and command execution +so that the service handler in __init__.py is a pure dispatcher. + +Preferred HA pattern: coordinator (or "client") owns I/O; the service +handler validates input and delegates to this class. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from aiofiles import open as aioopen +from aiofiles.ospath import exists +from asyncssh import HostKeyNotVerifiable, PermissionDenied, connect, read_known_hosts + +from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ServiceValidationError + +from .const import ( + DOMAIN, + CONF_KEY_FILE, + CONF_INPUT, + CONF_CHECK_KNOWN_HOSTS, + CONF_KNOWN_HOSTS, + CONF_CLIENT_KEYS, + CONF_CHECK, + CONF_OUTPUT, + CONF_ERROR, + CONF_EXIT_STATUS, + CONST_DEFAULT_TIMEOUT, +) + +_LOGGER = logging.getLogger(__name__) + + +class SshCommandCoordinator: + """Single owner of all SSH I/O for the SSH Command integration. + + Preferred HA pattern: coordinator owns I/O; the service handler in + __init__.py validates input and delegates execution to this class. + """ + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the coordinator.""" + self.hass = hass + + async def async_execute(self, data: dict[str, Any]) -> dict[str, Any]: + """Execute an SSH command and return stdout, stderr and exit status.""" + host = data.get(CONF_HOST) + username = data.get(CONF_USERNAME) + password = data.get(CONF_PASSWORD) + key_file = data.get(CONF_KEY_FILE) + command = data.get(CONF_COMMAND) + input_data = data.get(CONF_INPUT) + check_known_hosts = data.get(CONF_CHECK_KNOWN_HOSTS, True) + known_hosts = data.get(CONF_KNOWN_HOSTS) + timeout = data.get(CONF_TIMEOUT, CONST_DEFAULT_TIMEOUT) + + if input_data: + if await exists(input_data): + async with aioopen(input_data, 'r') as sf: + input_data = await sf.read() + + conn_kwargs = { + CONF_HOST: host, + CONF_USERNAME: username, + CONF_PASSWORD: password, + CONF_CLIENT_KEYS: key_file, + CONF_KNOWN_HOSTS: await self._resolve_known_hosts(check_known_hosts, known_hosts), + } + + run_kwargs: dict[str, Any] = { + CONF_COMMAND: command, + CONF_CHECK: False, + CONF_TIMEOUT: timeout, + } + + if input_data: + run_kwargs[CONF_INPUT] = input_data + + try: + async with connect(**conn_kwargs) as conn: + result = await conn.run(**run_kwargs) + except HostKeyNotVerifiable as exc: + raise ServiceValidationError( + "The host key could not be verified.", + translation_domain=DOMAIN, + translation_key="host_key_not_verifiable", + ) from exc + except PermissionDenied as exc: + raise ServiceValidationError( + "SSH login failed.", + translation_domain=DOMAIN, + translation_key="login_failed", + ) from exc + except TimeoutError as exc: + raise ServiceValidationError( + "Connection timed out.", + translation_domain=DOMAIN, + translation_key="connection_timed_out", + ) from exc + except OSError as e: + if e.strerror == 'Temporary failure in name resolution': + raise ServiceValidationError( + "Host is not reachable.", + translation_domain=DOMAIN, + translation_key="host_not_reachable", + ) from e + raise + + return { + CONF_OUTPUT: result.stdout, + CONF_ERROR: result.stderr, + CONF_EXIT_STATUS: result.exit_status, + } + + async def _resolve_known_hosts(self, check_known_hosts: bool, known_hosts: str | None) -> str | None: + """Resolve the known_hosts value for the SSH connection.""" + if not check_known_hosts: + return None + if not known_hosts: + known_hosts = str(Path('~', '.ssh', CONF_KNOWN_HOSTS).expanduser()) + if await exists(known_hosts): + return await self.hass.async_add_executor_job(read_known_hosts, known_hosts) + return known_hosts diff --git a/test/test_async_execute.py b/test/test_async_execute.py index f14ac32..890f410 100644 --- a/test/test_async_execute.py +++ b/test/test_async_execute.py @@ -15,7 +15,7 @@ from homeassistant.exceptions import ServiceValidationError -from ssh_command import async_setup +from ssh_command import async_setup, async_setup_entry from ssh_command.const import CONF_ERROR, CONF_EXIT_STATUS, CONF_OUTPUT SERVICE_DATA_BASE = { @@ -53,11 +53,15 @@ class TestAsyncExecute(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.mock_hass = MagicMock() + self.mock_hass.data = {} async def _executor_job(func, *args): return func(*args) self.mock_hass.async_add_executor_job = AsyncMock(side_effect=_executor_job) + mock_entry = MagicMock() + mock_entry.entry_id = "test_entry" + await async_setup_entry(self.mock_hass, mock_entry) await async_setup(self.mock_hass, {}) self.handler = self.mock_hass.services.async_register.call_args[0][2] @@ -79,8 +83,8 @@ async def test_success(self): mock_conn = self._make_mock_conn(stdout="hello\n", stderr="", exit_status=0) service_call = self._make_service_call(SERVICE_DATA_BASE) - with patch("ssh_command.connect", return_value=_MockConnect(mock_conn)): - with patch("ssh_command.exists", return_value=False): + with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)): + with patch("ssh_command.coordinator.exists", return_value=False): result = await self.handler(service_call) self.assertEqual(result[CONF_OUTPUT], "hello\n") @@ -90,8 +94,8 @@ async def test_success(self): async def test_host_key_not_verifiable(self): service_call = self._make_service_call(SERVICE_DATA_BASE) - with patch("ssh_command.connect", return_value=_MockConnectRaises(HostKeyNotVerifiable("test"))): - with patch("ssh_command.exists", return_value=False): + with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(HostKeyNotVerifiable("test"))): + with patch("ssh_command.coordinator.exists", return_value=False): with self.assertRaises(ServiceValidationError) as ctx: await self.handler(service_call) @@ -100,8 +104,8 @@ async def test_host_key_not_verifiable(self): async def test_permission_denied(self): service_call = self._make_service_call(SERVICE_DATA_BASE) - with patch("ssh_command.connect", return_value=_MockConnectRaises(PermissionDenied("auth failed"))): - with patch("ssh_command.exists", return_value=False): + with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(PermissionDenied("auth failed"))): + with patch("ssh_command.coordinator.exists", return_value=False): with self.assertRaises(ServiceValidationError) as ctx: await self.handler(service_call) @@ -110,8 +114,8 @@ async def test_permission_denied(self): async def test_timeout(self): service_call = self._make_service_call(SERVICE_DATA_BASE) - with patch("ssh_command.connect", return_value=_MockConnectRaises(TimeoutError())): - with patch("ssh_command.exists", return_value=False): + with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(TimeoutError())): + with patch("ssh_command.coordinator.exists", return_value=False): with self.assertRaises(ServiceValidationError) as ctx: await self.handler(service_call) @@ -122,8 +126,8 @@ async def test_name_resolution_failure(self): err.strerror = "Temporary failure in name resolution" service_call = self._make_service_call(SERVICE_DATA_BASE) - with patch("ssh_command.connect", return_value=_MockConnectRaises(err)): - with patch("ssh_command.exists", return_value=False): + with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)): + with patch("ssh_command.coordinator.exists", return_value=False): with self.assertRaises(ServiceValidationError) as ctx: await self.handler(service_call) @@ -134,8 +138,8 @@ async def test_other_oserror_is_reraised(self): err.strerror = "something else" service_call = self._make_service_call(SERVICE_DATA_BASE) - with patch("ssh_command.connect", return_value=_MockConnectRaises(err)): - with patch("ssh_command.exists", return_value=False): + with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)): + with patch("ssh_command.coordinator.exists", return_value=False): with self.assertRaises(OSError): await self.handler(service_call) @@ -149,8 +153,8 @@ async def test_input_from_file(self): data = {**SERVICE_DATA_BASE, "command": "cat", "input": tf_path} service_call = self._make_service_call(data) - with patch("ssh_command.connect", return_value=_MockConnect(mock_conn)): - with patch("ssh_command.exists", return_value=True): + with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)): + with patch("ssh_command.coordinator.exists", return_value=True): await self.handler(service_call) call_kwargs = mock_conn.run.call_args[1] @@ -163,8 +167,8 @@ async def test_input_string_not_file(self): data = {**SERVICE_DATA_BASE, "input": "inline input"} service_call = self._make_service_call(data) - with patch("ssh_command.connect", return_value=_MockConnect(mock_conn)): - with patch("ssh_command.exists", return_value=False): + with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)): + with patch("ssh_command.coordinator.exists", return_value=False): await self.handler(service_call) call_kwargs = mock_conn.run.call_args[1] @@ -174,8 +178,8 @@ async def test_check_known_hosts_false(self): mock_conn = self._make_mock_conn() service_call = self._make_service_call(SERVICE_DATA_BASE) - with patch("ssh_command.connect", return_value=_MockConnect(mock_conn)) as mock_connect: - with patch("ssh_command.exists", return_value=False): + with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect: + with patch("ssh_command.coordinator.exists", return_value=False): await self.handler(service_call) call_kwargs = mock_connect.call_args[1] @@ -187,9 +191,9 @@ async def test_known_hosts_file_exists(self): data = {**SERVICE_DATA_BASE, "check_known_hosts": True, "known_hosts": "/home/user/.ssh/known_hosts"} service_call = self._make_service_call(data) - with patch("ssh_command.connect", return_value=_MockConnect(mock_conn)) as mock_connect: - with patch("ssh_command.exists", return_value=True): - with patch("ssh_command.read_known_hosts", return_value=mock_known_hosts) as mock_rkh: + with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect: + with patch("ssh_command.coordinator.exists", return_value=True): + with patch("ssh_command.coordinator.read_known_hosts", return_value=mock_known_hosts) as mock_rkh: await self.handler(service_call) mock_rkh.assert_called_once_with("/home/user/.ssh/known_hosts") @@ -201,8 +205,8 @@ async def test_check_known_hosts_default_path_missing(self): data = {**SERVICE_DATA_BASE, "check_known_hosts": True} service_call = self._make_service_call(data) - with patch("ssh_command.connect", return_value=_MockConnect(mock_conn)) as mock_connect: - with patch("ssh_command.exists", return_value=False): + with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect: + with patch("ssh_command.coordinator.exists", return_value=False): await self.handler(service_call) call_kwargs = mock_connect.call_args[1] diff --git a/test/test_async_setup.py b/test/test_async_setup.py index 11e74c2..c044cfe 100644 --- a/test/test_async_setup.py +++ b/test/test_async_setup.py @@ -11,12 +11,14 @@ from ssh_command import async_setup, async_setup_entry, async_unload_entry from ssh_command.const import DOMAIN, SERVICE_EXECUTE +from ssh_command.coordinator import SshCommandCoordinator class TestAsyncSetup(unittest.IsolatedAsyncioTestCase): async def test_registers_service_and_returns_true(self): mock_hass = MagicMock() + mock_hass.data = {} result = await async_setup(mock_hass, {}) @@ -31,19 +33,59 @@ class TestAsyncSetupEntry(unittest.IsolatedAsyncioTestCase): async def test_returns_true(self): mock_hass = MagicMock() + mock_hass.data = {} mock_entry = MagicMock() + mock_entry.entry_id = "test_entry" result = await async_setup_entry(mock_hass, mock_entry) self.assertTrue(result) + async def test_creates_coordinator_in_hass_data(self): + mock_hass = MagicMock() + mock_hass.data = {} + mock_entry = MagicMock() + mock_entry.entry_id = "test_entry" + + await async_setup_entry(mock_hass, mock_entry) + + self.assertIn(DOMAIN, mock_hass.data) + self.assertIn("test_entry", mock_hass.data[DOMAIN]) + self.assertIsInstance(mock_hass.data[DOMAIN]["test_entry"], SshCommandCoordinator) + + async def test_coordinator_holds_hass_reference(self): + mock_hass = MagicMock() + mock_hass.data = {} + mock_entry = MagicMock() + mock_entry.entry_id = "test_entry" + + await async_setup_entry(mock_hass, mock_entry) + + coordinator = mock_hass.data[DOMAIN]["test_entry"] + self.assertIs(coordinator.hass, mock_hass) + class TestAsyncUnloadEntry(unittest.IsolatedAsyncioTestCase): async def test_returns_true(self): mock_hass = MagicMock() + mock_hass.data = {} mock_entry = MagicMock() + mock_entry.entry_id = "test_entry" result = await async_unload_entry(mock_hass, mock_entry) self.assertTrue(result) + + async def test_removes_coordinator_from_hass_data(self): + mock_hass = MagicMock() + mock_hass.data = {} + mock_entry = MagicMock() + mock_entry.entry_id = "test_entry" + + await async_setup_entry(mock_hass, mock_entry) + self.assertIn("test_entry", mock_hass.data[DOMAIN]) + + await async_unload_entry(mock_hass, mock_entry) + + self.assertNotIn("test_entry", mock_hass.data[DOMAIN]) diff --git a/test/test_coordinator.py b/test/test_coordinator.py new file mode 100644 index 0000000..afe3052 --- /dev/null +++ b/test/test_coordinator.py @@ -0,0 +1,151 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +absolute_mock_path = str(Path(__file__).parent / "homeassistant_mock") +sys.path.insert(0, absolute_mock_path) + +absolute_plugin_path = str(Path(__file__).parent.parent.parent.absolute()) +sys.path.insert(0, absolute_plugin_path) + +from asyncssh import HostKeyNotVerifiable, PermissionDenied + +from homeassistant.exceptions import ServiceValidationError + +from ssh_command.coordinator import SshCommandCoordinator +from ssh_command.const import CONF_OUTPUT, CONF_ERROR, CONF_EXIT_STATUS + +EXECUTE_DATA_BASE = { + "host": "192.0.2.1", + "username": "user", + "password": "secret", + "command": "echo hello", + "check_known_hosts": False, +} + + +class _MockConnect: + def __init__(self, conn): + self._conn = conn + + async def __aenter__(self): + return self._conn + + async def __aexit__(self, *args): + return None + + +class _MockConnectRaises: + def __init__(self, exc): + self._exc = exc + + async def __aenter__(self): + raise self._exc + + async def __aexit__(self, *args): + return None + + +class TestSshCommandCoordinator(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + self.mock_hass = MagicMock() + + async def _executor_job(func, *args): + return func(*args) + + self.mock_hass.async_add_executor_job = AsyncMock(side_effect=_executor_job) + self.coordinator = SshCommandCoordinator(self.mock_hass) + + def _make_mock_conn(self, stdout="", stderr="", exit_status=0): + mock_result = MagicMock() + mock_result.stdout = stdout + mock_result.stderr = stderr + mock_result.exit_status = exit_status + mock_conn = AsyncMock() + mock_conn.run = AsyncMock(return_value=mock_result) + return mock_conn + + async def test_async_execute_success(self): + mock_conn = self._make_mock_conn(stdout="hello\n", stderr="", exit_status=0) + + with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)): + with patch("ssh_command.coordinator.exists", return_value=False): + result = await self.coordinator.async_execute(EXECUTE_DATA_BASE) + + self.assertEqual(result[CONF_OUTPUT], "hello\n") + self.assertEqual(result[CONF_ERROR], "") + self.assertEqual(result[CONF_EXIT_STATUS], 0) + + async def test_async_execute_host_key_not_verifiable(self): + with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(HostKeyNotVerifiable("test"))): + with patch("ssh_command.coordinator.exists", return_value=False): + with self.assertRaises(ServiceValidationError) as ctx: + await self.coordinator.async_execute(EXECUTE_DATA_BASE) + + self.assertEqual(ctx.exception.translation_key, "host_key_not_verifiable") + + async def test_async_execute_permission_denied(self): + with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(PermissionDenied("auth failed"))): + with patch("ssh_command.coordinator.exists", return_value=False): + with self.assertRaises(ServiceValidationError) as ctx: + await self.coordinator.async_execute(EXECUTE_DATA_BASE) + + self.assertEqual(ctx.exception.translation_key, "login_failed") + + async def test_async_execute_timeout(self): + with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(TimeoutError())): + with patch("ssh_command.coordinator.exists", return_value=False): + with self.assertRaises(ServiceValidationError) as ctx: + await self.coordinator.async_execute(EXECUTE_DATA_BASE) + + self.assertEqual(ctx.exception.translation_key, "connection_timed_out") + + async def test_async_execute_name_resolution_failure(self): + err = OSError() + err.strerror = "Temporary failure in name resolution" + + with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)): + with patch("ssh_command.coordinator.exists", return_value=False): + with self.assertRaises(ServiceValidationError) as ctx: + await self.coordinator.async_execute(EXECUTE_DATA_BASE) + + self.assertEqual(ctx.exception.translation_key, "host_not_reachable") + + async def test_async_execute_other_oserror_reraised(self): + err = OSError("something else") + err.strerror = "something else" + + with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)): + with patch("ssh_command.coordinator.exists", return_value=False): + with self.assertRaises(OSError): + await self.coordinator.async_execute(EXECUTE_DATA_BASE) + + async def test_resolve_known_hosts_check_disabled(self): + result = await self.coordinator._resolve_known_hosts(False, None) + self.assertIsNone(result) + + async def test_resolve_known_hosts_file_exists(self): + mock_known_hosts = MagicMock() + + with patch("ssh_command.coordinator.exists", return_value=True): + with patch("ssh_command.coordinator.read_known_hosts", return_value=mock_known_hosts) as mock_rkh: + result = await self.coordinator._resolve_known_hosts(True, "/home/user/.ssh/known_hosts") + + mock_rkh.assert_called_once_with("/home/user/.ssh/known_hosts") + self.assertIs(result, mock_known_hosts) + + async def test_resolve_known_hosts_file_missing(self): + with patch("ssh_command.coordinator.exists", return_value=False): + result = await self.coordinator._resolve_known_hosts(True, "/nonexistent/known_hosts") + + self.assertEqual(result, "/nonexistent/known_hosts") + + async def test_resolve_known_hosts_default_path(self): + with patch("ssh_command.coordinator.exists", return_value=False): + result = await self.coordinator._resolve_known_hosts(True, None) + + self.assertIsInstance(result, str) + self.assertIn(".ssh", result) + self.assertIn("known_hosts", result)