Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/omniq/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import redis

from dataclasses import dataclass
from typing import Optional, Any, List
from typing import Optional, Any, List, ClassVar
from threading import Lock

from .clock import now_ms
Expand All @@ -14,7 +14,7 @@

@dataclass
class OmniqOps:
_script_lock = Lock()
_script_lock: ClassVar[Lock] = Lock()
r: RedisLike
scripts: OmniqScripts

Expand All @@ -29,7 +29,11 @@ def _evalsha_with_noscript_fallback(
return self.r.evalsha(sha, numkeys, *keys_and_args)
except redis.exceptions.NoScriptError:
with self._script_lock:
return self.r.eval(src, numkeys, *keys_and_args)
try:
return self.r.evalsha(sha, numkeys, *keys_and_args)
except redis.exceptions.NoScriptError:
new_sha = self.r.script_load(src)
return self.r.evalsha(new_sha, numkeys, *keys_and_args)

def publish(
self,
Expand Down Expand Up @@ -500,7 +504,6 @@ def child_ack(self, *, key: str, child_id: str) -> int:
except Exception:
return -1


@staticmethod
def paused_backoff_s(poll_interval_s: float) -> float:
return max(0.25, float(poll_interval_s) * 10.0)
Expand Down
34 changes: 34 additions & 0 deletions src/omniq/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@
from .types import ReserveResult, AckFailResult
from .helper import queue_base

def _safe_close_redis(r: Any) -> None:
if r is None:
return
try:
r.close()
return
except Exception:
pass
try:
r.connection_pool.disconnect()
except Exception:
pass

@dataclass
class OmniqClient:
_ops: OmniqOps
Expand All @@ -23,7 +36,10 @@ def __init__(
password: Optional[str] = None,
ssl: bool = False,
scripts_dir: Optional[str] = None,
client_name: Optional[str] = None,
):
self._owns_redis = redis is None

if redis is not None:
r = redis
else:
Expand All @@ -38,13 +54,31 @@ def __init__(
ssl=ssl,
)
)

if client_name:
try:
r.client_setname(str(client_name))
except Exception:
pass

if scripts_dir is None:
scripts_dir = default_scripts_dir()
scripts = load_scripts(r, scripts_dir)

self._ops = OmniqOps(r=r, scripts=scripts)

def close(self) -> None:
if not getattr(self, "_owns_redis", False):
return
r = getattr(self._ops, "r", None)
_safe_close_redis(r)

def __enter__(self) -> "OmniqClient":
return self

def __exit__(self, exc_type, exc, tb) -> None:
self.close()

@staticmethod
def queue_base(queue_name: str) -> str:
return queue_base(queue_name)
Expand Down
96 changes: 62 additions & 34 deletions src/omniq/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,37 @@ def start_heartbeater(
stop_evt = threading.Event()
flags: Dict[str, bool] = {"lost": False}

def _lost(msg: str) -> bool:
msg_u = (msg or "").upper()
return ("NOT_ACTIVE" in msg_u) or ("TOKEN_MISMATCH" in msg_u)

def hb_loop():
try:
client.heartbeat(queue=queue, job_id=job_id, lease_token=lease_token)
except Exception as e:
if stop_evt.is_set():
return
msg = str(e)
if "NOT_ACTIVE" in msg or "TOKEN_MISMATCH" in msg:
if _lost(msg):
flags["lost"] = True
stop_evt.set()
return
time.sleep(min(0.2, max(0.01, float(interval_s))))

while not stop_evt.wait(interval_s):
while True:
if stop_evt.wait(interval_s):
return
try:
client.heartbeat(queue=queue, job_id=job_id, lease_token=lease_token)
except Exception as e:
if stop_evt.is_set():
return
msg = str(e)
if "NOT_ACTIVE" in msg or "TOKEN_MISMATCH" in msg:
if _lost(msg):
flags["lost"] = True
stop_evt.set()
return
time.sleep(min(0.2, max(0.01, float(interval_s))))

t = threading.Thread(target=hb_loop, daemon=True)
t.start()
Expand Down Expand Up @@ -93,37 +105,35 @@ def consume(

ctrl = StopController(stop=False, sigint_count=0)

if stop_on_ctrl_c and threading.current_thread() is threading.main_thread():
def on_sigterm(signum, _frame):
ctrl.stop = True
if verbose:
_safe_log(logger, f"[consume] SIGTERM received; stopping... queue={queue}")
prev_sigterm = None
prev_sigint = None

signal.signal(signal.SIGTERM, on_sigterm)
try:
if stop_on_ctrl_c and threading.current_thread() is threading.main_thread():
def on_sigterm(signum, _frame):
ctrl.stop = True
if verbose:
_safe_log(logger, f"[consume] SIGTERM received; stopping... queue={queue}")

if drain:
prev = signal.getsignal(signal.SIGINT)
prev_sigterm = signal.getsignal(signal.SIGTERM)
signal.signal(signal.SIGTERM, on_sigterm)

def on_sigint(signum, frame):
ctrl.sigint_count += 1
if ctrl.sigint_count >= 2:
if verbose:
_safe_log(logger, f"[consume] SIGINT x2; hard exit now. queue={queue}")
try:
signal.signal(signal.SIGINT, prev if prev else signal.SIG_DFL)
except Exception:
signal.signal(signal.SIGINT, signal.SIG_DFL)
raise KeyboardInterrupt
if drain:
prev_sigint = signal.getsignal(signal.SIGINT)

ctrl.stop = True
if verbose:
_safe_log(logger, f"[consume] Ctrl+C received; draining current job then exiting. queue={queue}")
def on_sigint(signum, frame):
ctrl.sigint_count += 1
if ctrl.sigint_count >= 2:
if verbose:
_safe_log(logger, f"[consume] SIGINT x2; hard exit now. queue={queue}")
raise KeyboardInterrupt

signal.signal(signal.SIGINT, on_sigint)
else:
pass
ctrl.stop = True
if verbose:
_safe_log(logger, f"[consume] Ctrl+C received; draining current job then exiting. queue={queue}")

signal.signal(signal.SIGINT, on_sigint)

try:
while True:
if ctrl.stop:
if verbose:
Expand Down Expand Up @@ -214,8 +224,6 @@ def on_sigint(signum, frame):
try:
handler(ctx)

hb.stop_evt.set()

if not hb.flags.get("lost", False):
try:
client.ack_success(queue=queue, job_id=res.job_id, lease_token=res.lease_token)
Expand All @@ -228,12 +236,9 @@ def on_sigint(signum, frame):
except KeyboardInterrupt:
if verbose:
_safe_log(logger, f"[consume] KeyboardInterrupt; exiting now. queue={queue}")
hb.stop_evt.set()
return

except Exception as e:
hb.stop_evt.set()

if not hb.flags.get("lost", False):
try:
err = f"{type(e).__name__}: {e}"
Expand All @@ -255,7 +260,12 @@ def on_sigint(signum, frame):

finally:
try:
hb.thread.join(timeout=0.1)
hb.stop_evt.set()
except Exception:
pass
try:
join_timeout = max(0.2, min(2.0, hb_s * 1.5))
hb.thread.join(timeout=join_timeout)
except Exception:
pass

Expand All @@ -268,3 +278,21 @@ def on_sigint(signum, frame):
if verbose:
_safe_log(logger, f"[consume] KeyboardInterrupt (outer); exiting now. queue={queue}")
return

finally:
if stop_on_ctrl_c and threading.current_thread() is threading.main_thread():
try:
if prev_sigterm is not None:
signal.signal(signal.SIGTERM, prev_sigterm)
except Exception:
pass
try:
if prev_sigint is not None:
signal.signal(signal.SIGINT, prev_sigint)
except Exception:
pass

try:
client.close()
except Exception:
pass
1 change: 0 additions & 1 deletion src/omniq/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from .client import OmniqClient


@dataclass(frozen=True)
class Exec:
client: OmniqClient
Expand Down
16 changes: 15 additions & 1 deletion src/omniq/scripts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from dataclasses import dataclass
from threading import Lock
from typing import Protocol

class ScriptLoader(Protocol):
Expand Down Expand Up @@ -32,15 +33,23 @@ def default_scripts_dir() -> str:
here = os.path.dirname(__file__)
return os.path.join(here, "core", "scripts")

_scripts_cache: dict[str, OmniqScripts] = {}
_scripts_cache_lock = Lock()

def load_scripts(r: ScriptLoader, scripts_dir: str) -> OmniqScripts:
with _scripts_cache_lock:
cached = _scripts_cache.get(scripts_dir)
if cached is not None:
return cached

def load_one(name: str) -> ScriptDef:
path = os.path.join(scripts_dir, name)
with open(path, "r", encoding="utf-8") as f:
src = f.read()
sha = r.script_load(src)
return ScriptDef(sha=sha, src=src)

return OmniqScripts(
scripts = OmniqScripts(
enqueue=load_one("enqueue.lua"),
reserve=load_one("reserve.lua"),
ack_success=load_one("ack_success.lua"),
Expand All @@ -57,3 +66,8 @@ def load_one(name: str) -> ScriptDef:
childs_init=load_one("childs_init.lua"),
child_ack=load_one("child_ack.lua"),
)

with _scripts_cache_lock:
_scripts_cache[scripts_dir] = scripts

return scripts
Loading