Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 67 additions & 21 deletions src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,74 @@ use crate::{
};

#[allow(non_snake_case)]
/// Compute a binding value from the party ID, public nonces, and signed message using XMD-based expansion.
pub fn binding(id: &Scalar, B: &[PublicNonce], msg: &[u8]) -> Scalar {
let prefix = b"WSTS/binding";
/// Compute the group commitment from the list of PartyIDs and nonce commitments using XMD-based expansion.
pub fn group_commitment(commitment_list: &[(Scalar, PublicNonce)]) -> Scalar {
let prefix = b"WSTS/group_commitment";

// Serialize all input into a buffer
let mut buf = Vec::new();
buf.extend_from_slice(&id.to_bytes());

for b in B {
buf.extend_from_slice(b.D.compress().as_bytes());
buf.extend_from_slice(b.E.compress().as_bytes());
for (id, public_nonce) in commitment_list {
buf.extend_from_slice(&id.to_bytes());
buf.extend_from_slice(public_nonce.D.compress().as_bytes());
buf.extend_from_slice(public_nonce.E.compress().as_bytes());
}

buf.extend_from_slice(msg);
expand_to_scalar(&buf, prefix)
.expect("FATAL: DST is less than 256 bytes so operation should not fail")
}

#[allow(non_snake_case)]
/// Compute the group commitment from the list of PartyIDs and nonce commitments using XMD-based expansion.
pub fn group_commitment_compressed(commitment_list: &[(Scalar, Compressed, Compressed)]) -> Scalar {
let prefix = b"WSTS/group_commitment";

let mut buf = Vec::new();
for (id, hiding_commitment, binding_commitment) in commitment_list {
buf.extend_from_slice(&id.to_bytes());
buf.extend_from_slice(hiding_commitment.as_bytes());
buf.extend_from_slice(binding_commitment.as_bytes());
}

expand_to_scalar(&buf, prefix)
.expect("FATAL: DST is less than 256 bytes so operation should not fail")
}

#[allow(non_snake_case)]
/// Compute a binding value from the party ID, public nonces, and signed message using XMD-based expansion.
pub fn binding_compressed(id: &Scalar, B: &[(Compressed, Compressed)], msg: &[u8]) -> Scalar {
pub fn binding(
id: &Scalar,
group_public_key: Point,
commitment_list: &[(Scalar, PublicNonce)],
msg: &[u8],
) -> Scalar {
let prefix = b"WSTS/binding";
let encoded_group_commitment = group_commitment(commitment_list);

// Serialize all input into a buffer
let mut buf = Vec::new();
buf.extend_from_slice(&id.to_bytes());
buf.extend_from_slice(group_public_key.compress().as_bytes());
buf.extend_from_slice(msg);
buf.extend_from_slice(&encoded_group_commitment.to_bytes());

for (D, E) in B {
buf.extend_from_slice(D.as_bytes());
buf.extend_from_slice(E.as_bytes());
}
expand_to_scalar(&buf, prefix)
.expect("FATAL: DST is less than 256 bytes so operation should not fail")
}

#[allow(non_snake_case)]
/// Compute a binding value from the party ID, public nonces, and signed message using XMD-based expansion.
pub fn binding_compressed(
id: &Scalar,
group_public_key: Point,
commitment_list: &[(Scalar, Compressed, Compressed)],
msg: &[u8],
) -> Scalar {
let prefix = b"WSTS/binding";
let encoded_group_commitment = group_commitment_compressed(commitment_list);

let mut buf = Vec::new();
buf.extend_from_slice(&id.to_bytes());
buf.extend_from_slice(group_public_key.compress().as_bytes());
buf.extend_from_slice(msg);
buf.extend_from_slice(&encoded_group_commitment.to_bytes());

expand_to_scalar(&buf, prefix)
.expect("FATAL: DST is less than 256 bytes so operation should not fail")
Expand Down Expand Up @@ -82,10 +116,20 @@ pub fn lambda(i: u32, key_ids: &[u32]) -> Scalar {
// Is this the best way to return these values?
#[allow(non_snake_case)]
/// Compute the intermediate values used in both the parties and the aggregator
pub fn intermediate(msg: &[u8], party_ids: &[u32], nonces: &[PublicNonce]) -> (Vec<Point>, Point) {
pub fn intermediate(
msg: &[u8],
group_key: Point,
party_ids: &[u32],
nonces: &[PublicNonce],
) -> (Vec<Point>, Point) {
let commitment_list: Vec<(Scalar, PublicNonce)> = party_ids
.iter()
.zip(nonces)
.map(|(i, nonce)| (Scalar::from(*i), nonce.clone()))
.collect();
let rhos: Vec<Scalar> = party_ids
.iter()
.map(|&i| binding(&id(i), nonces, msg))
.map(|i| binding(&id(*i), group_key, &commitment_list, msg))
.collect();
let R_vec: Vec<Point> = zip(nonces, rhos)
.map(|(nonce, rho)| nonce.D + rho * nonce.E)
Expand All @@ -99,19 +143,21 @@ pub fn intermediate(msg: &[u8], party_ids: &[u32], nonces: &[PublicNonce]) -> (V
/// Compute the aggregate nonce
pub fn aggregate_nonce(
msg: &[u8],
group_key: Point,
party_ids: &[u32],
nonces: &[PublicNonce],
) -> Result<Point, PointError> {
let compressed_nonces: Vec<(Compressed, Compressed)> = nonces
let commitment_list: Vec<(Scalar, Compressed, Compressed)> = party_ids
.iter()
.map(|nonce| (nonce.D.compress(), nonce.E.compress()))
.zip(nonces)
.map(|(id, nonce)| (Scalar::from(*id), nonce.D.compress(), nonce.E.compress()))
.collect();
let scalars: Vec<Scalar> = party_ids
.iter()
.flat_map(|&i| {
[
Scalar::from(1),
binding_compressed(&id(i), &compressed_nonces, msg),
binding_compressed(&id(i), group_key, &commitment_list, msg),
]
})
.collect();
Expand Down
14 changes: 10 additions & 4 deletions src/state_machine/coordinator/fire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ impl<Aggregator: AggregatorTrait> Coordinator<Aggregator> {
if nonce_info.nonce_recv_key_ids.len() >= self.config.threshold as usize {
// We have a winning message!
self.message.clone_from(&nonce_response.message);
let aggregate_nonce = self.compute_aggregate_nonce();
let aggregate_nonce = self.compute_aggregate_nonce()?;
info!("Aggregate nonce: {aggregate_nonce}");

self.move_to(State::SigShareRequest(signature_type))?;
Expand Down Expand Up @@ -1248,7 +1248,7 @@ impl<Aggregator: AggregatorTrait> Coordinator<Aggregator> {
}

#[allow(non_snake_case)]
fn compute_aggregate_nonce(&self) -> Point {
fn compute_aggregate_nonce(&self) -> Result<Point, Error> {
// XXX this needs to be key_ids for v1 and signer_ids for v2
let public_nonces = self
.message_nonces
Expand All @@ -1266,9 +1266,14 @@ impl<Aggregator: AggregatorTrait> Coordinator<Aggregator> {
.cloned()
.flat_map(|pn| pn.nonces)
.collect::<Vec<PublicNonce>>();
let (_, R) = compute::intermediate(&self.message, &party_ids, &nonces);

R
let Some(group_key) = self.aggregate_public_key else {
return Err(Error::MissingAggregatePublicKey);
};
let (_, aggregate_nonce) =
compute::intermediate(&self.message, group_key, &party_ids, &nonces);

Ok(aggregate_nonce)
}

fn compute_num_key_ids<'a, I>(&self, signer_ids: I) -> Result<u32, Error>
Expand Down Expand Up @@ -1828,6 +1833,7 @@ pub mod test {
let signature_type = SignatureType::Frost;
let message = vec![0u8];
coordinator.state = State::NonceGather(signature_type);
coordinator.aggregate_public_key = Some(Point::from(Scalar::random(&mut rng)));

let nonce_response = NonceResponse {
dkg_id: 0,
Expand Down
13 changes: 9 additions & 4 deletions src/state_machine/coordinator/frost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ impl<Aggregator: AggregatorTrait> Coordinator<Aggregator> {
);
}
if self.ids_to_await.is_empty() {
let aggregate_nonce = self.compute_aggregate_nonce();
let aggregate_nonce = self.compute_aggregate_nonce()?;
info!(
%aggregate_nonce,
"Aggregate nonce"
Expand Down Expand Up @@ -733,7 +733,7 @@ impl<Aggregator: AggregatorTrait> Coordinator<Aggregator> {
}

#[allow(non_snake_case)]
fn compute_aggregate_nonce(&self) -> Point {
fn compute_aggregate_nonce(&self) -> Result<Point, Error> {
// XXX this needs to be key_ids for v1 and signer_ids for v2
let party_ids = self
.public_nonces
Expand All @@ -745,9 +745,13 @@ impl<Aggregator: AggregatorTrait> Coordinator<Aggregator> {
.values()
.flat_map(|pn| pn.nonces.clone())
.collect::<Vec<PublicNonce>>();
let (_, R) = compute::intermediate(&self.message, &party_ids, &nonces);
let Some(group_key) = self.aggregate_public_key else {
return Err(Error::MissingAggregatePublicKey);
};
let (_, aggregate_nonce) =
compute::intermediate(&self.message, group_key, &party_ids, &nonces);

R
Ok(aggregate_nonce)
}
}

Expand Down Expand Up @@ -1284,6 +1288,7 @@ pub mod test {
let signature_type = SignatureType::Frost;
let message = vec![0u8];
coordinator.state = State::NonceGather(signature_type);
coordinator.aggregate_public_key = Some(Point::from(Scalar::random(&mut rng)));

let nonce_response = NonceResponse {
dkg_id: 0,
Expand Down
1 change: 1 addition & 0 deletions src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ pub trait Signer: Clone + Debug + PartialEq {

/// Compute intermediate values
fn compute_intermediate(
&self,
msg: &[u8],
signer_ids: &[u32],
key_ids: &[u32],
Expand Down
34 changes: 25 additions & 9 deletions src/v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,14 @@ impl Party {

/// Sign `msg` with this party's share of the group private key, using the set of `signers` and corresponding `nonces`
pub fn sign(&self, msg: &[u8], signers: &[u32], nonces: &[PublicNonce]) -> SignatureShare {
let (_, aggregate_nonce) = compute::intermediate(msg, signers, nonces);
let mut z = &self.nonce.d + &self.nonce.e * compute::binding(&self.id(), nonces, msg);
let (_, aggregate_nonce) = compute::intermediate(msg, self.group_key, signers, nonces);
let commitment_list: Vec<(Scalar, PublicNonce)> = signers
.iter()
.zip(nonces)
.map(|(id, nonce)| (Scalar::from(*id), nonce.clone()))
.collect();
let mut z = &self.nonce.d
+ &self.nonce.e * compute::binding(&self.id(), self.group_key, &commitment_list, msg);
z += compute::challenge(&self.group_key, &aggregate_nonce, msg)
* &self.private_key
* compute::lambda(self.id, signers);
Expand Down Expand Up @@ -255,7 +261,13 @@ impl Party {
aggregate_nonce: &Point,
tweak: Option<Scalar>,
) -> SignatureShare {
let mut r = &self.nonce.d + &self.nonce.e * compute::binding(&self.id(), nonces, msg);
let commitment_list: Vec<(Scalar, PublicNonce)> = signers
.iter()
.zip(nonces)
.map(|(id, nonce)| (Scalar::from(*id), nonce.clone()))
.collect();
let mut r = &self.nonce.d
+ &self.nonce.e * compute::binding(&self.id(), self.group_key, &commitment_list, msg);
if tweak.is_some() && !aggregate_nonce.has_even_y() {
r = -r;
}
Expand Down Expand Up @@ -327,7 +339,6 @@ impl Aggregator {
}

let signers: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
let (_Rs, R) = compute::intermediate(msg, &signers, nonces);
let mut z = Scalar::zero();
let mut cx_sign = Scalar::one();
let aggregate_public_key = self.poly[0];
Expand All @@ -341,6 +352,7 @@ impl Aggregator {
}
_ => aggregate_public_key,
};
let (_Rs, R) = compute::intermediate(msg, aggregate_public_key, &signers, nonces);
let c = compute::challenge(&tweaked_public_key, &R, msg);

for sig_share in sig_shares {
Expand Down Expand Up @@ -374,7 +386,6 @@ impl Aggregator {
}

let signers: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
let (Rs, R) = compute::intermediate(msg, &signers, nonces);
let mut bad_party_keys = Vec::new();
let mut bad_party_sigs = Vec::new();
let aggregate_public_key = self.poly[0];
Expand All @@ -384,6 +395,7 @@ impl Aggregator {
}
_ => aggregate_public_key,
};
let (Rs, R) = compute::intermediate(msg, aggregate_public_key, &signers, nonces);
let c = compute::challenge(&tweaked_public_key, &R, msg);
let mut r_sign = Scalar::one();
let mut cx_sign = Scalar::one();
Expand Down Expand Up @@ -689,12 +701,13 @@ impl traits::Signer for Signer {
}

fn compute_intermediate(
&self,
msg: &[u8],
_signer_ids: &[u32],
key_ids: &[u32],
nonces: &[PublicNonce],
) -> (Vec<Point>, Point) {
compute::intermediate(msg, key_ids, nonces)
compute::intermediate(msg, self.group_key, key_ids, nonces)
}

fn validate_party_id(
Expand All @@ -715,7 +728,8 @@ impl traits::Signer for Signer {
key_ids: &[u32],
nonces: &[PublicNonce],
) -> Vec<SignatureShare> {
let aggregate_nonce = compute::aggregate_nonce(msg, key_ids, nonces).unwrap();
let aggregate_nonce =
compute::aggregate_nonce(msg, self.group_key, key_ids, nonces).unwrap();
self.parties
.iter()
.map(|p| p.sign_precomputed(msg, key_ids, nonces, &aggregate_nonce))
Expand All @@ -730,7 +744,8 @@ impl traits::Signer for Signer {
nonces: &[PublicNonce],
merkle_root: Option<[u8; 32]>,
) -> Vec<SignatureShare> {
let aggregate_nonce = compute::aggregate_nonce(msg, key_ids, nonces).unwrap();
let aggregate_nonce =
compute::aggregate_nonce(msg, self.group_key, key_ids, nonces).unwrap();
let tweak = compute::tweak(&self.parties[0].group_key, merkle_root);
self.parties
.iter()
Expand All @@ -747,7 +762,8 @@ impl traits::Signer for Signer {
key_ids: &[u32],
nonces: &[PublicNonce],
) -> Vec<SignatureShare> {
let aggregate_nonce = compute::aggregate_nonce(msg, key_ids, nonces).unwrap();
let aggregate_nonce =
compute::aggregate_nonce(msg, self.group_key, key_ids, nonces).unwrap();
self.parties
.iter()
.map(|p| {
Expand Down
18 changes: 12 additions & 6 deletions src/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,18 @@ impl Party {
} else {
self.group_key
};
let (_, R) = compute::intermediate(msg, party_ids, nonces);
let (_, R) = compute::intermediate(msg, self.group_key, party_ids, nonces);
let c = compute::challenge(&tweaked_public_key, &R, msg);
let mut r = &self.nonce.d + &self.nonce.e * compute::binding(&self.id(), nonces, msg);
let commitment_list: Vec<(Scalar, PublicNonce)> = party_ids
.iter()
.zip(nonces)
.map(|(id, nonce)| (Scalar::from(*id), nonce.clone()))
.collect();
let mut r = &self.nonce.d
+ &self.nonce.e * compute::binding(&self.id(), self.group_key, &commitment_list, msg);
if tweak.is_some() && !R.has_even_y() {
r = -r;
}

let mut cx = Scalar::zero();
for key_id in self.key_ids.iter() {
cx += c * &self.private_keys[key_id] * compute::lambda(*key_id, key_ids);
Expand Down Expand Up @@ -310,7 +315,7 @@ impl Aggregator {
}

let party_ids: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
let (_Rs, R) = compute::intermediate(msg, &party_ids, nonces);
let (_Rs, R) = compute::intermediate(msg, self.poly[0], &party_ids, nonces);
let mut z = Scalar::zero();
let mut cx_sign = Scalar::one();
let aggregate_public_key = self.poly[0];
Expand Down Expand Up @@ -361,7 +366,7 @@ impl Aggregator {
}

let party_ids: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
let (Rs, R) = compute::intermediate(msg, &party_ids, nonces);
let (Rs, R) = compute::intermediate(msg, self.poly[0], &party_ids, nonces);
let mut bad_party_keys = Vec::new();
let mut bad_party_sigs = Vec::new();
let aggregate_public_key = self.poly[0];
Expand Down Expand Up @@ -637,12 +642,13 @@ impl traits::Signer for Party {
}

fn compute_intermediate(
&self,
msg: &[u8],
signer_ids: &[u32],
_key_ids: &[u32],
nonces: &[PublicNonce],
) -> (Vec<Point>, Point) {
compute::intermediate(msg, signer_ids, nonces)
compute::intermediate(msg, self.group_key, signer_ids, nonces)
}

fn validate_party_id(
Expand Down
Loading