From 3739a5b60db85b52ec16c270aff183b3f63120c0 Mon Sep 17 00:00:00 2001 From: Christiaan676 Date: Fri, 7 Mar 2025 22:16:30 +0100 Subject: [PATCH] Correct usage of ECC --- x-wing/src/lib.rs | 78 +++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 43 deletions(-) diff --git a/x-wing/src/lib.rs b/x-wing/src/lib.rs index c60316e..79460fe 100644 --- a/x-wing/src/lib.rs +++ b/x-wing/src/lib.rs @@ -36,7 +36,7 @@ use rand_core::OsRng; use sha3::digest::core_api::XofReaderCoreWrapper; use sha3::digest::{ExtendableOutput, XofReader}; use sha3::{Sha3_256, Shake256, Shake256ReaderCore}; -use x25519_dalek::{x25519, X25519_BASEPOINT_BYTES}; +use x25519_dalek::{EphemeralSecret, PublicKey, StaticSecret}; #[cfg(feature = "zeroize")] use zeroize::{Zeroize, ZeroizeOnDrop}; @@ -69,7 +69,7 @@ pub type SharedSecret = [u8; 32]; #[derive(Clone, PartialEq)] pub struct EncapsulationKey { pk_m: MlKem768EncapsulationKey, - pk_x: x25519_dalek::PublicKey, + pk_x: PublicKey, } impl Encapsulate for EncapsulationKey { @@ -82,18 +82,13 @@ impl Encapsulate for EncapsulationKey { // Swapped order of operations compared to RFC, so that usage of the rng matches the RFC let (ct_m, ss_m) = self.pk_m.encapsulate(rng)?; - let ek_x: SharedSecret = generate(rng); - let ct_x = x25519(ek_x, X25519_BASEPOINT_BYTES); - let ss_x = x25519(ek_x, self.pk_x.to_bytes()); + let ek_x = EphemeralSecret::random_from_rng(rng); + // Equal to ct_x = x25519(ek_x, BASE_POINT) + let ct_x = PublicKey::from(&ek_x); + // Equal to ss_x = x25519(ek_x, pk_x) + let ss_x = ek_x.diffie_hellman(&self.pk_x); let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x); - - #[cfg(feature = "zeroize")] - { - let mut ss_x = ss_x; - ss_x.zeroize(); - } - let ct = Ciphertext { ct_m, ct_x }; Ok((ct, ss)) } @@ -101,9 +96,9 @@ impl Encapsulate for EncapsulationKey { impl EncapsulationKey { /// Convert the key to the following format: - /// ML-KEM-768 public key(1184 bytes) | X25519 public key(32 bytes). + /// ML-KEM-768 public key(1184 bytes) || X25519 public key(32 bytes). #[must_use] - pub fn as_bytes(&self) -> [u8; ENCAPSULATION_KEY_SIZE] { + pub fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_SIZE] { let mut buffer = [0u8; ENCAPSULATION_KEY_SIZE]; buffer[0..1184].copy_from_slice(&self.pk_m.as_bytes()); buffer[1184..1216].copy_from_slice(self.pk_x.as_bytes()); @@ -119,7 +114,7 @@ impl From<&[u8; ENCAPSULATION_KEY_SIZE]> for EncapsulationKey { let mut pk_x = [0; 32]; pk_x.copy_from_slice(&value[1184..]); - let pk_x = x25519_dalek::PublicKey::from(pk_x); + let pk_x = PublicKey::from(pk_x); EncapsulationKey { pk_m, pk_x } } } @@ -138,16 +133,13 @@ impl Decapsulate for DecapsulationKey { #[allow(clippy::similar_names)] // So we can use the names as in the RFC fn decapsulate(&self, ct: &Ciphertext) -> Result { let (sk_m, sk_x, _pk_m, pk_x) = self.expand_key(); + let ss_m = sk_m.decapsulate(&ct.ct_m)?; - let ss_x = x25519(sk_x.to_bytes(), ct.ct_x); - let ss = combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x); - #[cfg(feature = "zeroize")] - { - let mut ss_x = ss_x; - ss_x.zeroize(); - } + // equal to ss_x = x25519(sk_x, ct_x) + let ss_x = sk_x.diffie_hellman(&ct.ct_x); + let ss = combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x); Ok(ss) } } @@ -176,9 +168,9 @@ impl DecapsulationKey { &self, ) -> ( MlKem768DecapsulationKey, - x25519_dalek::StaticSecret, + StaticSecret, MlKem768EncapsulationKey, - x25519_dalek::PublicKey, + PublicKey, ) { use sha3::digest::Update; let mut shaker = Shake256::default(); @@ -190,8 +182,8 @@ impl DecapsulationKey { let (sk_m, pk_m) = MlKem768::generate_deterministic(&d, &z); let sk_x = read_from(&mut expanded); - let sk_x = x25519_dalek::StaticSecret::from(sk_x); - let pk_x = x25519_dalek::PublicKey::from(&sk_x); + let sk_x = StaticSecret::from(sk_x); + let pk_x = PublicKey::from(&sk_x); (sk_m, sk_x, pk_m, pk_x) } @@ -214,17 +206,17 @@ impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey { #[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))] pub struct Ciphertext { ct_m: ArrayN, - ct_x: [u8; 32], + ct_x: PublicKey, } impl Ciphertext { /// Convert the ciphertext to the following format: - /// ML-KEM-768 ciphertext(1088 bytes) | X25519 ciphertext(32 bytes). + /// ML-KEM-768 ciphertext(1088 bytes) || X25519 ciphertext(32 bytes). #[must_use] - pub fn as_bytes(&self) -> [u8; CIPHERTEXT_SIZE] { + pub fn to_bytes(&self) -> [u8; CIPHERTEXT_SIZE] { let mut buffer = [0; CIPHERTEXT_SIZE]; buffer[0..1088].copy_from_slice(&self.ct_m); - buffer[1088..].copy_from_slice(&self.ct_x); + buffer[1088..].copy_from_slice(self.ct_x.as_bytes()); buffer } } @@ -238,7 +230,7 @@ impl From<&[u8; CIPHERTEXT_SIZE]> for Ciphertext { Ciphertext { ct_m: ct_m.into(), - ct_x, + ct_x: ct_x.into(), } } } @@ -258,9 +250,9 @@ pub fn generate_key_pair(rng: &mut impl CryptoRngCore) -> (DecapsulationKey, Enc fn combiner( ss_m: &B32, - ss_x: &[u8; 32], - ct_x: &[u8; 32], - pk_x: &x25519_dalek::PublicKey, + ss_x: &x25519_dalek::SharedSecret, + ct_x: &PublicKey, + pk_x: &PublicKey, ) -> SharedSecret { use sha3::Digest; @@ -292,8 +284,8 @@ mod tests { use super::*; - struct SeedRng { - seed: Vec, + pub(crate) struct SeedRng { + pub(crate) seed: Vec, } impl SeedRng { @@ -360,14 +352,14 @@ mod tests { let mut seed = SeedRng::new(test_vector.seed); let (sk, pk) = generate_key_pair(&mut seed); - assert_eq!(sk.as_bytes().to_vec(), test_vector.sk); - assert_eq!(pk.as_bytes().to_vec(), test_vector.pk); + assert_eq!(sk.as_bytes(), &test_vector.sk); + assert_eq!(&pk.to_bytes(), test_vector.pk.as_slice()); let mut eseed = SeedRng::new(test_vector.eseed); let (ct, ss) = pk.encapsulate(&mut eseed).unwrap(); assert_eq!(ss, test_vector.ss); - assert_eq!(ct.as_bytes().to_vec(), test_vector.ct); + assert_eq!(&ct.to_bytes(), test_vector.ct.as_slice()); let ss = sk.decapsulate(&ct).unwrap(); assert_eq!(ss, test_vector.ss); @@ -379,10 +371,10 @@ mod tests { let ct_a = Ciphertext { ct_m: generate(&mut rng).into(), - ct_x: generate(&mut rng), + ct_x: generate(&mut rng).into(), }; - let bytes = ct_a.as_bytes(); + let bytes = ct_a.to_bytes(); let ct_b = Ciphertext::from(&bytes); @@ -395,10 +387,10 @@ mod tests { let pk = sk.encapsulation_key(); let sk_bytes = sk.as_bytes(); - let pk_bytes = pk.as_bytes(); + let pk_bytes = pk.to_bytes(); let sk_b = DecapsulationKey::from(*sk_bytes); - let pk_b = EncapsulationKey::from(&pk_bytes.clone()); + let pk_b = EncapsulationKey::from(&pk_bytes); assert!(sk == sk_b); assert!(pk == pk_b);