Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions .github/workflows/ci.yml → .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
@@ -1,42 +1,40 @@
name: CI
name: unit-tests

on:
push:
branches: [main]
pull_request:

env:
CARGO_TERM_COLOR: always

jobs:
fmt:
name: Formatting
name: formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
submodules: recursive

- uses: dtolnay/rust-toolchain@stable
with:
toolchain: "1.92.0"
components: rustfmt

- uses: Swatinem/rust-cache@v2

- run: make fmt

test:
name: Tests
name: tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
submodules: recursive

- name: Install system dependencies
- name: install system dependencies
run: sudo apt-get update && sudo apt-get install -y cmake libclang-dev

- uses: dtolnay/rust-toolchain@stable
with:
toolchain: "1.92.0"

- uses: Swatinem/rust-cache@v2

Expand Down
120 changes: 111 additions & 9 deletions llama-cpp-bindings-sys/wrapper_common.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "wrapper_common.h"

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <exception>
Expand All @@ -15,18 +16,30 @@
extern "C" llama_rs_status llama_rs_json_schema_to_grammar(
const char * schema_json,
bool force_gbnf,
char ** out_grammar) {
if (!schema_json || !out_grammar) {
char ** out_grammar,
char ** out_error) {
if (!schema_json || !out_grammar || !out_error) {
return LLAMA_RS_STATUS_INVALID_ARGUMENT;
}

*out_grammar = nullptr;
*out_error = nullptr;

try {
const auto schema = nlohmann::ordered_json::parse(schema_json);
const auto grammar = json_schema_to_grammar(schema, force_gbnf);
*out_grammar = llama_rs_dup_string(grammar);

return *out_grammar ? LLAMA_RS_STATUS_OK : LLAMA_RS_STATUS_ALLOCATION_FAILED;
} catch (const std::exception &) {
} catch (const std::exception & err) {
fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what());
*out_error = llama_rs_dup_string(err.what());

return LLAMA_RS_STATUS_EXCEPTION;
} catch (...) {
fprintf(stderr, "%s: unknown C++ exception\n", __func__);
*out_error = llama_rs_dup_string("unknown C++ exception");

return LLAMA_RS_STATUS_EXCEPTION;
}
}
Expand Down Expand Up @@ -85,10 +98,25 @@ extern "C" void llama_rs_string_free(char * ptr) {
extern "C" struct llama_sampler * llama_rs_sampler_init_grammar(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root) {
const char * grammar_root,
char ** out_error) {
if (!out_error) {
return nullptr;
}

*out_error = nullptr;

try {
return llama_sampler_init_grammar(vocab, grammar_str, grammar_root);
} catch (const std::exception & err) {
fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what());
*out_error = llama_rs_dup_string(err.what());

return nullptr;
} catch (...) {
fprintf(stderr, "%s: unknown C++ exception\n", __func__);
*out_error = llama_rs_dup_string("unknown C++ exception");

return nullptr;
}
}
Expand All @@ -100,7 +128,14 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy(
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
size_t num_trigger_tokens,
char ** out_error) {
if (!out_error) {
return nullptr;
}

*out_error = nullptr;

try {
std::vector<std::string> trigger_patterns;
trigger_patterns.reserve(num_trigger_words);
Expand All @@ -115,6 +150,7 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy(
for (const auto & pattern : trigger_patterns) {
trigger_patterns_c.push_back(pattern.c_str());
}

return llama_sampler_init_grammar_lazy_patterns(
vocab,
grammar_str,
Expand All @@ -123,7 +159,15 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy(
trigger_patterns_c.size(),
trigger_tokens,
num_trigger_tokens);
} catch (const std::exception & err) {
fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what());
*out_error = llama_rs_dup_string(err.what());

return nullptr;
} catch (...) {
fprintf(stderr, "%s: unknown C++ exception\n", __func__);
*out_error = llama_rs_dup_string("unknown C++ exception");

return nullptr;
}
}
Expand All @@ -135,7 +179,14 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns(
const char ** trigger_patterns,
size_t num_trigger_patterns,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
size_t num_trigger_tokens,
char ** out_error) {
if (!out_error) {
return nullptr;
}

*out_error = nullptr;

try {
return llama_sampler_init_grammar_lazy_patterns(
vocab,
Expand All @@ -145,7 +196,15 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns(
num_trigger_patterns,
trigger_tokens,
num_trigger_tokens);
} catch (const std::exception & err) {
fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what());
*out_error = llama_rs_dup_string(err.what());

return nullptr;
} catch (...) {
fprintf(stderr, "%s: unknown C++ exception\n", __func__);
*out_error = llama_rs_dup_string("unknown C++ exception");

return nullptr;
}
}
Expand All @@ -164,6 +223,7 @@ extern "C" llama_pos llama_rs_memory_seq_pos_max(
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
return -1;
}

return llama_memory_seq_pos_max(mem, seq_id);
}

Expand Down Expand Up @@ -231,16 +291,58 @@ extern "C" llama_rs_status llama_rs_memory_seq_div(
return LLAMA_RS_STATUS_OK;
}

extern "C" llama_rs_status llama_rs_sampler_accept(struct llama_sampler * sampler, llama_token token) {
if (!sampler) {
extern "C" llama_rs_status llama_rs_sampler_sample(
struct llama_sampler * sampler,
struct llama_context * ctx,
int32_t idx,
llama_token * out_token,
char ** out_error) {
if (!sampler || !ctx || !out_token || !out_error) {
return LLAMA_RS_STATUS_INVALID_ARGUMENT;
}

*out_error = nullptr;

try {
*out_token = llama_sampler_sample(sampler, ctx, idx);

return LLAMA_RS_STATUS_OK;
} catch (const std::exception & err) {
fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what());
*out_error = llama_rs_dup_string(err.what());

return LLAMA_RS_STATUS_EXCEPTION;
} catch (...) {
fprintf(stderr, "%s: unknown C++ exception\n", __func__);
*out_error = llama_rs_dup_string("unknown C++ exception");

return LLAMA_RS_STATUS_EXCEPTION;
}
}

extern "C" llama_rs_status llama_rs_sampler_accept(
struct llama_sampler * sampler,
llama_token token,
char ** out_error) {
if (!sampler || !out_error) {
return LLAMA_RS_STATUS_INVALID_ARGUMENT;
}

*out_error = nullptr;

try {
llama_sampler_accept(sampler, token);

return LLAMA_RS_STATUS_OK;
} catch (const std::exception &) {
} catch (const std::exception & err) {
fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what());
*out_error = llama_rs_dup_string(err.what());

return LLAMA_RS_STATUS_EXCEPTION;
} catch (...) {
fprintf(stderr, "%s: unknown C++ exception\n", __func__);
*out_error = llama_rs_dup_string("unknown C++ exception");

return LLAMA_RS_STATUS_EXCEPTION;
}
}
24 changes: 19 additions & 5 deletions llama-cpp-bindings-sys/wrapper_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ extern "C" {
llama_rs_status llama_rs_json_schema_to_grammar(
const char * schema_json,
bool force_gbnf,
char ** out_grammar);
char ** out_grammar,
char ** out_error);

struct llama_sampler * llama_rs_sampler_init_grammar(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root);
const char * grammar_root,
char ** out_error);

struct llama_sampler * llama_rs_sampler_init_grammar_lazy(
const struct llama_vocab * vocab,
Expand All @@ -53,7 +55,8 @@ struct llama_sampler * llama_rs_sampler_init_grammar_lazy(
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);
size_t num_trigger_tokens,
char ** out_error);

struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns(
const struct llama_vocab * vocab,
Expand All @@ -62,9 +65,20 @@ struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns(
const char ** trigger_patterns,
size_t num_trigger_patterns,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);
size_t num_trigger_tokens,
char ** out_error);

llama_rs_status llama_rs_sampler_accept(
struct llama_sampler * sampler,
llama_token token,
char ** out_error);

llama_rs_status llama_rs_sampler_accept(struct llama_sampler * sampler, llama_token token);
llama_rs_status llama_rs_sampler_sample(
struct llama_sampler * sampler,
struct llama_context * ctx,
int32_t idx,
llama_token * out_token,
char ** out_error);

void llama_rs_chat_template_result_free(struct llama_rs_chat_template_result * result);
void llama_rs_string_free(char * ptr);
Expand Down
26 changes: 21 additions & 5 deletions llama-cpp-bindings/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ pub enum GrammarError {
#[error("String contains null bytes: {0}")]
NulError(#[from] NulError),
/// The grammar call returned null
#[error("Grammar call returned null")]
NullGrammar,
#[error("Grammar initialization failed: {0}")]
NullGrammar(String),
/// An integer value exceeded the allowed range
#[error("Integer overflow: {0}")]
IntegerOverflow(String),
Expand All @@ -193,6 +193,18 @@ pub enum SamplingError {
IntegerOverflow(String),
}

/// Errors that can occur when sampling a token.
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum SampleError {
/// A C++ exception was thrown during sampling
#[error("C++ exception during sampling: {0}")]
CppException(String),

/// An invalid argument was passed to the sampler
#[error("Invalid argument passed to sampler")]
InvalidArgument,
}

/// Decode a error from llama.cpp into a [`DecodeError`].
impl From<NonZeroI32> for DecodeError {
fn from(value: NonZeroI32) -> Self {
Expand Down Expand Up @@ -345,9 +357,13 @@ pub enum ChatParseError {
/// Failed to accept a token in a sampler.
#[derive(Debug, thiserror::Error)]
pub enum SamplerAcceptError {
/// llama.cpp returned an error code.
#[error("ffi error {0}")]
FfiError(i32),
/// A C++ exception was thrown during accept
#[error("C++ exception during sampler accept: {0}")]
CppException(String),

/// An invalid argument was passed (null sampler or null error pointer)
#[error("Invalid argument passed to sampler accept")]
InvalidArgument,
}

/// Errors that can occur when modifying model parameters.
Expand Down
33 changes: 33 additions & 0 deletions llama-cpp-bindings/src/ffi_error_reader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use std::ffi::{CStr, c_char};

/// Reads a C error string, converts to Rust `String`, and frees the C memory.
///
/// # Safety
///
/// `error_ptr` must be either null or a valid pointer to a null-terminated
/// C string allocated by `llama_rs_dup_string`.
pub unsafe fn read_and_free_cpp_error(error_ptr: *mut c_char) -> String {
if error_ptr.is_null() {
return "unknown error".to_owned();
}

let message = unsafe { CStr::from_ptr(error_ptr) }
.to_string_lossy()
.into_owned();

unsafe { llama_cpp_bindings_sys::llama_rs_string_free(error_ptr) };

message
}

#[cfg(test)]
mod tests {
use super::read_and_free_cpp_error;

#[test]
fn returns_unknown_for_null_pointer() {
let result = unsafe { read_and_free_cpp_error(std::ptr::null_mut()) };

assert_eq!(result, "unknown error");
}
}
Loading
Loading