From 0b9aa871b45818a66f5f5cb0b739a9065268d6cb Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk Date: Thu, 2 Apr 2026 21:05:25 +0200 Subject: [PATCH 1/6] catch C++ exceptions at FFI boundary, rename utility modules --- .github/workflows/ci.yml | 8 +- llama-cpp-bindings-sys/wrapper_common.cpp | 120 ++++++++- llama-cpp-bindings-sys/wrapper_common.h | 24 +- llama-cpp-bindings/src/error.rs | 22 +- llama-cpp-bindings/src/ffi_error_reader.rs | 27 ++ ...ty_status_is_ok.rs => ffi_status_is_ok.rs} | 0 ..._status_to_i32.rs => ffi_status_to_i32.rs} | 0 ...tility_ggml_time_us.rs => ggml_time_us.rs} | 0 ...o_grammar.rs => json_schema_to_grammar.rs} | 47 ++-- llama-cpp-bindings/src/lib.rs | 37 +-- ...lity_llama_time_us.rs => llama_time_us.rs} | 0 llama-cpp-bindings/src/llguidance_sampler.rs | 2 +- ..._utility_max_devices.rs => max_devices.rs} | 0 ..._mlock_supported.rs => mlock_supported.rs} | 0 ...ty_mmap_supported.rs => mmap_supported.rs} | 0 llama-cpp-bindings/src/model.rs | 250 ++++++++++++++++++ llama-cpp-bindings/src/sampling.rs | 84 ++++-- llama-cpp-bindings/src/test_model.rs | 8 + .../tests/constrained_decoding.rs | 2 +- llama-cpp-bindings/tests/multimodal.rs | 2 +- llama-cpp-bindings/tests/openai_server.rs | 2 +- llama-cpp-bindings/tests/openai_streaming.rs | 2 +- llama-cpp-bindings/tests/text_generation.rs | 6 +- llama-cpp-bindings/tests/tool_calling.rs | 2 +- 24 files changed, 554 insertions(+), 91 deletions(-) create mode 100644 llama-cpp-bindings/src/ffi_error_reader.rs rename llama-cpp-bindings/src/{llama_utility_status_is_ok.rs => ffi_status_is_ok.rs} (100%) rename llama-cpp-bindings/src/{llama_utility_status_to_i32.rs => ffi_status_to_i32.rs} (100%) rename llama-cpp-bindings/src/{llama_utility_ggml_time_us.rs => ggml_time_us.rs} (100%) rename llama-cpp-bindings/src/{llama_utility_json_schema_to_grammar.rs => json_schema_to_grammar.rs} (61%) rename llama-cpp-bindings/src/{llama_utility_llama_time_us.rs => llama_time_us.rs} (100%) rename llama-cpp-bindings/src/{llama_utility_max_devices.rs => max_devices.rs} (100%) rename llama-cpp-bindings/src/{llama_utility_mlock_supported.rs => mlock_supported.rs} (100%) rename llama-cpp-bindings/src/{llama_utility_mmap_supported.rs => mmap_supported.rs} (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8e1ebf03..fcf98d7b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: CI +name: ci on: push: @@ -9,7 +9,7 @@ env: jobs: fmt: - name: Formatting + name: formatting runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -24,14 +24,14 @@ jobs: - run: make fmt test: - name: Tests + name: tests runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: submodules: recursive - - name: Install system dependencies + - name: install system dependencies run: sudo apt-get update && sudo apt-get install -y cmake libclang-dev - uses: dtolnay/rust-toolchain@stable diff --git a/llama-cpp-bindings-sys/wrapper_common.cpp b/llama-cpp-bindings-sys/wrapper_common.cpp index 373082a8..d4210720 100644 --- a/llama-cpp-bindings-sys/wrapper_common.cpp +++ b/llama-cpp-bindings-sys/wrapper_common.cpp @@ -8,6 +8,7 @@ #include "llama.cpp/common/json-schema-to-grammar.h" #include "llama.cpp/include/llama.h" +#include "llama.cpp/src/llama-impl.h" #include "wrapper_utils.h" #include @@ -15,18 +16,30 @@ extern "C" llama_rs_status llama_rs_json_schema_to_grammar( const char * schema_json, bool force_gbnf, - char ** out_grammar) { - if (!schema_json || !out_grammar) { + char ** out_grammar, + char ** out_error) { + if (!schema_json || !out_grammar || !out_error) { return LLAMA_RS_STATUS_INVALID_ARGUMENT; } *out_grammar = nullptr; + *out_error = nullptr; + try { const auto schema = nlohmann::ordered_json::parse(schema_json); const auto grammar = json_schema_to_grammar(schema, force_gbnf); *out_grammar = llama_rs_dup_string(grammar); + return *out_grammar ? LLAMA_RS_STATUS_OK : LLAMA_RS_STATUS_ALLOCATION_FAILED; - } catch (const std::exception &) { + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + + return LLAMA_RS_STATUS_EXCEPTION; + } catch (...) { + LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + return LLAMA_RS_STATUS_EXCEPTION; } } @@ -85,10 +98,25 @@ extern "C" void llama_rs_string_free(char * ptr) { extern "C" struct llama_sampler * llama_rs_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, - const char * grammar_root) { + const char * grammar_root, + char ** out_error) { + if (!out_error) { + return nullptr; + } + + *out_error = nullptr; + try { return llama_sampler_init_grammar(vocab, grammar_str, grammar_root); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + + return nullptr; } catch (...) { + LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + return nullptr; } } @@ -100,7 +128,14 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy( const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, - size_t num_trigger_tokens) { + size_t num_trigger_tokens, + char ** out_error) { + if (!out_error) { + return nullptr; + } + + *out_error = nullptr; + try { std::vector trigger_patterns; trigger_patterns.reserve(num_trigger_words); @@ -115,6 +150,7 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy( for (const auto & pattern : trigger_patterns) { trigger_patterns_c.push_back(pattern.c_str()); } + return llama_sampler_init_grammar_lazy_patterns( vocab, grammar_str, @@ -123,7 +159,15 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy( trigger_patterns_c.size(), trigger_tokens, num_trigger_tokens); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + + return nullptr; } catch (...) { + LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + return nullptr; } } @@ -135,7 +179,14 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns( const char ** trigger_patterns, size_t num_trigger_patterns, const llama_token * trigger_tokens, - size_t num_trigger_tokens) { + size_t num_trigger_tokens, + char ** out_error) { + if (!out_error) { + return nullptr; + } + + *out_error = nullptr; + try { return llama_sampler_init_grammar_lazy_patterns( vocab, @@ -145,7 +196,15 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns( num_trigger_patterns, trigger_tokens, num_trigger_tokens); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + + return nullptr; } catch (...) { + LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + return nullptr; } } @@ -164,6 +223,7 @@ extern "C" llama_pos llama_rs_memory_seq_pos_max( if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { return -1; } + return llama_memory_seq_pos_max(mem, seq_id); } @@ -231,16 +291,58 @@ extern "C" llama_rs_status llama_rs_memory_seq_div( return LLAMA_RS_STATUS_OK; } -extern "C" llama_rs_status llama_rs_sampler_accept(struct llama_sampler * sampler, llama_token token) { - if (!sampler) { +extern "C" llama_rs_status llama_rs_sampler_sample( + struct llama_sampler * sampler, + struct llama_context * ctx, + int32_t idx, + llama_token * out_token, + char ** out_error) { + if (!sampler || !ctx || !out_token || !out_error) { return LLAMA_RS_STATUS_INVALID_ARGUMENT; } + + *out_error = nullptr; + + try { + *out_token = llama_sampler_sample(sampler, ctx, idx); + + return LLAMA_RS_STATUS_OK; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + + return LLAMA_RS_STATUS_EXCEPTION; + } catch (...) { + LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + + return LLAMA_RS_STATUS_EXCEPTION; + } +} + +extern "C" llama_rs_status llama_rs_sampler_accept( + struct llama_sampler * sampler, + llama_token token, + char ** out_error) { + if (!sampler || !out_error) { + return LLAMA_RS_STATUS_INVALID_ARGUMENT; + } + + *out_error = nullptr; + try { llama_sampler_accept(sampler, token); + return LLAMA_RS_STATUS_OK; - } catch (const std::exception &) { + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + return LLAMA_RS_STATUS_EXCEPTION; } catch (...) { + LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + return LLAMA_RS_STATUS_EXCEPTION; } } diff --git a/llama-cpp-bindings-sys/wrapper_common.h b/llama-cpp-bindings-sys/wrapper_common.h index d1ab9c27..14a8998e 100644 --- a/llama-cpp-bindings-sys/wrapper_common.h +++ b/llama-cpp-bindings-sys/wrapper_common.h @@ -39,12 +39,14 @@ extern "C" { llama_rs_status llama_rs_json_schema_to_grammar( const char * schema_json, bool force_gbnf, - char ** out_grammar); + char ** out_grammar, + char ** out_error); struct llama_sampler * llama_rs_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, - const char * grammar_root); + const char * grammar_root, + char ** out_error); struct llama_sampler * llama_rs_sampler_init_grammar_lazy( const struct llama_vocab * vocab, @@ -53,7 +55,8 @@ struct llama_sampler * llama_rs_sampler_init_grammar_lazy( const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, - size_t num_trigger_tokens); + size_t num_trigger_tokens, + char ** out_error); struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns( const struct llama_vocab * vocab, @@ -62,9 +65,20 @@ struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns( const char ** trigger_patterns, size_t num_trigger_patterns, const llama_token * trigger_tokens, - size_t num_trigger_tokens); + size_t num_trigger_tokens, + char ** out_error); + +llama_rs_status llama_rs_sampler_accept( + struct llama_sampler * sampler, + llama_token token, + char ** out_error); -llama_rs_status llama_rs_sampler_accept(struct llama_sampler * sampler, llama_token token); +llama_rs_status llama_rs_sampler_sample( + struct llama_sampler * sampler, + struct llama_context * ctx, + int32_t idx, + llama_token * out_token, + char ** out_error); void llama_rs_chat_template_result_free(struct llama_rs_chat_template_result * result); void llama_rs_string_free(char * ptr); diff --git a/llama-cpp-bindings/src/error.rs b/llama-cpp-bindings/src/error.rs index 2f56ba88..58b3f460 100644 --- a/llama-cpp-bindings/src/error.rs +++ b/llama-cpp-bindings/src/error.rs @@ -175,8 +175,8 @@ pub enum GrammarError { #[error("String contains null bytes: {0}")] NulError(#[from] NulError), /// The grammar call returned null - #[error("Grammar call returned null")] - NullGrammar, + #[error("Grammar initialization failed: {0}")] + NullGrammar(String), /// An integer value exceeded the allowed range #[error("Integer overflow: {0}")] IntegerOverflow(String), @@ -193,6 +193,18 @@ pub enum SamplingError { IntegerOverflow(String), } +/// Errors that can occur when sampling a token. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum SampleError { + /// A C++ exception was thrown during sampling + #[error("C++ exception during sampling: {0}")] + CppException(String), + + /// An invalid argument was passed to the sampler + #[error("Invalid argument passed to sampler")] + InvalidArgument, +} + /// Decode a error from llama.cpp into a [`DecodeError`]. impl From for DecodeError { fn from(value: NonZeroI32) -> Self { @@ -345,9 +357,9 @@ pub enum ChatParseError { /// Failed to accept a token in a sampler. #[derive(Debug, thiserror::Error)] pub enum SamplerAcceptError { - /// llama.cpp returned an error code. - #[error("ffi error {0}")] - FfiError(i32), + /// A C++ exception was thrown during accept + #[error("C++ exception during sampler accept: {0}")] + CppException(String), } /// Errors that can occur when modifying model parameters. diff --git a/llama-cpp-bindings/src/ffi_error_reader.rs b/llama-cpp-bindings/src/ffi_error_reader.rs new file mode 100644 index 00000000..257b311a --- /dev/null +++ b/llama-cpp-bindings/src/ffi_error_reader.rs @@ -0,0 +1,27 @@ +use std::ffi::{CStr, c_char}; + +pub fn read_and_free_cpp_error(error_ptr: *mut c_char) -> String { + if error_ptr.is_null() { + return "unknown error".to_owned(); + } + + let message = unsafe { CStr::from_ptr(error_ptr) } + .to_string_lossy() + .into_owned(); + + unsafe { llama_cpp_bindings_sys::llama_rs_string_free(error_ptr) }; + + message +} + +#[cfg(test)] +mod tests { + use super::read_and_free_cpp_error; + + #[test] + fn returns_unknown_for_null_pointer() { + let result = read_and_free_cpp_error(std::ptr::null_mut()); + + assert_eq!(result, "unknown error"); + } +} diff --git a/llama-cpp-bindings/src/llama_utility_status_is_ok.rs b/llama-cpp-bindings/src/ffi_status_is_ok.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_status_is_ok.rs rename to llama-cpp-bindings/src/ffi_status_is_ok.rs diff --git a/llama-cpp-bindings/src/llama_utility_status_to_i32.rs b/llama-cpp-bindings/src/ffi_status_to_i32.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_status_to_i32.rs rename to llama-cpp-bindings/src/ffi_status_to_i32.rs diff --git a/llama-cpp-bindings/src/llama_utility_ggml_time_us.rs b/llama-cpp-bindings/src/ggml_time_us.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_ggml_time_us.rs rename to llama-cpp-bindings/src/ggml_time_us.rs diff --git a/llama-cpp-bindings/src/llama_utility_json_schema_to_grammar.rs b/llama-cpp-bindings/src/json_schema_to_grammar.rs similarity index 61% rename from llama-cpp-bindings/src/llama_utility_json_schema_to_grammar.rs rename to llama-cpp-bindings/src/json_schema_to_grammar.rs index af7ba6c8..fa3ad0fd 100644 --- a/llama-cpp-bindings/src/llama_utility_json_schema_to_grammar.rs +++ b/llama-cpp-bindings/src/json_schema_to_grammar.rs @@ -1,8 +1,7 @@ -use std::ffi::{CStr, CString}; +use std::ffi::{CStr, CString, c_char}; use crate::error::{LlamaCppError, Result}; -use crate::llama_utility_status_is_ok::status_is_ok; -use crate::llama_utility_status_to_i32::status_to_i32; +use crate::ffi_status_is_ok::status_is_ok; /// Convert a JSON schema string into a llama.cpp grammar string. /// @@ -11,32 +10,40 @@ use crate::llama_utility_status_to_i32::status_to_i32; pub fn json_schema_to_grammar(schema_json: &str) -> Result { let schema_cstr = CString::new(schema_json) .map_err(|err| LlamaCppError::JsonSchemaToGrammarError(err.to_string()))?; - let mut out = std::ptr::null_mut(); - let rc = unsafe { + let mut out: *mut c_char = std::ptr::null_mut(); + let mut error_ptr: *mut c_char = std::ptr::null_mut(); + + let status = unsafe { llama_cpp_bindings_sys::llama_rs_json_schema_to_grammar( schema_cstr.as_ptr(), false, - &raw mut out, + &mut out, + &mut error_ptr, ) }; - let result = { - if !status_is_ok(rc) || out.is_null() { - return Err(LlamaCppError::JsonSchemaToGrammarError(format!( - "ffi error {}", - status_to_i32(rc) - ))); - } - let grammar_bytes = unsafe { CStr::from_ptr(out) }.to_bytes().to_vec(); - let grammar = String::from_utf8(grammar_bytes) - .map_err(|err| LlamaCppError::JsonSchemaToGrammarError(err.to_string()))?; - - Ok(grammar) - }; + if !status_is_ok(status) || out.is_null() { + let message = if error_ptr.is_null() { + "unknown error".to_owned() + } else { + let message = unsafe { CStr::from_ptr(error_ptr) } + .to_string_lossy() + .into_owned(); + + unsafe { llama_cpp_bindings_sys::llama_rs_string_free(error_ptr) }; + + message + }; + + return Err(LlamaCppError::JsonSchemaToGrammarError(message)); + } + + let grammar_bytes = unsafe { CStr::from_ptr(out) }.to_bytes().to_vec(); unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out) }; - result + String::from_utf8(grammar_bytes) + .map_err(|err| LlamaCppError::JsonSchemaToGrammarError(err.to_string())) } #[cfg(test)] diff --git a/llama-cpp-bindings/src/lib.rs b/llama-cpp-bindings/src/lib.rs index 8aaad611..bd60075f 100644 --- a/llama-cpp-bindings/src/lib.rs +++ b/llama-cpp-bindings/src/lib.rs @@ -12,25 +12,26 @@ pub mod context; pub mod error; +pub mod ffi_error_reader; +pub mod ffi_status_is_ok; +pub mod ffi_status_to_i32; +pub mod ggml_time_us; pub mod gguf_context; pub mod gguf_context_error; pub mod gguf_type; +pub mod json_schema_to_grammar; pub mod llama_backend; pub mod llama_backend_device; pub mod llama_backend_numa_strategy; pub mod llama_batch; -pub mod llama_utility_ggml_time_us; -pub mod llama_utility_json_schema_to_grammar; -pub mod llama_utility_llama_time_us; -pub mod llama_utility_max_devices; -pub mod llama_utility_mlock_supported; -pub mod llama_utility_mmap_supported; -pub mod llama_utility_status_is_ok; -pub mod llama_utility_status_to_i32; +pub mod llama_time_us; #[cfg(feature = "llguidance")] pub mod llguidance_sampler; pub mod log; pub mod log_options; +pub mod max_devices; +pub mod mlock_supported; +pub mod mmap_supported; pub mod model; #[cfg(feature = "mtmd")] pub mod mtmd; @@ -44,22 +45,22 @@ pub use error::{ ApplyChatTemplateError, ChatParseError, ChatTemplateError, DecodeError, EmbeddingsError, EncodeError, GrammarError, LlamaContextLoadError, LlamaCppError, LlamaLoraAdapterInitError, LlamaLoraAdapterRemoveError, LlamaLoraAdapterSetError, LlamaModelLoadError, LogitsError, - MetaValError, ModelParamsError, NewLlamaChatMessageError, Result, SamplerAcceptError, - SamplingError, StringToTokenError, TokenSamplingError, TokenToStringError, + MetaValError, ModelParamsError, NewLlamaChatMessageError, Result, SampleError, + SamplerAcceptError, SamplingError, StringToTokenError, TokenSamplingError, TokenToStringError, }; pub use llama_backend_device::{ LlamaBackendDevice, LlamaBackendDeviceType, list_llama_ggml_backend_devices, }; -pub use llama_utility_ggml_time_us::ggml_time_us; -pub use llama_utility_json_schema_to_grammar::json_schema_to_grammar; -pub use llama_utility_llama_time_us::llama_time_us; -pub use llama_utility_max_devices::max_devices; -pub use llama_utility_mlock_supported::mlock_supported; -pub use llama_utility_mmap_supported::mmap_supported; -pub use llama_utility_status_is_ok::status_is_ok; -pub use llama_utility_status_to_i32::status_to_i32; +pub use ffi_status_is_ok::status_is_ok; +pub use ffi_status_to_i32::status_to_i32; +pub use ggml_time_us::ggml_time_us; +pub use json_schema_to_grammar::json_schema_to_grammar; +pub use llama_time_us::llama_time_us; +pub use max_devices::max_devices; +pub use mlock_supported::mlock_supported; +pub use mmap_supported::mmap_supported; pub use log::send_logs_to_tracing; pub use log_options::LogOptions; diff --git a/llama-cpp-bindings/src/llama_utility_llama_time_us.rs b/llama-cpp-bindings/src/llama_time_us.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_llama_time_us.rs rename to llama-cpp-bindings/src/llama_time_us.rs diff --git a/llama-cpp-bindings/src/llguidance_sampler.rs b/llama-cpp-bindings/src/llguidance_sampler.rs index 0db8dd16..b7a565e9 100644 --- a/llama-cpp-bindings/src/llguidance_sampler.rs +++ b/llama-cpp-bindings/src/llguidance_sampler.rs @@ -207,7 +207,7 @@ pub fn create_llg_sampler( }; if sampler.is_null() { - Err(GrammarError::NullGrammar) + Err(GrammarError::NullGrammar("llguidance sampler returned null".to_owned())) } else { Ok(LlamaSampler { sampler }) } diff --git a/llama-cpp-bindings/src/llama_utility_max_devices.rs b/llama-cpp-bindings/src/max_devices.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_max_devices.rs rename to llama-cpp-bindings/src/max_devices.rs diff --git a/llama-cpp-bindings/src/llama_utility_mlock_supported.rs b/llama-cpp-bindings/src/mlock_supported.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_mlock_supported.rs rename to llama-cpp-bindings/src/mlock_supported.rs diff --git a/llama-cpp-bindings/src/llama_utility_mmap_supported.rs b/llama-cpp-bindings/src/mmap_supported.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_mmap_supported.rs rename to llama-cpp-bindings/src/mmap_supported.rs diff --git a/llama-cpp-bindings/src/model.rs b/llama-cpp-bindings/src/model.rs index a3f81d2d..52f1fc3c 100644 --- a/llama-cpp-bindings/src/model.rs +++ b/llama-cpp-bindings/src/model.rs @@ -1838,4 +1838,254 @@ mod tests { assert!(result.is_err()); } + + #[test] + #[serial] + fn sample_returns_result_and_succeeds_with_valid_index() { + use crate::sampling::LlamaSampler; + use crate::token::LlamaToken; + + let (backend, model) = test_model::load_default_model().unwrap(); + let ctx_params = crate::context::params::LlamaContextParams::default() + .with_n_ctx(std::num::NonZeroU32::new(256)); + let mut context = model.new_context(&backend, ctx_params).unwrap(); + + let tokens = model.str_to_token("Hello", AddBos::Always).unwrap(); + let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap(); + + batch.add_sequence(&tokens, 0, false).unwrap(); + + context.decode(&mut batch).unwrap(); + + let mut sampler = + LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]); + + // sample() now returns Result to catch C++ exceptions at the FFI + // boundary instead of aborting the process. + let result = sampler.sample(&context, batch.n_tokens() - 1); + + assert!(result.is_ok()); + } + + #[test] + #[serial] + fn grammar_sampler_constrains_output_to_yes_or_no() { + use crate::sampling::LlamaSampler; + use std::sync::Arc; + + let backend = Arc::new(LlamaBackend::init().unwrap()); + let model_params = LlamaModelParams::default(); + let model_path = + test_model::download_file_from("Qwen/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf").unwrap(); + let model = LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap(); + + let ctx_params = crate::context::params::LlamaContextParams::default() + .with_n_ctx(std::num::NonZeroU32::new(512)); + let mut context = model.new_context(&backend, ctx_params).unwrap(); + + let prompt = "<|im_start|>user\nIs the sky blue? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + let tokens = model.str_to_token(prompt, AddBos::Always).unwrap(); + let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap(); + + batch.add_sequence(&tokens, 0, false).unwrap(); + + context.decode(&mut batch).unwrap(); + + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::grammar(&model, r#"root ::= [Yy] [Ee] [Ss] | [Nn] [Oo]"#, "root") + .unwrap(), + LlamaSampler::temp(0.8), + LlamaSampler::greedy(), + ]); + + let token = sampler.sample(&context, batch.n_tokens() - 1).unwrap(); + + assert!( + !model.is_eog_token(token), + "Grammar sampler should not allow EOS as first token" + ); + + let mut decoder = encoding_rs::UTF_8.new_decoder(); + let piece = model + .token_to_piece(token, &mut decoder, true, None) + .unwrap(); + let first_char = piece.chars().next().unwrap().to_lowercase().next().unwrap(); + + assert!( + first_char == 'y' || first_char == 'n', + "Grammar should constrain first token to start with y/n, got: '{piece}'" + ); + } + + #[test] + #[serial] + fn json_schema_grammar_sampler_constrains_output_to_json() { + use crate::sampling::LlamaSampler; + use std::sync::Arc; + + let backend = Arc::new(LlamaBackend::init().unwrap()); + let model_params = LlamaModelParams::default(); + let model_path = + test_model::download_file_from("Qwen/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf").unwrap(); + let model = LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap(); + + let ctx_params = crate::context::params::LlamaContextParams::default() + .with_n_ctx(std::num::NonZeroU32::new(512)); + let mut context = model.new_context(&backend, ctx_params).unwrap(); + + let prompt = "<|im_start|>user\nWhat is 2+2? Respond with a JSON object.<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + let tokens = model.str_to_token(prompt, AddBos::Always).unwrap(); + let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap(); + + batch.add_sequence(&tokens, 0, false).unwrap(); + + context.decode(&mut batch).unwrap(); + + let grammar_str = crate::json_schema_to_grammar( + r#"{"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]}"# + ).unwrap(); + + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::grammar(&model, &grammar_str, "root").unwrap(), + LlamaSampler::temp(0.8), + LlamaSampler::greedy(), + ]); + + let token = sampler.sample(&context, batch.n_tokens() - 1).unwrap(); + + assert!( + !model.is_eog_token(token), + "Grammar sampler should not allow EOS as first token" + ); + + let mut decoder = encoding_rs::UTF_8.new_decoder(); + let piece = model + .token_to_piece(token, &mut decoder, true, None) + .unwrap(); + + assert!( + piece.starts_with('{'), + "JSON schema grammar should constrain first token to start with '{{', got: '{piece}'" + ); + } + + #[test] + #[serial] + fn sample_with_grammar_produces_constrained_output_in_loop() { + use crate::sampling::LlamaSampler; + use std::sync::Arc; + + let backend = Arc::new(LlamaBackend::init().unwrap()); + let model_params = LlamaModelParams::default(); + let model_path = + test_model::download_file_from("Qwen/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf").unwrap(); + let model = LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap(); + + let ctx_params = crate::context::params::LlamaContextParams::default() + .with_n_ctx(std::num::NonZeroU32::new(512)); + let mut context = model.new_context(&backend, ctx_params).unwrap(); + + let prompt = "<|im_start|>user\nIs the sky blue? yes or no<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + let tokens = model.str_to_token(prompt, AddBos::Always).unwrap(); + let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap(); + + batch.add_sequence(&tokens, 0, false).unwrap(); + + context.decode(&mut batch).unwrap(); + + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::grammar(&model, r#"root ::= "yes" | "no""#, "root").unwrap(), + LlamaSampler::temp(0.8), + LlamaSampler::greedy(), + ]); + + let mut generated = String::new(); + let mut decoder = encoding_rs::UTF_8.new_decoder(); + let mut position = batch.n_tokens(); + + for iteration in 0..10 { + let token = sampler.sample(&context, -1).unwrap(); + let is_eog = model.is_eog_token(token); + + eprintln!(" iteration={iteration} token={} eog={is_eog}", token.0); + + if is_eog { + break; + } + + let piece = model + .token_to_piece(token, &mut decoder, true, None) + .unwrap(); + + eprintln!(" piece='{piece}'"); + + generated.push_str(&piece); + + batch.clear(); + batch.add(token, position, &[0], true).unwrap(); + position += 1; + + context.decode(&mut batch).unwrap(); + } + + let lowercase = generated.to_lowercase(); + + assert!( + lowercase == "yes" || lowercase == "no", + "Grammar loop should produce 'yes' or 'no', got: '{generated}'" + ); + } + + #[test] + #[serial] + fn sample_without_grammar_produces_multiple_tokens() { + use crate::sampling::LlamaSampler; + use std::sync::Arc; + + let backend = Arc::new(LlamaBackend::init().unwrap()); + let model_params = LlamaModelParams::default(); + let model_path = + test_model::download_file_from("Qwen/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf").unwrap(); + let model = LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap(); + + let ctx_params = crate::context::params::LlamaContextParams::default() + .with_n_ctx(std::num::NonZeroU32::new(512)); + let mut context = model.new_context(&backend, ctx_params).unwrap(); + + let prompt = + "<|im_start|>user\nSay hello<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + let tokens = model.str_to_token(prompt, AddBos::Always).unwrap(); + let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap(); + + batch.add_sequence(&tokens, 0, false).unwrap(); + + context.decode(&mut batch).unwrap(); + + let mut sampler = + LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]); + + let mut token_count = 0; + let mut position = batch.n_tokens(); + + for _ in 0..5 { + let token = sampler.sample(&context, -1).unwrap(); + + if model.is_eog_token(token) { + break; + } + + token_count += 1; + + batch.clear(); + batch.add(token, position, &[0], true).unwrap(); + position += 1; + + context.decode(&mut batch).unwrap(); + } + + assert!( + token_count > 0, + "Should produce at least one token without grammar" + ); + } } diff --git a/llama-cpp-bindings/src/sampling.rs b/llama-cpp-bindings/src/sampling.rs index 5b606ef7..4a96dfe7 100644 --- a/llama-cpp-bindings/src/sampling.rs +++ b/llama-cpp-bindings/src/sampling.rs @@ -5,27 +5,34 @@ use std::ffi::{CString, c_char}; use std::fmt::{Debug, Formatter}; use crate::context::LlamaContext; +use crate::ffi_error_reader::read_and_free_cpp_error; use crate::model::LlamaModel; use crate::token::LlamaToken; use crate::token::data_array::LlamaTokenDataArray; use crate::token::logit_bias::LlamaLogitBias; -use crate::{GrammarError, SamplerAcceptError, SamplingError, status_is_ok, status_to_i32}; +use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError, status_is_ok}; -const fn check_sampler_accept_status( +fn check_sampler_accept_status( status: llama_cpp_bindings_sys::llama_rs_status, + error_ptr: *mut c_char, ) -> Result<(), SamplerAcceptError> { if status_is_ok(status) { Ok(()) } else { - Err(SamplerAcceptError::FfiError(status_to_i32(status))) + Err(SamplerAcceptError::CppException(read_and_free_cpp_error( + error_ptr, + ))) } } -const fn check_sampler_not_null( +fn check_sampler_not_null( sampler: *mut llama_cpp_bindings_sys::llama_sampler, + error_ptr: *mut c_char, ) -> Result { if sampler.is_null() { - Err(GrammarError::NullGrammar) + Err(GrammarError::NullGrammar(read_and_free_cpp_error( + error_ptr, + ))) } else { Ok(LlamaSampler { sampler }) } @@ -56,14 +63,34 @@ impl Debug for LlamaSampler { } impl LlamaSampler { - /// Sample and accept a token from the idx-th output of the last evaluation - #[must_use] - pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken { - let token = unsafe { - llama_cpp_bindings_sys::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx) + /// Sample and accept a token from the idx-th output of the last evaluation. + /// + /// # Errors + /// + /// Returns [`SampleError`] if the C++ sampler throws an exception or if the index is invalid. + pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> Result { + let mut token: i32 = -1; + let mut error_ptr: *mut c_char = std::ptr::null_mut(); + + let status = unsafe { + llama_cpp_bindings_sys::llama_rs_sampler_sample( + self.sampler, + ctx.context.as_ptr(), + idx, + &mut token, + &mut error_ptr, + ) }; - LlamaToken(token) + match status { + llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => Ok(LlamaToken(token)), + llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => { + Err(SampleError::InvalidArgument) + } + _ => Err(SampleError::CppException(read_and_free_cpp_error( + error_ptr, + ))), + } } /// Applies this sampler to a [`LlamaTokenDataArray`]. @@ -115,10 +142,13 @@ impl LlamaSampler { /// # Errors /// Returns an error if the underlying sampler rejects the token. pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> { - let sampler_result = - unsafe { llama_cpp_bindings_sys::llama_rs_sampler_accept(self.sampler, token.0) }; + let mut error_ptr: *mut c_char = std::ptr::null_mut(); - check_sampler_accept_status(sampler_result) + let status = unsafe { + llama_cpp_bindings_sys::llama_rs_sampler_accept(self.sampler, token.0, &mut error_ptr) + }; + + check_sampler_accept_status(status, error_ptr) } /// Resets the internal state of the sampler. @@ -344,16 +374,18 @@ impl LlamaSampler { ) -> Result { let (grammar_str, grammar_root) = Self::sanitize_grammar_strings(grammar_str, grammar_root)?; + let mut error_ptr: *mut c_char = std::ptr::null_mut(); let sampler = unsafe { llama_cpp_bindings_sys::llama_rs_sampler_init_grammar( model.vocab_ptr(), grammar_str.as_ptr(), grammar_root.as_ptr(), + &mut error_ptr, ) }; - check_sampler_not_null(sampler) + check_sampler_not_null(sampler, error_ptr) } /// Lazy grammar sampler, introduced in @@ -372,6 +404,7 @@ impl LlamaSampler { let (grammar_str, grammar_root) = Self::sanitize_grammar_strings(grammar_str, grammar_root)?; let trigger_words = Self::sanitize_trigger_words(trigger_words)?; + let mut error_ptr: *mut c_char = std::ptr::null_mut(); let mut trigger_word_ptrs: Vec<*const c_char> = trigger_words.iter().map(|cs| cs.as_ptr()).collect(); @@ -385,10 +418,11 @@ impl LlamaSampler { trigger_word_ptrs.len(), trigger_tokens.as_ptr().cast(), trigger_tokens.len(), + &mut error_ptr, ) }; - check_sampler_not_null(sampler) + check_sampler_not_null(sampler, error_ptr) } /// Lazy grammar sampler using regex trigger patterns. @@ -409,6 +443,7 @@ impl LlamaSampler { let (grammar_str, grammar_root) = Self::sanitize_grammar_strings(grammar_str, grammar_root)?; let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?; + let mut error_ptr: *mut c_char = std::ptr::null_mut(); let mut trigger_pattern_ptrs: Vec<*const c_char> = trigger_patterns.iter().map(|cs| cs.as_ptr()).collect(); @@ -422,10 +457,11 @@ impl LlamaSampler { trigger_pattern_ptrs.len(), trigger_tokens.as_ptr().cast(), trigger_tokens.len(), + &mut error_ptr, ) }; - check_sampler_not_null(sampler) + check_sampler_not_null(sampler, error_ptr) } /// `LLGuidance` sampler for constrained decoding. @@ -938,9 +974,9 @@ mod tests { context.decode(&mut batch).unwrap(); let mut sampler = LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]); - let token = sampler.sample(&context, batch.n_tokens() - 1); + let result = sampler.sample(&context, batch.n_tokens() - 1); - assert_ne!(token, LlamaToken::new(-1)); + assert!(result.is_ok()); } #[test] @@ -957,14 +993,18 @@ mod tests { #[test] fn check_sampler_accept_status_error() { - let result = - super::check_sampler_accept_status(llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION); + let result = super::check_sampler_accept_status( + llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION, + std::ptr::null_mut(), + ); + assert!(result.is_err()); } #[test] fn check_sampler_not_null_returns_error() { - let result = super::check_sampler_not_null(std::ptr::null_mut()); + let result = super::check_sampler_not_null(std::ptr::null_mut(), std::ptr::null_mut()); + assert!(result.is_err()); } } diff --git a/llama-cpp-bindings/src/test_model.rs b/llama-cpp-bindings/src/test_model.rs index e0048909..752ce9d6 100644 --- a/llama-cpp-bindings/src/test_model.rs +++ b/llama-cpp-bindings/src/test_model.rs @@ -39,6 +39,14 @@ fn hf_encoder_model() -> Result { required_env("LLAMA_TEST_HF_ENCODER_MODEL") } +/// Downloads a file from a specific HuggingFace repo. +/// +/// # Errors +/// Returns an error if the download fails. +pub fn download_file_from(repo: &str, filename: &str) -> Result { + download_file(repo, filename) +} + fn download_file(repo: &str, filename: &str) -> Result { let path = hf_hub::api::sync::ApiBuilder::new() .with_progress(true) diff --git a/llama-cpp-bindings/tests/constrained_decoding.rs b/llama-cpp-bindings/tests/constrained_decoding.rs index 799c3ed0..6e2a4d52 100644 --- a/llama-cpp-bindings/tests/constrained_decoding.rs +++ b/llama-cpp-bindings/tests/constrained_decoding.rs @@ -52,7 +52,7 @@ fn json_schema_constrains_output() -> Result<()> { let mut generated = String::new(); while n_cur <= 128 { - let token = sampler.sample(&ctx, batch.n_tokens() - 1); + let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; if model.is_eog_token(token) { break; diff --git a/llama-cpp-bindings/tests/multimodal.rs b/llama-cpp-bindings/tests/multimodal.rs index 258e8c55..eee8b46c 100644 --- a/llama-cpp-bindings/tests/multimodal.rs +++ b/llama-cpp-bindings/tests/multimodal.rs @@ -91,7 +91,7 @@ fn multimodal_vision_inference_produces_output() -> Result<()> { let mut current_position = n_past; for _ in 0..max_tokens { - let token = sampler.sample(&ctx, -1); + let token = sampler.sample(&ctx, -1)?; if model.is_eog_token(token) { break; diff --git a/llama-cpp-bindings/tests/openai_server.rs b/llama-cpp-bindings/tests/openai_server.rs index 8629dc0d..aa239e24 100644 --- a/llama-cpp-bindings/tests/openai_server.rs +++ b/llama-cpp-bindings/tests/openai_server.rs @@ -78,7 +78,7 @@ fn run_chat_completion( let mut sampler = LlamaSampler::greedy(); while n_cur < max_tokens_total { - let token = sampler.sample(&ctx, batch.n_tokens() - 1); + let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; if model.is_eog_token(token) { break; diff --git a/llama-cpp-bindings/tests/openai_streaming.rs b/llama-cpp-bindings/tests/openai_streaming.rs index 156286f7..a36e85e9 100644 --- a/llama-cpp-bindings/tests/openai_streaming.rs +++ b/llama-cpp-bindings/tests/openai_streaming.rs @@ -100,7 +100,7 @@ fn streaming_deltas_produce_valid_chunks() -> Result<()> { let mut total_chunks = 0usize; while n_cur <= max_tokens { - let token = sampler.sample(&ctx, batch.n_tokens() - 1); + let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; if model.is_eog_token(token) { break; diff --git a/llama-cpp-bindings/tests/text_generation.rs b/llama-cpp-bindings/tests/text_generation.rs index 2eee177b..218e015d 100644 --- a/llama-cpp-bindings/tests/text_generation.rs +++ b/llama-cpp-bindings/tests/text_generation.rs @@ -64,7 +64,8 @@ fn raw_prompt_completion_with_timing() -> Result<()> { let mut generated = String::new(); while n_cur <= n_len { - let token = sampler.sample(&ctx, batch.n_tokens() - 1); + let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; + sampler.accept(token)?; if model.is_eog_token(token) { @@ -137,7 +138,8 @@ fn chat_inference_produces_coherent_output() -> Result<()> { let mut generated = String::new(); while position <= max_tokens { - let token = sampler.sample(&context, batch.n_tokens() - 1); + let token = sampler.sample(&context, batch.n_tokens() - 1)?; + sampler.accept(token)?; if model.is_eog_token(token) { diff --git a/llama-cpp-bindings/tests/tool_calling.rs b/llama-cpp-bindings/tests/tool_calling.rs index c18f06f6..3bee8792 100644 --- a/llama-cpp-bindings/tests/tool_calling.rs +++ b/llama-cpp-bindings/tests/tool_calling.rs @@ -100,7 +100,7 @@ fn tool_calling_generates_grammar_and_prompt() -> Result<()> { let additional_stops = result.additional_stops.clone(); while n_cur <= max_tokens { - let token = sampler.sample(&ctx, batch.n_tokens() - 1); + let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; if model.is_eog_token(token) { break; From 25d995992bc45f34e34cce4497a35071a692643f Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk Date: Thu, 2 Apr 2026 21:54:19 +0200 Subject: [PATCH 2/6] fix CI: align rust-toolchain to 1.93.0, mark ffi_error_reader as unsafe, fix llguidance NullGrammar, lowercase workflow names --- .github/workflows/ci.yml | 5 ----- llama-cpp-bindings/src/ffi_error_reader.rs | 10 ++++++++-- llama-cpp-bindings/src/llguidance_sampler.rs | 4 +++- llama-cpp-bindings/src/sampling.rs | 18 +++++++++--------- rust-toolchain.toml | 2 +- 5 files changed, 21 insertions(+), 18 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fcf98d7b..259a0eae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,9 +17,6 @@ jobs: submodules: recursive - uses: dtolnay/rust-toolchain@stable - with: - toolchain: "1.92.0" - components: rustfmt - run: make fmt @@ -35,8 +32,6 @@ jobs: run: sudo apt-get update && sudo apt-get install -y cmake libclang-dev - uses: dtolnay/rust-toolchain@stable - with: - toolchain: "1.92.0" - uses: Swatinem/rust-cache@v2 diff --git a/llama-cpp-bindings/src/ffi_error_reader.rs b/llama-cpp-bindings/src/ffi_error_reader.rs index 257b311a..17a445ab 100644 --- a/llama-cpp-bindings/src/ffi_error_reader.rs +++ b/llama-cpp-bindings/src/ffi_error_reader.rs @@ -1,6 +1,12 @@ use std::ffi::{CStr, c_char}; -pub fn read_and_free_cpp_error(error_ptr: *mut c_char) -> String { +/// Reads a C error string, converts to Rust `String`, and frees the C memory. +/// +/// # Safety +/// +/// `error_ptr` must be either null or a valid pointer to a null-terminated +/// C string allocated by `llama_rs_dup_string`. +pub unsafe fn read_and_free_cpp_error(error_ptr: *mut c_char) -> String { if error_ptr.is_null() { return "unknown error".to_owned(); } @@ -20,7 +26,7 @@ mod tests { #[test] fn returns_unknown_for_null_pointer() { - let result = read_and_free_cpp_error(std::ptr::null_mut()); + let result = unsafe { read_and_free_cpp_error(std::ptr::null_mut()) }; assert_eq!(result, "unknown error"); } diff --git a/llama-cpp-bindings/src/llguidance_sampler.rs b/llama-cpp-bindings/src/llguidance_sampler.rs index b7a565e9..4a1738ab 100644 --- a/llama-cpp-bindings/src/llguidance_sampler.rs +++ b/llama-cpp-bindings/src/llguidance_sampler.rs @@ -207,7 +207,9 @@ pub fn create_llg_sampler( }; if sampler.is_null() { - Err(GrammarError::NullGrammar("llguidance sampler returned null".to_owned())) + Err(GrammarError::NullGrammar( + "llguidance sampler returned null".to_owned(), + )) } else { Ok(LlamaSampler { sampler }) } diff --git a/llama-cpp-bindings/src/sampling.rs b/llama-cpp-bindings/src/sampling.rs index 4a96dfe7..5ca3dc94 100644 --- a/llama-cpp-bindings/src/sampling.rs +++ b/llama-cpp-bindings/src/sampling.rs @@ -19,9 +19,9 @@ fn check_sampler_accept_status( if status_is_ok(status) { Ok(()) } else { - Err(SamplerAcceptError::CppException(read_and_free_cpp_error( - error_ptr, - ))) + Err(SamplerAcceptError::CppException(unsafe { + read_and_free_cpp_error(error_ptr) + })) } } @@ -30,9 +30,9 @@ fn check_sampler_not_null( error_ptr: *mut c_char, ) -> Result { if sampler.is_null() { - Err(GrammarError::NullGrammar(read_and_free_cpp_error( - error_ptr, - ))) + Err(GrammarError::NullGrammar(unsafe { + read_and_free_cpp_error(error_ptr) + })) } else { Ok(LlamaSampler { sampler }) } @@ -87,9 +87,9 @@ impl LlamaSampler { llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => { Err(SampleError::InvalidArgument) } - _ => Err(SampleError::CppException(read_and_free_cpp_error( - error_ptr, - ))), + _ => Err(SampleError::CppException(unsafe { + read_and_free_cpp_error(error_ptr) + })), } } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 711b3a48..fb5449af 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.92.0" +channel = "1.93.0" components = ["clippy", "rustfmt"] From 93d122c5e6bad78ba82a2cb7c945cf7370c7fe6b Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk Date: Thu, 2 Apr 2026 21:58:53 +0200 Subject: [PATCH 3/6] run CI on push only for main, PRs use pull_request event --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 259a0eae..6f9e4b97 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,6 +2,7 @@ name: ci on: push: + branches: [main] pull_request: env: From 417c00d2255bbfcebc05a7655a457441301e0cb0 Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk Date: Thu, 2 Apr 2026 22:02:45 +0200 Subject: [PATCH 4/6] fix clippy: use raw mut for FFI pointers, add backticks in doc comment --- llama-cpp-bindings/src/json_schema_to_grammar.rs | 4 ++-- llama-cpp-bindings/src/sampling.rs | 16 ++++++++++------ llama-cpp-bindings/src/test_model.rs | 2 +- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/llama-cpp-bindings/src/json_schema_to_grammar.rs b/llama-cpp-bindings/src/json_schema_to_grammar.rs index fa3ad0fd..34590a82 100644 --- a/llama-cpp-bindings/src/json_schema_to_grammar.rs +++ b/llama-cpp-bindings/src/json_schema_to_grammar.rs @@ -17,8 +17,8 @@ pub fn json_schema_to_grammar(schema_json: &str) -> Result { llama_cpp_bindings_sys::llama_rs_json_schema_to_grammar( schema_cstr.as_ptr(), false, - &mut out, - &mut error_ptr, + &raw mut out, + &raw mut error_ptr, ) }; diff --git a/llama-cpp-bindings/src/sampling.rs b/llama-cpp-bindings/src/sampling.rs index 5ca3dc94..4d3a2fd3 100644 --- a/llama-cpp-bindings/src/sampling.rs +++ b/llama-cpp-bindings/src/sampling.rs @@ -77,8 +77,8 @@ impl LlamaSampler { self.sampler, ctx.context.as_ptr(), idx, - &mut token, - &mut error_ptr, + &raw mut token, + &raw mut error_ptr, ) }; @@ -145,7 +145,11 @@ impl LlamaSampler { let mut error_ptr: *mut c_char = std::ptr::null_mut(); let status = unsafe { - llama_cpp_bindings_sys::llama_rs_sampler_accept(self.sampler, token.0, &mut error_ptr) + llama_cpp_bindings_sys::llama_rs_sampler_accept( + self.sampler, + token.0, + &raw mut error_ptr, + ) }; check_sampler_accept_status(status, error_ptr) @@ -381,7 +385,7 @@ impl LlamaSampler { model.vocab_ptr(), grammar_str.as_ptr(), grammar_root.as_ptr(), - &mut error_ptr, + &raw mut error_ptr, ) }; @@ -418,7 +422,7 @@ impl LlamaSampler { trigger_word_ptrs.len(), trigger_tokens.as_ptr().cast(), trigger_tokens.len(), - &mut error_ptr, + &raw mut error_ptr, ) }; @@ -457,7 +461,7 @@ impl LlamaSampler { trigger_pattern_ptrs.len(), trigger_tokens.as_ptr().cast(), trigger_tokens.len(), - &mut error_ptr, + &raw mut error_ptr, ) }; diff --git a/llama-cpp-bindings/src/test_model.rs b/llama-cpp-bindings/src/test_model.rs index 752ce9d6..991dee9a 100644 --- a/llama-cpp-bindings/src/test_model.rs +++ b/llama-cpp-bindings/src/test_model.rs @@ -39,7 +39,7 @@ fn hf_encoder_model() -> Result { required_env("LLAMA_TEST_HF_ENCODER_MODEL") } -/// Downloads a file from a specific HuggingFace repo. +/// Downloads a file from a specific `HuggingFace` repo. /// /// # Errors /// Returns an error if the download fails. From 88ecfb496b0cd08a7ff0a4eb4e3cd901e2573918 Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk Date: Thu, 2 Apr 2026 22:43:43 +0200 Subject: [PATCH 5/6] remove panics from production code, fix double-accept, use env vars for test models, remove internal header dependency --- llama-cpp-bindings-sys/wrapper_common.cpp | 26 ++-- llama-cpp-bindings/src/error.rs | 4 + llama-cpp-bindings/src/model.rs | 132 +++++++----------- llama-cpp-bindings/src/mtmd/mtmd_context.rs | 2 +- llama-cpp-bindings/src/mtmd/mtmd_error.rs | 3 + .../src/mtmd/mtmd_input_chunks.rs | 33 ++--- llama-cpp-bindings/src/sampling.rs | 51 +++++-- .../tests/constrained_decoding.rs | 2 - llama-cpp-bindings/tests/openai_server.rs | 2 +- llama-cpp-bindings/tests/openai_streaming.rs | 2 +- llama-cpp-bindings/tests/text_generation.rs | 4 - llama-cpp-bindings/tests/tool_calling.rs | 2 +- 12 files changed, 133 insertions(+), 130 deletions(-) diff --git a/llama-cpp-bindings-sys/wrapper_common.cpp b/llama-cpp-bindings-sys/wrapper_common.cpp index d4210720..a71cae64 100644 --- a/llama-cpp-bindings-sys/wrapper_common.cpp +++ b/llama-cpp-bindings-sys/wrapper_common.cpp @@ -1,5 +1,6 @@ #include "wrapper_common.h" +#include #include #include #include @@ -8,7 +9,6 @@ #include "llama.cpp/common/json-schema-to-grammar.h" #include "llama.cpp/include/llama.h" -#include "llama.cpp/src/llama-impl.h" #include "wrapper_utils.h" #include @@ -32,12 +32,12 @@ extern "C" llama_rs_status llama_rs_json_schema_to_grammar( return *out_grammar ? LLAMA_RS_STATUS_OK : LLAMA_RS_STATUS_ALLOCATION_FAILED; } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); *out_error = llama_rs_dup_string(err.what()); return LLAMA_RS_STATUS_EXCEPTION; } catch (...) { - LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + fprintf(stderr, "%s: unknown C++ exception\n", __func__); *out_error = llama_rs_dup_string("unknown C++ exception"); return LLAMA_RS_STATUS_EXCEPTION; @@ -109,12 +109,12 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar( try { return llama_sampler_init_grammar(vocab, grammar_str, grammar_root); } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); *out_error = llama_rs_dup_string(err.what()); return nullptr; } catch (...) { - LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + fprintf(stderr, "%s: unknown C++ exception\n", __func__); *out_error = llama_rs_dup_string("unknown C++ exception"); return nullptr; @@ -160,12 +160,12 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy( trigger_tokens, num_trigger_tokens); } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); *out_error = llama_rs_dup_string(err.what()); return nullptr; } catch (...) { - LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + fprintf(stderr, "%s: unknown C++ exception\n", __func__); *out_error = llama_rs_dup_string("unknown C++ exception"); return nullptr; @@ -197,12 +197,12 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns( trigger_tokens, num_trigger_tokens); } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); *out_error = llama_rs_dup_string(err.what()); return nullptr; } catch (...) { - LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + fprintf(stderr, "%s: unknown C++ exception\n", __func__); *out_error = llama_rs_dup_string("unknown C++ exception"); return nullptr; @@ -308,12 +308,12 @@ extern "C" llama_rs_status llama_rs_sampler_sample( return LLAMA_RS_STATUS_OK; } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); *out_error = llama_rs_dup_string(err.what()); return LLAMA_RS_STATUS_EXCEPTION; } catch (...) { - LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + fprintf(stderr, "%s: unknown C++ exception\n", __func__); *out_error = llama_rs_dup_string("unknown C++ exception"); return LLAMA_RS_STATUS_EXCEPTION; @@ -335,12 +335,12 @@ extern "C" llama_rs_status llama_rs_sampler_accept( return LLAMA_RS_STATUS_OK; } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: C++ exception: %s\n", __func__, err.what()); + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); *out_error = llama_rs_dup_string(err.what()); return LLAMA_RS_STATUS_EXCEPTION; } catch (...) { - LLAMA_LOG_ERROR("%s: unknown C++ exception\n", __func__); + fprintf(stderr, "%s: unknown C++ exception\n", __func__); *out_error = llama_rs_dup_string("unknown C++ exception"); return LLAMA_RS_STATUS_EXCEPTION; diff --git a/llama-cpp-bindings/src/error.rs b/llama-cpp-bindings/src/error.rs index 58b3f460..9fc73b02 100644 --- a/llama-cpp-bindings/src/error.rs +++ b/llama-cpp-bindings/src/error.rs @@ -360,6 +360,10 @@ pub enum SamplerAcceptError { /// A C++ exception was thrown during accept #[error("C++ exception during sampler accept: {0}")] CppException(String), + + /// An invalid argument was passed (null sampler or null error pointer) + #[error("Invalid argument passed to sampler accept")] + InvalidArgument, } /// Errors that can occur when modifying model parameters. diff --git a/llama-cpp-bindings/src/model.rs b/llama-cpp-bindings/src/model.rs index 52f1fc3c..bf22e247 100644 --- a/llama-cpp-bindings/src/model.rs +++ b/llama-cpp-bindings/src/model.rs @@ -77,16 +77,15 @@ impl LlamaModel { unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) } } - /// get the number of tokens the model was trained on + /// Get the number of tokens the model was trained on. /// - /// # Panics + /// # Errors /// - /// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most - /// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive. - #[must_use] - pub fn n_ctx_train(&self) -> u32 { + /// Returns an error if the value returned by llama.cpp does not fit into a `u32`. + pub fn n_ctx_train(&self) -> Result { let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) }; - u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32") + + u32::try_from(n_ctx_train) } /// Get all tokens in the model. @@ -151,11 +150,8 @@ impl LlamaModel { /// /// # Errors /// - /// - if [`str`] contains a null byte. - /// - /// # Panics - /// - /// - if there is more than [`usize::MAX`] [`LlamaToken`]s in [`str`]. + /// - if [`str`] contains a null byte + /// - if an integer conversion fails during tokenization /// /// /// ```no_run @@ -200,7 +196,7 @@ impl LlamaModel { }; let size = if size.is_negative() { - buffer.reserve_exact(usize::try_from(-size).expect("negated size fits into usize")); + buffer.reserve_exact(usize::try_from(-size)?); unsafe { llama_cpp_bindings_sys::llama_tokenize( self.vocab_ptr(), @@ -228,14 +224,17 @@ impl LlamaModel { /// Get the type of a token. /// - /// # Panics + /// # Errors /// - /// If the token type is not known to this library. - #[must_use] - pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs { + /// Returns an error if the token type is not known to this library. + pub fn token_attr( + &self, + LlamaToken(id): LlamaToken, + ) -> Result { let token_type = unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) }; - LlamaTokenAttrs::try_from(token_type).expect("token type is valid") + + LlamaTokenAttrs::try_from(token_type) } /// Convert a token to a string using the underlying llama.cpp `llama_token_to_piece` function. @@ -252,9 +251,7 @@ impl LlamaModel { /// /// - if the token type is unknown /// - /// # Panics - /// - /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen) + /// - if the returned size from llama.cpp does not fit into a `usize` pub fn token_to_piece( &self, token: LlamaToken, @@ -263,15 +260,11 @@ impl LlamaModel { lstrip: Option, ) -> Result { let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) { - Err(TokenToStringError::InsufficientBufferSpace(required_size)) => self - .token_to_piece_bytes( - token, - (-required_size) - .try_into() - .expect("Error buffer size is positive"), - special, - lstrip, - ), + Err(TokenToStringError::InsufficientBufferSpace(required_size)) => { + let buffer_size: usize = (-required_size).try_into()?; + + self.token_to_piece_bytes(token, buffer_size, special, lstrip) + } other => other, }?; @@ -292,6 +285,7 @@ impl LlamaModel { /// /// - if the token type is unknown /// - the resultant token is larger than `buffer_size`. + /// - if an integer conversion fails #[allow(clippy::missing_panics_doc)] pub fn token_to_piece_bytes( &self, @@ -325,7 +319,7 @@ impl LlamaModel { size => { let string = unsafe { CString::from_raw(buf) }; let mut bytes = string.into_bytes(); - let len = usize::try_from(size).expect("size is positive and fits into usize"); + let len = usize::try_from(size)?; bytes.truncate(len); Ok(bytes) @@ -344,13 +338,13 @@ impl LlamaModel { /// The type of vocab the model was trained on. /// - /// # Panics + /// # Errors /// - /// If llama-cpp emits a vocab type that is not known to this library. - #[must_use] - pub fn vocab_type(&self) -> VocabType { + /// Returns an error if llama.cpp emits a vocab type that is not known to this library. + pub fn vocab_type(&self) -> Result { let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) }; - VocabType::try_from(vocab_type).expect("invalid vocab type") + + VocabType::try_from(vocab_type) } /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32 @@ -380,35 +374,29 @@ impl LlamaModel { /// Returns the number of layers within the model. /// - /// # Panics - /// Panics if the layer count returned by llama.cpp is negative. - #[must_use] - pub fn n_layer(&self) -> u32 { - // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe + /// # Errors + /// + /// Returns an error if the layer count returned by llama.cpp does not fit into a `u32`. + pub fn n_layer(&self) -> Result { u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) }) - .expect("llama.cpp returns a positive value for n_layer") } /// Returns the number of attention heads within the model. /// - /// # Panics - /// Panics if the head count returned by llama.cpp is negative. - #[must_use] - pub fn n_head(&self) -> u32 { - // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe + /// # Errors + /// + /// Returns an error if the head count returned by llama.cpp does not fit into a `u32`. + pub fn n_head(&self) -> Result { u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) }) - .expect("llama.cpp returns a positive value for n_head") } /// Returns the number of KV attention heads. /// - /// # Panics - /// Panics if the KV head count returned by llama.cpp is negative. - #[must_use] - pub fn n_head_kv(&self) -> u32 { - // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe + /// # Errors + /// + /// Returns an error if the KV head count returned by llama.cpp does not fit into a `u32`. + pub fn n_head_kv(&self) -> Result { u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) }) - .expect("llama.cpp returns a positive value for n_head_kv") } /// Returns whether the model is a hybrid network (Jamba, Granite, Qwen3xx, etc.) @@ -969,7 +957,7 @@ mod tests { assert!(model.n_vocab() > 0); assert!(model.n_embd() > 0); assert!(model.n_params() > 0); - assert!(model.n_ctx_train() > 0); + assert!(model.n_ctx_train().unwrap() > 0); } #[test] @@ -1086,7 +1074,7 @@ mod tests { fn n_layer_returns_positive() { let (_backend, model) = test_model::load_default_model().unwrap(); - assert!(model.n_layer() > 0); + assert!(model.n_layer().unwrap() > 0); } #[test] @@ -1094,7 +1082,7 @@ mod tests { fn n_head_returns_positive() { let (_backend, model) = test_model::load_default_model().unwrap(); - assert!(model.n_head() > 0); + assert!(model.n_head().unwrap() > 0); } #[test] @@ -1102,7 +1090,7 @@ mod tests { fn n_head_kv_returns_positive() { let (_backend, model) = test_model::load_default_model().unwrap(); - assert!(model.n_head_kv() > 0); + assert!(model.n_head_kv().unwrap() > 0); } #[test] @@ -1342,14 +1330,14 @@ mod tests { fn token_attr_returns_valid_attr() { let (_backend, model) = test_model::load_default_model().unwrap(); let bos = model.token_bos(); - let _attr = model.token_attr(bos); + let _attr = model.token_attr(bos).unwrap(); } #[test] #[serial] fn vocab_type_returns_valid_type() { let (_backend, model) = test_model::load_default_model().unwrap(); - let _vocab_type = model.vocab_type(); + let _vocab_type = model.vocab_type().unwrap(); } #[test] @@ -1873,11 +1861,7 @@ mod tests { use crate::sampling::LlamaSampler; use std::sync::Arc; - let backend = Arc::new(LlamaBackend::init().unwrap()); - let model_params = LlamaModelParams::default(); - let model_path = - test_model::download_file_from("Qwen/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf").unwrap(); - let model = LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap(); + let (backend, model) = test_model::load_default_model().unwrap(); let ctx_params = crate::context::params::LlamaContextParams::default() .with_n_ctx(std::num::NonZeroU32::new(512)); @@ -1923,11 +1907,7 @@ mod tests { use crate::sampling::LlamaSampler; use std::sync::Arc; - let backend = Arc::new(LlamaBackend::init().unwrap()); - let model_params = LlamaModelParams::default(); - let model_path = - test_model::download_file_from("Qwen/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf").unwrap(); - let model = LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap(); + let (backend, model) = test_model::load_default_model().unwrap(); let ctx_params = crate::context::params::LlamaContextParams::default() .with_n_ctx(std::num::NonZeroU32::new(512)); @@ -1975,11 +1955,7 @@ mod tests { use crate::sampling::LlamaSampler; use std::sync::Arc; - let backend = Arc::new(LlamaBackend::init().unwrap()); - let model_params = LlamaModelParams::default(); - let model_path = - test_model::download_file_from("Qwen/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf").unwrap(); - let model = LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap(); + let (backend, model) = test_model::load_default_model().unwrap(); let ctx_params = crate::context::params::LlamaContextParams::default() .with_n_ctx(std::num::NonZeroU32::new(512)); @@ -2042,11 +2018,7 @@ mod tests { use crate::sampling::LlamaSampler; use std::sync::Arc; - let backend = Arc::new(LlamaBackend::init().unwrap()); - let model_params = LlamaModelParams::default(); - let model_path = - test_model::download_file_from("Qwen/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf").unwrap(); - let model = LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap(); + let (backend, model) = test_model::load_default_model().unwrap(); let ctx_params = crate::context::params::LlamaContextParams::default() .with_n_ctx(std::num::NonZeroU32::new(512)); diff --git a/llama-cpp-bindings/src/mtmd/mtmd_context.rs b/llama-cpp-bindings/src/mtmd/mtmd_context.rs index 5a2e3f44..5f34544f 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_context.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_context.rs @@ -132,7 +132,7 @@ impl MtmdContext { text: MtmdInputText, bitmaps: &[&MtmdBitmap], ) -> Result { - let chunks = MtmdInputChunks::new(); + let chunks = MtmdInputChunks::new()?; let text_cstring = CString::new(text.text)?; let input_text = llama_cpp_bindings_sys::mtmd_input_text { text: text_cstring.as_ptr(), diff --git a/llama-cpp-bindings/src/mtmd/mtmd_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_error.rs index 9515fd06..09048ab8 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_error.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_error.rs @@ -51,6 +51,9 @@ pub enum MtmdTokenizeError { /// Image preprocessing error occurred #[error("Image preprocessing error")] ImagePreprocessingError, + /// Failed to create input chunks collection + #[error("{0}")] + InputChunksError(#[from] MtmdInputChunksError), /// Text contains characters that cannot be converted to C string #[error("Failed to create CString from text: {0}")] CStringError(#[from] std::ffi::NulError), diff --git a/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs b/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs index 0b8aacd1..cc564a39 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs @@ -4,6 +4,7 @@ use crate::context::LlamaContext; use super::mtmd_context::MtmdContext; use super::mtmd_error::MtmdEvalError; +use super::mtmd_error::MtmdInputChunksError; use super::mtmd_input_chunk::MtmdInputChunk; const fn check_eval_result(result: i32) -> Result<(), MtmdEvalError> { @@ -25,32 +26,28 @@ pub struct MtmdInputChunks { pub chunks: NonNull, } -impl Default for MtmdInputChunks { - fn default() -> Self { - Self::new() - } -} - impl MtmdInputChunks { - /// Create a new empty input chunks collection - /// # Panics - /// This function will panic if the underlying llama.cpp function returns null, - /// which should not happen. + /// Create a new empty input chunks collection. + /// + /// # Errors + /// + /// Returns `MtmdInputChunksError::NullResult` if the underlying llama.cpp function + /// returns null. /// /// # Examples /// /// ``` /// use llama_cpp_bindings::mtmd::MtmdInputChunks; /// - /// let chunks = MtmdInputChunks::new(); + /// let chunks = MtmdInputChunks::new().unwrap(); /// assert_eq!(chunks.len(), 0); /// assert!(chunks.is_empty()); /// ``` - #[must_use] - pub fn new() -> Self { + pub fn new() -> Result { let chunks = unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_init() }; - let chunks = NonNull::new(chunks).expect("llama.cpp mtmd_input_chunks_init returned null"); - Self { chunks } + let chunks = NonNull::new(chunks).ok_or(MtmdInputChunksError::NullResult)?; + + Ok(Self { chunks }) } /// Get the number of chunks @@ -148,8 +145,8 @@ mod tests { use super::MtmdInputChunks; #[test] - fn default_creates_empty_chunks() { - let chunks = MtmdInputChunks::default(); + fn new_creates_empty_chunks() { + let chunks = MtmdInputChunks::new().unwrap(); assert!(chunks.is_empty()); assert_eq!(chunks.len(), 0); @@ -157,7 +154,7 @@ mod tests { #[test] fn get_out_of_bounds_returns_none() { - let chunks = MtmdInputChunks::new(); + let chunks = MtmdInputChunks::new().unwrap(); assert!(chunks.get(0).is_none()); assert!(chunks.get(999).is_none()); diff --git a/llama-cpp-bindings/src/sampling.rs b/llama-cpp-bindings/src/sampling.rs index 4d3a2fd3..1e677faf 100644 --- a/llama-cpp-bindings/src/sampling.rs +++ b/llama-cpp-bindings/src/sampling.rs @@ -10,18 +10,20 @@ use crate::model::LlamaModel; use crate::token::LlamaToken; use crate::token::data_array::LlamaTokenDataArray; use crate::token::logit_bias::LlamaLogitBias; -use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError, status_is_ok}; +use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError}; fn check_sampler_accept_status( status: llama_cpp_bindings_sys::llama_rs_status, error_ptr: *mut c_char, ) -> Result<(), SamplerAcceptError> { - if status_is_ok(status) { - Ok(()) - } else { - Err(SamplerAcceptError::CppException(unsafe { + match status { + llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => Ok(()), + llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => { + Err(SamplerAcceptError::InvalidArgument) + } + _ => Err(SamplerAcceptError::CppException(unsafe { read_and_free_cpp_error(error_ptr) - })) + })), } } @@ -539,7 +541,12 @@ impl LlamaSampler { let mut seq_breaker_pointers: Vec<*const c_char> = seq_breakers.iter().map(|s| s.as_ptr()).collect(); - let n_ctx_train = checked_u32_as_i32(model.n_ctx_train())?; + let n_ctx_train_value = model.n_ctx_train().map_err(|convert_error| { + GrammarError::IntegerOverflow(format!( + "n_ctx_train does not fit into u32: {convert_error}" + )) + })?; + let n_ctx_train = checked_u32_as_i32(n_ctx_train_value)?; let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dry( model.vocab_ptr(), @@ -996,13 +1003,39 @@ mod tests { } #[test] - fn check_sampler_accept_status_error() { + fn check_sampler_accept_status_ok() { + let result = super::check_sampler_accept_status( + llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK, + std::ptr::null_mut(), + ); + + assert!(result.is_ok()); + } + + #[test] + fn check_sampler_accept_status_invalid_argument() { + let result = super::check_sampler_accept_status( + llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT, + std::ptr::null_mut(), + ); + + assert!(matches!( + result, + Err(crate::SamplerAcceptError::InvalidArgument) + )); + } + + #[test] + fn check_sampler_accept_status_exception() { let result = super::check_sampler_accept_status( llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION, std::ptr::null_mut(), ); - assert!(result.is_err()); + assert!(matches!( + result, + Err(crate::SamplerAcceptError::CppException(_)) + )); } #[test] diff --git a/llama-cpp-bindings/tests/constrained_decoding.rs b/llama-cpp-bindings/tests/constrained_decoding.rs index 6e2a4d52..faf01321 100644 --- a/llama-cpp-bindings/tests/constrained_decoding.rs +++ b/llama-cpp-bindings/tests/constrained_decoding.rs @@ -63,8 +63,6 @@ fn json_schema_constrains_output() -> Result<()> { print!("{output_string}"); std::io::stdout().flush()?; - sampler.accept(token)?; - batch.clear(); batch.add(token, n_cur, &[0], true)?; n_cur += 1; diff --git a/llama-cpp-bindings/tests/openai_server.rs b/llama-cpp-bindings/tests/openai_server.rs index aa239e24..0c698e7d 100644 --- a/llama-cpp-bindings/tests/openai_server.rs +++ b/llama-cpp-bindings/tests/openai_server.rs @@ -44,7 +44,7 @@ fn run_chat_completion( let tokens = model.str_to_token(&result.prompt, AddBos::Always)?; let tokens_len_u32 = u32::try_from(tokens.len())?; - let n_ctx = model.n_ctx_train().max(tokens_len_u32 + max_tokens); + let n_ctx = model.n_ctx_train()?.max(tokens_len_u32 + max_tokens); let ctx_params = LlamaContextParams::default() .with_n_ctx(NonZeroU32::new(n_ctx)) .with_n_batch(n_ctx); diff --git a/llama-cpp-bindings/tests/openai_streaming.rs b/llama-cpp-bindings/tests/openai_streaming.rs index a36e85e9..2f874ce1 100644 --- a/llama-cpp-bindings/tests/openai_streaming.rs +++ b/llama-cpp-bindings/tests/openai_streaming.rs @@ -63,7 +63,7 @@ fn streaming_deltas_produce_valid_chunks() -> Result<()> { let tokens = model.str_to_token(&result.prompt, AddBos::Always)?; let n_predict: i32 = 128; let n_ctx = model - .n_ctx_train() + .n_ctx_train()? .max(tokens.len() as u32 + n_predict as u32); let ctx_params = LlamaContextParams::default() .with_n_ctx(NonZeroU32::new(n_ctx)) diff --git a/llama-cpp-bindings/tests/text_generation.rs b/llama-cpp-bindings/tests/text_generation.rs index 218e015d..7b866209 100644 --- a/llama-cpp-bindings/tests/text_generation.rs +++ b/llama-cpp-bindings/tests/text_generation.rs @@ -66,8 +66,6 @@ fn raw_prompt_completion_with_timing() -> Result<()> { while n_cur <= n_len { let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; - sampler.accept(token)?; - if model.is_eog_token(token) { break; } @@ -140,8 +138,6 @@ fn chat_inference_produces_coherent_output() -> Result<()> { while position <= max_tokens { let token = sampler.sample(&context, batch.n_tokens() - 1)?; - sampler.accept(token)?; - if model.is_eog_token(token) { break; } diff --git a/llama-cpp-bindings/tests/tool_calling.rs b/llama-cpp-bindings/tests/tool_calling.rs index 3bee8792..5012c833 100644 --- a/llama-cpp-bindings/tests/tool_calling.rs +++ b/llama-cpp-bindings/tests/tool_calling.rs @@ -67,7 +67,7 @@ fn tool_calling_generates_grammar_and_prompt() -> Result<()> { let tokens = model.str_to_token(&result.prompt, AddBos::Always)?; let n_predict: i32 = 128; let n_ctx = model - .n_ctx_train() + .n_ctx_train()? .max(tokens.len() as u32 + n_predict as u32); let ctx_params = LlamaContextParams::default() .with_n_ctx(NonZeroU32::new(n_ctx)) From 452860a88d10901ec4627c580576c92d3c916064 Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk Date: Thu, 2 Apr 2026 23:18:09 +0200 Subject: [PATCH 6/6] rename ci.yml to unit-tests.yml, add rust-cache to formatting job --- .github/workflows/{ci.yml => unit-tests.yml} | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) rename .github/workflows/{ci.yml => unit-tests.yml} (92%) diff --git a/.github/workflows/ci.yml b/.github/workflows/unit-tests.yml similarity index 92% rename from .github/workflows/ci.yml rename to .github/workflows/unit-tests.yml index 6f9e4b97..a1e2a152 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/unit-tests.yml @@ -1,4 +1,4 @@ -name: ci +name: unit-tests on: push: @@ -19,6 +19,8 @@ jobs: - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - run: make fmt test: