Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docker/ci-scaler/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ FROM $BASE_IMAGE
ENV GH_TOKEN=""
ENV ASGS=""
ENV DOMAIN=""
ENV DYNAMODB_TABLE_PREFIX=""
ENV AWS_ENDPOINT_URL=""
ENV AWS_ACCESS_KEY_ID=""
ENV AWS_SECRET_ACCESS_KEY=""
ENV TZ=""

ENV DEBIAN_FRONTEND=noninteractive
Expand Down
6 changes: 6 additions & 0 deletions docker/ci-scaler/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ To use:
"{owner}/{repo}:{label}:{asg_name}"
- `DOMAIN`: domain of API Gateway which listens for GitHub webhook
requests via HTTPS and forwards all requests to this container's port 8088
- `DYNAMODB_TABLE_PREFIX`: if set, use DynamoDB tables to store the state
across webhook requests; useful when running multiple instances of
ci-scaler
- `AWS_ENDPOINT_URL`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`:
optionally, you may pass these variables to access AWS API; used in
debugging mostly
- `TZ` (optional): timezone name

Example for docker compose:
Expand Down
3 changes: 2 additions & 1 deletion docker/ci-scaler/guest/entrypoint.99-run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ set -u -e
if [[ "$ASGS" != "" ]]; then
exec python3 ./scaler/main.py \
--asgs="$ASGS" \
--domain="$DOMAIN"
--domain="$DOMAIN" \
--dynamodb-table-prefix="$DYNAMODB_TABLE_PREFIX"
else
exec sleep 1000000000
fi
12 changes: 11 additions & 1 deletion docker/ci-scaler/guest/scaler/api_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,19 @@ def aws(
input: str | None = None,
) -> str | None:
region = aws_region()
endpoint_url = os.environ.get("AWS_ENDPOINT_URL")
if args[0] == "dynamodb" and endpoint_url:
region = "us-east-1"
if not region:
return None
return check_output(["aws", f"--region={region}", *args], input=input)
cmd = [
"aws",
f"--region={region}",
*([f"--endpoint-url={endpoint_url}"] if endpoint_url else ()),
*args,
]
out = check_output(cmd, input=input)
return out


def aws_json(
Expand Down
8 changes: 5 additions & 3 deletions docker/ci-scaler/guest/scaler/handler_idle_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
AsgSpec,
Runner,
RunnersRegistry,
ExpiringDict,
logged_result,
)
from typing import Literal
from storage import StorageFactory

REVISIT_TERMINATED_INSTANCE_SEC = datetime.timedelta(minutes=10).total_seconds()

Expand All @@ -26,12 +25,15 @@ def __init__(
*,
asg_spec: AsgSpec,
max_idle_age_sec: int,
storage: StorageFactory,
):
super().__init__(asg_spec=asg_spec)
self.max_idle_age_sec = max_idle_age_sec
self.idle_runners = RunnersRegistry()
self.terminated_instance_ids = ExpiringDict[str, Literal[True]](
self.terminated_instance_ids = storage.create(
bool,
ttl=REVISIT_TERMINATED_INSTANCE_SEC,
name="terminated-instance-ids",
)

def handle(self, runners: list[Runner]) -> None:
Expand Down
160 changes: 101 additions & 59 deletions docker/ci-scaler/guest/scaler/handler_webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
aws_cloudwatch_put_metric_data,
)
from helpers import (
ExpiringDict,
PostJsonHttpRequestHandler,
AsgSpec,
log,
logged_result,
)
from typing import Any, Literal, cast
from storage import StorageFactory
from typing import Any, Literal
from zlib import crc32


DUPLICATED_EVENTS_TTL = 3600
Expand Down Expand Up @@ -55,23 +56,39 @@ class Webhook:
@dataclasses.dataclass
class JobTiming:
job_id: int
queued_at: float | None = None
started_at: float | None = None
completed_at: float | None = None
bumped: set[str] = dataclasses.field(default_factory=set)
queued_at: int | None = None
started_at: int | None = None
completed_at: int | None = None
bumped: list[str] = dataclasses.field(default_factory=list)


class HandlerWebhooks:
def __init__(self, *, domain: str, asg_specs: list[AsgSpec]):
def __init__(
self,
*,
domain: str,
asg_specs: list[AsgSpec],
storage: StorageFactory,
):
self.domain = domain
self.asg_specs = asg_specs
self.webhooks: dict[str, Webhook] = {}
self.secret = gh_get_webhook_secret()
self.duplicated_events = ExpiringDict[tuple[int, str], float](
ttl=DUPLICATED_EVENTS_TTL
self.duplicated_events = storage.create(
int,
ttl=DUPLICATED_EVENTS_TTL,
name="duplicated-events",
)
self.job_timings = storage.create(
JobTiming,
ttl=JOB_TIMING_TTL,
name="job-timings",
)
self.workflows = storage.create(
dict[str, Any],
ttl=WORKFLOW_TTL,
name="workflows",
)
self.job_timings = ExpiringDict[int, JobTiming](ttl=JOB_TIMING_TTL)
self.workflows = ExpiringDict[str, dict[str, Any]](ttl=WORKFLOW_TTL)
this = self

class RequestHandler(PostJsonHttpRequestHandler):
Expand Down Expand Up @@ -133,41 +150,52 @@ def handle(
data: dict[str, Any],
data_bytes: bytes,
):
action = data.get("action")
workflow_run = data.get(WORKFLOW_RUN_EVENT)
workflow_job = data.get(WORKFLOW_JOB_EVENT)

# For local debugging only! Allows to simulate a webhook with just
# querying an URL that includes the repo name and label:
# - /workflow_run/owner/repo/label
# - /workflow_job/owner/repo/label/{queued|in_progress|completed}/job_id
extra_debug_labels: dict[str, int] = {}
if (
handler.client_address[0] == "127.0.0.1"
and not action
and not workflow_run
and not workflow_job
and not data.get("action")
and not data.get(WORKFLOW_RUN_EVENT)
and not data.get(WORKFLOW_JOB_EVENT)
):
if match := re.match(
rf"^/{WORKFLOW_RUN_EVENT}/([^/]+/[^/]+)/([^/]+)/?$",
handler.path,
):
return self._handle_workflow_run_in_progress(
handler=handler,
repository=match.group(1),
labels={match.group(2): 1},
)
extra_debug_labels[match.group(2)] = 1
data = {
"action": "in_progress",
"repository": {
"full_name": match.group(1),
},
WORKFLOW_RUN_EVENT: {
"id": crc32(match.group(2).encode()),
"run_attempt": 1,
"name": "test",
"head_sha": "",
"path": "/.github/workflows/ci.yml",
},
}
elif match := re.match(
rf"^/{WORKFLOW_JOB_EVENT}/([^/]+/[^/]+)/([^/]+)/([^/]+)/([^/]+)/?$",
handler.path,
):
return self._handle_workflow_job_timing(
handler=handler,
repository=match.group(1),
labels={match.group(2): 1},
action=cast(Any, match.group(3)),
job_id=int(match.group(4)),
name=None,
)
data = {
"repository": {
"full_name": match.group(1),
},
WORKFLOW_JOB_EVENT: {
"id": int(match.group(4)),
"run_attempt": 1,
"name": "test",
"labels": [match.group(2)],
},
"action": match.group(3),
"job_id": int(match.group(4)),
}
else:
return handler.send_error(
404,
Expand All @@ -177,14 +205,28 @@ def handle(
+ f"/{WORKFLOW_JOB_EVENT}/owner/repo/label/{'{queued|in_progress|completed}'}/job_id"
+ f", but got {handler.path}",
)
else:
assert self.secret
error = verify_signature(
secret=self.secret,
headers=handler.headers,
data_bytes=data_bytes,
)
if error:
return handler.send_error(403, error)

action = data.get("action")
repository: str | None = data.get("repository", {}).get("full_name", None)
keys = [k for k in data.keys() if k not in IGNORE_KEYS]
workflow_run = data.get(WORKFLOW_RUN_EVENT)
workflow_job = data.get(WORKFLOW_JOB_EVENT)

name = (
str(workflow_run.get("name"))
if workflow_run
else str(workflow_job.get("name")) if workflow_job else None
)
keys = [k for k in data.keys() if k not in IGNORE_KEYS]

if keys:
handler.log_suffix = (
f"{{{','.join(keys)}}}"
Expand All @@ -198,29 +240,16 @@ def handle(
if not repository:
return handler.send_json(202, message="ignoring event with no repository")

assert self.secret
error = verify_signature(
secret=self.secret,
headers=handler.headers,
data_bytes=data_bytes,
)
if error:
return handler.send_error(403, error)

# This event is used for increasing the number of runners.
if workflow_run:
if action != "requested" and action != "in_progress":
return handler.send_json(
202,
message='ignoring action != ["requested", "in_progress"]',
)

event_key = (
int(workflow_run["id"]),
str(workflow_run["run_attempt"]),
)
handler.log_suffix += (
f" id={workflow_run['id']}:{workflow_run['run_attempt']}"
)
event_key = f"{workflow_run['id']}:{workflow_run['run_attempt']}"
handler.log_suffix += f" id={event_key}"
processed_at = self.duplicated_events.get(event_key)
if processed_at:
return handler.send_json(
Expand All @@ -243,7 +272,7 @@ def handle(
self.workflows[cache_key] = workflow
else:
message += f" (cached)"
labels = gh_predict_workflow_labels(
labels = extra_debug_labels | gh_predict_workflow_labels(
workflow=workflow,
known_labels=[
asg_spec.label
Expand All @@ -258,29 +287,31 @@ def handle(
except Exception as e:
return handler.send_error(500, f"{message} failed: {e}")

self.duplicated_events[event_key] = time.time()
self.duplicated_events[event_key] = int(time.time())
return self._handle_workflow_run_in_progress(
handler=handler,
repository=repository,
labels=labels,
)

# This event is only used for statistics about timing.
if workflow_job:
if action != "queued" and action != "in_progress" and action != "completed":
allowed_actions = ["queued", "in_progress", "completed"]
if action not in allowed_actions:
return handler.send_json(
202,
message='ignoring action != ["queued", "in_progress", "completed"]',
message=f"ignoring action != {allowed_actions}",
)

event_key = (int(workflow_job["id"]), action)
event_key = f"{workflow_job['id']}:{workflow_job['run_attempt']}:{action}"
processed_at = self.duplicated_events.get(event_key)
if processed_at:
return handler.send_json(
202,
message=f"ignoring event that has already been processed at {time.ctime(processed_at)}",
)

self.duplicated_events[event_key] = time.time()
self.duplicated_events[event_key] = int(time.time())
return self._handle_workflow_job_timing(
handler=handler,
repository=repository,
Expand All @@ -290,6 +321,7 @@ def handle(
name=name,
)

# Unrecognized event, skipping.
return handler.send_json(
202,
message=f"ignoring event with no {WORKFLOW_RUN_EVENT} and {WORKFLOW_JOB_EVENT}",
Expand Down Expand Up @@ -348,10 +380,11 @@ def _handle_workflow_job_timing(
message=f"ignoring event, since no matching auto-scaling group(s) found for repository {repository} and labels {[*labels.keys()]}",
)

timing = self.job_timings.get(job_id) or JobTiming(job_id=job_id)
self.job_timings[job_id] = timing
timing = self.job_timings.get(str(job_id))
if timing is None:
timing = JobTiming(job_id=job_id)

now = time.time()
now = int(time.time())
if action == "queued":
timing.queued_at = now
elif action == "in_progress":
Expand Down Expand Up @@ -383,7 +416,9 @@ def _handle_workflow_job_timing(

for metric in timing.bumped:
metrics.pop(metric, None)
timing.bumped.update(metrics.keys())
timing.bumped.extend(metrics.keys())

self.job_timings[str(job_id)] = timing

if metrics:
job_name = (
Expand Down Expand Up @@ -419,7 +454,14 @@ def _handle_workflow_job_timing(

return handler.send_json(
200,
message=f"processed event for job_id={job_id}: {asg_spec}",
message=(
f"logged timing event for job_id={job_id}: {asg_spec}, "
+ (
", ".join(f"{k}:{v}" for k, v in metrics.items())
if metrics
else "no metrics yet"
)
),
)


Expand Down
Loading
Loading