Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion fairy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,5 @@ async fn run(cfg: AppConfig) -> Result<()> {
error!("{e}");
tokio::time::sleep(Duration::from_secs(5)).await;
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
3 changes: 1 addition & 2 deletions vicky/src/bin/vicky/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use rocket::{Request, http::Status, response::Responder};
use thiserror::Error;
use tokio::sync::broadcast::error::SendError;
use vickylib::errors::VickyError;

use crate::events::GlobalEvent;
use vickylib::vicky::events::GlobalEvent;

#[derive(Error, Debug)]
pub enum AppError {
Expand Down
9 changes: 1 addition & 8 deletions vicky/src/bin/vicky/events.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
use rocket::response::stream::{Event, EventStream};
use rocket::{State, get};
use serde::{Deserialize, Serialize};
use std::time;
use tokio::sync::broadcast::{self, error::TryRecvError};

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum GlobalEvent {
TaskAdd,
TaskUpdate { uuid: uuid::Uuid },
}
use vickylib::vicky::events::GlobalEvent;

#[get("/")]
pub fn get_global_events(
Expand Down
17 changes: 15 additions & 2 deletions vicky/src/bin/vicky/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::config::{Config, OIDCConfigResolved, build_rocket_config};
use crate::events::{GlobalEvent, get_global_events};
use crate::events::get_global_events;
use crate::locks::{
locks_get_active, locks_get_detailed_poisoned, locks_get_poisoned, locks_unlock,
};
Expand All @@ -17,6 +17,7 @@ use log::{LevelFilter, error, info, trace, warn};
use rocket::fairing::AdHoc;
use rocket::{Build, Ignite, Rocket, routes};
use snafu::ResultExt;
use std::sync::Arc;
use std::time::Duration;
use tokio::select;
use tokio::sync::broadcast;
Expand All @@ -25,6 +26,8 @@ use vickylib::database::entities::Database;
use vickylib::database::entities::task::HEARTBEAT_TIMEOUT_SEC;
use vickylib::logs::LogDrain;
use vickylib::s3::client::S3Client;
use vickylib::vicky::events::GlobalEvent;
use vickylib::vicky::scheduler::Scheduler;

mod auth;
mod config;
Expand Down Expand Up @@ -108,14 +111,17 @@ async fn inner_main() -> Result<()> {

let (tx_global_events, _rx_task_events) = broadcast::channel::<GlobalEvent>(5);

let scheduler = Scheduler::new();

let web_server = build_web_api(
app_config,
build_rocket,
oidc_config_resolved,
jwks_verifier,
s3_log_bucket_client,
log_drain,
tx_global_events,
tx_global_events.clone(),
scheduler.clone(),
)
.await?;

Expand All @@ -126,6 +132,9 @@ async fn inner_main() -> Result<()> {
let web_task =
tokio::task::spawn(async move { web_server.launch().await.context(startup::LaunchErr) });

let scheduler_join_handle =
tokio::task::spawn(scheduler.run(tx_global_events, db_pool.clone()));

let task_timeout_sweeper = tokio::task::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(10));

Expand Down Expand Up @@ -153,11 +162,13 @@ async fn inner_main() -> Result<()> {
select! {
e = web_task => e.map(|_| ()).context(startup::JoinErr)?,
_ = task_timeout_sweeper => panic!("Task timeout sweeper shouldn't exit"),
_ = scheduler_join_handle => panic!("Scheduler shouldn't exit"),
}

Ok(())
}

#[allow(clippy::too_many_arguments)]
async fn build_web_api(
app_config: Config,
build_rocket: Rocket<Build>,
Expand All @@ -166,6 +177,7 @@ async fn build_web_api(
s3_log_bucket_client: S3Client,
log_drain: LogDrain,
tx_global_events: Sender<GlobalEvent>,
scheduler: Arc<Scheduler>,
) -> Result<Rocket<Ignite>> {
info!("starting web api");

Expand All @@ -176,6 +188,7 @@ async fn build_web_api(
.manage(tx_global_events)
.manage(app_config.web_config)
.manage(oidc_config_resolved)
.manage(scheduler)
.attach(Database::fairing())
.attach(AdHoc::config::<Config>())
.attach(AdHoc::try_on_ignite(
Expand Down
19 changes: 11 additions & 8 deletions vicky/src/bin/vicky/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use rocket::http::Status;
use rocket::response::stream::{Event, EventStream};
use rocket::{State, get, post, serde::json::Json};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::sync::Arc;
use std::time;
use tokio::sync::broadcast::{self, error::TryRecvError};
use uuid::Uuid;
Expand All @@ -12,7 +14,7 @@ use vickylib::database::entities::task::{FlakeRef, TaskResult, TaskStatus};
use vickylib::database::entities::{Database, Lock, Task};
use vickylib::query::FilterParams;
use vickylib::{
errors::VickyError, logs::LogDrain, s3::client::S3Client, vicky::scheduler::Scheduler,
logs::LogDrain, s3::client::S3Client, vicky::events::GlobalEvent, vicky::scheduler::Scheduler,
};

macro_rules! task_or {
Expand All @@ -34,7 +36,6 @@ use crate::auth::AnyAuthGuard;
use crate::{
auth::{MachineGuard, UserGuard},
errors::AppError,
events::GlobalEvent,
};

#[derive(Debug, PartialEq, Serialize, Deserialize)]
Expand All @@ -44,7 +45,7 @@ pub struct RoTaskNew {
display_name: String,
flake_ref: FlakeRef,
locks: Vec<Lock>,
features: Vec<String>,
features: HashSet<String>,
group: Option<String>,
}

Expand Down Expand Up @@ -249,15 +250,17 @@ pub async fn tasks_put_logs(
#[post("/claim", format = "json", data = "<features>")]
pub async fn tasks_claim(
db: Database,
scheduler: &State<Arc<Scheduler>>,
features: Json<RoTaskClaim>,
global_events: &State<broadcast::Sender<GlobalEvent>>,
_machine: MachineGuard,
) -> Result<Json<Option<Task>>, AppError> {
let tasks = db.get_all_tasks().await?;
let poisoned_locks = db.get_poisoned_locks().await?;
let scheduler = Scheduler::new(&tasks, &poisoned_locks, &features.features)
.map_err(|x| VickyError::Scheduler { source: x })?;
let next_task = scheduler.get_next_task();
let next_task: Option<Task> = tokio::time::timeout(
time::Duration::from_secs(10),
scheduler.get_next_task(&features.features),
)
.await
.ok();

match next_task {
Some(next_task) => {
Expand Down
11 changes: 6 additions & 5 deletions vicky/src/lib/database/entities/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use chrono::{DateTime, Utc};
use diesel::{AsExpression, FromSqlRow};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use uuid::Uuid;

pub const HEARTBEAT_TIMEOUT_SEC: i64 = 60;
Expand Down Expand Up @@ -62,7 +63,7 @@ pub struct Task {
pub flake_ref: FlakeRef,

#[builder(field)]
pub features: Vec<String>,
pub features: HashSet<String>,

#[builder(default = Uuid::new_v4())]
pub id: Uuid,
Expand Down Expand Up @@ -163,11 +164,11 @@ impl<T: task_builder::State> TaskBuilder<T> {
}

pub fn requires_feature<S: Into<String>>(mut self, feature: S) -> Self {
self.features.push(feature.into());
self.features.insert(feature.into());
self
}

pub fn requires_features(mut self, features: Vec<String>) -> Self {
pub fn requires_features(mut self, features: HashSet<String>) -> Self {
self.features = features;
self
}
Expand Down Expand Up @@ -216,7 +217,7 @@ impl From<(DbTask, Vec<DbLock>)> for Task {
flake: task.flake_ref_uri,
args: task.flake_ref_args,
},
features: task.features,
features: task.features.into_iter().collect(),
created_at: task.created_at,
claimed_at: task.claimed_at,
finished_at: task.finished_at,
Expand Down Expand Up @@ -366,7 +367,7 @@ pub mod db_impl {
id: task.id,
display_name: task.display_name,
status: task.status,
features: task.features,
features: task.features.into_iter().collect(),
flake_ref_uri: task.flake_ref.flake,
flake_ref_args: task.flake_ref.args,
created_at: task.created_at,
Expand Down
2 changes: 2 additions & 0 deletions vicky/src/lib/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pub enum SchedulerError {
GeneralSchedulingError,
#[error("lock already owned")]
LockAlreadyOwnedError,
#[error("channel closed")]
ChannelClosed,
}

#[derive(Error, Debug)]
Expand Down
Loading
Loading