Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,6 @@ docs/manager/rest-reference/openapi.json

/DIST-INFO
/INSTALL-INFO

# Raft cluster config
raft-cluster-config.toml
1 change: 1 addition & 0 deletions changes/2105.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Raft-based leader election process to manager group in HA condition in order to make their states consistent.
10 changes: 9 additions & 1 deletion configs/manager/halfstack.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pool-pre-ping = false


[manager]
num-proc = 4
num-proc = 3
service-addr = { host = "0.0.0.0", port = 8081 }
#user = "nobody"
#group = "nobody"
Expand All @@ -32,9 +32,16 @@ disabled-plugins = []

hide-agents = true

global-timer = "raft"

# The order of agent selection.
agent-selection-resource-priority = ["cuda", "rocm", "tpu", "cpu", "mem"]

[raft]
heartbeat-tick = 3
election-tick = 10
log-dir = "./logs"

Comment on lines +40 to +44
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding raft backend is only enabled when this directive is specified at the configuration file. This kind of approach can be too implicit from system manager's perspective; Please refactor activation mechanism so that config writer can explicitly mention which timer backend to be used.

Copy link
Copy Markdown
Member Author

@jopemachine jopemachine May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the related codes in 5eb1e9f to reflect the feedback, and this issue is resolved in the commit.

In this commit, I created global-timer option to manager.toml (manager local config).

I think the global-timer option allows to specify whether to use raft or distributed-lock more explicitly.

May I ask for a review for the commit?

[docker-registry]
ssl-verify = false

Expand All @@ -48,6 +55,7 @@ drivers = ["console"]
"aiotools" = "INFO"
"aiohttp" = "INFO"
"ai.backend" = "INFO"
"ai.backend.manager.server.raft" = "INFO"
"alembic" = "INFO"
"sqlalchemy" = "WARNING"

Expand Down
3 changes: 3 additions & 0 deletions configs/manager/sample.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ agent-selection-resource-priority = ["cuda", "rocm", "tpu", "cpu", "mem"]
# compatibility issues.
# event-loop = "asyncio"

# One of: "raft", "distributed_lock"
global-timer = "raft"

# One of: "filelock", "pg_advisory", "redlock", "etcd"
# Choose the implementation of distributed lock.
# "filelock" is the simplest one when the manager is deployed on only one node.
Expand Down
767 changes: 404 additions & 363 deletions python.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ coloredlogs~=15.0
colorama>=0.4.4
cryptography>=2.8
dataclasses-json~=0.5.7
etcd-client-py==0.3.0
etcetra~=0.1.19
faker~=24.7.1
graphene~=3.3.0
Expand Down Expand Up @@ -59,6 +60,7 @@ PyYAML~=6.0
pydantic~=2.6.4
packaging>=21.3
hiredis>=2.2.3
raftify==0.1.67
redis[hiredis]==4.5.5
rich~=13.6
SQLAlchemy[postgresql_asyncpg]~=1.4.40
Expand Down Expand Up @@ -97,4 +99,3 @@ types-tabulate
backend.ai-krunner-alpine==5.2.0
backend.ai-krunner-static-gnu==4.2.0

etcd-client-py==0.3.0
113 changes: 90 additions & 23 deletions src/ai/backend/common/distributed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import abc
import asyncio
import logging
from typing import TYPE_CHECKING, Callable, Final

from aiomonitor.task import preserve_termination_log
from raftify import RaftNode

from ai.backend.logging import BraceStyleAdapter

Expand All @@ -16,32 +18,112 @@
log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class GlobalTimer:
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
"""

class AbstractGlobalTimer(metaclass=abc.ABCMeta):
_event_producer: Final[EventProducer]
_event_factory: Final[Callable[[], AbstractEvent]]
_stopped: bool
interval: float
initial_delay: float
task_name: str | None

def __init__(
self,
dist_lock: AbstractDistributedLock,
event_producer: EventProducer,
event_factory: Callable[[], AbstractEvent],
interval: float = 10.0,
initial_delay: float = 0.0,
*,
task_name: str | None = None,
) -> None:
self._dist_lock = dist_lock
self._event_producer = event_producer
self._event_factory = event_factory
self._stopped = False
self.interval = interval
self.initial_delay = initial_delay
self.task_name = task_name

async def join(self) -> None:
self._tick_task = asyncio.create_task(self.generate_tick())
if self.task_name is not None:
self._tick_task.set_name(self.task_name)

async def leave(self) -> None:
self._stopped = True
await asyncio.sleep(0)
if not self._tick_task.done():
try:
self._tick_task.cancel()
await self._tick_task
except asyncio.CancelledError:
pass

@abc.abstractmethod
async def generate_tick(self) -> None:
raise NotImplementedError


class RaftGlobalTimer(AbstractGlobalTimer):
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
"""

def __init__(
self,
raft_node: RaftNode,
event_producer: EventProducer,
event_factory: Callable[[], AbstractEvent],
interval: float = 10.0,
initial_delay: float = 0.0,
*,
task_name: str | None = None,
) -> None:
super().__init__(
event_producer, event_factory, interval, initial_delay, task_name=task_name
)
self.raft_node = raft_node

async def generate_tick(self) -> None:
try:
await asyncio.sleep(self.initial_delay)
if self._stopped:
return
while True:
try:
if self._stopped:
return
if await self.raft_node.is_leader():
await self._event_producer.produce_event(self._event_factory())
if self._stopped:
return
await asyncio.sleep(self.interval)
except asyncio.TimeoutError: # timeout raised from etcd lock
log.warn("timeout raised while trying to acquire lock. retrying...")
except asyncio.CancelledError:
pass


class DistributedLockGlobalTimer(AbstractGlobalTimer):
"""
Executes the given async function only once in the given interval,
uniquely among multiple manager instances across multiple nodes.
"""

def __init__(
self,
dist_lock: AbstractDistributedLock,
event_producer: EventProducer,
event_factory: Callable[[], AbstractEvent],
interval: float = 10.0,
initial_delay: float = 0.0,
*,
task_name: str | None = None,
) -> None:
super().__init__(
event_producer, event_factory, interval, initial_delay, task_name=task_name
)
self._dist_lock = dist_lock

@preserve_termination_log
async def generate_tick(self) -> None:
try:
Expand All @@ -63,18 +145,3 @@ async def generate_tick(self) -> None:
log.warning("timeout raised while trying to acquire lock. retrying...")
except asyncio.CancelledError:
pass

async def join(self) -> None:
self._tick_task = asyncio.create_task(self.generate_tick())
if self.task_name is not None:
self._tick_task.set_name(self.task_name)

async def leave(self) -> None:
self._stopped = True
await asyncio.sleep(0)
if not self._tick_task.done():
try:
self._tick_task.cancel()
await self._tick_task
except asyncio.CancelledError:
pass
29 changes: 28 additions & 1 deletion src/ai/backend/manager/api/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import enum
from typing import TYPE_CHECKING, Optional, cast

import attrs
from raftify import Raft

from ai.backend.logging.types import CIStrEnum

if TYPE_CHECKING:
from ai.backend.common.bgtask import BackgroundTaskManager
Expand All @@ -26,6 +30,27 @@ class BaseContext:
pass


class GlobalTimerKind(CIStrEnum):
RAFT = enum.auto()
DISTRIBUTED_LOCK = enum.auto()


class GlobalTimerContext:
timer_kind: GlobalTimerKind
_raft: Optional[Raft] = None

def __init__(self, timer_kind: GlobalTimerKind) -> None:
self.timer_kind = timer_kind

@property
def raft(self) -> Raft:
return cast(Raft, self._raft)

@raft.setter
def raft(self, rhs: Raft) -> None:
self._raft = rhs


@attrs.define(slots=True, auto_attribs=True, init=False)
class RootContext(BaseContext):
pidx: int
Expand All @@ -40,6 +65,7 @@ class RootContext(BaseContext):
redis_lock: RedisConnectionInfo
shared_config: SharedConfig
local_config: LocalConfig
raft_cluster_config: Optional[LocalConfig]
cors_options: CORSOptions

webapp_plugin_ctx: WebappPluginContext
Expand All @@ -53,3 +79,4 @@ class RootContext(BaseContext):
error_monitor: ErrorPluginContext
stats_monitor: StatsPluginContext
background_task_manager: BackgroundTaskManager
global_timer_ctx: GlobalTimerContext
39 changes: 29 additions & 10 deletions src/ai/backend/manager/api/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
from dateutil.relativedelta import relativedelta

from ai.backend.common import validators as tx
from ai.backend.common.distributed import GlobalTimer
from ai.backend.common.distributed import (
AbstractGlobalTimer,
DistributedLockGlobalTimer,
RaftGlobalTimer,
)
from ai.backend.common.events import AbstractEvent, EmptyEventArgs, EventHandler
from ai.backend.common.types import AgentId
from ai.backend.logging import BraceStyleAdapter, LogLevel
from ai.backend.manager.api.context import GlobalTimerKind

from ..defs import LockID
from ..models import UserRole, error_logs, groups
Expand Down Expand Up @@ -234,7 +239,7 @@ async def log_cleanup_task(app: web.Application, src: AgentId, event: DoLogClean

@attrs.define(slots=True, auto_attribs=True, init=False)
class PrivateContext:
log_cleanup_timer: GlobalTimer
log_cleanup_timer: AbstractGlobalTimer
log_cleanup_timer_evh: EventHandler[web.Application, DoLogCleanupEvent]


Expand All @@ -246,14 +251,28 @@ async def init(app: web.Application) -> None:
app,
log_cleanup_task,
)
app_ctx.log_cleanup_timer = GlobalTimer(
root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0),
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
task_name="log_cleanup_task",
)

match root_ctx.global_timer_ctx.timer_kind:
case GlobalTimerKind.RAFT:
app_ctx.log_cleanup_timer = RaftGlobalTimer(
root_ctx.global_timer_ctx.raft.get_raft_node(),
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
task_name="log_cleanup_task",
)
case GlobalTimerKind.DISTRIBUTED_LOCK:
app_ctx.log_cleanup_timer = DistributedLockGlobalTimer(
root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0),
root_ctx.event_producer,
lambda: DoLogCleanupEvent(),
20.0,
initial_delay=17.0,
task_name="log_cleanup_task",
)
case _:
assert False, f"Unknown global timer backend: {root_ctx.global_timer_ctx.timer_kind}"
await app_ctx.log_cleanup_timer.join()


Expand Down
Loading