Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 35 additions & 43 deletions x-wing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use rand_core::OsRng;
use sha3::digest::core_api::XofReaderCoreWrapper;
use sha3::digest::{ExtendableOutput, XofReader};
use sha3::{Sha3_256, Shake256, Shake256ReaderCore};
use x25519_dalek::{x25519, X25519_BASEPOINT_BYTES};
use x25519_dalek::{EphemeralSecret, PublicKey, StaticSecret};
#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};

Expand Down Expand Up @@ -69,7 +69,7 @@ pub type SharedSecret = [u8; 32];
#[derive(Clone, PartialEq)]
pub struct EncapsulationKey {
pk_m: MlKem768EncapsulationKey,
pk_x: x25519_dalek::PublicKey,
pk_x: PublicKey,
}

impl Encapsulate<Ciphertext, SharedSecret> for EncapsulationKey {
Expand All @@ -82,28 +82,23 @@ impl Encapsulate<Ciphertext, SharedSecret> for EncapsulationKey {
// Swapped order of operations compared to RFC, so that usage of the rng matches the RFC
let (ct_m, ss_m) = self.pk_m.encapsulate(rng)?;

let ek_x: SharedSecret = generate(rng);
let ct_x = x25519(ek_x, X25519_BASEPOINT_BYTES);
let ss_x = x25519(ek_x, self.pk_x.to_bytes());
let ek_x = EphemeralSecret::random_from_rng(rng);
// Equal to ct_x = x25519(ek_x, BASE_POINT)
let ct_x = PublicKey::from(&ek_x);
// Equal to ss_x = x25519(ek_x, pk_x)
let ss_x = ek_x.diffie_hellman(&self.pk_x);

let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);

#[cfg(feature = "zeroize")]
{
let mut ss_x = ss_x;
ss_x.zeroize();
}

let ct = Ciphertext { ct_m, ct_x };
Ok((ct, ss))
}
}

impl EncapsulationKey {
/// Convert the key to the following format:
/// ML-KEM-768 public key(1184 bytes) | X25519 public key(32 bytes).
/// ML-KEM-768 public key(1184 bytes) || X25519 public key(32 bytes).
#[must_use]
pub fn as_bytes(&self) -> [u8; ENCAPSULATION_KEY_SIZE] {
pub fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_SIZE] {
let mut buffer = [0u8; ENCAPSULATION_KEY_SIZE];
buffer[0..1184].copy_from_slice(&self.pk_m.as_bytes());
buffer[1184..1216].copy_from_slice(self.pk_x.as_bytes());
Expand All @@ -119,7 +114,7 @@ impl From<&[u8; ENCAPSULATION_KEY_SIZE]> for EncapsulationKey {

let mut pk_x = [0; 32];
pk_x.copy_from_slice(&value[1184..]);
let pk_x = x25519_dalek::PublicKey::from(pk_x);
let pk_x = PublicKey::from(pk_x);
EncapsulationKey { pk_m, pk_x }
}
}
Expand All @@ -138,16 +133,13 @@ impl Decapsulate<Ciphertext, SharedSecret> for DecapsulationKey {
#[allow(clippy::similar_names)] // So we can use the names as in the RFC
fn decapsulate(&self, ct: &Ciphertext) -> Result<SharedSecret, Self::Error> {
let (sk_m, sk_x, _pk_m, pk_x) = self.expand_key();

let ss_m = sk_m.decapsulate(&ct.ct_m)?;
let ss_x = x25519(sk_x.to_bytes(), ct.ct_x);
let ss = combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x);

#[cfg(feature = "zeroize")]
{
let mut ss_x = ss_x;
ss_x.zeroize();
}
// equal to ss_x = x25519(sk_x, ct_x)
let ss_x = sk_x.diffie_hellman(&ct.ct_x);

let ss = combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x);
Ok(ss)
}
}
Expand Down Expand Up @@ -176,9 +168,9 @@ impl DecapsulationKey {
&self,
) -> (
MlKem768DecapsulationKey,
x25519_dalek::StaticSecret,
StaticSecret,
MlKem768EncapsulationKey,
x25519_dalek::PublicKey,
PublicKey,
) {
use sha3::digest::Update;
let mut shaker = Shake256::default();
Expand All @@ -190,8 +182,8 @@ impl DecapsulationKey {
let (sk_m, pk_m) = MlKem768::generate_deterministic(&d, &z);

let sk_x = read_from(&mut expanded);
let sk_x = x25519_dalek::StaticSecret::from(sk_x);
let pk_x = x25519_dalek::PublicKey::from(&sk_x);
let sk_x = StaticSecret::from(sk_x);
let pk_x = PublicKey::from(&sk_x);

(sk_m, sk_x, pk_m, pk_x)
}
Expand All @@ -214,17 +206,17 @@ impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
pub struct Ciphertext {
ct_m: ArrayN<u8, 1088>,
ct_x: [u8; 32],
ct_x: PublicKey,
}

impl Ciphertext {
/// Convert the ciphertext to the following format:
/// ML-KEM-768 ciphertext(1088 bytes) | X25519 ciphertext(32 bytes).
/// ML-KEM-768 ciphertext(1088 bytes) || X25519 ciphertext(32 bytes).
#[must_use]
pub fn as_bytes(&self) -> [u8; CIPHERTEXT_SIZE] {
pub fn to_bytes(&self) -> [u8; CIPHERTEXT_SIZE] {
let mut buffer = [0; CIPHERTEXT_SIZE];
buffer[0..1088].copy_from_slice(&self.ct_m);
buffer[1088..].copy_from_slice(&self.ct_x);
buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
buffer
}
}
Expand All @@ -238,7 +230,7 @@ impl From<&[u8; CIPHERTEXT_SIZE]> for Ciphertext {

Ciphertext {
ct_m: ct_m.into(),
ct_x,
ct_x: ct_x.into(),
}
}
}
Expand All @@ -258,9 +250,9 @@ pub fn generate_key_pair(rng: &mut impl CryptoRngCore) -> (DecapsulationKey, Enc

fn combiner(
ss_m: &B32,
ss_x: &[u8; 32],
ct_x: &[u8; 32],
pk_x: &x25519_dalek::PublicKey,
ss_x: &x25519_dalek::SharedSecret,
ct_x: &PublicKey,
pk_x: &PublicKey,
) -> SharedSecret {
use sha3::Digest;

Expand Down Expand Up @@ -292,8 +284,8 @@ mod tests {

use super::*;

struct SeedRng {
seed: Vec<u8>,
pub(crate) struct SeedRng {
pub(crate) seed: Vec<u8>,
}

impl SeedRng {
Expand Down Expand Up @@ -360,14 +352,14 @@ mod tests {
let mut seed = SeedRng::new(test_vector.seed);
let (sk, pk) = generate_key_pair(&mut seed);

assert_eq!(sk.as_bytes().to_vec(), test_vector.sk);
assert_eq!(pk.as_bytes().to_vec(), test_vector.pk);
assert_eq!(sk.as_bytes(), &test_vector.sk);
assert_eq!(&pk.to_bytes(), test_vector.pk.as_slice());

let mut eseed = SeedRng::new(test_vector.eseed);
let (ct, ss) = pk.encapsulate(&mut eseed).unwrap();

assert_eq!(ss, test_vector.ss);
assert_eq!(ct.as_bytes().to_vec(), test_vector.ct);
assert_eq!(&ct.to_bytes(), test_vector.ct.as_slice());

let ss = sk.decapsulate(&ct).unwrap();
assert_eq!(ss, test_vector.ss);
Expand All @@ -379,10 +371,10 @@ mod tests {

let ct_a = Ciphertext {
ct_m: generate(&mut rng).into(),
ct_x: generate(&mut rng),
ct_x: generate(&mut rng).into(),
};

let bytes = ct_a.as_bytes();
let bytes = ct_a.to_bytes();

let ct_b = Ciphertext::from(&bytes);

Expand All @@ -395,10 +387,10 @@ mod tests {
let pk = sk.encapsulation_key();

let sk_bytes = sk.as_bytes();
let pk_bytes = pk.as_bytes();
let pk_bytes = pk.to_bytes();

let sk_b = DecapsulationKey::from(*sk_bytes);
let pk_b = EncapsulationKey::from(&pk_bytes.clone());
let pk_b = EncapsulationKey::from(&pk_bytes);

assert!(sk == sk_b);
assert!(pk == pk_b);
Expand Down