diff --git a/src/omniq/_ops.py b/src/omniq/_ops.py index 2e8b8fe..9059c62 100644 --- a/src/omniq/_ops.py +++ b/src/omniq/_ops.py @@ -2,7 +2,7 @@ import redis from dataclasses import dataclass -from typing import Optional, Any, List +from typing import Optional, Any, List, ClassVar from threading import Lock from .clock import now_ms @@ -14,7 +14,7 @@ @dataclass class OmniqOps: - _script_lock = Lock() + _script_lock: ClassVar[Lock] = Lock() r: RedisLike scripts: OmniqScripts @@ -29,7 +29,11 @@ def _evalsha_with_noscript_fallback( return self.r.evalsha(sha, numkeys, *keys_and_args) except redis.exceptions.NoScriptError: with self._script_lock: - return self.r.eval(src, numkeys, *keys_and_args) + try: + return self.r.evalsha(sha, numkeys, *keys_and_args) + except redis.exceptions.NoScriptError: + new_sha = self.r.script_load(src) + return self.r.evalsha(new_sha, numkeys, *keys_and_args) def publish( self, @@ -500,7 +504,6 @@ def child_ack(self, *, key: str, child_id: str) -> int: except Exception: return -1 - @staticmethod def paused_backoff_s(poll_interval_s: float) -> float: return max(0.25, float(poll_interval_s) * 10.0) diff --git a/src/omniq/client.py b/src/omniq/client.py index 70d7c0c..6a8d949 100644 --- a/src/omniq/client.py +++ b/src/omniq/client.py @@ -7,6 +7,19 @@ from .types import ReserveResult, AckFailResult from .helper import queue_base +def _safe_close_redis(r: Any) -> None: + if r is None: + return + try: + r.close() + return + except Exception: + pass + try: + r.connection_pool.disconnect() + except Exception: + pass + @dataclass class OmniqClient: _ops: OmniqOps @@ -23,7 +36,10 @@ def __init__( password: Optional[str] = None, ssl: bool = False, scripts_dir: Optional[str] = None, + client_name: Optional[str] = None, ): + self._owns_redis = redis is None + if redis is not None: r = redis else: @@ -38,6 +54,12 @@ def __init__( ssl=ssl, ) ) + + if client_name: + try: + r.client_setname(str(client_name)) + except Exception: + pass if scripts_dir is None: scripts_dir = default_scripts_dir() @@ -45,6 +67,18 @@ def __init__( self._ops = OmniqOps(r=r, scripts=scripts) + def close(self) -> None: + if not getattr(self, "_owns_redis", False): + return + r = getattr(self._ops, "r", None) + _safe_close_redis(r) + + def __enter__(self) -> "OmniqClient": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + @staticmethod def queue_base(queue_name: str) -> str: return queue_base(queue_name) diff --git a/src/omniq/consumer.py b/src/omniq/consumer.py index 3429e87..c946e24 100644 --- a/src/omniq/consumer.py +++ b/src/omniq/consumer.py @@ -31,25 +31,37 @@ def start_heartbeater( stop_evt = threading.Event() flags: Dict[str, bool] = {"lost": False} + def _lost(msg: str) -> bool: + msg_u = (msg or "").upper() + return ("NOT_ACTIVE" in msg_u) or ("TOKEN_MISMATCH" in msg_u) + def hb_loop(): try: client.heartbeat(queue=queue, job_id=job_id, lease_token=lease_token) except Exception as e: + if stop_evt.is_set(): + return msg = str(e) - if "NOT_ACTIVE" in msg or "TOKEN_MISMATCH" in msg: + if _lost(msg): flags["lost"] = True stop_evt.set() return + time.sleep(min(0.2, max(0.01, float(interval_s)))) - while not stop_evt.wait(interval_s): + while True: + if stop_evt.wait(interval_s): + return try: client.heartbeat(queue=queue, job_id=job_id, lease_token=lease_token) except Exception as e: + if stop_evt.is_set(): + return msg = str(e) - if "NOT_ACTIVE" in msg or "TOKEN_MISMATCH" in msg: + if _lost(msg): flags["lost"] = True stop_evt.set() return + time.sleep(min(0.2, max(0.01, float(interval_s)))) t = threading.Thread(target=hb_loop, daemon=True) t.start() @@ -93,37 +105,35 @@ def consume( ctrl = StopController(stop=False, sigint_count=0) - if stop_on_ctrl_c and threading.current_thread() is threading.main_thread(): - def on_sigterm(signum, _frame): - ctrl.stop = True - if verbose: - _safe_log(logger, f"[consume] SIGTERM received; stopping... queue={queue}") + prev_sigterm = None + prev_sigint = None - signal.signal(signal.SIGTERM, on_sigterm) + try: + if stop_on_ctrl_c and threading.current_thread() is threading.main_thread(): + def on_sigterm(signum, _frame): + ctrl.stop = True + if verbose: + _safe_log(logger, f"[consume] SIGTERM received; stopping... queue={queue}") - if drain: - prev = signal.getsignal(signal.SIGINT) + prev_sigterm = signal.getsignal(signal.SIGTERM) + signal.signal(signal.SIGTERM, on_sigterm) - def on_sigint(signum, frame): - ctrl.sigint_count += 1 - if ctrl.sigint_count >= 2: - if verbose: - _safe_log(logger, f"[consume] SIGINT x2; hard exit now. queue={queue}") - try: - signal.signal(signal.SIGINT, prev if prev else signal.SIG_DFL) - except Exception: - signal.signal(signal.SIGINT, signal.SIG_DFL) - raise KeyboardInterrupt + if drain: + prev_sigint = signal.getsignal(signal.SIGINT) - ctrl.stop = True - if verbose: - _safe_log(logger, f"[consume] Ctrl+C received; draining current job then exiting. queue={queue}") + def on_sigint(signum, frame): + ctrl.sigint_count += 1 + if ctrl.sigint_count >= 2: + if verbose: + _safe_log(logger, f"[consume] SIGINT x2; hard exit now. queue={queue}") + raise KeyboardInterrupt - signal.signal(signal.SIGINT, on_sigint) - else: - pass + ctrl.stop = True + if verbose: + _safe_log(logger, f"[consume] Ctrl+C received; draining current job then exiting. queue={queue}") + + signal.signal(signal.SIGINT, on_sigint) - try: while True: if ctrl.stop: if verbose: @@ -214,8 +224,6 @@ def on_sigint(signum, frame): try: handler(ctx) - hb.stop_evt.set() - if not hb.flags.get("lost", False): try: client.ack_success(queue=queue, job_id=res.job_id, lease_token=res.lease_token) @@ -228,12 +236,9 @@ def on_sigint(signum, frame): except KeyboardInterrupt: if verbose: _safe_log(logger, f"[consume] KeyboardInterrupt; exiting now. queue={queue}") - hb.stop_evt.set() return except Exception as e: - hb.stop_evt.set() - if not hb.flags.get("lost", False): try: err = f"{type(e).__name__}: {e}" @@ -255,7 +260,12 @@ def on_sigint(signum, frame): finally: try: - hb.thread.join(timeout=0.1) + hb.stop_evt.set() + except Exception: + pass + try: + join_timeout = max(0.2, min(2.0, hb_s * 1.5)) + hb.thread.join(timeout=join_timeout) except Exception: pass @@ -268,3 +278,21 @@ def on_sigint(signum, frame): if verbose: _safe_log(logger, f"[consume] KeyboardInterrupt (outer); exiting now. queue={queue}") return + + finally: + if stop_on_ctrl_c and threading.current_thread() is threading.main_thread(): + try: + if prev_sigterm is not None: + signal.signal(signal.SIGTERM, prev_sigterm) + except Exception: + pass + try: + if prev_sigint is not None: + signal.signal(signal.SIGINT, prev_sigint) + except Exception: + pass + + try: + client.close() + except Exception: + pass diff --git a/src/omniq/exec.py b/src/omniq/exec.py index 11f6ea2..5bb3df9 100644 --- a/src/omniq/exec.py +++ b/src/omniq/exec.py @@ -3,7 +3,6 @@ from .client import OmniqClient - @dataclass(frozen=True) class Exec: client: OmniqClient diff --git a/src/omniq/scripts.py b/src/omniq/scripts.py index 45b9296..f7830f5 100644 --- a/src/omniq/scripts.py +++ b/src/omniq/scripts.py @@ -1,5 +1,6 @@ import os from dataclasses import dataclass +from threading import Lock from typing import Protocol class ScriptLoader(Protocol): @@ -32,7 +33,15 @@ def default_scripts_dir() -> str: here = os.path.dirname(__file__) return os.path.join(here, "core", "scripts") +_scripts_cache: dict[str, OmniqScripts] = {} +_scripts_cache_lock = Lock() + def load_scripts(r: ScriptLoader, scripts_dir: str) -> OmniqScripts: + with _scripts_cache_lock: + cached = _scripts_cache.get(scripts_dir) + if cached is not None: + return cached + def load_one(name: str) -> ScriptDef: path = os.path.join(scripts_dir, name) with open(path, "r", encoding="utf-8") as f: @@ -40,7 +49,7 @@ def load_one(name: str) -> ScriptDef: sha = r.script_load(src) return ScriptDef(sha=sha, src=src) - return OmniqScripts( + scripts = OmniqScripts( enqueue=load_one("enqueue.lua"), reserve=load_one("reserve.lua"), ack_success=load_one("ack_success.lua"), @@ -57,3 +66,8 @@ def load_one(name: str) -> ScriptDef: childs_init=load_one("childs_init.lua"), child_ack=load_one("child_ack.lua"), ) + + with _scripts_cache_lock: + _scripts_cache[scripts_dir] = scripts + + return scripts diff --git a/src/omniq/transport.py b/src/omniq/transport.py index 94414e1..9717ee1 100644 --- a/src/omniq/transport.py +++ b/src/omniq/transport.py @@ -22,7 +22,7 @@ def zcard(self, key: str) -> int: ... def zrange(self, key: str, start: int, end: int) -> list[Optional[str]]: ... def get(self, key: str) -> Optional[str]: ... def hmget(self, key: str, *fields: str) -> list[Optional[str]]: ... - def zscore(self, key: str, member: str) -> list[Optional[str]]: ... + def zscore(self, key: str, member: str) -> Optional[float]: ... @dataclass(frozen=True) class RedisConnOpts: @@ -33,16 +33,31 @@ class RedisConnOpts: username: Optional[str] = None password: Optional[str] = None ssl: bool = False + socket_timeout: Optional[float] = None socket_connect_timeout: Optional[float] = None + max_connections: Optional[int] = None + health_check_interval: Optional[int] = 30 + socket_keepalive: bool = True + +def _safe_close(client: Any) -> None: + try: + client.close() + return + except Exception: + pass + try: + client.connection_pool.disconnect() + except Exception: + pass + def _looks_like_cluster_error(e: Exception) -> bool: msg = str(e).lower() - return ( "cluster support disabled" in msg or "cluster mode is not enabled" in msg - or "unknown command" in msg and "cluster" in msg + or ("unknown command" in msg and "cluster" in msg) or "this instance has cluster support disabled" in msg or "err this instance has cluster support disabled" in msg or "only (p)subscribe / (p)unsubscribe / ping / quit allowed in this context" in msg @@ -50,28 +65,46 @@ def _looks_like_cluster_error(e: Exception) -> bool: or "ask" in msg ) +def _common_kwargs(opts: RedisConnOpts) -> dict[str, Any]: + kw: dict[str, Any] = { + "decode_responses": True, + "ssl": bool(opts.ssl), + "username": opts.username, + "password": opts.password, + "socket_timeout": opts.socket_timeout, + "socket_connect_timeout": opts.socket_connect_timeout, + "socket_keepalive": bool(opts.socket_keepalive), + } + + if opts.max_connections is not None: + kw["max_connections"] = int(opts.max_connections) + if opts.health_check_interval is not None: + kw["health_check_interval"] = int(opts.health_check_interval) + + return {k: v for k, v in kw.items() if v is not None} + def build_redis_client(opts: RedisConnOpts) -> redis.Redis: + kw = _common_kwargs(opts) + if opts.redis_url: - return redis.Redis.from_url(opts.redis_url, decode_responses=True) + return redis.Redis.from_url(opts.redis_url, **kw) if not opts.host: raise ValueError("RedisConnOpts requires host (or redis_url)") if RedisCluster is not None: + rc = None try: rc = RedisCluster( host=opts.host, port=int(opts.port), - username=opts.username, - password=opts.password, - ssl=bool(opts.ssl), - socket_timeout=opts.socket_timeout, - socket_connect_timeout=opts.socket_connect_timeout, - decode_responses=True, + **kw, ) rc.ping() return rc except Exception as e: + if rc is not None: + _safe_close(rc) if _looks_like_cluster_error(e): pass else: @@ -81,11 +114,5 @@ def build_redis_client(opts: RedisConnOpts) -> redis.Redis: host=opts.host, port=int(opts.port), db=int(opts.db), - username=opts.username, - password=opts.password, - ssl=bool(opts.ssl), - socket_timeout=opts.socket_timeout, - socket_connect_timeout=opts.socket_connect_timeout, - decode_responses=True, + **kw, ) -