Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 20 additions & 87 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,20 @@

from __future__ import annotations

from pathlib import Path
from typing import Any

import voluptuous as vol
from aiofiles import open as aioopen
from aiofiles.ospath import exists
from asyncssh import HostKeyNotVerifiable, PermissionDenied, connect, read_known_hosts

from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT, CONF_ERROR
from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT
from homeassistant.core import HomeAssistant, ServiceCall, SupportsResponse, ServiceResponse
from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN, SERVICE_EXECUTE, CONF_KEY_FILE, CONF_INPUT, CONST_DEFAULT_TIMEOUT, \
CONF_CHECK_KNOWN_HOSTS, CONF_KNOWN_HOSTS, CONF_CLIENT_KEYS, CONF_CHECK, CONF_OUTPUT, CONF_EXIT_STATUS
CONF_CHECK_KNOWN_HOSTS, CONF_KNOWN_HOSTS
from .coordinator import SshCommandCoordinator

CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) # pylint: disable=invalid-name

Expand Down Expand Up @@ -77,90 +75,22 @@ async def _validate_service_data(data: dict[str, Any]) -> None:
)


async def _resolve_known_hosts(hass: HomeAssistant, check_known_hosts: bool, known_hosts: str | None) -> str | None:
"""Resolve the known_hosts value for the SSH connection."""
if not check_known_hosts:
return None
if not known_hosts:
known_hosts = str(Path('~', '.ssh', CONF_KNOWN_HOSTS).expanduser())
if await exists(known_hosts):
# open the known hosts file asynchronously, otherwise Home Assistant will complain about blocking I/O
return await hass.async_add_executor_job(read_known_hosts, known_hosts)
return known_hosts


async def async_setup(hass: HomeAssistant, _config: ConfigType) -> bool:
"""Set up the SSH Command integration."""
hass.data.setdefault(DOMAIN, {})

async def async_execute(service_call: ServiceCall) -> ServiceResponse:
await _validate_service_data(service_call.data)
host = service_call.data.get(CONF_HOST)
username = service_call.data.get(CONF_USERNAME)
password = service_call.data.get(CONF_PASSWORD)
key_file = service_call.data.get(CONF_KEY_FILE)
command = service_call.data.get(CONF_COMMAND)
input_data = service_call.data.get(CONF_INPUT)
check_known_hosts = service_call.data.get(CONF_CHECK_KNOWN_HOSTS, True)
known_hosts = service_call.data.get(CONF_KNOWN_HOSTS)
timeout = service_call.data.get(CONF_TIMEOUT, CONST_DEFAULT_TIMEOUT)

if input_data:
if await exists(input_data):
# input is a file path, read it and send content as input
async with aioopen(input_data, 'r') as sf:
input_data = await sf.read()

conn_kwargs = {
CONF_HOST: host,
CONF_USERNAME: username,
CONF_PASSWORD: password,
CONF_CLIENT_KEYS: key_file,
CONF_KNOWN_HOSTS: await _resolve_known_hosts(hass, check_known_hosts, known_hosts),
}

run_kwargs = {
CONF_COMMAND: command,
CONF_CHECK: False,
CONF_TIMEOUT: timeout,
}

if input_data:
run_kwargs[CONF_INPUT] = input_data

try:
async with connect(**conn_kwargs) as conn:
result = await conn.run(**run_kwargs)
except HostKeyNotVerifiable as exc:
raise ServiceValidationError(
"The host key could not be verified.",
translation_domain=DOMAIN,
translation_key="host_key_not_verifiable",
) from exc
except PermissionDenied as exc:
# ssh_command is a single-instance integration (enforced by single_instance_allowed
# in the config flow), so there is at most one coordinator in hass.data[DOMAIN].
coordinator = next(iter(hass.data.get(DOMAIN, {}).values()), None)
if coordinator is None:
raise ServiceValidationError(
"SSH login failed.",
"SSH Command integration is not set up.",
translation_domain=DOMAIN,
translation_key="login_failed",
) from exc
except TimeoutError as exc:
raise ServiceValidationError(
"Connection timed out.",
translation_domain=DOMAIN,
translation_key="connection_timed_out",
) from exc
except OSError as e:
if e.strerror == 'Temporary failure in name resolution':
raise ServiceValidationError(
"Host is not reachable.",
translation_domain=DOMAIN,
translation_key="host_not_reachable",
) from e
raise

return {
CONF_OUTPUT: result.stdout,
CONF_ERROR: result.stderr,
CONF_EXIT_STATUS: result.exit_status,
}
translation_key="integration_not_set_up",
)
return await coordinator.async_execute(service_call.data)

hass.services.async_register(
DOMAIN,
Expand All @@ -173,11 +103,14 @@ async def async_execute(service_call: ServiceCall) -> ServiceResponse:
return True


async def async_setup_entry(_hass: HomeAssistant, _entry: ConfigEntry) -> bool:
"""Set up SSH Command from a config entry. Nothing to do here."""
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up SSH Command from a config entry."""
coordinator = SshCommandCoordinator(hass)
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator
return True


async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry) -> bool:
"""Unload a config entry. Nothing to do here."""
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
hass.data.get(DOMAIN, {}).pop(entry.entry_id, None)
return True
131 changes: 131 additions & 0 deletions coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Coordinator for the SSH Command integration.

The SshCommandCoordinator is the single owner of all SSH I/O for the
integration. It encapsulates connection management and command execution
so that the service handler in __init__.py is a pure dispatcher.

Preferred HA pattern: coordinator (or "client") owns I/O; the service
handler validates input and delegates to this class.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

from aiofiles import open as aioopen
from aiofiles.ospath import exists
from asyncssh import HostKeyNotVerifiable, PermissionDenied, connect, read_known_hosts

from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ServiceValidationError

from .const import (
DOMAIN,
CONF_KEY_FILE,
CONF_INPUT,
CONF_CHECK_KNOWN_HOSTS,
CONF_KNOWN_HOSTS,
CONF_CLIENT_KEYS,
CONF_CHECK,
CONF_OUTPUT,
CONF_ERROR,
CONF_EXIT_STATUS,
CONST_DEFAULT_TIMEOUT,
)

_LOGGER = logging.getLogger(__name__)


class SshCommandCoordinator:
"""Single owner of all SSH I/O for the SSH Command integration.

Preferred HA pattern: coordinator owns I/O; the service handler in
__init__.py validates input and delegates execution to this class.
"""

def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the coordinator."""
self.hass = hass

async def async_execute(self, data: dict[str, Any]) -> dict[str, Any]:
"""Execute an SSH command and return stdout, stderr and exit status."""
host = data.get(CONF_HOST)
username = data.get(CONF_USERNAME)
password = data.get(CONF_PASSWORD)
key_file = data.get(CONF_KEY_FILE)
command = data.get(CONF_COMMAND)
input_data = data.get(CONF_INPUT)
check_known_hosts = data.get(CONF_CHECK_KNOWN_HOSTS, True)
known_hosts = data.get(CONF_KNOWN_HOSTS)
timeout = data.get(CONF_TIMEOUT, CONST_DEFAULT_TIMEOUT)

if input_data:
if await exists(input_data):
async with aioopen(input_data, 'r') as sf:
input_data = await sf.read()

conn_kwargs = {
CONF_HOST: host,
CONF_USERNAME: username,
CONF_PASSWORD: password,
CONF_CLIENT_KEYS: key_file,
CONF_KNOWN_HOSTS: await self._resolve_known_hosts(check_known_hosts, known_hosts),
}

run_kwargs: dict[str, Any] = {
CONF_COMMAND: command,
CONF_CHECK: False,
CONF_TIMEOUT: timeout,
}

if input_data:
run_kwargs[CONF_INPUT] = input_data

try:
async with connect(**conn_kwargs) as conn:
result = await conn.run(**run_kwargs)
except HostKeyNotVerifiable as exc:
raise ServiceValidationError(
"The host key could not be verified.",
translation_domain=DOMAIN,
translation_key="host_key_not_verifiable",
) from exc
except PermissionDenied as exc:
raise ServiceValidationError(
"SSH login failed.",
translation_domain=DOMAIN,
translation_key="login_failed",
) from exc
except TimeoutError as exc:
raise ServiceValidationError(
"Connection timed out.",
translation_domain=DOMAIN,
translation_key="connection_timed_out",
) from exc
except OSError as e:
if e.strerror == 'Temporary failure in name resolution':
raise ServiceValidationError(
"Host is not reachable.",
translation_domain=DOMAIN,
translation_key="host_not_reachable",
) from e
raise

return {
CONF_OUTPUT: result.stdout,
CONF_ERROR: result.stderr,
CONF_EXIT_STATUS: result.exit_status,
}

async def _resolve_known_hosts(self, check_known_hosts: bool, known_hosts: str | None) -> str | None:
"""Resolve the known_hosts value for the SSH connection."""
if not check_known_hosts:
return None
if not known_hosts:
known_hosts = str(Path('~', '.ssh', CONF_KNOWN_HOSTS).expanduser())
if await exists(known_hosts):
return await self.hass.async_add_executor_job(read_known_hosts, known_hosts)
return known_hosts
Loading
Loading