Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions examples/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <chrono>
#include <filesystem>
#include <fstream>
#include <future>
#include <iomanip>
#include <iostream>
#include <mutex>
Expand Down Expand Up @@ -384,6 +385,18 @@ int main(int argc, const char** argv) {
return httplib::Server::HandlerResponse::Unhandled;
});

auto wait_for_generation = [](std::future<void>& ft, sd_ctx_t* sd_ctx, const httplib::Request& req) {
std::future_status ft_status;
do {
if (!ft.valid())
break;
ft_status = ft.wait_for(std::chrono::milliseconds(1000));
if (req.is_connection_closed()) {
sd_cancel_generation(sd_ctx, SD_CANCEL_ALL);
}
} while (ft_status != std::future_status::ready);
};

// index html
std::string index_html;
#ifdef HAVE_INDEX_HTML
Expand Down Expand Up @@ -532,11 +545,13 @@ int main(int argc, const char** argv) {
sd_image_t* results = nullptr;
int num_results = 0;

{
std::future<void> ft = std::async(std::launch::async, [&]() {
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
});

wait_for_generation(ft, sd_ctx, req);

for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr) {
Expand Down Expand Up @@ -779,11 +794,13 @@ int main(int argc, const char** argv) {
sd_image_t* results = nullptr;
int num_results = 0;

{
std::future<void> ft = std::async(std::launch::async, [&]() {
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
});

wait_for_generation(ft, sd_ctx, req);

json out;
out["created"] = static_cast<long long>(std::time(nullptr));
Expand Down Expand Up @@ -1095,11 +1112,13 @@ int main(int argc, const char** argv) {
sd_image_t* results = nullptr;
int num_results = 0;

{
std::future<void> ft = std::async(std::launch::async, [&]() {
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
});

wait_for_generation(ft, sd_ctx, req);

json out;
out["images"] = json::array();
Expand Down
9 changes: 9 additions & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,15 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);

enum sd_cancel_mode_t
{
SD_CANCEL_ALL,
SD_CANCEL_NEW_LATENTS,
SD_CANCEL_RESET
};

SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode);

SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);

Expand Down
49 changes: 49 additions & 0 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "latent-preview.h"
#include "name_conversion.h"

#include <atomic>

const char* model_version_to_str[] = {
"SD 1.x",
"SD 1.x Inpaint",
Expand Down Expand Up @@ -478,6 +480,9 @@ static void log_sample_cache_summary(const SampleCacheRuntime& runtime, size_t t

/*=============================================== StableDiffusionGGML ================================================*/

static_assert(std::atomic<sd_cancel_mode_t>::is_always_lock_free,
"sd_cancel_mode_t must be lock-free");

class StableDiffusionGGML {
public:
ggml_backend_t backend = nullptr; // general backend
Expand Down Expand Up @@ -528,6 +533,8 @@ class StableDiffusionGGML {

std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();

std::atomic<sd_cancel_mode_t> cancellation_flag;

StableDiffusionGGML() = default;

~StableDiffusionGGML() {
Expand All @@ -543,6 +550,18 @@ class StableDiffusionGGML {
ggml_backend_free(backend);
}

void set_cancel_flag(enum sd_cancel_mode_t flag) {
cancellation_flag.store(flag, std::memory_order_release);
}

void reset_cancel_flag() {
set_cancel_flag(SD_CANCEL_RESET);
}

enum sd_cancel_mode_t get_cancel_flag() {
return cancellation_flag.load(std::memory_order_acquire);
}

void init_backend() {
#ifdef SD_USE_CUDA
LOG_DEBUG("Using CUDA backend");
Expand Down Expand Up @@ -2100,6 +2119,12 @@ class StableDiffusionGGML {
}

auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
enum sd_cancel_mode_t cancel_flag = get_cancel_flag();
if (cancel_flag != SD_CANCEL_RESET) {
LOG_DEBUG("cancelling latent decodings");
return nullptr;
}

auto sd_preview_cb = sd_get_preview_callback();
auto sd_preview_cb_data = sd_get_preview_callback_data();
auto sd_preview_mode = sd_get_preview_mode();
Expand Down Expand Up @@ -3146,6 +3171,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
img_cond = SDCondition(uncond.c_crossattn, uncond.c_vector, cond.c_concat);
}
for (int b = 0; b < batch_count; b++) {

if (sd_ctx->sd->get_cancel_flag() != SD_CANCEL_RESET) {
LOG_ERROR("cancelling generation");
break;
}

int64_t sampling_start = ggml_time_ms();
int64_t cur_seed = seed + b;
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, batch_count, cur_seed);
Expand Down Expand Up @@ -3207,6 +3238,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
LOG_INFO("decoding %zu latents", final_latents.size());
std::vector<ggml_tensor*> decoded_images; // collect decoded images
for (size_t i = 0; i < final_latents.size(); i++) {

if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling latent decodings");
break;
}

t1 = ggml_time_ms();
ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */);
// print_ggml_tensor(img);
Expand Down Expand Up @@ -3243,6 +3280,16 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
return result_images;
}

void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode)
{
if (sd_ctx && sd_ctx->sd) {
if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) {
mode = SD_CANCEL_ALL;
}
sd_ctx->sd->set_cancel_flag(mode);
}
}

sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;

Expand Down Expand Up @@ -3300,6 +3347,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
return nullptr;
}

sd_ctx->sd->reset_cancel_flag();

ggml_init_params params;
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1G
params.mem_buffer = nullptr;
Expand Down
Loading