diff --git a/mls-rs-core/Cargo.toml b/mls-rs-core/Cargo.toml index 7e21eab8..b10a33e2 100644 --- a/mls-rs-core/Cargo.toml +++ b/mls-rs-core/Cargo.toml @@ -20,6 +20,7 @@ ffi = ["dep:safer-ffi", "dep:safer-ffi-gen"] x509 = [] test_suite = ["serde", "dep:serde_json", "dep:itertools"] serde = ["dep:serde", "zeroize/serde", "hex/serde", "dep:serde_bytes"] +replace_proposal = [] [dependencies] mls-rs-codec = { version = "0.5.2", path = "../mls-rs-codec", default-features = false} diff --git a/mls-rs-core/src/extension.rs b/mls-rs-core/src/extension.rs index 77f87215..39402f3d 100644 --- a/mls-rs-core/src/extension.rs +++ b/mls-rs-core/src/extension.rs @@ -34,6 +34,12 @@ impl ExtensionType { pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4); pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5); + // XXX(RLB): This value is chosen from the vendor range, since this is an experimental + // implementation that is only intended to work between `mls-rs` instances. Once a real value + // is assigned by IANA, this code will need to be updated. + #[cfg(feature = "replace_proposal")] + pub const LEAF_NODE_EPOCH: ExtensionType = ExtensionType(0xFF01); + /// Default extension types defined /// in [RFC 9420](https://www.rfc-editor.org/rfc/rfc9420.html#name-leaf-node-contents) pub const DEFAULT: &'static [ExtensionType] = &[ diff --git a/mls-rs-core/src/group/proposal_type.rs b/mls-rs-core/src/group/proposal_type.rs index 976cfaee..63ce9711 100644 --- a/mls-rs-core/src/group/proposal_type.rs +++ b/mls-rs-core/src/group/proposal_type.rs @@ -57,6 +57,9 @@ impl ProposalType { pub const EXTERNAL_INIT: ProposalType = ProposalType(6); pub const GROUP_CONTEXT_EXTENSIONS: ProposalType = ProposalType(7); + #[cfg(feature = "replace_proposal")] + pub const REPLACE: ProposalType = ProposalType(0xff01); + /// Default proposal types defined /// in [RFC 9420](https://www.rfc-editor.org/rfc/rfc9420.html#name-leaf-node-contents) pub const DEFAULT: &'static [ProposalType] = &[ diff --git a/mls-rs/Cargo.toml b/mls-rs/Cargo.toml index 8683fc4b..6799c5c9 100644 --- a/mls-rs/Cargo.toml +++ b/mls-rs/Cargo.toml @@ -33,6 +33,7 @@ by_ref_proposal = [] psk = [] x509 = ["mls-rs-core/x509", "dep:mls-rs-identity-x509"] rfc_compliant = ["state_update", "private_message", "custom_proposal", "out_of_order", "psk", "x509", "prior_epoch", "by_ref_proposal", "mls-rs-core/rfc_compliant"] +replace_proposal = ["mls-rs-core/replace_proposal", "by_ref_proposal"] std = ["mls-rs-core/std", "mls-rs-codec/std", "mls-rs-identity-x509?/std", "hex/std", "futures/std", "itertools/use_std", "safer-ffi-gen?/std", "zeroize/std", "dep:debug_tree", "dep:thiserror", "serde?/std"] diff --git a/mls-rs/examples/basic_usage.rs b/mls-rs/examples/basic_usage.rs index c49af8f1..840e50e5 100644 --- a/mls-rs/examples/basic_usage.rs +++ b/mls-rs/examples/basic_usage.rs @@ -9,7 +9,7 @@ use mls_rs::{ basic::{BasicCredential, BasicIdentityProvider}, SigningIdentity, }, - CipherSuite, CipherSuiteProvider, Client, CryptoProvider, ExtensionList, + CipherSuite, CipherSuiteProvider, Client, CryptoProvider, ExtensionList, Group, }; const CIPHERSUITE: CipherSuite = CipherSuite::CURVE25519_AES128; @@ -61,8 +61,19 @@ fn main() -> Result<(), MlsError> { alice_group.apply_pending_commit()?; // Bob joins the group with the welcome message created as part of Alice's commit. - let (mut bob_group, _) = bob.join_group(None, &alice_commit.welcome_messages[0])?; + let (bob_group, _) = bob.join_group(None, &alice_commit.welcome_messages[0])?; + #[cfg(feature = "private_message")] + encrypt_decrypt(alice_group, bob_group)?; + + Ok(()) +} + +#[cfg(feature = "private_message")] +fn encrypt_decrypt( + mut alice_group: Group, + mut bob_group: Group, +) -> Result<(), MlsError> { // Alice encrypts an application message to Bob. let msg = alice_group.encrypt_application_message(b"hello world", Default::default())?; diff --git a/mls-rs/src/client.rs b/mls-rs/src/client.rs index a09b0756..956601c9 100644 --- a/mls-rs/src/client.rs +++ b/mls-rs/src/client.rs @@ -1045,10 +1045,16 @@ mod tests { #[test] fn builder_can_be_obtained_from_client_to_edit_properties_for_new_client() { + #[cfg(not(feature = "replace_proposal"))] + let expected_extensions = [33, 34].map(Into::into); + + #[cfg(feature = "replace_proposal")] + let expected_extensions = [0xff01, 33, 34].map(Into::into); + let alice = TestClientBuilder::new_for_test() .extension_type(33.into()) .build(); let bob = alice.to_builder().extension_type(34.into()).build(); - assert_eq!(bob.config.supported_extensions(), [33, 34].map(Into::into)); + assert_eq!(bob.config.supported_extensions(), expected_extensions); } } diff --git a/mls-rs/src/client_builder.rs b/mls-rs/src/client_builder.rs index 186c4369..75512190 100644 --- a/mls-rs/src/client_builder.rs +++ b/mls-rs/src/client_builder.rs @@ -29,6 +29,9 @@ use crate::{ #[cfg(feature = "std")] use crate::time::MlsTime; +#[cfg(feature = "replace_proposal")] +use alloc::vec; + use alloc::vec::Vec; #[cfg(feature = "sqlite")] @@ -879,8 +882,14 @@ pub(crate) struct Settings { impl Default for Settings { fn default() -> Self { + #[cfg(not(feature = "replace_proposal"))] + let extension_types = Default::default(); + + #[cfg(feature = "replace_proposal")] + let extension_types = vec![ExtensionType::LEAF_NODE_EPOCH]; + Self { - extension_types: Default::default(), + extension_types, protocol_versions: Default::default(), key_package_extensions: Default::default(), leaf_node_extensions: Default::default(), diff --git a/mls-rs/src/extension/built_in.rs b/mls-rs/src/extension/built_in.rs index 361a1125..0e933ae4 100644 --- a/mls-rs/src/extension/built_in.rs +++ b/mls-rs/src/extension/built_in.rs @@ -235,6 +235,35 @@ impl MlsCodecExtension for ExternalSendersExt { } } +/// Mark a LeafNode with the epoch in which it was created. Note that a LeafNode with +/// `leaf_node_source` set to `update` or `commit` is already bound to the `group_id` for the group +/// by way of the `LeafNodeTBS` signed structure. +#[cfg(feature = "replace_proposal")] +#[cfg_attr( + all(feature = "ffi", not(test)), + safer_ffi_gen::ffi_type(clone, opaque) +)] +#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] +#[non_exhaustive] +pub struct LeafNodeEpochExt { + pub epoch: u64, +} + +#[cfg(feature = "replace_proposal")] +#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] +impl LeafNodeEpochExt { + pub fn new(epoch: u64) -> Self { + Self { epoch } + } +} + +#[cfg(feature = "replace_proposal")] +impl MlsCodecExtension for LeafNodeEpochExt { + fn extension_type() -> ExtensionType { + ExtensionType::LEAF_NODE_EPOCH + } +} + #[cfg(test)] mod tests { use super::*; @@ -327,4 +356,18 @@ mod tests { let restored = ExternalPubExt::from_extension(&as_extension).unwrap(); assert_eq!(ext, restored) } + + #[cfg(feature = "replace_proposal")] + #[test] + fn test_leaf_node_epoch() { + let ext = LeafNodeEpochExt { + epoch: 0x01234567890abcdef, + }; + + let as_extension = ext.clone().into_extension().unwrap(); + assert_eq!(as_extension.extension_type, ExtensionType::LEAF_NODE_EPOCH); + + let restored = LeafNodeEpochExt::from_extension(&as_extension).unwrap(); + assert_eq!(ext, restored) + } } diff --git a/mls-rs/src/group/commit.rs b/mls-rs/src/group/commit.rs index 093dd979..8dbd2cab 100644 --- a/mls-rs/src/group/commit.rs +++ b/mls-rs/src/group/commit.rs @@ -197,6 +197,19 @@ where Ok(self) } + /// Insert a [`ReplaceProposal`](crate::group::proposal::ReplaceProposal) into + /// the current commit that is being built. + #[cfg(feature = "replace_proposal")] + pub fn replace_member( + mut self, + to_replace: u32, + leaf_node: LeafNode, + ) -> Result { + let proposal = self.group.replace_proposal(to_replace, leaf_node)?; + self.proposals.push(proposal); + Ok(self) + } + /// Insert a /// [`GroupContextExtensions`](crate::group::proposal::Proposal::GroupContextExtensions) /// into the current commit that is being built. @@ -1017,6 +1030,54 @@ mod tests { assert_commit_builder_output(group, commit_output, vec![expected_remove], 0); } + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn test_commit_builder_replace() -> Result<(), MlsError> { + let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| { + b.custom_proposal_type(ProposalType::REPLACE) + }) + .await + .group; + + let (alice, alice_kp) = + test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await; + + // Add Alice to the group + let output = group + .commit_builder() + .add_member(alice_kp.clone()) + .unwrap() + .build() + .await?; + + group.apply_pending_commit().await?; + + // Alice creates an Update proposal, including a fresh LeafNode + let (mut alice_group, _) = alice.join_group(None, &output.welcome_messages[0])?; + let proposal = alice_group.update_proposal(None, None)?; + let Proposal::Update(update) = proposal else { + panic!("non update proposal found") + }; + + // The committer replaces Alice's appearance in the group + let commit_output = group + .commit_builder() + .replace_member(1, update.leaf_node.clone())? + .build() + .await?; + + let expected_replace = group.replace_proposal(1, update.leaf_node)?; + + assert_commit_builder_output( + group.clone(), + commit_output.clone(), + vec![expected_replace], + 0, + ); + + Ok(()) + } + #[cfg(feature = "psk")] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_commit_builder_psk() { diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 4a32017e..664a619d 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -35,7 +35,13 @@ use crate::tree_kem::{TreeKemPrivate, TreeKemPublic}; use crate::{CipherSuiteProvider, CryptoProvider}; #[cfg(feature = "by_ref_proposal")] -use crate::crypto::{HpkePublicKey, HpkeSecretKey}; +use crate::{ + crypto::{HpkePublicKey, HpkeSecretKey}, + map::SmallMap, +}; + +#[cfg(feature = "replace_proposal")] +use crate::extension::LeafNodeEpochExt; use crate::extension::ExternalPubExt; @@ -241,6 +247,15 @@ impl NewMemberInfo { } } +#[cfg(feature = "by_ref_proposal")] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)] +struct PendingUpdate { + epoch: u64, + secret_key: HpkeSecretKey, + signer: Option, +} + /// An MLS end-to-end encrypted group. /// /// # Group Evolution @@ -265,8 +280,7 @@ where private_tree: TreeKemPrivate, key_schedule: KeySchedule, #[cfg(feature = "by_ref_proposal")] - pending_updates: - crate::map::SmallMap)>, // Hash of leaf node hpke public key to secret key + pending_updates: SmallMap, // Hash of leaf node hpke public key to secret key pending_commit: Option, #[cfg(feature = "psk")] previous_psk: Option, @@ -745,34 +759,50 @@ where } } - // Apply own update + // Apply own update or a Replace proposal replacing us let new_signer = None; #[cfg(feature = "by_ref_proposal")] let mut new_signer = new_signer; #[cfg(feature = "by_ref_proposal")] - for p in &provisional_state.applied_proposals.updates { - if p.sender == Sender::Member(*self_index) { - let leaf_pk = &p.proposal.leaf_node.public_key; + { + let updated_leaves = core::iter::empty(); - // Update the leaf in the private tree if this is our update - #[cfg(feature = "std")] - let new_leaf_sk_and_signer = self.pending_updates.get(leaf_pk); + #[cfg(feature = "by_ref_proposal")] + let updated_leaves = updated_leaves.chain( + provisional_state + .applied_proposals + .update_senders + .iter() + .zip(provisional_state.applied_proposals.updates.iter()) + .filter(|(&i, _p)| i.0 == *self_index) + .map(|(_i, p)| p.proposal.leaf_node.public_key.clone()), + ); - #[cfg(not(feature = "std"))] - let new_leaf_sk_and_signer = self - .pending_updates + #[cfg(feature = "replace_proposal")] + let updated_leaves = updated_leaves.chain( + provisional_state + .applied_proposals + .replaces .iter() - .find_map(|(pk, sk)| (pk == leaf_pk).then_some(sk)); + .filter(|p| p.proposal.to_replace.0 == *self_index) + .map(|p| p.proposal.leaf_node.public_key.clone()), + ); + + // The duplicate update filtering above assures that there is at most one self-update. + let mut updated_leaves = updated_leaves; + let self_update = updated_leaves.next(); + + if let Some(leaf_pk) = self_update { + // Update the leaf in the private tree if this is our update + let pending_update = self.pending_updates.get(&leaf_pk); - let new_leaf_sk = new_leaf_sk_and_signer.map(|(sk, _)| sk.clone()); - new_signer = new_leaf_sk_and_signer.and_then(|(_, sk)| sk.clone()); + let new_leaf_sk = pending_update.map(|upd| upd.secret_key.clone()); + new_signer = pending_update.and_then(|upd| upd.signer.clone()); provisional_private_tree .update_leaf(new_leaf_sk.ok_or(MlsError::UpdateErrorNoSecretKey)?); - - break; } } @@ -909,33 +939,78 @@ where signing_identity: Option, ) -> Result { // Grab a copy of the current node and update it to have new key material - let mut new_leaf_node = self.current_user_leaf_node()?.clone(); + let mut leaf_node = self.current_user_leaf_node()?.clone(); + let properties = self.config.leaf_properties(); + let epoch = self.current_epoch(); + + #[cfg(feature = "replace_proposal")] + let mut properties = properties; - let secret_key = new_leaf_node + #[cfg(feature = "replace_proposal")] + properties + .extensions + .set_from(LeafNodeEpochExt::new(epoch))?; + + let secret_key = leaf_node .update( &self.cipher_suite_provider, self.group_id(), self.current_member_index(), - self.config.leaf_properties(), + properties, signing_identity, signer.as_ref().unwrap_or(&self.signer), ) .await?; + let pending_update = PendingUpdate { + epoch, + secret_key, + signer, + }; + // Store the secret key in the pending updates storage for later - #[cfg(feature = "std")] self.pending_updates - .insert(new_leaf_node.public_key.clone(), (secret_key, signer)); + .insert(leaf_node.public_key.clone(), pending_update); - #[cfg(not(feature = "std"))] - self.pending_updates - .push((new_leaf_node.public_key.clone(), (secret_key, signer))); + Ok(Proposal::Update(UpdateProposal { leaf_node })) + } + + /// Create a proposal message that replaces another member. + /// + /// `authenticated_data` will be sent unencrypted along with the contents + /// of the proposal message. + #[cfg(feature = "replace_proposal")] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub async fn propose_replace( + &mut self, + to_replace: u32, + leaf_node: LeafNode, + authenticated_data: Vec, + ) -> Result { + let proposal = self.replace_proposal(to_replace, leaf_node).await?; + self.proposal_message(proposal, authenticated_data).await + } - Ok(Proposal::Update(UpdateProposal { - leaf_node: new_leaf_node, + #[cfg(feature = "replace_proposal")] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub async fn replace_proposal( + &mut self, + to_replace: u32, + leaf_node: LeafNode, + ) -> Result { + Ok(Proposal::Replace(ReplaceProposal { + to_replace: LeafIndex(to_replace), + leaf_node, })) } + /// Abandon any cached state corresponding to a leaf node + #[cfg(feature = "replace_proposal")] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub async fn abandon_leaf_node(&mut self, leaf_node: &LeafNode) { + self.pending_updates.remove(&leaf_node.public_key); + } + /// Create a proposal message that removes an existing member from the /// group. /// @@ -1801,12 +1876,33 @@ where #[cfg(feature = "by_ref_proposal")] self.state.proposals.clear(); - // Clear the pending updates list - #[cfg(feature = "by_ref_proposal")] + // If the only way our leaf can change is via Update, just clear the cache + #[cfg(all(feature = "by_ref_proposal", not(feature = "replace_proposal")))] { self.pending_updates = Default::default(); } + // If Updates can span epochs via Replace, only clear out leaf nodes that cannot possibly + // be used again, namely those that have been committed or those with an earlier epoch than + // the current leaf. + #[cfg(feature = "replace_proposal")] + { + // Delete any cached state for the current public key + let current_leaf_pk = self.current_user_leaf_node()?.public_key.clone(); + self.pending_updates.remove(¤t_leaf_pk); + + // If the current leaf node contains an epoch value, delete any cached state for + // updates from prior epochs. + let epoch_ext = self + .current_user_leaf_node()? + .extensions + .get_as::()?; + if let Some(epoch_ext) = epoch_ext { + self.pending_updates + .retain(|_pk, upd| upd.epoch >= epoch_ext.epoch); + } + } + self.pending_commit = None; Ok(()) @@ -2017,6 +2113,20 @@ mod tests { let mut extension_list = ExtensionList::default(); extension_list.set_from(new_extension).unwrap(); + // If we don't push here, then we get "no need to be `mut`" warnings when replace_proposal + // isn't enabled. + let extensions: Vec = vec![42.into()]; + + #[cfg(feature = "replace_proposal")] + let mut extensions = extensions; + + #[cfg(feature = "replace_proposal")] + { + let epoch_extension = LeafNodeEpochExt::new(0); + extension_list.set_from(epoch_extension).unwrap(); + extensions.push(ExtensionType::LEAF_NODE_EPOCH); + } + let mut test_group = test_group_custom( TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, @@ -2047,7 +2157,7 @@ mod tests { assert_eq!( update.leaf_node.ungreased_capabilities().sorted(), Capabilities { - extensions: vec![42.into()], + extensions, ..get_test_capabilities() } .sorted() @@ -2139,6 +2249,126 @@ mod tests { ); } + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn test_replace_proposals() { + let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; + let (mut bob_group, _) = alice_group.join("bob").await; + + // Create a replace proposal + let bob_new_leaf = match bob_group.update_proposal() { + Proposal::Update(update) => update.leaf_node, + _ => panic!("non update proposal found"), + }; + + let proposal = alice_group.replace_proposal(1, bob_new_leaf.clone()).await; + + let replace = match proposal.clone() { + Proposal::Replace(replace) => replace, + _ => panic!("non replace proposal found"), + }; + + assert_eq!(replace.to_replace, LeafIndex(1)); + assert_eq!(replace.leaf_node, bob_new_leaf); + + // Commit the replace and verify that Bob was replaced + let commit_output = alice_group + .group + .commit_builder() + .raw_proposal(proposal) + .build() + .unwrap(); + alice_group.process_pending_commit().unwrap(); + bob_group + .process_message(commit_output.commit_message) + .unwrap(); + + let alice_new_leaf_for_bob = alice_group + .group + .current_epoch_tree() + .get_leaf_node(LeafIndex(1)) + .unwrap(); + assert_eq!(*alice_new_leaf_for_bob, bob_new_leaf); + + let bob_new_leaf_for_bob = bob_group.group.current_user_leaf_node().unwrap(); + assert_eq!(*bob_new_leaf_for_bob, bob_new_leaf); + } + + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn test_replace_proposals_across_epochs() { + let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; + let (mut bob_group, _) = alice_group.join("bob").await; + + // Bob produces an Update including a new LeafNode + let bob_new_leaf = match bob_group.update_proposal() { + Proposal::Update(update) => update.leaf_node, + _ => panic!("non update proposal found"), + }; + + // Alice commits without replacing Bob + let commit_output = alice_group.group.commit_builder().build().unwrap(); + alice_group.process_pending_commit().unwrap(); + bob_group + .process_message(commit_output.commit_message) + .unwrap(); + + // Alice commits a Replace proposal for Bob + let proposal = alice_group.replace_proposal(1, bob_new_leaf.clone()).await; + let commit_output = alice_group + .group + .commit_builder() + .raw_proposal(proposal) + .build() + .unwrap(); + alice_group.process_pending_commit().unwrap(); + bob_group + .process_message(commit_output.commit_message) + .unwrap(); + + // Check that Bob has been replaced by his new appearance + let alice_new_leaf_for_bob = alice_group + .group + .current_epoch_tree() + .get_leaf_node(LeafIndex(1)) + .unwrap(); + assert_eq!(*alice_new_leaf_for_bob, bob_new_leaf); + + let bob_new_leaf_for_bob = bob_group.group.current_user_leaf_node().unwrap(); + assert_eq!(*bob_new_leaf_for_bob, bob_new_leaf); + } + + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn test_replace_proposal_abandon() { + let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; + let (mut bob_group, _) = alice_group.join("bob").await; + + // Bob produces an Update including a new LeafNode + let bob_new_leaf = match bob_group.update_proposal() { + Proposal::Update(update) => update.leaf_node, + _ => panic!("non update proposal found"), + }; + + // Bob abandons the LeafNode, deleting its cached private state + bob_group.group.abandon_leaf_node(&bob_new_leaf); + + // Alice commits a Replace proposal for Bob + let proposal = alice_group.replace_proposal(1, bob_new_leaf.clone()).await; + let commit_output = alice_group + .group + .commit_builder() + .raw_proposal(proposal) + .build() + .unwrap(); + + alice_group.process_pending_commit().unwrap(); + + // Bob should fail to process the commit because of the missing state + let res = bob_group.process_message(commit_output.commit_message); + assert_matches!(res, Err(MlsError::UpdateErrorNoSecretKey)) + } + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn test_two_member_group( protocol_version: ProtocolVersion, diff --git a/mls-rs/src/group/proposal.rs b/mls-rs/src/group/proposal.rs index 1aab618f..5b82a31a 100644 --- a/mls-rs/src/group/proposal.rs +++ b/mls-rs/src/group/proposal.rs @@ -112,6 +112,38 @@ impl UpdateProposal { } } +#[cfg(feature = "replace_proposal")] +#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +/// A proposal that will replace an existing [`Member`](mls_rs_core::group::Member) of a +/// [`Group`](crate::group::Group). +pub struct ReplaceProposal { + pub(crate) to_replace: LeafIndex, + pub(crate) leaf_node: LeafNode, +} + +#[cfg(feature = "replace_proposal")] +impl ReplaceProposal { + /// The new [`SigningIdentity`] of the [`Member`](mls_rs_core::group::Member) + /// that is being updated by this proposal. + pub fn signing_identity(&self) -> &SigningIdentity { + &self.leaf_node.signing_identity + } + + /// New Client [`Capabilities`] of the [`Member`](mls_rs_core::group::Member) + /// that will be updated by this proposal. + pub fn capabilities(&self) -> Capabilities { + self.leaf_node.ungreased_capabilities() + } + + /// New Leaf node extensions that will be entered into the group state for the + /// [`Member`](mls_rs_core::group::Member) that is being updated by this proposal. + pub fn leaf_node_extensions(&self) -> ExtensionList { + self.leaf_node.ungreased_extensions() + } +} + #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -326,6 +358,8 @@ pub enum Proposal { Add(alloc::boxed::Box), #[cfg(feature = "by_ref_proposal")] Update(UpdateProposal), + #[cfg(feature = "replace_proposal")] + Replace(ReplaceProposal), Remove(RemoveProposal), #[cfg(feature = "psk")] Psk(PreSharedKeyProposal), @@ -342,6 +376,8 @@ impl MlsSize for Proposal { Proposal::Add(p) => p.mls_encoded_len(), #[cfg(feature = "by_ref_proposal")] Proposal::Update(p) => p.mls_encoded_len(), + #[cfg(feature = "replace_proposal")] + Proposal::Replace(p) => p.mls_encoded_len(), Proposal::Remove(p) => p.mls_encoded_len(), #[cfg(feature = "psk")] Proposal::Psk(p) => p.mls_encoded_len(), @@ -364,6 +400,8 @@ impl MlsEncode for Proposal { Proposal::Add(p) => p.mls_encode(writer), #[cfg(feature = "by_ref_proposal")] Proposal::Update(p) => p.mls_encode(writer), + #[cfg(feature = "replace_proposal")] + Proposal::Replace(p) => p.mls_encode(writer), Proposal::Remove(p) => p.mls_encode(writer), #[cfg(feature = "psk")] Proposal::Psk(p) => p.mls_encode(writer), @@ -397,6 +435,8 @@ impl MlsDecode for Proposal { } #[cfg(feature = "by_ref_proposal")] ProposalType::UPDATE => Proposal::Update(UpdateProposal::mls_decode(reader)?), + #[cfg(feature = "replace_proposal")] + ProposalType::REPLACE => Proposal::Replace(ReplaceProposal::mls_decode(reader)?), ProposalType::REMOVE => Proposal::Remove(RemoveProposal::mls_decode(reader)?), #[cfg(feature = "psk")] ProposalType::PSK => Proposal::Psk(PreSharedKeyProposal::mls_decode(reader)?), @@ -425,6 +465,8 @@ impl Proposal { Proposal::Add(_) => ProposalType::ADD, #[cfg(feature = "by_ref_proposal")] Proposal::Update(_) => ProposalType::UPDATE, + #[cfg(feature = "replace_proposal")] + Proposal::Replace(_) => ProposalType::REPLACE, Proposal::Remove(_) => ProposalType::REMOVE, #[cfg(feature = "psk")] Proposal::Psk(_) => ProposalType::PSK, @@ -443,6 +485,8 @@ pub enum BorrowedProposal<'a> { Add(&'a AddProposal), #[cfg(feature = "by_ref_proposal")] Update(&'a UpdateProposal), + #[cfg(feature = "replace_proposal")] + Replace(&'a ReplaceProposal), Remove(&'a RemoveProposal), #[cfg(feature = "psk")] Psk(&'a PreSharedKeyProposal), @@ -459,6 +503,8 @@ impl<'a> From> for Proposal { BorrowedProposal::Add(add) => Proposal::Add(alloc::boxed::Box::new(add.clone())), #[cfg(feature = "by_ref_proposal")] BorrowedProposal::Update(update) => Proposal::Update(update.clone()), + #[cfg(feature = "replace_proposal")] + BorrowedProposal::Replace(replace) => Proposal::Replace(replace.clone()), BorrowedProposal::Remove(remove) => Proposal::Remove(remove.clone()), #[cfg(feature = "psk")] BorrowedProposal::Psk(psk) => Proposal::Psk(psk.clone()), @@ -479,6 +525,8 @@ impl BorrowedProposal<'_> { BorrowedProposal::Add(_) => ProposalType::ADD, #[cfg(feature = "by_ref_proposal")] BorrowedProposal::Update(_) => ProposalType::UPDATE, + #[cfg(feature = "replace_proposal")] + BorrowedProposal::Replace(_) => ProposalType::REPLACE, BorrowedProposal::Remove(_) => ProposalType::REMOVE, #[cfg(feature = "psk")] BorrowedProposal::Psk(_) => ProposalType::PSK, @@ -497,6 +545,8 @@ impl<'a> From<&'a Proposal> for BorrowedProposal<'a> { Proposal::Add(p) => BorrowedProposal::Add(p), #[cfg(feature = "by_ref_proposal")] Proposal::Update(p) => BorrowedProposal::Update(p), + #[cfg(feature = "replace_proposal")] + Proposal::Replace(p) => BorrowedProposal::Replace(p), Proposal::Remove(p) => BorrowedProposal::Remove(p), #[cfg(feature = "psk")] Proposal::Psk(p) => BorrowedProposal::Psk(p), @@ -522,6 +572,13 @@ impl<'a> From<&'a UpdateProposal> for BorrowedProposal<'a> { } } +#[cfg(feature = "replace_proposal")] +impl<'a> From<&'a ReplaceProposal> for BorrowedProposal<'a> { + fn from(p: &'a ReplaceProposal) -> Self { + Self::Replace(p) + } +} + impl<'a> From<&'a RemoveProposal> for BorrowedProposal<'a> { fn from(p: &'a RemoveProposal) -> Self { Self::Remove(p) diff --git a/mls-rs/src/group/proposal_cache.rs b/mls-rs/src/group/proposal_cache.rs index 7e714e45..d388c6fb 100644 --- a/mls-rs/src/group/proposal_cache.rs +++ b/mls-rs/src/group/proposal_cache.rs @@ -709,6 +709,9 @@ mod tests { use crate::extension::RequiredCapabilitiesExt; + #[cfg(feature = "replace_proposal")] + use crate::group::{LeafNodeEpochExt, ReplaceProposal}; + #[cfg(feature = "by_ref_proposal")] use crate::{ extension::ExternalSendersExt, @@ -786,15 +789,42 @@ mod tests { .unwrap()[0] } + #[cfg(feature = "replace_proposal")] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + async fn update_member(tree: &mut TreeKemPublic, leaf_index: u32, leaf_node: LeafNode) { + tree.update_leaf( + leaf_index, + leaf_node, + &BasicIdentityProvider, + &test_cipher_suite_provider(TEST_CIPHER_SUITE), + ) + .await + .unwrap(); + } + + #[allow(unused_variables)] // `epoch` is only used conditionally #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - async fn update_leaf_node(name: &str, leaf_index: u32) -> LeafNode { + async fn update_leaf_node(name: &str, leaf_index: u32, epoch: Option) -> LeafNode { let (mut leaf, _, signer) = get_basic_test_node_sig_key(TEST_CIPHER_SUITE, name).await; + let properties = default_properties(); + + #[cfg(feature = "replace_proposal")] + let mut properties = properties; + + #[cfg(feature = "replace_proposal")] + if let Some(epoch) = epoch { + properties + .extensions + .set_from(LeafNodeEpochExt::new(epoch)) + .unwrap(); + } + leaf.update( &test_cipher_suite_provider(TEST_CIPHER_SUITE), TEST_GROUP, leaf_index, - default_properties(), + properties, None, &signer, ) @@ -2558,6 +2588,29 @@ mod tests { assert_matches!(res, Err(MlsError::InvalidCommitSelfUpdate)); } + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn receiving_replace_for_committer_fails() { + let (alice, mut tree) = new_tree("alice").await; + let bob = add_member(&mut tree, "bob").await; + + let replace = Proposal::Replace(make_replace_proposal(0, "alice")); + let replace_ref = make_proposal_ref(&replace, bob).await; + + let res = CommitReceiver::new( + &tree, + alice, + alice, + test_cipher_suite_provider(TEST_CIPHER_SUITE), + ) + .cache(replace_ref.clone(), replace, bob) + .receive([replace_ref]) + .await; + + // XXX(RLB): Should this use a different error code? + assert_matches!(res, Err(MlsError::InvalidCommitSelfUpdate)); + } + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn sending_additional_update_for_committer_fails() { let (alice, tree) = new_tree("alice").await; @@ -2593,6 +2646,127 @@ mod tests { assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); } + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn sending_replace_for_committer_filters_it_out() { + let (alice, mut tree) = new_tree("alice").await; + let bob = add_member(&mut tree, "bob").await; + + let proposal = Proposal::Replace(make_replace_proposal(0, "alice")); + let proposal_info = make_proposal_info(&proposal, bob).await; + + let processed_proposals = + CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) + .cache( + proposal_info.proposal_ref().unwrap().clone(), + proposal.clone(), + bob, + ) + .send() + .await + .unwrap(); + + assert_eq!(processed_proposals.0, Vec::new()); + + #[cfg(feature = "state_update")] + assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); + } + + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn sending_multiple_update_sends_only_one() { + let (alice, mut tree) = new_tree("alice").await; + let bob = add_member(&mut tree, "bob").await; + + let update1 = Proposal::Update(make_update_proposal("bob").await); + let update1_info = make_proposal_info(&update1, bob).await; + let update1_ref = update1_info.proposal_ref().unwrap().clone(); + + let update2 = Proposal::Update(make_update_proposal("bob").await); + let update2_info = make_proposal_info(&update2, bob).await; + let update2_ref = update2_info.proposal_ref().unwrap().clone(); + + let processed_proposals = + CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) + .cache(update1_ref.clone(), update1.clone(), bob) + .cache(update2_ref.clone(), update2.clone(), bob) + .send() + .await + .unwrap(); + + assert_eq!(processed_proposals.0.len(), 1); + assert_eq!(processed_proposals.1.unused_proposals.len(), 1); + + // Proposals are processed in an unpredictable order, so we can't test that a specific + // proposal was selected. We just check that one was selected and the other was rejected. + let processed1 = processed_proposals.0[0] == update1_ref.into(); + let processed2 = processed_proposals.0[0] == update2_ref.into(); + let unused1 = processed_proposals.1.unused_proposals[0] == update1_info; + let unused2 = processed_proposals.1.unused_proposals[0] == update2_info; + assert!((processed1 && unused2) || (unused1 && processed2)); + } + + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn sending_multiple_replace_sends_only_one() { + let (alice, mut tree) = new_tree("alice").await; + let bob = add_member(&mut tree, "bob").await; + + let replace1 = Proposal::Replace(make_replace_proposal(1, "bob")); + let replace1_info = make_proposal_info(&replace1, bob).await; + let replace1_ref = replace1_info.proposal_ref().unwrap().clone(); + + let replace2 = Proposal::Replace(make_replace_proposal(1, "bob")); + let replace2_info = make_proposal_info(&replace2, bob).await; + let replace2_ref = replace2_info.proposal_ref().unwrap().clone(); + + let processed_proposals = + CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) + .cache(replace1_ref.clone(), replace1.clone(), bob) + .cache(replace2_ref.clone(), replace2.clone(), bob) + .send() + .await + .unwrap(); + + assert_eq!(processed_proposals.0.len(), 1); + assert_eq!(processed_proposals.1.unused_proposals.len(), 1); + + // Proposals are processed in an unpredictable order, so we can't test that a specific + // proposal was selected. We just check that one was selected and the other was rejected. + let processed1 = processed_proposals.0[0] == replace1_ref.into(); + let processed2 = processed_proposals.0[0] == replace2_ref.into(); + let unused1 = processed_proposals.1.unused_proposals[0] == replace1_info; + let unused2 = processed_proposals.1.unused_proposals[0] == replace2_info; + assert!((processed1 && unused2) || (unused1 && processed2)); + } + + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn sending_update_and_replace_for_the_same_leaf_filters_the_replace() { + let (alice, mut tree) = new_tree("alice").await; + let bob = add_member(&mut tree, "bob").await; + + let update = Proposal::Update(make_update_proposal("bob").await); + let update_info = make_proposal_info(&update, bob).await; + let update_ref = update_info.proposal_ref().unwrap().clone(); + + let replace = Proposal::Replace(make_replace_proposal(1, "bob")); + let replace_info = make_proposal_info(&replace, alice).await; + let replace_ref = replace_info.proposal_ref().unwrap().clone(); + + let processed_proposals = + CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) + .cache(update_ref.clone(), update.clone(), bob) + .cache(replace_ref.clone(), replace.clone(), alice) + .send() + .await + .unwrap(); + + assert_eq!(processed_proposals.0, vec![update_ref.into()]); + + #[cfg(feature = "state_update")] + assert_eq!(processed_proposals.1.unused_proposals, vec![replace_info]); + } + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn receiving_remove_for_committer_fails() { let (alice, tree) = new_tree("alice").await; @@ -2669,6 +2843,85 @@ mod tests { assert_matches!(res, Err(MlsError::UpdatingNonExistingMember)); } + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn receiving_replace_and_remove_for_same_leaf_fails() { + let (alice, mut tree) = new_tree("alice").await; + let bob = add_member(&mut tree, "bob").await; + + let replace = Proposal::Replace(make_replace_proposal(1, "bob")); + let replace_ref = make_proposal_ref(&replace, alice).await; + + let remove = Proposal::Remove(RemoveProposal { to_remove: bob }); + let remove_ref = make_proposal_ref(&remove, bob).await; + + let res = CommitReceiver::new( + &tree, + alice, + alice, + test_cipher_suite_provider(TEST_CIPHER_SUITE), + ) + .cache(replace_ref.clone(), replace, bob) + .cache(remove_ref.clone(), remove, bob) + .receive([replace_ref, remove_ref]) + .await; + + assert_matches!(res, Err(MlsError::UpdatingNonExistingMember)); + } + + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn receiving_update_and_replace_for_same_leaf_fails() { + let (alice, mut tree) = new_tree("alice").await; + let bob = add_member(&mut tree, "bob").await; + + let update = Proposal::Update(make_update_proposal("bob").await); + let update_ref = make_proposal_ref(&update, bob).await; + + let replace = Proposal::Replace(make_replace_proposal(1, "bob")); + let replace_ref = make_proposal_ref(&replace, alice).await; + + let res = CommitReceiver::new( + &tree, + alice, + alice, + test_cipher_suite_provider(TEST_CIPHER_SUITE), + ) + .cache(update_ref.clone(), update, bob) + .cache(replace_ref.clone(), replace, alice) + .receive([update_ref, replace_ref]) + .await; + + // XXX(RLB): This doesn't seem like the most apt error code. + assert_matches!(res, Err(MlsError::UpdatingNonExistingMember)); + } + + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn receiving_multiple_update_for_same_leaf_fails() { + let (alice, mut tree) = new_tree("alice").await; + let bob = add_member(&mut tree, "bob").await; + + let update1 = Proposal::Update(make_update_proposal("bob").await); + let update1_ref = make_proposal_ref(&update1, bob).await; + + let update2 = Proposal::Update(make_update_proposal("bob").await); + let update2_ref = make_proposal_ref(&update2, bob).await; + + let res = CommitReceiver::new( + &tree, + alice, + alice, + test_cipher_suite_provider(TEST_CIPHER_SUITE), + ) + .cache(update1_ref.clone(), update1, bob) + .cache(update2_ref.clone(), update2, bob) + .receive([update1_ref, update2_ref]) + .await; + + // XXX(RLB): This doesn't seem like the most apt error code. + assert_matches!(res, Err(MlsError::UpdatingNonExistingMember)); + } + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn sending_update_and_remove_for_same_leaf_filters_update_out() { let (alice, mut tree) = new_tree("alice").await; @@ -2698,6 +2951,36 @@ mod tests { assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]); } + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn sending_replace_and_remove_for_same_leaf_filters_replace_out() { + let (alice, mut tree) = new_tree("alice").await; + let bob = add_member(&mut tree, "bob").await; + + let replace = Proposal::Replace(make_replace_proposal(1, "bob")); + let replace_info = make_proposal_info(&replace, alice).await; + + let remove = Proposal::Remove(RemoveProposal { to_remove: bob }); + let remove_ref = make_proposal_ref(&remove, alice).await; + + let processed_proposals = + CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) + .cache( + replace_info.proposal_ref().unwrap().clone(), + replace.clone(), + alice, + ) + .cache(remove_ref.clone(), remove, alice) + .send() + .await + .unwrap(); + + assert_eq!(processed_proposals.0, vec![remove_ref.into()]); + + #[cfg(feature = "state_update")] + assert_eq!(processed_proposals.1.unused_proposals, vec![replace_info]); + } + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn make_add_proposal() -> Box { Box::new(AddProposal { @@ -2815,6 +3098,104 @@ mod tests { assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]); } + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn receiving_replace_for_different_identity_fails() { + let (alice, mut tree) = new_tree("alice").await; + let _bob = add_member(&mut tree, "bob").await; + + let replace = Proposal::Replace(make_replace_proposal(1, "carol")); + let replace_ref = make_proposal_ref(&replace, alice).await; + + let res = CommitReceiver::new( + &tree, + alice, + alice, + test_cipher_suite_provider(TEST_CIPHER_SUITE), + ) + .cache(replace_ref.clone(), replace, alice) + .receive([replace_ref]) + .await; + + assert_matches!(res, Err(MlsError::InvalidSuccessor)); + } + + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn sending_replace_for_different_identity_filters_it_out() { + let (alice, mut tree) = new_tree("alice").await; + let _bob = add_member(&mut tree, "bob").await; + + let replace = Proposal::Replace(make_replace_proposal(1, "carol")); + let replace_info = make_proposal_info(&replace, alice).await; + + let processed_proposals = + CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) + .cache(replace_info.proposal_ref().unwrap().clone(), replace, alice) + .send() + .await + .unwrap(); + + assert_eq!(processed_proposals.0, Vec::new()); + + #[cfg(feature = "state_update")] + assert_eq!(processed_proposals.1.unused_proposals, vec![replace_info]); + } + + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn receiving_replace_for_old_epoch_fails() { + let (alice, mut tree) = new_tree("alice").await; + add_member(&mut tree, "bob").await; + + // Fast-forward Bob to epoch 42 + let bob_leaf_new_epoch = update_leaf_node("bob", 1, Some(42)); + update_member(&mut tree, 1, bob_leaf_new_epoch); + + // Try to replace Bob with a LeafNode from epoch 21 + let replace = Proposal::Replace(make_replace_proposal_custom(1, "bob", 21)); + let replace_ref = make_proposal_ref(&replace, alice).await; + + let res = CommitReceiver::new( + &tree, + alice, + alice, + test_cipher_suite_provider(TEST_CIPHER_SUITE), + ) + .cache(replace_ref.clone(), replace, alice) + .receive([replace_ref]) + .await; + + assert_matches!(res, Err(MlsError::InvalidSuccessor)); + } + + #[cfg(feature = "replace_proposal")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn sending_replace_for_old_epoch_filters_it_out() { + let (alice, mut tree) = new_tree("alice").await; + add_member(&mut tree, "bob").await; + + // Fast-forward Bob to epoch 42 + let bob_leaf_new_epoch = update_leaf_node("bob", 1, Some(42)); + update_member(&mut tree, 1, bob_leaf_new_epoch); + + // Try to replace Bob with a LeafNode from epoch 21 + let replace = Proposal::Replace(make_replace_proposal_custom(1, "bob", 21)); + let replace_info = make_proposal_info(&replace, alice).await; + + let processed_proposals = + CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) + .cache(replace_info.proposal_ref().unwrap().clone(), replace, alice) + .send() + .await + .unwrap(); + + assert_eq!(processed_proposals.0, Vec::new()); + + #[cfg(feature = "state_update")] + assert_eq!(processed_proposals.1.unused_proposals, vec![replace_info]); + } + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn receiving_add_for_same_client_as_existing_member_fails() { let (alice, public_tree) = new_tree("alice").await; @@ -3493,7 +3874,7 @@ mod tests { .await .unwrap(); - let bob_new_leaf = update_leaf_node("bob", 1).await; + let bob_new_leaf = update_leaf_node("bob", 1, None).await; let pk1_to_pk2 = Proposal::Update(UpdateProposal { leaf_node: alice_new_leaf.clone(), @@ -4215,14 +4596,32 @@ mod tests { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn make_update_proposal(name: &str) -> UpdateProposal { UpdateProposal { - leaf_node: update_leaf_node(name, 1).await, + leaf_node: update_leaf_node(name, 1, None).await, + } + } + + #[cfg(feature = "replace_proposal")] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + async fn make_replace_proposal(index: u32, name: &str) -> ReplaceProposal { + ReplaceProposal { + to_replace: LeafIndex(index), + leaf_node: update_leaf_node(name, index, None).await, + } + } + + #[cfg(feature = "replace_proposal")] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + async fn make_replace_proposal_custom(index: u32, name: &str, epoch: u64) -> ReplaceProposal { + ReplaceProposal { + to_replace: LeafIndex(index), + leaf_node: update_leaf_node(name, index, Some(epoch)).await, } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn make_update_proposal_custom(name: &str, leaf_index: u32) -> UpdateProposal { UpdateProposal { - leaf_node: update_leaf_node(name, leaf_index).await, + leaf_node: update_leaf_node(name, leaf_index, None).await, } } diff --git a/mls-rs/src/group/proposal_filter/bundle.rs b/mls-rs/src/group/proposal_filter/bundle.rs index f18a75b2..6bc96f63 100644 --- a/mls-rs/src/group/proposal_filter/bundle.rs +++ b/mls-rs/src/group/proposal_filter/bundle.rs @@ -16,6 +16,9 @@ use crate::{ ExtensionList, }; +#[cfg(feature = "replace_proposal")] +use crate::group::ReplaceProposal; + #[cfg(feature = "by_ref_proposal")] use crate::group::{proposal_cache::CachedProposal, LeafIndex, ProposalRef, UpdateProposal}; @@ -38,6 +41,8 @@ pub struct ProposalBundle { pub(crate) updates: Vec>, #[cfg(feature = "by_ref_proposal")] pub(crate) update_senders: Vec, + #[cfg(feature = "replace_proposal")] + pub(crate) replaces: Vec>, pub(crate) removals: Vec>, #[cfg(feature = "psk")] pub(crate) psks: Vec>, @@ -62,6 +67,12 @@ impl ProposalBundle { sender, source, }), + #[cfg(feature = "replace_proposal")] + Proposal::Replace(proposal) => self.replaces.push(ProposalInfo { + proposal, + sender, + source, + }), Proposal::Remove(proposal) => self.removals.push(ProposalInfo { proposal, sender, @@ -177,6 +188,11 @@ impl ProposalBundle { f(&proposal.as_ref().map(BorrowedProposal::from)) })?; + #[cfg(feature = "replace_proposal")] + self.retain_by_type::(|proposal| { + f(&proposal.as_ref().map(BorrowedProposal::from)) + })?; + self.retain_by_type::(|proposal| { f(&proposal.as_ref().map(BorrowedProposal::from)) })?; @@ -246,6 +262,13 @@ impl ProposalBundle { .map(|p| p.as_ref().map(BorrowedProposal::Update)), ); + #[cfg(feature = "replace_proposal")] + let res = res.chain( + self.replaces + .iter() + .map(|p| p.as_ref().map(BorrowedProposal::Replace)), + ); + #[cfg(feature = "psk")] let res = res.chain( self.psks @@ -298,6 +321,9 @@ impl ProposalBundle { #[cfg(feature = "by_ref_proposal")] let res = res.chain(self.updates.into_iter().map(|p| p.map(Proposal::Update))); + #[cfg(feature = "replace_proposal")] + let res = res.chain(self.replaces.into_iter().map(|p| p.map(Proposal::Replace))); + res.chain( self.additions .into_iter() @@ -344,6 +370,12 @@ impl ProposalBundle { &self.update_senders } + /// Replace proposals in the bundle. + #[cfg(feature = "replace_proposal")] + pub fn replace_proposals(&self) -> &[ProposalInfo] { + &self.replaces + } + /// Remove proposals in the bundle. pub fn remove_proposals(&self) -> &[ProposalInfo] { &self.removals @@ -621,6 +653,8 @@ macro_rules! impl_proposable { impl_proposable!(AddProposal, ADD, additions); #[cfg(feature = "by_ref_proposal")] impl_proposable!(UpdateProposal, UPDATE, updates); +#[cfg(feature = "replace_proposal")] +impl_proposable!(ReplaceProposal, REPLACE, replaces); impl_proposable!(RemoveProposal, REMOVE, removals); #[cfg(feature = "psk")] impl_proposable!(PreSharedKeyProposal, PSK, psks); diff --git a/mls-rs/src/group/proposal_filter/filtering.rs b/mls-rs/src/group/proposal_filter/filtering.rs index 8e67ff58..45e1b22f 100644 --- a/mls-rs/src/group/proposal_filter/filtering.rs +++ b/mls-rs/src/group/proposal_filter/filtering.rs @@ -15,16 +15,22 @@ use crate::{ tree_kem::{ leaf_node_validator::{LeafNodeValidator, ValidationContext}, node::LeafIndex, - TreeKemPublic, }, CipherSuiteProvider, ExtensionList, }; +#[cfg(feature = "custom_proposal")] +use crate::tree_kem::TreeKemPublic; + use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier}; #[cfg(feature = "by_ref_proposal")] use crate::extension::ExternalSendersExt; +#[cfg(feature = "replace_proposal")] +use crate::{extension::LeafNodeEpochExt, group::ReplaceProposal}; + +use alloc::vec; use alloc::vec::Vec; use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, psk::PreSharedKeyStorage}; @@ -60,9 +66,17 @@ where commit_time: Option, ) -> Result { let proposals = filter_out_invalid_proposers(strategy, proposals)?; + let proposals = filter_out_update_for_committer(strategy, commit_sender, proposals)?; + + #[cfg(feature = "replace_proposal")] + let proposals = filter_out_replace_for_committer(strategy, commit_sender, proposals)?; + + let proposals = filter_out_duplicate_updates(strategy, commit_sender, proposals)?; + + #[cfg(feature = "replace_proposal")] + let proposals = filter_out_duplicate_replaces(strategy, commit_sender, proposals)?; - let mut proposals: ProposalBundle = - filter_out_update_for_committer(strategy, commit_sender, proposals)?; + let mut proposals = proposals; // We ignore the strategy here because the check above ensures all updates are from members proposals.update_senders = proposals @@ -175,6 +189,7 @@ where Some(group_extensions_in_use), ); + // Check that Updates are valid let bad_indices: Vec<_> = wrap_iter(proposals.update_proposals()) .zip(wrap_iter(proposals.update_proposal_senders())) .enumerate() @@ -220,6 +235,76 @@ where proposals.update_senders.remove(i); }); + // Check that Replaces are valid + #[cfg(feature = "replace_proposal")] + { + let bad_indices: Vec<_> = wrap_iter(proposals.replace_proposals()) + .enumerate() + .filter_map(|(i, p)| async move { + let res = { + let to_replace = p.proposal.to_replace; + let leaf = &p.proposal.leaf_node; + + let res = leaf_node_validator + .check_if_valid( + leaf, + ValidationContext::Update(( + self.group_id, + *to_replace, + commit_time, + )), + ) + .await; + + let old_leaf = match self.original_tree.get_leaf_node(to_replace) { + Ok(leaf) => leaf, + Err(e) => return Some(Err(e)), + }; + + let valid_successor = self + .identity_provider + .valid_successor( + &old_leaf.signing_identity, + &leaf.signing_identity, + group_extensions_in_use, + ) + .await + .map_err(|e| MlsError::IdentityProviderError(e.into_any_error())) + .and_then(|valid| { + valid.then_some(()).ok_or(MlsError::InvalidSuccessor) + }); + + // XXX(RLB) It's not clear that this is the right policy. On the one hand, + // it allows for the epoch locking to be turned off. On the other hand, + // this can be abused to roll back to any leaf that doesn't have an epoch + // marker. + + // If both the old and new leaves have `leaf_node_epoch` extensions, then + // the new value must be at least as big as the old value. + let old_epoch = old_leaf.extensions.get_as::().ok()?; + let new_epoch = leaf.extensions.get_as::().ok()?; + let epoch_successor = old_epoch + .and_then(|old| new_epoch.map(|new| new.epoch >= old.epoch)) + .unwrap_or(true) + .then_some(()) + .ok_or(MlsError::InvalidSuccessor); + + res.and(valid_successor).and(epoch_successor) + }; + + apply_strategy(strategy, p.is_by_reference(), res) + .map(|b| (!b).then_some(i)) + .transpose() + }) + .try_collect() + .await?; + + bad_indices.into_iter().rev().for_each(|i| { + proposals.remove::(i); + }); + } + + // Check that Adds are valid let bad_indices: Vec<_> = wrap_iter(proposals.add_proposals()) .enumerate() .filter_map(|(i, p)| async move { @@ -291,6 +376,24 @@ fn filter_out_update_for_committer( Ok(proposals) } +#[cfg(feature = "replace_proposal")] +fn filter_out_replace_for_committer( + strategy: FilterStrategy, + commit_sender: LeafIndex, + mut proposals: ProposalBundle, +) -> Result { + proposals.retain_by_type::(|p| { + apply_strategy( + strategy, + p.is_by_reference(), + (p.proposal.to_replace != commit_sender) + .then_some(()) + .ok_or(MlsError::InvalidCommitSelfUpdate), + ) + })?; + Ok(proposals) +} + fn filter_out_removal_of_committer( strategy: FilterStrategy, commit_sender: LeafIndex, @@ -308,6 +411,63 @@ fn filter_out_removal_of_committer( Ok(proposals) } +fn filter_out_duplicate_updates( + strategy: FilterStrategy, + _commit_sender: LeafIndex, + mut proposals: ProposalBundle, +) -> Result { + let mut seen = vec![]; + proposals.retain_by_type::(|p| { + let Sender::Member(index) = p.sender else { + return Err(MlsError::InvalidSender); + }; + + let fresh = !seen.contains(&index); + seen.push(index); + + apply_strategy( + strategy, + p.is_by_reference(), + fresh + .then_some(()) + .ok_or(MlsError::UpdatingNonExistingMember), + ) + })?; + Ok(proposals) +} + +#[cfg(feature = "replace_proposal")] +fn filter_out_duplicate_replaces( + strategy: FilterStrategy, + _commit_sender: LeafIndex, + mut proposals: ProposalBundle, +) -> Result { + let mut seen: Vec<_> = proposals + .by_type::() + .map(|p| { + let Sender::Member(index) = p.sender else { + unreachable!() + }; + index + }) + .collect(); + + proposals.retain_by_type::(|p| { + let index = *p.proposal.to_replace; + let fresh = !seen.contains(&index); + seen.push(index); + + apply_strategy( + strategy, + p.is_by_reference(), + fresh + .then_some(()) + .ok_or(MlsError::UpdatingNonExistingMember), + ) + })?; + Ok(proposals) +} + #[cfg(feature = "by_ref_proposal")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn filter_out_invalid_group_extensions( @@ -432,6 +592,7 @@ pub(crate) fn proposer_can_propose( by_ref: bool, ) -> Result<(), MlsError> { let can_propose = match (proposer, by_ref) { + #[cfg(not(feature = "replace_proposal"))] (Sender::Member(_), false) => matches!( proposal_type, ProposalType::ADD @@ -440,6 +601,17 @@ pub(crate) fn proposer_can_propose( | ProposalType::RE_INIT | ProposalType::GROUP_CONTEXT_EXTENSIONS ), + #[cfg(feature = "replace_proposal")] + (Sender::Member(_), false) => matches!( + proposal_type, + ProposalType::ADD + | ProposalType::REPLACE + | ProposalType::REMOVE + | ProposalType::PSK + | ProposalType::RE_INIT + | ProposalType::GROUP_CONTEXT_EXTENSIONS + ), + #[cfg(not(feature = "replace_proposal"))] (Sender::Member(_), true) => matches!( proposal_type, ProposalType::ADD @@ -449,9 +621,20 @@ pub(crate) fn proposer_can_propose( | ProposalType::RE_INIT | ProposalType::GROUP_CONTEXT_EXTENSIONS ), + #[cfg(feature = "replace_proposal")] + (Sender::Member(_), true) => matches!( + proposal_type, + ProposalType::ADD + | ProposalType::UPDATE + | ProposalType::REPLACE + | ProposalType::REMOVE + | ProposalType::PSK + | ProposalType::RE_INIT + | ProposalType::GROUP_CONTEXT_EXTENSIONS + ), #[cfg(feature = "by_ref_proposal")] (Sender::External(_), false) => false, - #[cfg(feature = "by_ref_proposal")] + #[cfg(all(feature = "by_ref_proposal", not(feature = "replace_proposal")))] (Sender::External(_), true) => matches!( proposal_type, ProposalType::ADD @@ -460,6 +643,16 @@ pub(crate) fn proposer_can_propose( | ProposalType::PSK | ProposalType::GROUP_CONTEXT_EXTENSIONS ), + #[cfg(all(feature = "by_ref_proposal", feature = "replace_proposal"))] + (Sender::External(_), true) => matches!( + proposal_type, + ProposalType::ADD + | ProposalType::REMOVE + | ProposalType::REPLACE + | ProposalType::RE_INIT + | ProposalType::PSK + | ProposalType::GROUP_CONTEXT_EXTENSIONS + ), (Sender::NewMemberCommit, false) => matches!( proposal_type, ProposalType::REMOVE | ProposalType::PSK | ProposalType::EXTERNAL_INIT @@ -497,6 +690,16 @@ pub(crate) fn filter_out_invalid_proposers( } } + #[cfg(feature = "replace_proposal")] + for i in (0..proposals.replace_proposals().len()).rev() { + let p = &proposals.replace_proposals()[i]; + let res = proposer_can_propose(p.sender, ProposalType::REPLACE, p.is_by_reference()); + + if !apply_strategy(strategy, p.is_by_reference(), res)? { + proposals.remove::(i); + } + } + for i in (0..proposals.remove_proposals().len()).rev() { let p = &proposals.remove_proposals()[i]; let res = proposer_can_propose(p.sender, ProposalType::REMOVE, p.is_by_reference()); diff --git a/mls-rs/src/group/proposal_filter/filtering_common.rs b/mls-rs/src/group/proposal_filter/filtering_common.rs index 278c0dee..5fd24da5 100644 --- a/mls-rs/src/group/proposal_filter/filtering_common.rs +++ b/mls-rs/src/group/proposal_filter/filtering_common.rs @@ -41,8 +41,11 @@ use crate::group::{JustPreSharedKeyID, ResumptionPSKUsage, ResumptionPsk}; #[cfg(all(feature = "std", feature = "psk"))] use std::collections::HashSet; +#[cfg(all(feature = "by_ref_proposal", feature = "psk"))] +use super::filtering::apply_strategy; + #[cfg(feature = "by_ref_proposal")] -use super::filtering::{apply_strategy, filter_out_invalid_proposers, FilterStrategy}; +use super::filtering::{filter_out_invalid_proposers, FilterStrategy}; #[cfg(feature = "custom_proposal")] use super::filtering::filter_out_unsupported_custom_proposals; diff --git a/mls-rs/src/group/snapshot.rs b/mls-rs/src/group/snapshot.rs index 82d6f94d..8d49e524 100644 --- a/mls-rs/src/group/snapshot.rs +++ b/mls-rs/src/group/snapshot.rs @@ -14,18 +14,18 @@ use crate::{ }; #[cfg(feature = "by_ref_proposal")] -use crate::{ - crypto::{HpkePublicKey, HpkeSecretKey}, - group::{ - message_hash::MessageHash, - proposal_cache::{CachedProposal, ProposalCache}, - ProposalMessageDescription, ProposalRef, - }, - map::SmallMap, +use crate::{crypto::HpkePublicKey, group::PendingUpdate, map::SmallMap}; + +#[cfg(feature = "by_ref_proposal")] +use crate::group::{ + message_hash::MessageHash, + proposal_cache::{CachedProposal, ProposalCache}, + ProposalMessageDescription, ProposalRef, }; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::crypto::SignatureSecretKey; + #[cfg(feature = "tree_index")] use mls_rs_core::identity::IdentityProvider; @@ -38,7 +38,7 @@ pub(crate) struct Snapshot { epoch_secrets: EpochSecrets, key_schedule: KeySchedule, #[cfg(feature = "by_ref_proposal")] - pending_updates: SmallMap)>, + pending_updates: SmallMap, pending_commit: Option, signer: SignatureSecretKey, } diff --git a/mls-rs/src/group/test_utils.rs b/mls-rs/src/group/test_utils.rs index 7cfe5217..98bbe691 100644 --- a/mls-rs/src/group/test_utils.rs +++ b/mls-rs/src/group/test_utils.rs @@ -50,6 +50,19 @@ impl TestGroup { self.group.update_proposal(None, None).await.unwrap() } + #[cfg(feature = "replace_proposal")] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub(crate) async fn replace_proposal( + &mut self, + to_replace: u32, + leaf_node: LeafNode, + ) -> Proposal { + self.group + .replace_proposal(to_replace, leaf_node) + .await + .unwrap() + } + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(crate) async fn join_with_custom_config( &mut self, diff --git a/mls-rs/src/map.rs b/mls-rs/src/map.rs index 067072a5..8e9811f0 100644 --- a/mls-rs/src/map.rs +++ b/mls-rs/src/map.rs @@ -34,8 +34,6 @@ mod map_impl { collections::{btree_map::Entry, BTreeMap}, vec::Vec, }; - #[cfg(feature = "by_ref_proposal")] - use itertools::Itertools; #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmallMap(pub(super) Vec<(K, V)>); @@ -64,6 +62,13 @@ mod map_impl { fn find(&self, key: &K) -> Option { self.0.iter().position(|(k, _)| k == key) } + + pub fn retain(&mut self, mut f: F) + where + F: FnMut(&K, &mut V) -> bool, + { + self.0.retain_mut(|(k, v)| f(k, v)); + } } } diff --git a/mls-rs/src/tree_kem/leaf_node.rs b/mls-rs/src/tree_kem/leaf_node.rs index c59ed789..ba6958fc 100644 --- a/mls-rs/src/tree_kem/leaf_node.rs +++ b/mls-rs/src/tree_kem/leaf_node.rs @@ -267,10 +267,12 @@ pub(crate) mod test_utils { use crate::{ cipher_suite::CipherSuite, crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider}, + extension::ApplicationIdExt, identity::test_utils::{get_test_signing_identity, BasicWithCustomProvider}, }; - use crate::extension::ApplicationIdExt; + #[cfg(feature = "replace_proposal")] + use crate::extension::ExtensionType; use super::*; @@ -385,6 +387,10 @@ pub(crate) mod test_utils { CredentialType::from(BasicWithCustomProvider::CUSTOM_CREDENTIAL_TYPE), ], cipher_suites: TestCryptoProvider::all_supported_cipher_suites(), + + #[cfg(feature = "replace_proposal")] + extensions: vec![ExtensionType::LEAF_NODE_EPOCH], + ..Default::default() } } diff --git a/mls-rs/src/tree_kem/mod.rs b/mls-rs/src/tree_kem/mod.rs index 430ee161..01935dfe 100644 --- a/mls-rs/src/tree_kem/mod.rs +++ b/mls-rs/src/tree_kem/mod.rs @@ -26,6 +26,9 @@ use crate::crypto::{self, CipherSuiteProvider, HpkeSecretKey}; #[cfg(feature = "by_ref_proposal")] use crate::group::proposal::{AddProposal, UpdateProposal}; +#[cfg(feature = "replace_proposal")] +use crate::group::proposal::ReplaceProposal; + #[cfg(any(test, feature = "by_ref_proposal"))] use crate::group::proposal::RemoveProposal; @@ -369,13 +372,43 @@ impl TreeKemPublic { } } - // Remove from the tree old leaves from updates + // Remove from the tree old leaves from updates and replaces let mut partial_updates = vec![]; + let senders = proposal_bundle.update_senders.iter().copied(); + let updates = + proposal_bundle + .updates + .iter() + .zip(senders) + .enumerate() + .map(|(i, (p, index))| { + ( + true, + i, + index, + p.proposal.leaf_node.clone(), + p.is_by_reference(), + ) + }); + + #[cfg(not(feature = "replace_proposal"))] + let replaces = core::iter::empty(); + + #[cfg(feature = "replace_proposal")] + let replaces = proposal_bundle.replaces.iter().enumerate().map(|(i, p)| { + ( + false, + i, + p.proposal.to_replace, + p.proposal.leaf_node.clone(), + p.is_by_reference(), + ) + }); - for (i, (p, index)) in proposal_bundle.updates.iter().zip(senders).enumerate() { - let new_leaf = p.proposal.leaf_node.clone(); + let updates = updates.chain(replaces); + for (is_update, i, index, new_leaf, is_by_reference) in updates { match self.nodes.blank_leaf_node(index) { Ok(old_leaf) => { #[cfg(feature = "tree_index")] @@ -385,10 +418,10 @@ impl TreeKemPublic { #[cfg(feature = "tree_index")] self.index.remove(&old_leaf, &old_id); - partial_updates.push((index, old_leaf, new_leaf, i)); + partial_updates.push((index, old_leaf, new_leaf, is_update, i)); } _ => { - if !filter || !p.is_by_reference() { + if !filter || !is_by_reference { return Err(MlsError::UpdatingNonExistingMember); } } @@ -400,11 +433,12 @@ impl TreeKemPublic { let mut removed_leaves = vec![]; let mut updated_indices = vec![]; - let mut bad_indices = vec![]; + let mut bad_update_indices = vec![]; + let mut bad_replace_indices = vec![]; // Apply updates one by one. If there's an update which we can't apply or revert, we revert // all updates. - for (index, old_leaf, new_leaf, i) in partial_updates.into_iter() { + for (index, old_leaf, new_leaf, is_update, i) in partial_updates.into_iter() { #[cfg(feature = "tree_index")] let res = index_insert(&mut self.index, &new_leaf, index, id_provider, extensions).await; @@ -433,7 +467,11 @@ impl TreeKemPublic { if res.is_ok() { self.nodes.insert_leaf(index, old_leaf); - bad_indices.push(i); + if is_update { + bad_update_indices.push(i); + } else { + bad_replace_indices.push(i); + } } else { // Revert all updates and stop. We're already in the "filter" case, so we don't throw an error. #[cfg(feature = "tree_index")] @@ -461,11 +499,21 @@ impl TreeKemPublic { if updated_indices.is_empty() { // This takes care of the "revert all" scenario proposal_bundle.updates = vec![]; + + #[cfg(feature = "replace_proposal")] + { + proposal_bundle.replaces = vec![]; + } } else { - for i in bad_indices.into_iter().rev() { + for i in bad_update_indices.into_iter().rev() { proposal_bundle.remove::(i); proposal_bundle.update_senders.remove(i); } + + #[cfg(feature = "replace_proposal")] + for i in bad_replace_indices.into_iter().rev() { + proposal_bundle.remove::(i); + } } // Apply adds @@ -548,6 +596,43 @@ impl TreeKemPublic { self.nodes.blank_direct_path(index)?; } + // Apply replaces + #[cfg(not(feature = "replace_proposal"))] + let replaced = []; + + #[cfg(feature = "replace_proposal")] + let mut replaced = vec![]; + + #[cfg(feature = "replace_proposal")] + for p in &proposal_bundle.replaces { + let index = p.proposal.to_replace; + + #[cfg(feature = "tree_index")] + { + // If this fails, it's not because the proposal is bad. + let old_leaf = self.nodes.blank_leaf_node(index)?; + + let identity = + identity(&old_leaf.signing_identity, id_provider, extensions).await?; + + self.index.remove(&old_leaf, &identity); + index_insert( + &mut self.index, + &p.proposal.leaf_node, + index, + id_provider, + extensions, + ) + .await?; + } + + #[cfg(feature = "tree_index")] + self.nodes.insert_leaf(index, p.proposal.leaf_node); + + self.nodes.blank_direct_path(index)?; + replaced.push(index); + } + // Apply adds let mut start = LeafIndex(0); let mut added = vec![]; @@ -567,6 +652,7 @@ impl TreeKemPublic { .iter() .map(|p| p.proposal.to_remove) .chain(added.iter().copied()) + .chain(replaced.iter().copied()) .collect_vec(); self.update_hashes(&updated_leaves, cipher_suite_provider)