diff --git a/.github/workflows/ci.yml b/.github/workflows/unit-tests.yml similarity index 73% rename from .github/workflows/ci.yml rename to .github/workflows/unit-tests.yml index 8e1ebf03..a1e2a152 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/unit-tests.yml @@ -1,7 +1,8 @@ -name: CI +name: unit-tests on: push: + branches: [main] pull_request: env: @@ -9,7 +10,7 @@ env: jobs: fmt: - name: Formatting + name: formatting runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -17,26 +18,23 @@ jobs: submodules: recursive - uses: dtolnay/rust-toolchain@stable - with: - toolchain: "1.92.0" - components: rustfmt + + - uses: Swatinem/rust-cache@v2 - run: make fmt test: - name: Tests + name: tests runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: submodules: recursive - - name: Install system dependencies + - name: install system dependencies run: sudo apt-get update && sudo apt-get install -y cmake libclang-dev - uses: dtolnay/rust-toolchain@stable - with: - toolchain: "1.92.0" - uses: Swatinem/rust-cache@v2 diff --git a/llama-cpp-bindings-sys/wrapper_common.cpp b/llama-cpp-bindings-sys/wrapper_common.cpp index 373082a8..a71cae64 100644 --- a/llama-cpp-bindings-sys/wrapper_common.cpp +++ b/llama-cpp-bindings-sys/wrapper_common.cpp @@ -1,5 +1,6 @@ #include "wrapper_common.h" +#include #include #include #include @@ -15,18 +16,30 @@ extern "C" llama_rs_status llama_rs_json_schema_to_grammar( const char * schema_json, bool force_gbnf, - char ** out_grammar) { - if (!schema_json || !out_grammar) { + char ** out_grammar, + char ** out_error) { + if (!schema_json || !out_grammar || !out_error) { return LLAMA_RS_STATUS_INVALID_ARGUMENT; } *out_grammar = nullptr; + *out_error = nullptr; + try { const auto schema = nlohmann::ordered_json::parse(schema_json); const auto grammar = json_schema_to_grammar(schema, force_gbnf); *out_grammar = llama_rs_dup_string(grammar); + return *out_grammar ? LLAMA_RS_STATUS_OK : LLAMA_RS_STATUS_ALLOCATION_FAILED; - } catch (const std::exception &) { + } catch (const std::exception & err) { + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + + return LLAMA_RS_STATUS_EXCEPTION; + } catch (...) { + fprintf(stderr, "%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + return LLAMA_RS_STATUS_EXCEPTION; } } @@ -85,10 +98,25 @@ extern "C" void llama_rs_string_free(char * ptr) { extern "C" struct llama_sampler * llama_rs_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, - const char * grammar_root) { + const char * grammar_root, + char ** out_error) { + if (!out_error) { + return nullptr; + } + + *out_error = nullptr; + try { return llama_sampler_init_grammar(vocab, grammar_str, grammar_root); + } catch (const std::exception & err) { + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + + return nullptr; } catch (...) { + fprintf(stderr, "%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + return nullptr; } } @@ -100,7 +128,14 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy( const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, - size_t num_trigger_tokens) { + size_t num_trigger_tokens, + char ** out_error) { + if (!out_error) { + return nullptr; + } + + *out_error = nullptr; + try { std::vector trigger_patterns; trigger_patterns.reserve(num_trigger_words); @@ -115,6 +150,7 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy( for (const auto & pattern : trigger_patterns) { trigger_patterns_c.push_back(pattern.c_str()); } + return llama_sampler_init_grammar_lazy_patterns( vocab, grammar_str, @@ -123,7 +159,15 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy( trigger_patterns_c.size(), trigger_tokens, num_trigger_tokens); + } catch (const std::exception & err) { + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + + return nullptr; } catch (...) { + fprintf(stderr, "%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + return nullptr; } } @@ -135,7 +179,14 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns( const char ** trigger_patterns, size_t num_trigger_patterns, const llama_token * trigger_tokens, - size_t num_trigger_tokens) { + size_t num_trigger_tokens, + char ** out_error) { + if (!out_error) { + return nullptr; + } + + *out_error = nullptr; + try { return llama_sampler_init_grammar_lazy_patterns( vocab, @@ -145,7 +196,15 @@ extern "C" struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns( num_trigger_patterns, trigger_tokens, num_trigger_tokens); + } catch (const std::exception & err) { + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + + return nullptr; } catch (...) { + fprintf(stderr, "%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + return nullptr; } } @@ -164,6 +223,7 @@ extern "C" llama_pos llama_rs_memory_seq_pos_max( if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { return -1; } + return llama_memory_seq_pos_max(mem, seq_id); } @@ -231,16 +291,58 @@ extern "C" llama_rs_status llama_rs_memory_seq_div( return LLAMA_RS_STATUS_OK; } -extern "C" llama_rs_status llama_rs_sampler_accept(struct llama_sampler * sampler, llama_token token) { - if (!sampler) { +extern "C" llama_rs_status llama_rs_sampler_sample( + struct llama_sampler * sampler, + struct llama_context * ctx, + int32_t idx, + llama_token * out_token, + char ** out_error) { + if (!sampler || !ctx || !out_token || !out_error) { return LLAMA_RS_STATUS_INVALID_ARGUMENT; } + + *out_error = nullptr; + + try { + *out_token = llama_sampler_sample(sampler, ctx, idx); + + return LLAMA_RS_STATUS_OK; + } catch (const std::exception & err) { + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + + return LLAMA_RS_STATUS_EXCEPTION; + } catch (...) { + fprintf(stderr, "%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + + return LLAMA_RS_STATUS_EXCEPTION; + } +} + +extern "C" llama_rs_status llama_rs_sampler_accept( + struct llama_sampler * sampler, + llama_token token, + char ** out_error) { + if (!sampler || !out_error) { + return LLAMA_RS_STATUS_INVALID_ARGUMENT; + } + + *out_error = nullptr; + try { llama_sampler_accept(sampler, token); + return LLAMA_RS_STATUS_OK; - } catch (const std::exception &) { + } catch (const std::exception & err) { + fprintf(stderr, "%s: C++ exception: %s\n", __func__, err.what()); + *out_error = llama_rs_dup_string(err.what()); + return LLAMA_RS_STATUS_EXCEPTION; } catch (...) { + fprintf(stderr, "%s: unknown C++ exception\n", __func__); + *out_error = llama_rs_dup_string("unknown C++ exception"); + return LLAMA_RS_STATUS_EXCEPTION; } } diff --git a/llama-cpp-bindings-sys/wrapper_common.h b/llama-cpp-bindings-sys/wrapper_common.h index d1ab9c27..14a8998e 100644 --- a/llama-cpp-bindings-sys/wrapper_common.h +++ b/llama-cpp-bindings-sys/wrapper_common.h @@ -39,12 +39,14 @@ extern "C" { llama_rs_status llama_rs_json_schema_to_grammar( const char * schema_json, bool force_gbnf, - char ** out_grammar); + char ** out_grammar, + char ** out_error); struct llama_sampler * llama_rs_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, - const char * grammar_root); + const char * grammar_root, + char ** out_error); struct llama_sampler * llama_rs_sampler_init_grammar_lazy( const struct llama_vocab * vocab, @@ -53,7 +55,8 @@ struct llama_sampler * llama_rs_sampler_init_grammar_lazy( const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, - size_t num_trigger_tokens); + size_t num_trigger_tokens, + char ** out_error); struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns( const struct llama_vocab * vocab, @@ -62,9 +65,20 @@ struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns( const char ** trigger_patterns, size_t num_trigger_patterns, const llama_token * trigger_tokens, - size_t num_trigger_tokens); + size_t num_trigger_tokens, + char ** out_error); + +llama_rs_status llama_rs_sampler_accept( + struct llama_sampler * sampler, + llama_token token, + char ** out_error); -llama_rs_status llama_rs_sampler_accept(struct llama_sampler * sampler, llama_token token); +llama_rs_status llama_rs_sampler_sample( + struct llama_sampler * sampler, + struct llama_context * ctx, + int32_t idx, + llama_token * out_token, + char ** out_error); void llama_rs_chat_template_result_free(struct llama_rs_chat_template_result * result); void llama_rs_string_free(char * ptr); diff --git a/llama-cpp-bindings/src/error.rs b/llama-cpp-bindings/src/error.rs index 2f56ba88..9fc73b02 100644 --- a/llama-cpp-bindings/src/error.rs +++ b/llama-cpp-bindings/src/error.rs @@ -175,8 +175,8 @@ pub enum GrammarError { #[error("String contains null bytes: {0}")] NulError(#[from] NulError), /// The grammar call returned null - #[error("Grammar call returned null")] - NullGrammar, + #[error("Grammar initialization failed: {0}")] + NullGrammar(String), /// An integer value exceeded the allowed range #[error("Integer overflow: {0}")] IntegerOverflow(String), @@ -193,6 +193,18 @@ pub enum SamplingError { IntegerOverflow(String), } +/// Errors that can occur when sampling a token. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum SampleError { + /// A C++ exception was thrown during sampling + #[error("C++ exception during sampling: {0}")] + CppException(String), + + /// An invalid argument was passed to the sampler + #[error("Invalid argument passed to sampler")] + InvalidArgument, +} + /// Decode a error from llama.cpp into a [`DecodeError`]. impl From for DecodeError { fn from(value: NonZeroI32) -> Self { @@ -345,9 +357,13 @@ pub enum ChatParseError { /// Failed to accept a token in a sampler. #[derive(Debug, thiserror::Error)] pub enum SamplerAcceptError { - /// llama.cpp returned an error code. - #[error("ffi error {0}")] - FfiError(i32), + /// A C++ exception was thrown during accept + #[error("C++ exception during sampler accept: {0}")] + CppException(String), + + /// An invalid argument was passed (null sampler or null error pointer) + #[error("Invalid argument passed to sampler accept")] + InvalidArgument, } /// Errors that can occur when modifying model parameters. diff --git a/llama-cpp-bindings/src/ffi_error_reader.rs b/llama-cpp-bindings/src/ffi_error_reader.rs new file mode 100644 index 00000000..17a445ab --- /dev/null +++ b/llama-cpp-bindings/src/ffi_error_reader.rs @@ -0,0 +1,33 @@ +use std::ffi::{CStr, c_char}; + +/// Reads a C error string, converts to Rust `String`, and frees the C memory. +/// +/// # Safety +/// +/// `error_ptr` must be either null or a valid pointer to a null-terminated +/// C string allocated by `llama_rs_dup_string`. +pub unsafe fn read_and_free_cpp_error(error_ptr: *mut c_char) -> String { + if error_ptr.is_null() { + return "unknown error".to_owned(); + } + + let message = unsafe { CStr::from_ptr(error_ptr) } + .to_string_lossy() + .into_owned(); + + unsafe { llama_cpp_bindings_sys::llama_rs_string_free(error_ptr) }; + + message +} + +#[cfg(test)] +mod tests { + use super::read_and_free_cpp_error; + + #[test] + fn returns_unknown_for_null_pointer() { + let result = unsafe { read_and_free_cpp_error(std::ptr::null_mut()) }; + + assert_eq!(result, "unknown error"); + } +} diff --git a/llama-cpp-bindings/src/llama_utility_status_is_ok.rs b/llama-cpp-bindings/src/ffi_status_is_ok.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_status_is_ok.rs rename to llama-cpp-bindings/src/ffi_status_is_ok.rs diff --git a/llama-cpp-bindings/src/llama_utility_status_to_i32.rs b/llama-cpp-bindings/src/ffi_status_to_i32.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_status_to_i32.rs rename to llama-cpp-bindings/src/ffi_status_to_i32.rs diff --git a/llama-cpp-bindings/src/llama_utility_ggml_time_us.rs b/llama-cpp-bindings/src/ggml_time_us.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_ggml_time_us.rs rename to llama-cpp-bindings/src/ggml_time_us.rs diff --git a/llama-cpp-bindings/src/llama_utility_json_schema_to_grammar.rs b/llama-cpp-bindings/src/json_schema_to_grammar.rs similarity index 62% rename from llama-cpp-bindings/src/llama_utility_json_schema_to_grammar.rs rename to llama-cpp-bindings/src/json_schema_to_grammar.rs index af7ba6c8..34590a82 100644 --- a/llama-cpp-bindings/src/llama_utility_json_schema_to_grammar.rs +++ b/llama-cpp-bindings/src/json_schema_to_grammar.rs @@ -1,8 +1,7 @@ -use std::ffi::{CStr, CString}; +use std::ffi::{CStr, CString, c_char}; use crate::error::{LlamaCppError, Result}; -use crate::llama_utility_status_is_ok::status_is_ok; -use crate::llama_utility_status_to_i32::status_to_i32; +use crate::ffi_status_is_ok::status_is_ok; /// Convert a JSON schema string into a llama.cpp grammar string. /// @@ -11,32 +10,40 @@ use crate::llama_utility_status_to_i32::status_to_i32; pub fn json_schema_to_grammar(schema_json: &str) -> Result { let schema_cstr = CString::new(schema_json) .map_err(|err| LlamaCppError::JsonSchemaToGrammarError(err.to_string()))?; - let mut out = std::ptr::null_mut(); - let rc = unsafe { + let mut out: *mut c_char = std::ptr::null_mut(); + let mut error_ptr: *mut c_char = std::ptr::null_mut(); + + let status = unsafe { llama_cpp_bindings_sys::llama_rs_json_schema_to_grammar( schema_cstr.as_ptr(), false, &raw mut out, + &raw mut error_ptr, ) }; - let result = { - if !status_is_ok(rc) || out.is_null() { - return Err(LlamaCppError::JsonSchemaToGrammarError(format!( - "ffi error {}", - status_to_i32(rc) - ))); - } - let grammar_bytes = unsafe { CStr::from_ptr(out) }.to_bytes().to_vec(); - let grammar = String::from_utf8(grammar_bytes) - .map_err(|err| LlamaCppError::JsonSchemaToGrammarError(err.to_string()))?; - - Ok(grammar) - }; + if !status_is_ok(status) || out.is_null() { + let message = if error_ptr.is_null() { + "unknown error".to_owned() + } else { + let message = unsafe { CStr::from_ptr(error_ptr) } + .to_string_lossy() + .into_owned(); + + unsafe { llama_cpp_bindings_sys::llama_rs_string_free(error_ptr) }; + + message + }; + + return Err(LlamaCppError::JsonSchemaToGrammarError(message)); + } + + let grammar_bytes = unsafe { CStr::from_ptr(out) }.to_bytes().to_vec(); unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out) }; - result + String::from_utf8(grammar_bytes) + .map_err(|err| LlamaCppError::JsonSchemaToGrammarError(err.to_string())) } #[cfg(test)] diff --git a/llama-cpp-bindings/src/lib.rs b/llama-cpp-bindings/src/lib.rs index 8aaad611..bd60075f 100644 --- a/llama-cpp-bindings/src/lib.rs +++ b/llama-cpp-bindings/src/lib.rs @@ -12,25 +12,26 @@ pub mod context; pub mod error; +pub mod ffi_error_reader; +pub mod ffi_status_is_ok; +pub mod ffi_status_to_i32; +pub mod ggml_time_us; pub mod gguf_context; pub mod gguf_context_error; pub mod gguf_type; +pub mod json_schema_to_grammar; pub mod llama_backend; pub mod llama_backend_device; pub mod llama_backend_numa_strategy; pub mod llama_batch; -pub mod llama_utility_ggml_time_us; -pub mod llama_utility_json_schema_to_grammar; -pub mod llama_utility_llama_time_us; -pub mod llama_utility_max_devices; -pub mod llama_utility_mlock_supported; -pub mod llama_utility_mmap_supported; -pub mod llama_utility_status_is_ok; -pub mod llama_utility_status_to_i32; +pub mod llama_time_us; #[cfg(feature = "llguidance")] pub mod llguidance_sampler; pub mod log; pub mod log_options; +pub mod max_devices; +pub mod mlock_supported; +pub mod mmap_supported; pub mod model; #[cfg(feature = "mtmd")] pub mod mtmd; @@ -44,22 +45,22 @@ pub use error::{ ApplyChatTemplateError, ChatParseError, ChatTemplateError, DecodeError, EmbeddingsError, EncodeError, GrammarError, LlamaContextLoadError, LlamaCppError, LlamaLoraAdapterInitError, LlamaLoraAdapterRemoveError, LlamaLoraAdapterSetError, LlamaModelLoadError, LogitsError, - MetaValError, ModelParamsError, NewLlamaChatMessageError, Result, SamplerAcceptError, - SamplingError, StringToTokenError, TokenSamplingError, TokenToStringError, + MetaValError, ModelParamsError, NewLlamaChatMessageError, Result, SampleError, + SamplerAcceptError, SamplingError, StringToTokenError, TokenSamplingError, TokenToStringError, }; pub use llama_backend_device::{ LlamaBackendDevice, LlamaBackendDeviceType, list_llama_ggml_backend_devices, }; -pub use llama_utility_ggml_time_us::ggml_time_us; -pub use llama_utility_json_schema_to_grammar::json_schema_to_grammar; -pub use llama_utility_llama_time_us::llama_time_us; -pub use llama_utility_max_devices::max_devices; -pub use llama_utility_mlock_supported::mlock_supported; -pub use llama_utility_mmap_supported::mmap_supported; -pub use llama_utility_status_is_ok::status_is_ok; -pub use llama_utility_status_to_i32::status_to_i32; +pub use ffi_status_is_ok::status_is_ok; +pub use ffi_status_to_i32::status_to_i32; +pub use ggml_time_us::ggml_time_us; +pub use json_schema_to_grammar::json_schema_to_grammar; +pub use llama_time_us::llama_time_us; +pub use max_devices::max_devices; +pub use mlock_supported::mlock_supported; +pub use mmap_supported::mmap_supported; pub use log::send_logs_to_tracing; pub use log_options::LogOptions; diff --git a/llama-cpp-bindings/src/llama_utility_llama_time_us.rs b/llama-cpp-bindings/src/llama_time_us.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_llama_time_us.rs rename to llama-cpp-bindings/src/llama_time_us.rs diff --git a/llama-cpp-bindings/src/llguidance_sampler.rs b/llama-cpp-bindings/src/llguidance_sampler.rs index 0db8dd16..4a1738ab 100644 --- a/llama-cpp-bindings/src/llguidance_sampler.rs +++ b/llama-cpp-bindings/src/llguidance_sampler.rs @@ -207,7 +207,9 @@ pub fn create_llg_sampler( }; if sampler.is_null() { - Err(GrammarError::NullGrammar) + Err(GrammarError::NullGrammar( + "llguidance sampler returned null".to_owned(), + )) } else { Ok(LlamaSampler { sampler }) } diff --git a/llama-cpp-bindings/src/llama_utility_max_devices.rs b/llama-cpp-bindings/src/max_devices.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_max_devices.rs rename to llama-cpp-bindings/src/max_devices.rs diff --git a/llama-cpp-bindings/src/llama_utility_mlock_supported.rs b/llama-cpp-bindings/src/mlock_supported.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_mlock_supported.rs rename to llama-cpp-bindings/src/mlock_supported.rs diff --git a/llama-cpp-bindings/src/llama_utility_mmap_supported.rs b/llama-cpp-bindings/src/mmap_supported.rs similarity index 100% rename from llama-cpp-bindings/src/llama_utility_mmap_supported.rs rename to llama-cpp-bindings/src/mmap_supported.rs diff --git a/llama-cpp-bindings/src/model.rs b/llama-cpp-bindings/src/model.rs index a3f81d2d..bf22e247 100644 --- a/llama-cpp-bindings/src/model.rs +++ b/llama-cpp-bindings/src/model.rs @@ -77,16 +77,15 @@ impl LlamaModel { unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) } } - /// get the number of tokens the model was trained on + /// Get the number of tokens the model was trained on. /// - /// # Panics + /// # Errors /// - /// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most - /// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive. - #[must_use] - pub fn n_ctx_train(&self) -> u32 { + /// Returns an error if the value returned by llama.cpp does not fit into a `u32`. + pub fn n_ctx_train(&self) -> Result { let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) }; - u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32") + + u32::try_from(n_ctx_train) } /// Get all tokens in the model. @@ -151,11 +150,8 @@ impl LlamaModel { /// /// # Errors /// - /// - if [`str`] contains a null byte. - /// - /// # Panics - /// - /// - if there is more than [`usize::MAX`] [`LlamaToken`]s in [`str`]. + /// - if [`str`] contains a null byte + /// - if an integer conversion fails during tokenization /// /// /// ```no_run @@ -200,7 +196,7 @@ impl LlamaModel { }; let size = if size.is_negative() { - buffer.reserve_exact(usize::try_from(-size).expect("negated size fits into usize")); + buffer.reserve_exact(usize::try_from(-size)?); unsafe { llama_cpp_bindings_sys::llama_tokenize( self.vocab_ptr(), @@ -228,14 +224,17 @@ impl LlamaModel { /// Get the type of a token. /// - /// # Panics + /// # Errors /// - /// If the token type is not known to this library. - #[must_use] - pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs { + /// Returns an error if the token type is not known to this library. + pub fn token_attr( + &self, + LlamaToken(id): LlamaToken, + ) -> Result { let token_type = unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) }; - LlamaTokenAttrs::try_from(token_type).expect("token type is valid") + + LlamaTokenAttrs::try_from(token_type) } /// Convert a token to a string using the underlying llama.cpp `llama_token_to_piece` function. @@ -252,9 +251,7 @@ impl LlamaModel { /// /// - if the token type is unknown /// - /// # Panics - /// - /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen) + /// - if the returned size from llama.cpp does not fit into a `usize` pub fn token_to_piece( &self, token: LlamaToken, @@ -263,15 +260,11 @@ impl LlamaModel { lstrip: Option, ) -> Result { let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) { - Err(TokenToStringError::InsufficientBufferSpace(required_size)) => self - .token_to_piece_bytes( - token, - (-required_size) - .try_into() - .expect("Error buffer size is positive"), - special, - lstrip, - ), + Err(TokenToStringError::InsufficientBufferSpace(required_size)) => { + let buffer_size: usize = (-required_size).try_into()?; + + self.token_to_piece_bytes(token, buffer_size, special, lstrip) + } other => other, }?; @@ -292,6 +285,7 @@ impl LlamaModel { /// /// - if the token type is unknown /// - the resultant token is larger than `buffer_size`. + /// - if an integer conversion fails #[allow(clippy::missing_panics_doc)] pub fn token_to_piece_bytes( &self, @@ -325,7 +319,7 @@ impl LlamaModel { size => { let string = unsafe { CString::from_raw(buf) }; let mut bytes = string.into_bytes(); - let len = usize::try_from(size).expect("size is positive and fits into usize"); + let len = usize::try_from(size)?; bytes.truncate(len); Ok(bytes) @@ -344,13 +338,13 @@ impl LlamaModel { /// The type of vocab the model was trained on. /// - /// # Panics + /// # Errors /// - /// If llama-cpp emits a vocab type that is not known to this library. - #[must_use] - pub fn vocab_type(&self) -> VocabType { + /// Returns an error if llama.cpp emits a vocab type that is not known to this library. + pub fn vocab_type(&self) -> Result { let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) }; - VocabType::try_from(vocab_type).expect("invalid vocab type") + + VocabType::try_from(vocab_type) } /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32 @@ -380,35 +374,29 @@ impl LlamaModel { /// Returns the number of layers within the model. /// - /// # Panics - /// Panics if the layer count returned by llama.cpp is negative. - #[must_use] - pub fn n_layer(&self) -> u32 { - // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe + /// # Errors + /// + /// Returns an error if the layer count returned by llama.cpp does not fit into a `u32`. + pub fn n_layer(&self) -> Result { u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) }) - .expect("llama.cpp returns a positive value for n_layer") } /// Returns the number of attention heads within the model. /// - /// # Panics - /// Panics if the head count returned by llama.cpp is negative. - #[must_use] - pub fn n_head(&self) -> u32 { - // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe + /// # Errors + /// + /// Returns an error if the head count returned by llama.cpp does not fit into a `u32`. + pub fn n_head(&self) -> Result { u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) }) - .expect("llama.cpp returns a positive value for n_head") } /// Returns the number of KV attention heads. /// - /// # Panics - /// Panics if the KV head count returned by llama.cpp is negative. - #[must_use] - pub fn n_head_kv(&self) -> u32 { - // llama.cpp API returns int32_t but the underlying field is uint32_t, so this is safe + /// # Errors + /// + /// Returns an error if the KV head count returned by llama.cpp does not fit into a `u32`. + pub fn n_head_kv(&self) -> Result { u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) }) - .expect("llama.cpp returns a positive value for n_head_kv") } /// Returns whether the model is a hybrid network (Jamba, Granite, Qwen3xx, etc.) @@ -969,7 +957,7 @@ mod tests { assert!(model.n_vocab() > 0); assert!(model.n_embd() > 0); assert!(model.n_params() > 0); - assert!(model.n_ctx_train() > 0); + assert!(model.n_ctx_train().unwrap() > 0); } #[test] @@ -1086,7 +1074,7 @@ mod tests { fn n_layer_returns_positive() { let (_backend, model) = test_model::load_default_model().unwrap(); - assert!(model.n_layer() > 0); + assert!(model.n_layer().unwrap() > 0); } #[test] @@ -1094,7 +1082,7 @@ mod tests { fn n_head_returns_positive() { let (_backend, model) = test_model::load_default_model().unwrap(); - assert!(model.n_head() > 0); + assert!(model.n_head().unwrap() > 0); } #[test] @@ -1102,7 +1090,7 @@ mod tests { fn n_head_kv_returns_positive() { let (_backend, model) = test_model::load_default_model().unwrap(); - assert!(model.n_head_kv() > 0); + assert!(model.n_head_kv().unwrap() > 0); } #[test] @@ -1342,14 +1330,14 @@ mod tests { fn token_attr_returns_valid_attr() { let (_backend, model) = test_model::load_default_model().unwrap(); let bos = model.token_bos(); - let _attr = model.token_attr(bos); + let _attr = model.token_attr(bos).unwrap(); } #[test] #[serial] fn vocab_type_returns_valid_type() { let (_backend, model) = test_model::load_default_model().unwrap(); - let _vocab_type = model.vocab_type(); + let _vocab_type = model.vocab_type().unwrap(); } #[test] @@ -1838,4 +1826,238 @@ mod tests { assert!(result.is_err()); } + + #[test] + #[serial] + fn sample_returns_result_and_succeeds_with_valid_index() { + use crate::sampling::LlamaSampler; + use crate::token::LlamaToken; + + let (backend, model) = test_model::load_default_model().unwrap(); + let ctx_params = crate::context::params::LlamaContextParams::default() + .with_n_ctx(std::num::NonZeroU32::new(256)); + let mut context = model.new_context(&backend, ctx_params).unwrap(); + + let tokens = model.str_to_token("Hello", AddBos::Always).unwrap(); + let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap(); + + batch.add_sequence(&tokens, 0, false).unwrap(); + + context.decode(&mut batch).unwrap(); + + let mut sampler = + LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]); + + // sample() now returns Result to catch C++ exceptions at the FFI + // boundary instead of aborting the process. + let result = sampler.sample(&context, batch.n_tokens() - 1); + + assert!(result.is_ok()); + } + + #[test] + #[serial] + fn grammar_sampler_constrains_output_to_yes_or_no() { + use crate::sampling::LlamaSampler; + use std::sync::Arc; + + let (backend, model) = test_model::load_default_model().unwrap(); + + let ctx_params = crate::context::params::LlamaContextParams::default() + .with_n_ctx(std::num::NonZeroU32::new(512)); + let mut context = model.new_context(&backend, ctx_params).unwrap(); + + let prompt = "<|im_start|>user\nIs the sky blue? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + let tokens = model.str_to_token(prompt, AddBos::Always).unwrap(); + let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap(); + + batch.add_sequence(&tokens, 0, false).unwrap(); + + context.decode(&mut batch).unwrap(); + + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::grammar(&model, r#"root ::= [Yy] [Ee] [Ss] | [Nn] [Oo]"#, "root") + .unwrap(), + LlamaSampler::temp(0.8), + LlamaSampler::greedy(), + ]); + + let token = sampler.sample(&context, batch.n_tokens() - 1).unwrap(); + + assert!( + !model.is_eog_token(token), + "Grammar sampler should not allow EOS as first token" + ); + + let mut decoder = encoding_rs::UTF_8.new_decoder(); + let piece = model + .token_to_piece(token, &mut decoder, true, None) + .unwrap(); + let first_char = piece.chars().next().unwrap().to_lowercase().next().unwrap(); + + assert!( + first_char == 'y' || first_char == 'n', + "Grammar should constrain first token to start with y/n, got: '{piece}'" + ); + } + + #[test] + #[serial] + fn json_schema_grammar_sampler_constrains_output_to_json() { + use crate::sampling::LlamaSampler; + use std::sync::Arc; + + let (backend, model) = test_model::load_default_model().unwrap(); + + let ctx_params = crate::context::params::LlamaContextParams::default() + .with_n_ctx(std::num::NonZeroU32::new(512)); + let mut context = model.new_context(&backend, ctx_params).unwrap(); + + let prompt = "<|im_start|>user\nWhat is 2+2? Respond with a JSON object.<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + let tokens = model.str_to_token(prompt, AddBos::Always).unwrap(); + let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap(); + + batch.add_sequence(&tokens, 0, false).unwrap(); + + context.decode(&mut batch).unwrap(); + + let grammar_str = crate::json_schema_to_grammar( + r#"{"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]}"# + ).unwrap(); + + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::grammar(&model, &grammar_str, "root").unwrap(), + LlamaSampler::temp(0.8), + LlamaSampler::greedy(), + ]); + + let token = sampler.sample(&context, batch.n_tokens() - 1).unwrap(); + + assert!( + !model.is_eog_token(token), + "Grammar sampler should not allow EOS as first token" + ); + + let mut decoder = encoding_rs::UTF_8.new_decoder(); + let piece = model + .token_to_piece(token, &mut decoder, true, None) + .unwrap(); + + assert!( + piece.starts_with('{'), + "JSON schema grammar should constrain first token to start with '{{', got: '{piece}'" + ); + } + + #[test] + #[serial] + fn sample_with_grammar_produces_constrained_output_in_loop() { + use crate::sampling::LlamaSampler; + use std::sync::Arc; + + let (backend, model) = test_model::load_default_model().unwrap(); + + let ctx_params = crate::context::params::LlamaContextParams::default() + .with_n_ctx(std::num::NonZeroU32::new(512)); + let mut context = model.new_context(&backend, ctx_params).unwrap(); + + let prompt = "<|im_start|>user\nIs the sky blue? yes or no<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + let tokens = model.str_to_token(prompt, AddBos::Always).unwrap(); + let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap(); + + batch.add_sequence(&tokens, 0, false).unwrap(); + + context.decode(&mut batch).unwrap(); + + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::grammar(&model, r#"root ::= "yes" | "no""#, "root").unwrap(), + LlamaSampler::temp(0.8), + LlamaSampler::greedy(), + ]); + + let mut generated = String::new(); + let mut decoder = encoding_rs::UTF_8.new_decoder(); + let mut position = batch.n_tokens(); + + for iteration in 0..10 { + let token = sampler.sample(&context, -1).unwrap(); + let is_eog = model.is_eog_token(token); + + eprintln!(" iteration={iteration} token={} eog={is_eog}", token.0); + + if is_eog { + break; + } + + let piece = model + .token_to_piece(token, &mut decoder, true, None) + .unwrap(); + + eprintln!(" piece='{piece}'"); + + generated.push_str(&piece); + + batch.clear(); + batch.add(token, position, &[0], true).unwrap(); + position += 1; + + context.decode(&mut batch).unwrap(); + } + + let lowercase = generated.to_lowercase(); + + assert!( + lowercase == "yes" || lowercase == "no", + "Grammar loop should produce 'yes' or 'no', got: '{generated}'" + ); + } + + #[test] + #[serial] + fn sample_without_grammar_produces_multiple_tokens() { + use crate::sampling::LlamaSampler; + use std::sync::Arc; + + let (backend, model) = test_model::load_default_model().unwrap(); + + let ctx_params = crate::context::params::LlamaContextParams::default() + .with_n_ctx(std::num::NonZeroU32::new(512)); + let mut context = model.new_context(&backend, ctx_params).unwrap(); + + let prompt = + "<|im_start|>user\nSay hello<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + let tokens = model.str_to_token(prompt, AddBos::Always).unwrap(); + let mut batch = crate::llama_batch::LlamaBatch::new(512, 1).unwrap(); + + batch.add_sequence(&tokens, 0, false).unwrap(); + + context.decode(&mut batch).unwrap(); + + let mut sampler = + LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]); + + let mut token_count = 0; + let mut position = batch.n_tokens(); + + for _ in 0..5 { + let token = sampler.sample(&context, -1).unwrap(); + + if model.is_eog_token(token) { + break; + } + + token_count += 1; + + batch.clear(); + batch.add(token, position, &[0], true).unwrap(); + position += 1; + + context.decode(&mut batch).unwrap(); + } + + assert!( + token_count > 0, + "Should produce at least one token without grammar" + ); + } } diff --git a/llama-cpp-bindings/src/mtmd/mtmd_context.rs b/llama-cpp-bindings/src/mtmd/mtmd_context.rs index 5a2e3f44..5f34544f 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_context.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_context.rs @@ -132,7 +132,7 @@ impl MtmdContext { text: MtmdInputText, bitmaps: &[&MtmdBitmap], ) -> Result { - let chunks = MtmdInputChunks::new(); + let chunks = MtmdInputChunks::new()?; let text_cstring = CString::new(text.text)?; let input_text = llama_cpp_bindings_sys::mtmd_input_text { text: text_cstring.as_ptr(), diff --git a/llama-cpp-bindings/src/mtmd/mtmd_error.rs b/llama-cpp-bindings/src/mtmd/mtmd_error.rs index 9515fd06..09048ab8 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_error.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_error.rs @@ -51,6 +51,9 @@ pub enum MtmdTokenizeError { /// Image preprocessing error occurred #[error("Image preprocessing error")] ImagePreprocessingError, + /// Failed to create input chunks collection + #[error("{0}")] + InputChunksError(#[from] MtmdInputChunksError), /// Text contains characters that cannot be converted to C string #[error("Failed to create CString from text: {0}")] CStringError(#[from] std::ffi::NulError), diff --git a/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs b/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs index 0b8aacd1..cc564a39 100644 --- a/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs +++ b/llama-cpp-bindings/src/mtmd/mtmd_input_chunks.rs @@ -4,6 +4,7 @@ use crate::context::LlamaContext; use super::mtmd_context::MtmdContext; use super::mtmd_error::MtmdEvalError; +use super::mtmd_error::MtmdInputChunksError; use super::mtmd_input_chunk::MtmdInputChunk; const fn check_eval_result(result: i32) -> Result<(), MtmdEvalError> { @@ -25,32 +26,28 @@ pub struct MtmdInputChunks { pub chunks: NonNull, } -impl Default for MtmdInputChunks { - fn default() -> Self { - Self::new() - } -} - impl MtmdInputChunks { - /// Create a new empty input chunks collection - /// # Panics - /// This function will panic if the underlying llama.cpp function returns null, - /// which should not happen. + /// Create a new empty input chunks collection. + /// + /// # Errors + /// + /// Returns `MtmdInputChunksError::NullResult` if the underlying llama.cpp function + /// returns null. /// /// # Examples /// /// ``` /// use llama_cpp_bindings::mtmd::MtmdInputChunks; /// - /// let chunks = MtmdInputChunks::new(); + /// let chunks = MtmdInputChunks::new().unwrap(); /// assert_eq!(chunks.len(), 0); /// assert!(chunks.is_empty()); /// ``` - #[must_use] - pub fn new() -> Self { + pub fn new() -> Result { let chunks = unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_init() }; - let chunks = NonNull::new(chunks).expect("llama.cpp mtmd_input_chunks_init returned null"); - Self { chunks } + let chunks = NonNull::new(chunks).ok_or(MtmdInputChunksError::NullResult)?; + + Ok(Self { chunks }) } /// Get the number of chunks @@ -148,8 +145,8 @@ mod tests { use super::MtmdInputChunks; #[test] - fn default_creates_empty_chunks() { - let chunks = MtmdInputChunks::default(); + fn new_creates_empty_chunks() { + let chunks = MtmdInputChunks::new().unwrap(); assert!(chunks.is_empty()); assert_eq!(chunks.len(), 0); @@ -157,7 +154,7 @@ mod tests { #[test] fn get_out_of_bounds_returns_none() { - let chunks = MtmdInputChunks::new(); + let chunks = MtmdInputChunks::new().unwrap(); assert!(chunks.get(0).is_none()); assert!(chunks.get(999).is_none()); diff --git a/llama-cpp-bindings/src/sampling.rs b/llama-cpp-bindings/src/sampling.rs index 5b606ef7..1e677faf 100644 --- a/llama-cpp-bindings/src/sampling.rs +++ b/llama-cpp-bindings/src/sampling.rs @@ -5,27 +5,36 @@ use std::ffi::{CString, c_char}; use std::fmt::{Debug, Formatter}; use crate::context::LlamaContext; +use crate::ffi_error_reader::read_and_free_cpp_error; use crate::model::LlamaModel; use crate::token::LlamaToken; use crate::token::data_array::LlamaTokenDataArray; use crate::token::logit_bias::LlamaLogitBias; -use crate::{GrammarError, SamplerAcceptError, SamplingError, status_is_ok, status_to_i32}; +use crate::{GrammarError, SampleError, SamplerAcceptError, SamplingError}; -const fn check_sampler_accept_status( +fn check_sampler_accept_status( status: llama_cpp_bindings_sys::llama_rs_status, + error_ptr: *mut c_char, ) -> Result<(), SamplerAcceptError> { - if status_is_ok(status) { - Ok(()) - } else { - Err(SamplerAcceptError::FfiError(status_to_i32(status))) + match status { + llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => Ok(()), + llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => { + Err(SamplerAcceptError::InvalidArgument) + } + _ => Err(SamplerAcceptError::CppException(unsafe { + read_and_free_cpp_error(error_ptr) + })), } } -const fn check_sampler_not_null( +fn check_sampler_not_null( sampler: *mut llama_cpp_bindings_sys::llama_sampler, + error_ptr: *mut c_char, ) -> Result { if sampler.is_null() { - Err(GrammarError::NullGrammar) + Err(GrammarError::NullGrammar(unsafe { + read_and_free_cpp_error(error_ptr) + })) } else { Ok(LlamaSampler { sampler }) } @@ -56,14 +65,34 @@ impl Debug for LlamaSampler { } impl LlamaSampler { - /// Sample and accept a token from the idx-th output of the last evaluation - #[must_use] - pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken { - let token = unsafe { - llama_cpp_bindings_sys::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx) + /// Sample and accept a token from the idx-th output of the last evaluation. + /// + /// # Errors + /// + /// Returns [`SampleError`] if the C++ sampler throws an exception or if the index is invalid. + pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> Result { + let mut token: i32 = -1; + let mut error_ptr: *mut c_char = std::ptr::null_mut(); + + let status = unsafe { + llama_cpp_bindings_sys::llama_rs_sampler_sample( + self.sampler, + ctx.context.as_ptr(), + idx, + &raw mut token, + &raw mut error_ptr, + ) }; - LlamaToken(token) + match status { + llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => Ok(LlamaToken(token)), + llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT => { + Err(SampleError::InvalidArgument) + } + _ => Err(SampleError::CppException(unsafe { + read_and_free_cpp_error(error_ptr) + })), + } } /// Applies this sampler to a [`LlamaTokenDataArray`]. @@ -115,10 +144,17 @@ impl LlamaSampler { /// # Errors /// Returns an error if the underlying sampler rejects the token. pub fn try_accept(&mut self, token: LlamaToken) -> Result<(), SamplerAcceptError> { - let sampler_result = - unsafe { llama_cpp_bindings_sys::llama_rs_sampler_accept(self.sampler, token.0) }; + let mut error_ptr: *mut c_char = std::ptr::null_mut(); + + let status = unsafe { + llama_cpp_bindings_sys::llama_rs_sampler_accept( + self.sampler, + token.0, + &raw mut error_ptr, + ) + }; - check_sampler_accept_status(sampler_result) + check_sampler_accept_status(status, error_ptr) } /// Resets the internal state of the sampler. @@ -344,16 +380,18 @@ impl LlamaSampler { ) -> Result { let (grammar_str, grammar_root) = Self::sanitize_grammar_strings(grammar_str, grammar_root)?; + let mut error_ptr: *mut c_char = std::ptr::null_mut(); let sampler = unsafe { llama_cpp_bindings_sys::llama_rs_sampler_init_grammar( model.vocab_ptr(), grammar_str.as_ptr(), grammar_root.as_ptr(), + &raw mut error_ptr, ) }; - check_sampler_not_null(sampler) + check_sampler_not_null(sampler, error_ptr) } /// Lazy grammar sampler, introduced in @@ -372,6 +410,7 @@ impl LlamaSampler { let (grammar_str, grammar_root) = Self::sanitize_grammar_strings(grammar_str, grammar_root)?; let trigger_words = Self::sanitize_trigger_words(trigger_words)?; + let mut error_ptr: *mut c_char = std::ptr::null_mut(); let mut trigger_word_ptrs: Vec<*const c_char> = trigger_words.iter().map(|cs| cs.as_ptr()).collect(); @@ -385,10 +424,11 @@ impl LlamaSampler { trigger_word_ptrs.len(), trigger_tokens.as_ptr().cast(), trigger_tokens.len(), + &raw mut error_ptr, ) }; - check_sampler_not_null(sampler) + check_sampler_not_null(sampler, error_ptr) } /// Lazy grammar sampler using regex trigger patterns. @@ -409,6 +449,7 @@ impl LlamaSampler { let (grammar_str, grammar_root) = Self::sanitize_grammar_strings(grammar_str, grammar_root)?; let trigger_patterns = Self::sanitize_trigger_patterns(trigger_patterns)?; + let mut error_ptr: *mut c_char = std::ptr::null_mut(); let mut trigger_pattern_ptrs: Vec<*const c_char> = trigger_patterns.iter().map(|cs| cs.as_ptr()).collect(); @@ -422,10 +463,11 @@ impl LlamaSampler { trigger_pattern_ptrs.len(), trigger_tokens.as_ptr().cast(), trigger_tokens.len(), + &raw mut error_ptr, ) }; - check_sampler_not_null(sampler) + check_sampler_not_null(sampler, error_ptr) } /// `LLGuidance` sampler for constrained decoding. @@ -499,7 +541,12 @@ impl LlamaSampler { let mut seq_breaker_pointers: Vec<*const c_char> = seq_breakers.iter().map(|s| s.as_ptr()).collect(); - let n_ctx_train = checked_u32_as_i32(model.n_ctx_train())?; + let n_ctx_train_value = model.n_ctx_train().map_err(|convert_error| { + GrammarError::IntegerOverflow(format!( + "n_ctx_train does not fit into u32: {convert_error}" + )) + })?; + let n_ctx_train = checked_u32_as_i32(n_ctx_train_value)?; let sampler = unsafe { llama_cpp_bindings_sys::llama_sampler_init_dry( model.vocab_ptr(), @@ -938,9 +985,9 @@ mod tests { context.decode(&mut batch).unwrap(); let mut sampler = LlamaSampler::chain_simple([LlamaSampler::temp(0.8), LlamaSampler::greedy()]); - let token = sampler.sample(&context, batch.n_tokens() - 1); + let result = sampler.sample(&context, batch.n_tokens() - 1); - assert_ne!(token, LlamaToken::new(-1)); + assert!(result.is_ok()); } #[test] @@ -956,15 +1003,45 @@ mod tests { } #[test] - fn check_sampler_accept_status_error() { - let result = - super::check_sampler_accept_status(llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION); - assert!(result.is_err()); + fn check_sampler_accept_status_ok() { + let result = super::check_sampler_accept_status( + llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK, + std::ptr::null_mut(), + ); + + assert!(result.is_ok()); + } + + #[test] + fn check_sampler_accept_status_invalid_argument() { + let result = super::check_sampler_accept_status( + llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT, + std::ptr::null_mut(), + ); + + assert!(matches!( + result, + Err(crate::SamplerAcceptError::InvalidArgument) + )); + } + + #[test] + fn check_sampler_accept_status_exception() { + let result = super::check_sampler_accept_status( + llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION, + std::ptr::null_mut(), + ); + + assert!(matches!( + result, + Err(crate::SamplerAcceptError::CppException(_)) + )); } #[test] fn check_sampler_not_null_returns_error() { - let result = super::check_sampler_not_null(std::ptr::null_mut()); + let result = super::check_sampler_not_null(std::ptr::null_mut(), std::ptr::null_mut()); + assert!(result.is_err()); } } diff --git a/llama-cpp-bindings/src/test_model.rs b/llama-cpp-bindings/src/test_model.rs index e0048909..991dee9a 100644 --- a/llama-cpp-bindings/src/test_model.rs +++ b/llama-cpp-bindings/src/test_model.rs @@ -39,6 +39,14 @@ fn hf_encoder_model() -> Result { required_env("LLAMA_TEST_HF_ENCODER_MODEL") } +/// Downloads a file from a specific `HuggingFace` repo. +/// +/// # Errors +/// Returns an error if the download fails. +pub fn download_file_from(repo: &str, filename: &str) -> Result { + download_file(repo, filename) +} + fn download_file(repo: &str, filename: &str) -> Result { let path = hf_hub::api::sync::ApiBuilder::new() .with_progress(true) diff --git a/llama-cpp-bindings/tests/constrained_decoding.rs b/llama-cpp-bindings/tests/constrained_decoding.rs index 799c3ed0..faf01321 100644 --- a/llama-cpp-bindings/tests/constrained_decoding.rs +++ b/llama-cpp-bindings/tests/constrained_decoding.rs @@ -52,7 +52,7 @@ fn json_schema_constrains_output() -> Result<()> { let mut generated = String::new(); while n_cur <= 128 { - let token = sampler.sample(&ctx, batch.n_tokens() - 1); + let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; if model.is_eog_token(token) { break; @@ -63,8 +63,6 @@ fn json_schema_constrains_output() -> Result<()> { print!("{output_string}"); std::io::stdout().flush()?; - sampler.accept(token)?; - batch.clear(); batch.add(token, n_cur, &[0], true)?; n_cur += 1; diff --git a/llama-cpp-bindings/tests/multimodal.rs b/llama-cpp-bindings/tests/multimodal.rs index 258e8c55..eee8b46c 100644 --- a/llama-cpp-bindings/tests/multimodal.rs +++ b/llama-cpp-bindings/tests/multimodal.rs @@ -91,7 +91,7 @@ fn multimodal_vision_inference_produces_output() -> Result<()> { let mut current_position = n_past; for _ in 0..max_tokens { - let token = sampler.sample(&ctx, -1); + let token = sampler.sample(&ctx, -1)?; if model.is_eog_token(token) { break; diff --git a/llama-cpp-bindings/tests/openai_server.rs b/llama-cpp-bindings/tests/openai_server.rs index 8629dc0d..0c698e7d 100644 --- a/llama-cpp-bindings/tests/openai_server.rs +++ b/llama-cpp-bindings/tests/openai_server.rs @@ -44,7 +44,7 @@ fn run_chat_completion( let tokens = model.str_to_token(&result.prompt, AddBos::Always)?; let tokens_len_u32 = u32::try_from(tokens.len())?; - let n_ctx = model.n_ctx_train().max(tokens_len_u32 + max_tokens); + let n_ctx = model.n_ctx_train()?.max(tokens_len_u32 + max_tokens); let ctx_params = LlamaContextParams::default() .with_n_ctx(NonZeroU32::new(n_ctx)) .with_n_batch(n_ctx); @@ -78,7 +78,7 @@ fn run_chat_completion( let mut sampler = LlamaSampler::greedy(); while n_cur < max_tokens_total { - let token = sampler.sample(&ctx, batch.n_tokens() - 1); + let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; if model.is_eog_token(token) { break; diff --git a/llama-cpp-bindings/tests/openai_streaming.rs b/llama-cpp-bindings/tests/openai_streaming.rs index 156286f7..2f874ce1 100644 --- a/llama-cpp-bindings/tests/openai_streaming.rs +++ b/llama-cpp-bindings/tests/openai_streaming.rs @@ -63,7 +63,7 @@ fn streaming_deltas_produce_valid_chunks() -> Result<()> { let tokens = model.str_to_token(&result.prompt, AddBos::Always)?; let n_predict: i32 = 128; let n_ctx = model - .n_ctx_train() + .n_ctx_train()? .max(tokens.len() as u32 + n_predict as u32); let ctx_params = LlamaContextParams::default() .with_n_ctx(NonZeroU32::new(n_ctx)) @@ -100,7 +100,7 @@ fn streaming_deltas_produce_valid_chunks() -> Result<()> { let mut total_chunks = 0usize; while n_cur <= max_tokens { - let token = sampler.sample(&ctx, batch.n_tokens() - 1); + let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; if model.is_eog_token(token) { break; diff --git a/llama-cpp-bindings/tests/text_generation.rs b/llama-cpp-bindings/tests/text_generation.rs index 2eee177b..7b866209 100644 --- a/llama-cpp-bindings/tests/text_generation.rs +++ b/llama-cpp-bindings/tests/text_generation.rs @@ -64,8 +64,7 @@ fn raw_prompt_completion_with_timing() -> Result<()> { let mut generated = String::new(); while n_cur <= n_len { - let token = sampler.sample(&ctx, batch.n_tokens() - 1); - sampler.accept(token)?; + let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; if model.is_eog_token(token) { break; @@ -137,8 +136,7 @@ fn chat_inference_produces_coherent_output() -> Result<()> { let mut generated = String::new(); while position <= max_tokens { - let token = sampler.sample(&context, batch.n_tokens() - 1); - sampler.accept(token)?; + let token = sampler.sample(&context, batch.n_tokens() - 1)?; if model.is_eog_token(token) { break; diff --git a/llama-cpp-bindings/tests/tool_calling.rs b/llama-cpp-bindings/tests/tool_calling.rs index c18f06f6..5012c833 100644 --- a/llama-cpp-bindings/tests/tool_calling.rs +++ b/llama-cpp-bindings/tests/tool_calling.rs @@ -67,7 +67,7 @@ fn tool_calling_generates_grammar_and_prompt() -> Result<()> { let tokens = model.str_to_token(&result.prompt, AddBos::Always)?; let n_predict: i32 = 128; let n_ctx = model - .n_ctx_train() + .n_ctx_train()? .max(tokens.len() as u32 + n_predict as u32); let ctx_params = LlamaContextParams::default() .with_n_ctx(NonZeroU32::new(n_ctx)) @@ -100,7 +100,7 @@ fn tool_calling_generates_grammar_and_prompt() -> Result<()> { let additional_stops = result.additional_stops.clone(); while n_cur <= max_tokens { - let token = sampler.sample(&ctx, batch.n_tokens() - 1); + let token = sampler.sample(&ctx, batch.n_tokens() - 1)?; if model.is_eog_token(token) { break; diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 711b3a48..fb5449af 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.92.0" +channel = "1.93.0" components = ["clippy", "rustfmt"]