diff --git a/src/compute.rs b/src/compute.rs index cf5249f..83e7a45 100644 --- a/src/compute.rs +++ b/src/compute.rs @@ -13,20 +13,32 @@ use crate::{ }; #[allow(non_snake_case)] -/// Compute a binding value from the party ID, public nonces, and signed message using XMD-based expansion. -pub fn binding(id: &Scalar, B: &[PublicNonce], msg: &[u8]) -> Scalar { - let prefix = b"WSTS/binding"; +/// Compute the group commitment from the list of PartyIDs and nonce commitments using XMD-based expansion. +pub fn group_commitment(commitment_list: &[(Scalar, PublicNonce)]) -> Scalar { + let prefix = b"WSTS/group_commitment"; - // Serialize all input into a buffer let mut buf = Vec::new(); - buf.extend_from_slice(&id.to_bytes()); - - for b in B { - buf.extend_from_slice(b.D.compress().as_bytes()); - buf.extend_from_slice(b.E.compress().as_bytes()); + for (id, public_nonce) in commitment_list { + buf.extend_from_slice(&id.to_bytes()); + buf.extend_from_slice(public_nonce.D.compress().as_bytes()); + buf.extend_from_slice(public_nonce.E.compress().as_bytes()); } - buf.extend_from_slice(msg); + expand_to_scalar(&buf, prefix) + .expect("FATAL: DST is less than 256 bytes so operation should not fail") +} + +#[allow(non_snake_case)] +/// Compute the group commitment from the list of PartyIDs and nonce commitments using XMD-based expansion. +pub fn group_commitment_compressed(commitment_list: &[(Scalar, Compressed, Compressed)]) -> Scalar { + let prefix = b"WSTS/group_commitment"; + + let mut buf = Vec::new(); + for (id, hiding_commitment, binding_commitment) in commitment_list { + buf.extend_from_slice(&id.to_bytes()); + buf.extend_from_slice(hiding_commitment.as_bytes()); + buf.extend_from_slice(binding_commitment.as_bytes()); + } expand_to_scalar(&buf, prefix) .expect("FATAL: DST is less than 256 bytes so operation should not fail") @@ -34,19 +46,41 @@ pub fn binding(id: &Scalar, B: &[PublicNonce], msg: &[u8]) -> Scalar { #[allow(non_snake_case)] /// Compute a binding value from the party ID, public nonces, and signed message using XMD-based expansion. -pub fn binding_compressed(id: &Scalar, B: &[(Compressed, Compressed)], msg: &[u8]) -> Scalar { +pub fn binding( + id: &Scalar, + group_public_key: Point, + commitment_list: &[(Scalar, PublicNonce)], + msg: &[u8], +) -> Scalar { let prefix = b"WSTS/binding"; + let encoded_group_commitment = group_commitment(commitment_list); - // Serialize all input into a buffer let mut buf = Vec::new(); buf.extend_from_slice(&id.to_bytes()); + buf.extend_from_slice(group_public_key.compress().as_bytes()); + buf.extend_from_slice(msg); + buf.extend_from_slice(&encoded_group_commitment.to_bytes()); - for (D, E) in B { - buf.extend_from_slice(D.as_bytes()); - buf.extend_from_slice(E.as_bytes()); - } + expand_to_scalar(&buf, prefix) + .expect("FATAL: DST is less than 256 bytes so operation should not fail") +} + +#[allow(non_snake_case)] +/// Compute a binding value from the party ID, public nonces, and signed message using XMD-based expansion. +pub fn binding_compressed( + id: &Scalar, + group_public_key: Point, + commitment_list: &[(Scalar, Compressed, Compressed)], + msg: &[u8], +) -> Scalar { + let prefix = b"WSTS/binding"; + let encoded_group_commitment = group_commitment_compressed(commitment_list); + let mut buf = Vec::new(); + buf.extend_from_slice(&id.to_bytes()); + buf.extend_from_slice(group_public_key.compress().as_bytes()); buf.extend_from_slice(msg); + buf.extend_from_slice(&encoded_group_commitment.to_bytes()); expand_to_scalar(&buf, prefix) .expect("FATAL: DST is less than 256 bytes so operation should not fail") @@ -82,10 +116,20 @@ pub fn lambda(i: u32, key_ids: &[u32]) -> Scalar { // Is this the best way to return these values? #[allow(non_snake_case)] /// Compute the intermediate values used in both the parties and the aggregator -pub fn intermediate(msg: &[u8], party_ids: &[u32], nonces: &[PublicNonce]) -> (Vec, Point) { +pub fn intermediate( + msg: &[u8], + group_key: Point, + party_ids: &[u32], + nonces: &[PublicNonce], +) -> (Vec, Point) { + let commitment_list: Vec<(Scalar, PublicNonce)> = party_ids + .iter() + .zip(nonces) + .map(|(i, nonce)| (Scalar::from(*i), nonce.clone())) + .collect(); let rhos: Vec = party_ids .iter() - .map(|&i| binding(&id(i), nonces, msg)) + .map(|i| binding(&id(*i), group_key, &commitment_list, msg)) .collect(); let R_vec: Vec = zip(nonces, rhos) .map(|(nonce, rho)| nonce.D + rho * nonce.E) @@ -99,19 +143,21 @@ pub fn intermediate(msg: &[u8], party_ids: &[u32], nonces: &[PublicNonce]) -> (V /// Compute the aggregate nonce pub fn aggregate_nonce( msg: &[u8], + group_key: Point, party_ids: &[u32], nonces: &[PublicNonce], ) -> Result { - let compressed_nonces: Vec<(Compressed, Compressed)> = nonces + let commitment_list: Vec<(Scalar, Compressed, Compressed)> = party_ids .iter() - .map(|nonce| (nonce.D.compress(), nonce.E.compress())) + .zip(nonces) + .map(|(id, nonce)| (Scalar::from(*id), nonce.D.compress(), nonce.E.compress())) .collect(); let scalars: Vec = party_ids .iter() .flat_map(|&i| { [ Scalar::from(1), - binding_compressed(&id(i), &compressed_nonces, msg), + binding_compressed(&id(i), group_key, &commitment_list, msg), ] }) .collect(); diff --git a/src/state_machine/coordinator/fire.rs b/src/state_machine/coordinator/fire.rs index b9ff95a..fccadaf 100644 --- a/src/state_machine/coordinator/fire.rs +++ b/src/state_machine/coordinator/fire.rs @@ -1019,7 +1019,7 @@ impl Coordinator { if nonce_info.nonce_recv_key_ids.len() >= self.config.threshold as usize { // We have a winning message! self.message.clone_from(&nonce_response.message); - let aggregate_nonce = self.compute_aggregate_nonce(); + let aggregate_nonce = self.compute_aggregate_nonce()?; info!("Aggregate nonce: {aggregate_nonce}"); self.move_to(State::SigShareRequest(signature_type))?; @@ -1248,7 +1248,7 @@ impl Coordinator { } #[allow(non_snake_case)] - fn compute_aggregate_nonce(&self) -> Point { + fn compute_aggregate_nonce(&self) -> Result { // XXX this needs to be key_ids for v1 and signer_ids for v2 let public_nonces = self .message_nonces @@ -1266,9 +1266,14 @@ impl Coordinator { .cloned() .flat_map(|pn| pn.nonces) .collect::>(); - let (_, R) = compute::intermediate(&self.message, &party_ids, &nonces); - R + let Some(group_key) = self.aggregate_public_key else { + return Err(Error::MissingAggregatePublicKey); + }; + let (_, aggregate_nonce) = + compute::intermediate(&self.message, group_key, &party_ids, &nonces); + + Ok(aggregate_nonce) } fn compute_num_key_ids<'a, I>(&self, signer_ids: I) -> Result @@ -1828,6 +1833,7 @@ pub mod test { let signature_type = SignatureType::Frost; let message = vec![0u8]; coordinator.state = State::NonceGather(signature_type); + coordinator.aggregate_public_key = Some(Point::from(Scalar::random(&mut rng))); let nonce_response = NonceResponse { dkg_id: 0, diff --git a/src/state_machine/coordinator/frost.rs b/src/state_machine/coordinator/frost.rs index ff5e20a..37a0012 100644 --- a/src/state_machine/coordinator/frost.rs +++ b/src/state_machine/coordinator/frost.rs @@ -551,7 +551,7 @@ impl Coordinator { ); } if self.ids_to_await.is_empty() { - let aggregate_nonce = self.compute_aggregate_nonce(); + let aggregate_nonce = self.compute_aggregate_nonce()?; info!( %aggregate_nonce, "Aggregate nonce" @@ -733,7 +733,7 @@ impl Coordinator { } #[allow(non_snake_case)] - fn compute_aggregate_nonce(&self) -> Point { + fn compute_aggregate_nonce(&self) -> Result { // XXX this needs to be key_ids for v1 and signer_ids for v2 let party_ids = self .public_nonces @@ -745,9 +745,13 @@ impl Coordinator { .values() .flat_map(|pn| pn.nonces.clone()) .collect::>(); - let (_, R) = compute::intermediate(&self.message, &party_ids, &nonces); + let Some(group_key) = self.aggregate_public_key else { + return Err(Error::MissingAggregatePublicKey); + }; + let (_, aggregate_nonce) = + compute::intermediate(&self.message, group_key, &party_ids, &nonces); - R + Ok(aggregate_nonce) } } @@ -1284,6 +1288,7 @@ pub mod test { let signature_type = SignatureType::Frost; let message = vec![0u8]; coordinator.state = State::NonceGather(signature_type); + coordinator.aggregate_public_key = Some(Point::from(Scalar::random(&mut rng))); let nonce_response = NonceResponse { dkg_id: 0, diff --git a/src/traits.rs b/src/traits.rs index 68c074d..b6fdc21 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -108,6 +108,7 @@ pub trait Signer: Clone + Debug + PartialEq { /// Compute intermediate values fn compute_intermediate( + &self, msg: &[u8], signer_ids: &[u32], key_ids: &[u32], diff --git a/src/v1.rs b/src/v1.rs index a03f56a..bccc608 100644 --- a/src/v1.rs +++ b/src/v1.rs @@ -219,8 +219,14 @@ impl Party { /// Sign `msg` with this party's share of the group private key, using the set of `signers` and corresponding `nonces` pub fn sign(&self, msg: &[u8], signers: &[u32], nonces: &[PublicNonce]) -> SignatureShare { - let (_, aggregate_nonce) = compute::intermediate(msg, signers, nonces); - let mut z = &self.nonce.d + &self.nonce.e * compute::binding(&self.id(), nonces, msg); + let (_, aggregate_nonce) = compute::intermediate(msg, self.group_key, signers, nonces); + let commitment_list: Vec<(Scalar, PublicNonce)> = signers + .iter() + .zip(nonces) + .map(|(id, nonce)| (Scalar::from(*id), nonce.clone())) + .collect(); + let mut z = &self.nonce.d + + &self.nonce.e * compute::binding(&self.id(), self.group_key, &commitment_list, msg); z += compute::challenge(&self.group_key, &aggregate_nonce, msg) * &self.private_key * compute::lambda(self.id, signers); @@ -255,7 +261,13 @@ impl Party { aggregate_nonce: &Point, tweak: Option, ) -> SignatureShare { - let mut r = &self.nonce.d + &self.nonce.e * compute::binding(&self.id(), nonces, msg); + let commitment_list: Vec<(Scalar, PublicNonce)> = signers + .iter() + .zip(nonces) + .map(|(id, nonce)| (Scalar::from(*id), nonce.clone())) + .collect(); + let mut r = &self.nonce.d + + &self.nonce.e * compute::binding(&self.id(), self.group_key, &commitment_list, msg); if tweak.is_some() && !aggregate_nonce.has_even_y() { r = -r; } @@ -327,7 +339,6 @@ impl Aggregator { } let signers: Vec = sig_shares.iter().map(|ss| ss.id).collect(); - let (_Rs, R) = compute::intermediate(msg, &signers, nonces); let mut z = Scalar::zero(); let mut cx_sign = Scalar::one(); let aggregate_public_key = self.poly[0]; @@ -341,6 +352,7 @@ impl Aggregator { } _ => aggregate_public_key, }; + let (_Rs, R) = compute::intermediate(msg, aggregate_public_key, &signers, nonces); let c = compute::challenge(&tweaked_public_key, &R, msg); for sig_share in sig_shares { @@ -374,7 +386,6 @@ impl Aggregator { } let signers: Vec = sig_shares.iter().map(|ss| ss.id).collect(); - let (Rs, R) = compute::intermediate(msg, &signers, nonces); let mut bad_party_keys = Vec::new(); let mut bad_party_sigs = Vec::new(); let aggregate_public_key = self.poly[0]; @@ -384,6 +395,7 @@ impl Aggregator { } _ => aggregate_public_key, }; + let (Rs, R) = compute::intermediate(msg, aggregate_public_key, &signers, nonces); let c = compute::challenge(&tweaked_public_key, &R, msg); let mut r_sign = Scalar::one(); let mut cx_sign = Scalar::one(); @@ -689,12 +701,13 @@ impl traits::Signer for Signer { } fn compute_intermediate( + &self, msg: &[u8], _signer_ids: &[u32], key_ids: &[u32], nonces: &[PublicNonce], ) -> (Vec, Point) { - compute::intermediate(msg, key_ids, nonces) + compute::intermediate(msg, self.group_key, key_ids, nonces) } fn validate_party_id( @@ -715,7 +728,8 @@ impl traits::Signer for Signer { key_ids: &[u32], nonces: &[PublicNonce], ) -> Vec { - let aggregate_nonce = compute::aggregate_nonce(msg, key_ids, nonces).unwrap(); + let aggregate_nonce = + compute::aggregate_nonce(msg, self.group_key, key_ids, nonces).unwrap(); self.parties .iter() .map(|p| p.sign_precomputed(msg, key_ids, nonces, &aggregate_nonce)) @@ -730,7 +744,8 @@ impl traits::Signer for Signer { nonces: &[PublicNonce], merkle_root: Option<[u8; 32]>, ) -> Vec { - let aggregate_nonce = compute::aggregate_nonce(msg, key_ids, nonces).unwrap(); + let aggregate_nonce = + compute::aggregate_nonce(msg, self.group_key, key_ids, nonces).unwrap(); let tweak = compute::tweak(&self.parties[0].group_key, merkle_root); self.parties .iter() @@ -747,7 +762,8 @@ impl traits::Signer for Signer { key_ids: &[u32], nonces: &[PublicNonce], ) -> Vec { - let aggregate_nonce = compute::aggregate_nonce(msg, key_ids, nonces).unwrap(); + let aggregate_nonce = + compute::aggregate_nonce(msg, self.group_key, key_ids, nonces).unwrap(); self.parties .iter() .map(|p| { diff --git a/src/v2.rs b/src/v2.rs index 199f0aa..5cf5c32 100644 --- a/src/v2.rs +++ b/src/v2.rs @@ -256,13 +256,18 @@ impl Party { } else { self.group_key }; - let (_, R) = compute::intermediate(msg, party_ids, nonces); + let (_, R) = compute::intermediate(msg, self.group_key, party_ids, nonces); let c = compute::challenge(&tweaked_public_key, &R, msg); - let mut r = &self.nonce.d + &self.nonce.e * compute::binding(&self.id(), nonces, msg); + let commitment_list: Vec<(Scalar, PublicNonce)> = party_ids + .iter() + .zip(nonces) + .map(|(id, nonce)| (Scalar::from(*id), nonce.clone())) + .collect(); + let mut r = &self.nonce.d + + &self.nonce.e * compute::binding(&self.id(), self.group_key, &commitment_list, msg); if tweak.is_some() && !R.has_even_y() { r = -r; } - let mut cx = Scalar::zero(); for key_id in self.key_ids.iter() { cx += c * &self.private_keys[key_id] * compute::lambda(*key_id, key_ids); @@ -310,7 +315,7 @@ impl Aggregator { } let party_ids: Vec = sig_shares.iter().map(|ss| ss.id).collect(); - let (_Rs, R) = compute::intermediate(msg, &party_ids, nonces); + let (_Rs, R) = compute::intermediate(msg, self.poly[0], &party_ids, nonces); let mut z = Scalar::zero(); let mut cx_sign = Scalar::one(); let aggregate_public_key = self.poly[0]; @@ -361,7 +366,7 @@ impl Aggregator { } let party_ids: Vec = sig_shares.iter().map(|ss| ss.id).collect(); - let (Rs, R) = compute::intermediate(msg, &party_ids, nonces); + let (Rs, R) = compute::intermediate(msg, self.poly[0], &party_ids, nonces); let mut bad_party_keys = Vec::new(); let mut bad_party_sigs = Vec::new(); let aggregate_public_key = self.poly[0]; @@ -637,12 +642,13 @@ impl traits::Signer for Party { } fn compute_intermediate( + &self, msg: &[u8], signer_ids: &[u32], _key_ids: &[u32], nonces: &[PublicNonce], ) -> (Vec, Point) { - compute::intermediate(msg, signer_ids, nonces) + compute::intermediate(msg, self.group_key, signer_ids, nonces) } fn validate_party_id(