Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions tests/test_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2190,3 +2190,109 @@ async def test_ubisys_polled_em_keeps_polling_when_disabled(
# Polling task must still be running (no duplicate created)
assert entity._polling_task is not None
assert not entity._polling_task.done()


async def test_pollable_sensor_enable_non_idempotent_disable_leaves_orphan_poll_task(
zha_gateway: Gateway,
) -> None:
"""Test PollableSensor enable/disable lifecycle is idempotent."""
zigpy_device = elec_measurement_zigpy_device_mock(zha_gateway)
zha_device = await join_zigpy_device(zha_gateway, zigpy_device)
entity = get_entity(
zha_device,
platform=Platform.SENSOR,
exact_entity_type=sensor.PolledElectricalMeasurement,
)

assert entity._polling_task is not None

first_task: asyncio.Task | None = None
second_task: asyncio.Task | None = None

# Issue being validated:
# PollableSensor.enable() always calls maybe_start_polling() and does not guard
# against an existing polling task, so repeated enable() calls create extra tasks.
# PollableSensor.disable() then only cancels/removes self._polling_task (latest).
#
# Why this is a problem:
# A stale poll task can outlive disable(), continue background work, and leak task
# ownership because only the newest handle is tracked for cancellation.
try:
# Reset baseline from on_add() auto-started polling.
entity.disable()
await asyncio.sleep(0)

entity.enable()
first_task = entity._polling_task
assert first_task is not None
assert not first_task.done()

entity.enable()
second_task = entity._polling_task
assert second_task is not None
assert second_task is first_task
assert not second_task.done()

entity.disable()
await asyncio.sleep(0)

assert first_task.cancelled()
assert second_task.cancelled()
assert first_task not in entity._tracked_tasks
assert second_task not in entity._tracked_tasks
finally:
for task in (first_task, second_task):
if task is None:
continue
if task in entity._tracked_tasks:
entity._tracked_tasks.remove(task)
if not task.done():
task.cancel()
await asyncio.gather(
*(task for task in (first_task, second_task) if task is not None),
return_exceptions=True,
)


async def test_pollable_sensor_replaces_completed_polling_task(
zha_gateway: Gateway,
) -> None:
"""Test completed poll task handles are replaced cleanly."""
zigpy_device = elec_measurement_zigpy_device_mock(zha_gateway)
zha_device = await join_zigpy_device(zha_gateway, zigpy_device)
entity = get_entity(
zha_device,
platform=Platform.SENSOR,
exact_entity_type=sensor.PolledElectricalMeasurement,
)

# Reset baseline from on_add() auto-started polling.
entity.disable()
await asyncio.sleep(0)

completed_task = asyncio.create_task(asyncio.sleep(0))
await completed_task
entity._polling_task = completed_task
entity._tracked_tasks.append(completed_task)

# Issue being validated:
# maybe_start_polling() must remove a completed poll task from tracking before
# creating the next task handle.
#
# Why this is a problem:
# Completed task handles left in _tracked_tasks can accumulate stale lifecycle
# state and interfere with deterministic cleanup in disable/remove paths.
replacement_task: asyncio.Task | None = None
try:
entity.maybe_start_polling()
replacement_task = entity._polling_task

assert replacement_task is not None
assert replacement_task is not completed_task
assert completed_task not in entity._tracked_tasks
assert replacement_task in entity._tracked_tasks
finally:
entity.disable()
await asyncio.sleep(0)
if replacement_task and replacement_task in entity._tracked_tasks:
entity._tracked_tasks.remove(replacement_task)
36 changes: 23 additions & 13 deletions zha/application/platforms/sensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,18 +418,27 @@ def should_poll(self) -> bool:

def maybe_start_polling(self) -> None:
"""Start polling if necessary."""
if self.should_poll:
self._polling_task = self.device.gateway.async_create_background_task(
self._refresh(),
name=f"sensor_state_poller_{self.unique_id}_{self.__class__.__name__}",
eager_start=True,
untracked=True,
)
self._tracked_tasks.append(self._polling_task)
self.debug(
"started polling with refresh interval of %s",
getattr(self, "__polling_interval"),
)
if not self.should_poll:
return

if self._polling_task and not self._polling_task.done():
return

if self._polling_task and self._polling_task.done():
with contextlib.suppress(ValueError):
self._tracked_tasks.remove(self._polling_task)

self._polling_task = self.device.gateway.async_create_background_task(
self._refresh(),
name=f"sensor_state_poller_{self.unique_id}_{self.__class__.__name__}",
eager_start=True,
untracked=True,
)
self._tracked_tasks.append(self._polling_task)
self.debug(
"started polling with refresh interval of %s",
getattr(self, "__polling_interval"),
)

def enable(self) -> None:
"""Enable the entity."""
Expand All @@ -440,7 +449,8 @@ def disable(self) -> None:
"""Disable the entity."""
super().disable()
if self._polling_task:
self._tracked_tasks.remove(self._polling_task)
with contextlib.suppress(ValueError):
self._tracked_tasks.remove(self._polling_task)
self._polling_task.cancel()
self._polling_task = None

Expand Down
Loading