diff --git a/esdk/src/message/encrypted_data_keys.rs b/esdk/src/message/encrypted_data_keys.rs new file mode 100644 index 000000000..cd4b96d65 --- /dev/null +++ b/esdk/src/message/encrypted_data_keys.rs @@ -0,0 +1,236 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Encrypted data key serialization and deserialization. + +use super::serialize_functions::{read_seq_u16, read_str_u16, read_u16, write_bytes, write_u16}; +use super::{Error, ser_err}; +use crate::types::{SafeRead, SafeWrite}; +use aws_mpl_legacy::EncryptedDataKey; + + +pub(crate) fn write_edks(w: &mut dyn SafeWrite, edks: &[EncryptedDataKey]) -> Result<(), Error> { + //= spec/data-format/message-header.md#encrypted-data-keys + //# The Encrypted Data Keys MUST consist of, in order, + //# Encrypted Data Key Count, + //# and Encrypted Data Key Entries. + + // Encrypted Data Key Count + + //= spec/data-format/message-header.md#encrypted-data-key-count + //# This value MUST be greater than 0. + if edks.is_empty() { + return ser_err("Cannot serialize empty encrypted data keys list"); + } + + //= spec/data-format/message-header.md#encrypted-data-key-count + //# The length of the serialized encrypted data key count MUST be 2 bytes. + // + //= spec/data-format/message-header.md#encrypted-data-key-count + //# The encrypted data key count MUST be interpreted as a UInt16. + let Ok(edk_count) = u16::try_from(edks.len()) else { + return ser_err("Count too large for UInt16"); + }; + write_u16(w, edk_count)?; + + // Encrypted Data Key Entries + + for edk in edks { + write_edk(w, edk)?; + } + Ok(()) +} + +pub(crate) fn write_edk(w: &mut dyn SafeWrite, edk: &EncryptedDataKey) -> Result<(), Error> { + //= spec/data-format/message-header.md#encrypted-data-key-entries + //# Each Encrypted Data Key Entry MUST consist of, in order, + //# Key Provider ID Length, + //# Key Provider ID, + //# Key Provider Information Length, + //# Key Provider Information, + //# Encrypted Data Key Length, + //# and Encrypted Data Key. + + // Key Provider ID Length + + let kp_id_bytes = edk.key_provider_id.as_bytes(); + + //= spec/data-format/message-header.md#key-provider-id-length + //# The key provider ID length MUST be interpreted as a UInt16. + let Ok(kp_id_len) = u16::try_from(kp_id_bytes.len()) else { + return ser_err("Key provider ID length too long for 16 bits"); + }; + + //= spec/data-format/message-header.md#key-provider-id-length + //# The length of the serialized key provider ID length field MUST be 2 bytes. + write_u16(w, kp_id_len)?; + + // Key Provider ID + + //= spec/data-format/message-header.md#key-provider-id + //= reason=The length field is derived from the same byte slice that is serialized, so they are equal by construction. + //# The length of the serialized key provider ID MUST be equal to the value of the [Key Provider ID Length](#key-provider-id-length) field. + // + //= spec/data-format/message-header.md#key-provider-id + //# The key provider ID MUST be interpreted as UTF-8 encoded bytes. + write_bytes(w, kp_id_bytes)?; + + // Key Provider Information Length + + //= spec/data-format/message-header.md#key-provider-information-length + //# The key provider information length MUST be interpreted as a UInt16. + let Ok(kp_info_len) = u16::try_from(edk.key_provider_info.len()) else { + return ser_err("Key provider info length too long for 16 bits"); + }; + + //= spec/data-format/message-header.md#key-provider-information-length + //# The length of the serialized key provider information length field MUST be 2 bytes. + write_u16(w, kp_info_len)?; + + // Key Provider Information + + //= spec/data-format/message-header.md#key-provider-information + //= reason=The length field is derived from the same byte slice that is serialized, so they are equal by construction. + //# The length of the serialized key provider information MUST be equal to the value of the [Key Provider Information Length](#key-provider-information-length) field. + // + //= spec/data-format/message-header.md#key-provider-information + //# The key provider information MUST be interpreted as bytes. + write_bytes(w, &edk.key_provider_info)?; + + // Encrypted Data Key Length + + //= spec/data-format/message-header.md#encrypted-data-key-length + //# The encrypted data key length MUST be interpreted as a UInt16. + let Ok(edk_len) = u16::try_from(edk.ciphertext.len()) else { + return ser_err("Encrypted data key length too long for 16 bits"); + }; + + //= spec/data-format/message-header.md#encrypted-data-key-length + //# The length of the serialized encrypted data key length field MUST be 2 bytes. + write_u16(w, edk_len)?; + + // Encrypted Data Key + + //= spec/data-format/message-header.md#encrypted-data-key + //= reason=The length field is derived from the same byte slice that is serialized, so they are equal by construction. + //# The length of the serialized encrypted data key MUST be equal to the value of the [Encrypted Data Key Length](#encrypted-data-key-length) field. + // + //= spec/data-format/message-header.md#encrypted-data-key + //= type=implication + //# The encrypted data key MUST be interpreted as bytes. + write_bytes(w, &edk.ciphertext) +} + +pub(crate) fn read_edks( + r: &mut dyn SafeRead, + max_edks: Option, + raw: &mut dyn SafeWrite, +) -> Result, Error> { + //= spec/data-format/message-header.md#encrypted-data-keys + //# The Encrypted Data Keys MUST consist of, in order, + //# Encrypted Data Key Count, + //# and Encrypted Data Key Entries. + + // Encrypted Data Key Count + + //= spec/data-format/message-header.md#encrypted-data-key-count + //# The length of the serialized encrypted data key count MUST be 2 bytes. + // + //= spec/data-format/message-header.md#encrypted-data-key-count + //# The encrypted data key count MUST be interpreted as a UInt16. + let count = read_u16(r, raw)?; + + //= spec/data-format/message-header.md#encrypted-data-key-count + //# This value MUST be greater than 0. + if count == 0 { + return ser_err("Encrypted data key count must be greater than 0"); + } + + if let Some(max_edks) = max_edks + && usize::from(count) > max_edks.get() + { + //= spec/data-format/message-header.md#encrypted-data-key-count + //# This value MUST be less than or equal to the [maximum number of encrypted data keys](../client-apis/client.md#maximum-number-of-encrypted-data-keys) if the maximum number is configured. + // + //= spec/client-apis/decrypt.md#v2-header-deserialization + //# If the number of [encrypted data keys](../framework/structures.md#encrypted-data-keys) + //# deserialized from the [message header](../data-format/message-header.md) + //# is greater than the [maximum number of encrypted data keys](client.md#maximum-number-of-encrypted-data-keys) configured in the [client](client.md), + //# then as soon as that can be determined during deserializing + //# decrypt MUST process no more bytes and yield an error. + return ser_err("Ciphertext encrypted data keys exceed maximum encrypted data keys limit"); + } + + // Encrypted Data Key Entries + + let mut edks = Vec::with_capacity(usize::from(count)); + for _ in 0..count { + edks.push(read_edk(r, raw)?); + } + Ok(edks) +} + +pub(crate) fn read_edk( + r: &mut dyn SafeRead, + raw: &mut dyn SafeWrite, +) -> Result { + //= spec/data-format/message-header.md#encrypted-data-key-entries + //# Each Encrypted Data Key Entry MUST consist of, in order, + //# Key Provider ID Length, + //# Key Provider ID, + //# Key Provider Information Length, + //# Key Provider Information, + //# Encrypted Data Key Length, + //# and Encrypted Data Key. + + // Key Provider ID Length and Key Provider ID + + //= spec/data-format/message-header.md#key-provider-id-length + //# The key provider ID length MUST be interpreted as a UInt16. + // + //= spec/data-format/message-header.md#key-provider-id-length + //# The length of the serialized key provider ID length field MUST be 2 bytes. + // + //= spec/data-format/message-header.md#key-provider-id + //= reason=read_str_u16 reads a u16 length then that many bytes, so the length field and data are equal by construction. + //# The length of the serialized key provider ID MUST be equal to the value of the [Key Provider ID Length](#key-provider-id-length) field. + // + //= spec/data-format/message-header.md#key-provider-id + //# The key provider ID MUST be interpreted as UTF-8 encoded bytes. + let provider_id = read_str_u16(r, raw)?; + + // Key Provider Information Length and Key Provider Information + + //= spec/data-format/message-header.md#key-provider-information-length + //# The key provider information length MUST be interpreted as a UInt16. + // + //= spec/data-format/message-header.md#key-provider-information-length + //# The length of the serialized key provider information length field MUST be 2 bytes. + // + //= spec/data-format/message-header.md#key-provider-information + //= reason=read_seq_u16 reads a u16 length then that many bytes, so the length field and data are equal by construction. + //# The length of the serialized key provider information MUST be equal to the value of the [Key Provider Information Length](#key-provider-information-length) field. + // + //= spec/data-format/message-header.md#key-provider-information + //# The key provider information MUST be interpreted as bytes. + let provider_info = read_seq_u16(r, raw)?; + + // Encrypted Data Key Length and Encrypted Data Key + + //= spec/data-format/message-header.md#encrypted-data-key-length + //# The encrypted data key length MUST be interpreted as a UInt16. + // + //= spec/data-format/message-header.md#encrypted-data-key-length + //# The length of the serialized encrypted data key length field MUST be 2 bytes. + // + //= spec/data-format/message-header.md#encrypted-data-key + //= reason=read_seq_u16 reads a u16 length then that many bytes, so the length field and data are equal by construction. + //# The length of the serialized encrypted data key MUST be equal to the value of the [Encrypted Data Key Length](#encrypted-data-key-length) field. + // + //= spec/data-format/message-header.md#encrypted-data-key + //= type=implication + //# The encrypted data key MUST be interpreted as bytes. + let edk = read_seq_u16(r, raw)?; + + Ok(EncryptedDataKey::new(provider_id, provider_info, edk)) +} diff --git a/esdk/src/message/encryption_context.rs b/esdk/src/message/encryption_context.rs new file mode 100644 index 000000000..e1b886820 --- /dev/null +++ b/esdk/src/message/encryption_context.rs @@ -0,0 +1,128 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Encryption context serialization for message header and AAD. + +use super::serializable_types::ESDKCanonicalEncryptionContext; +use super::serialize_functions::{read_str_u16, read_u16, write_bytes, write_u16}; +use super::{Error, ser_err}; +use crate::types::{SafeRead, SafeWrite}; + +/// Read the header's AAD encryption context sub-section. +pub(crate) fn read_canonical_ec( + r: &mut dyn SafeRead, + raw: &mut dyn SafeWrite, +) -> Result { + // Empty EC: length 0, no further bytes. + let bytes = usize::from(read_u16(r, raw)?); + if bytes == 0 { + return Ok(Vec::new()); + } + + // Count, then `count` (key, value) pairs. + let count = usize::from(read_u16(r, raw)?); + let mut result: ESDKCanonicalEncryptionContext = Vec::with_capacity(count); + for _ in 0..count { + let key = read_str_u16(r, raw)?; + let value = read_str_u16(r, raw)?; + result.push((key, value)); + } + + Ok(result) +} + +/// Write canonical EC bytes for signing/AES-GCM AAD; empty EC writes nothing. +pub(crate) fn write_empty_ec_or_write_aad( + w: &mut dyn SafeWrite, + data: &ESDKCanonicalEncryptionContext, +) -> Result<(), Error> { + if data.is_empty() { + //= spec/data-format/message-header.md#key-value-pairs + //# When the [encryption context](../framework/structures.md#encryption-context) is empty, + //# this field MUST NOT be included in the [AAD](#aad). + Ok(()) + } else { + write_aad(w, data) + } +} + +/// Serialized length of the key-value-pairs body in bytes. +fn get_length(data: &ESDKCanonicalEncryptionContext) -> usize { + let mut length = 0; + for pair in data { + // key_len(2) + key bytes + val_len(2) + val bytes. + // `.len()` on a String returns the number of UTF-8 bytes, which is what we serialize. + length += 2 + pair.0.len() + 2 + pair.1.len(); + } + length +} + +/// Write the header's AAD EC sub-section: length + key-value pairs. +pub(crate) fn write_aad_section( + w: &mut dyn SafeWrite, + data: &ESDKCanonicalEncryptionContext, +) -> Result<(), Error> { + if data.is_empty() { + //= spec/data-format/message-header.md#key-value-pairs-length + //# When the [encryption context](../framework/structures.md#encryption-context) is empty, the value of this field MUST be 0. + write_u16(w, 0)?; + return Ok(()); + } + + //= spec/data-format/message-header.md#aad + //# The AAD MUST consist of, in order, + //# Key Value Pairs Length, + //# and Key Value Pairs. + + // Key Value Pairs Length: covers the Key Value Pair Count field plus all pairs. + + //= spec/data-format/message-header.md#key-value-pairs-length + //# The length of the serialized key value pairs length field MUST be 2 bytes. + let bytes = 2 + get_length(data); // 2 for the Key Value Pair Count UInt16. + + //= spec/data-format/message-header.md#key-value-pairs-length + //# The length of the serialized key value pairs length field MUST be 2 bytes. + // + //= spec/data-format/message-header.md#key-value-pairs-length + //# The key value pairs length MUST be interpreted as a UInt16. + let Ok(bytes_u16) = u16::try_from(bytes) else { + return ser_err("Encryption context key value pair length value is too large for u16"); + }; + write_u16(w, bytes_u16)?; + + // Key Value Pairs + + write_aad(w, data) +} + +/// Write the key-value-pairs body: count, then (key, value) pairs. +pub(crate) fn write_aad( + w: &mut dyn SafeWrite, + data: &ESDKCanonicalEncryptionContext, +) -> Result<(), Error> { + // Count. + let Ok(data_len) = u16::try_from(data.len()) else { + return ser_err("Encryption context key value pair count is too large for u16"); + }; + write_u16(w, data_len)?; + + for pair in data { + //= spec/data-format/message-header.md#key-value-pairs + //# The encryption context key-value pairs MUST be serialized according to its [specification for serialization](../framework/structures.md#serialization). + + // Key: length + UTF-8 bytes. + let Ok(key_len) = u16::try_from(pair.0.len()) else { + return ser_err("Encryption context key length is too large for u16"); + }; + write_u16(w, key_len)?; + write_bytes(w, pair.0.as_bytes())?; + + // Value: length + UTF-8 bytes. + let Ok(val_len) = u16::try_from(pair.1.len()) else { + return ser_err("Encryption context value length is too large for u16"); + }; + write_u16(w, val_len)?; + write_bytes(w, pair.1.as_bytes())?; + } + Ok(()) +} diff --git a/esdk/src/message/serializable_types.rs b/esdk/src/message/serializable_types.rs new file mode 100644 index 000000000..243eb0f81 --- /dev/null +++ b/esdk/src/message/serializable_types.rs @@ -0,0 +1,123 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Type aliases and helper functions for message serialization. + +use crate::types::EncryptionContext; +use aws_mpl_legacy::EncryptedDataKey; +use aws_mpl_legacy::suites::AlgorithmSuite; + +/// Unordered encryption context from the public API. +pub(crate) type ESDKEncryptionContext = EncryptionContext; +pub(crate) type ESDKEncryptionContextPair = (String, String); +/// Sorted-by-key encryption context used for on-wire serialization. +pub(crate) type ESDKCanonicalEncryptionContext = Vec; + +/// Max total key-value-pairs body length: the outer UInt16 length field (2 bytes) +/// plus the count UInt16 (2 bytes) must together still fit in a UInt16, leaving +/// u16::MAX - 2 for the payload. +const ESDK_CANONICAL_ENCRYPTION_CONTEXT_MAX_LENGTH: u64 = u16::MAX as u64 - 2; + +pub(crate) const fn get_iv_length(a: &AlgorithmSuite) -> u8 { + match a.encrypt { + aws_mpl_legacy::suites::Encrypt::AesGcm(_e) => 12, + _ => 0, + } +} + +pub(crate) const fn get_tag_length(a: &AlgorithmSuite) -> u8 { + match a.encrypt { + aws_mpl_legacy::suites::Encrypt::AesGcm(_e) => 16, + _ => 0, + } +} + +pub(crate) const fn get_encrypt_key_length(a: &AlgorithmSuite) -> u8 { + match a.encrypt { + aws_mpl_legacy::suites::Encrypt::AesGcm(e) => e.key_len(), + _ => 0, + } +} + +// Length properties of the Encryption Context. +// The Encryption Context has a complex relationship with length. +// Each key or value MUST be less than Uint16, +// However the entire thing MUST also serialize to less than Uint16. +// In practice, this means than the longest value, +// given a key of 1 bytes is Uint16-2-2-1. +// e.g. +// 2 for the key length +// 1 for the key data +// 2 for the value length +// Uint16-2-2-1 for the value data + +/// Serialized byte length of the key-value-pairs body (no outer length prefix). +/// +/// Accumulates in `usize` and casts to `u64` on return. Per the ESDK message +/// format, the AAD's maximum allowed length is `2^16 - 1` bytes, so a legal +/// encryption context never produces a sum that overflows even a 16-bit +/// accumulator — the `usize` sum and `as u64` cast are safe by construction. +pub(crate) fn length(encryption_context: &ESDKEncryptionContext) -> u64 { + let mut length: usize = 0; + for (key, value) in encryption_context { + length += 2 + key.len() + 2 + value.len(); + } + length as u64 +} + +/// Sort by key to produce the canonical on-wire ordering. +pub(crate) fn to_canonical_pairs( + encryption_context: ESDKEncryptionContext, +) -> ESDKCanonicalEncryptionContext { + let mut pairs: Vec<(String, String)> = encryption_context.into_iter().collect(); + pairs.sort_by(|a, b| a.0.cmp(&b.0)); + pairs +} + +pub(crate) fn from_canonical_pairs(pairs: ESDKCanonicalEncryptionContext) -> ESDKEncryptionContext { + let mut map: ESDKEncryptionContext = ESDKEncryptionContext::new(); + for (key, value) in pairs { + map.insert(key, value); + } + map +} + +/// True iff `ec` fits the on-wire encoding: pair count, each key/value length, +/// and total serialized length all fit in a UInt16. +pub(crate) fn is_esdk_encryption_context(ec: &EncryptionContext) -> bool { + if ec.len() > usize::from(u16::MAX) { + return false; + } + if length(ec) > ESDK_CANONICAL_ENCRYPTION_CONTEXT_MAX_LENGTH { + return false; + } + for (key, value) in ec { + if key.len() > usize::from(u16::MAX) { + return false; + } + if value.len() > usize::from(u16::MAX) { + return false; + } + } + true +} + +/// True iff every EDK field length fits in a UInt16. +pub(crate) fn is_esdk_encrypted_data_key(edk: &EncryptedDataKey) -> bool { + u16::try_from(edk.key_provider_id.len()).is_ok() + && u16::try_from(edk.key_provider_info.len()).is_ok() + && u16::try_from(edk.ciphertext.len()).is_ok() +} + +/// True iff the EDK count and each entry fit in UInt16. +pub(crate) fn is_esdk_encrypted_data_keys(edks: &[EncryptedDataKey]) -> bool { + if edks.len() > usize::from(u16::MAX) { + return false; + } + for edk in edks { + if !is_esdk_encrypted_data_key(edk) { + return false; + } + } + true +} diff --git a/esdk/src/message/serialize_functions.rs b/esdk/src/message/serialize_functions.rs new file mode 100644 index 000000000..b2db5ecb4 --- /dev/null +++ b/esdk/src/message/serialize_functions.rs @@ -0,0 +1,210 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +//! Low-level byte read/write primitives for message serialization. +//! +//! The `raw` SafeWrite parameter on every `read_*` function tees the consumed +//! bytes into a mirror buffer so callers can reconstruct the exact raw header +//! bytes used for authentication and signing. + +use super::{Error, ser_err}; +use crate::error::ErrorKind; +use crate::types::{SafeRead, SafeWrite}; +use std::backtrace::Backtrace; +use std::sync::Arc; + +#[track_caller] +fn ser_io(e: std::io::Error) -> Error { + match e.kind() { + std::io::ErrorKind::UnexpectedEof => Error { + kind: ErrorKind::SerializationError, + message: "Unexpected end of data".into(), + cause: Some(Arc::new(e)), + backtrace: Arc::new(Backtrace::capture()), + }, + _ => Error { + kind: ErrorKind::SerializationError, + message: "IO Error".into(), + cause: Some(Arc::new(e)), + backtrace: Arc::new(Backtrace::capture()), + }, + } +} + +/// Read up to `buf.len()` bytes; returns the number actually read (may be < len on EOF). +pub(crate) fn read_up_to(this: &mut dyn SafeRead, buf: &mut [u8]) -> Result { + let mut curr: usize = 0; + loop { + match this.read(&mut buf[curr..]) { + Ok(0) => { + return Ok(curr); + } + Ok(n) => { + curr += n; + if curr == buf.len() { + return Ok(curr); + } + } + // Err(ref e) if e.is_interrupted() => {} + Err(e) => return Err(ser_io(e)), + } + } +} + +/// Like `read_up_to`, but `first` is a one-byte peek already consumed by the caller. +pub(crate) fn read_up_to_peek( + this: &mut dyn SafeRead, + buf: &mut [u8], + first: Option, +) -> Result { + if buf.is_empty() { + return Ok(0); + } + match first { + Some(f) => { + buf[0] = f; + match read_up_to(this, &mut buf[1..]) { + Ok(n) => Ok(n + 1), + Err(e) => Err(e), + } + } + None => read_up_to(this, buf), + } +} + +#[track_caller] +fn ser_utf8(item: std::string::FromUtf8Error) -> Error { + Error { + kind: ErrorKind::SerializationError, + message: "UTF8 Decode Error".into(), + cause: Some(Arc::new(item)), + backtrace: Arc::new(Backtrace::capture()), + } +} + +pub(crate) fn write_bytes(w: &mut dyn SafeWrite, data: &[u8]) -> Result<(), Error> { + w.write_all(data).map_err(ser_io)?; + Ok(()) +} + +// Big-endian fixed-width writers. +pub(crate) fn write_u8(w: &mut dyn SafeWrite, data: u8) -> Result<(), Error> { + write_bytes(w, &data.to_be_bytes()) +} +pub(crate) fn write_u16(w: &mut dyn SafeWrite, data: u16) -> Result<(), Error> { + write_bytes(w, &data.to_be_bytes()) +} +pub(crate) fn write_u32(w: &mut dyn SafeWrite, data: u32) -> Result<(), Error> { + write_bytes(w, &data.to_be_bytes()) +} + +/// Read exactly `buf.len()` bytes and mirror them into `raw`. +pub(crate) fn read_bytes( + r: &mut dyn SafeRead, + buf: &mut [u8], + raw: &mut dyn SafeWrite, +) -> Result<(), Error> { + r.read_exact(buf).map_err(ser_io)?; + write_bytes(raw, buf) +} + +/// Read exactly `length` bytes into a fresh Vec. +pub(crate) fn read_vec( + r: &mut dyn SafeRead, + length: usize, + raw: &mut dyn SafeWrite, +) -> Result, Error> { + let mut result = vec![0; length]; + read_bytes(r, &mut result, raw)?; + Ok(result) +} + +// Big-endian fixed-width readers. Each mirrors the consumed bytes into `raw`. +pub(crate) fn read_u8(r: &mut dyn SafeRead, raw: &mut dyn SafeWrite) -> Result { + let mut result = [0u8; 1]; + read_bytes(r, &mut result, raw)?; + Ok(result[0]) +} + +/// Read one byte, returning `Ok(None)` on clean EOF. Does NOT mirror into a +/// raw buffer (used for streaming peek). +pub(crate) fn read_opt_u8(r: &mut dyn SafeRead) -> Result, Error> { + let mut result = [0u8; 1]; + match r.read_exact(&mut result) { + Ok(()) => Ok(Some(result[0])), + Err(e) => match e.kind() { + std::io::ErrorKind::UnexpectedEof => Ok(None), + _ => Err(ser_io(e)), + }, + } +} + +pub(crate) fn read_u16(r: &mut dyn SafeRead, raw: &mut dyn SafeWrite) -> Result { + let mut result = [0u8; 2]; + read_bytes(r, &mut result, raw)?; + Ok(u16::from_be_bytes(result)) +} +pub(crate) fn read_u32(r: &mut dyn SafeRead, raw: &mut dyn SafeWrite) -> Result { + let mut result = [0u8; 4]; + read_bytes(r, &mut result, raw)?; + Ok(u32::from_be_bytes(result)) +} +pub(crate) fn read_u64(r: &mut dyn SafeRead, raw: &mut dyn SafeWrite) -> Result { + let mut result = [0u8; 8]; + read_bytes(r, &mut result, raw)?; + Ok(u64::from_be_bytes(result)) +} + +/// Read a UInt16 length prefix followed by that many bytes. +pub(crate) fn read_seq_u16( + r: &mut dyn SafeRead, + raw: &mut dyn SafeWrite, +) -> Result, Error> { + let len = read_u16(r, raw)?; + read_vec(r, usize::from(len), raw) +} + +/// Read a UInt32 length prefix followed by that many bytes into `data`, +/// rejecting lengths above `bound`. +pub(crate) fn read_seq_u32_bounded( + r: &mut dyn SafeRead, + bound: u32, + msg: &str, + data: &mut Vec, + raw: &mut dyn SafeWrite, +) -> Result<(), Error> { + let len = read_u32(r, raw)?; + if len > bound { + return ser_err(msg); + } + let Ok(len_usize) = usize::try_from(len) else { + return ser_err("length too large for platform"); + }; + data.resize(len_usize, 0); + read_bytes(r, &mut data[..], raw) +} + +/// Read a UInt64 length prefix followed by that many bytes, rejecting lengths +/// above `bound`. +pub(crate) fn read_seq_u64_bounded( + r: &mut dyn SafeRead, + bound: u64, + msg: &str, + raw: &mut dyn SafeWrite, +) -> Result, Error> { + let len = read_u64(r, raw)?; + if len > bound { + return ser_err(msg); + } + let Ok(len_usize) = usize::try_from(len) else { + return ser_err("length too large for platform"); + }; + read_vec(r, len_usize, raw) +} + +/// Read a UInt16-prefixed UTF-8 string. +pub(crate) fn read_str_u16(r: &mut dyn SafeRead, raw: &mut dyn SafeWrite) -> Result { + let len = read_u16(r, raw)?; + let result = read_vec(r, usize::from(len), raw)?; + let result = String::from_utf8(result).map_err(ser_utf8)?; + Ok(result) +} diff --git a/esdk/tests/test_encrypted_data_keys.rs b/esdk/tests/test_encrypted_data_keys.rs new file mode 100644 index 000000000..ef9287f40 --- /dev/null +++ b/esdk/tests/test_encrypted_data_keys.rs @@ -0,0 +1,377 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Tests for the Encrypted Data Keys sections of spec/data-format/message-header.md + +mod fixtures; +mod test_helpers; + +use aws_esdk::*; +use aws_mpl_legacy::commitment::EsdkCommitmentPolicy; +use fixtures::*; +use test_helpers::*; + +#[tokio::test(flavor = "multi_thread")] +async fn test_encrypted_data_keys_ordering() { + let single = aes_keyring(0).await; + let generator = aes_keyring(0).await; + let c1 = aes_keyring(1).await; + let c2 = aes_keyring(2).await; + let triple = multi_keyring(generator, vec![c1, c2]).await; + + // Cover both the single-EDK case and the multi-EDK case so that the + // "Count, then Entries" structure is exercised with the count field set + // to both 1 and 3. + for (label, keyring, expected_count) in [ + ("single", single, 1u16), + ("triple", triple, 3u16), + ] { + for version in VERSIONS { + //= spec/data-format/message-header.md#encrypted-data-keys + //= type=test + //# The Encrypted Data Keys MUST consist of, in order, + //# Encrypted Data Key Count, + //# and Encrypted Data Key Entries. + let ct = encrypt_with_version(b"ordering test", version, keyring.clone()).await; + let edk_section_start = skip_to_edk_section(&ct, version); + let parsed = parse_edk_section(&ct, version); + + // 1. Encrypted Data Key Count (2 bytes) + let count = u16::from_be_bytes([ct[edk_section_start], ct[edk_section_start + 1]]); + assert_eq!(count, expected_count, "{label} {version:?}: count field value"); + assert_eq!(parsed.edk_count, count); + + // 2. Encrypted Data Key Entries (immediately after the count) + let entries_start = edk_section_start + 2; + let first_pid_len = u16::from_be_bytes([ct[entries_start], ct[entries_start + 1]]); + assert_eq!( + first_pid_len, parsed.edks[0].provider_id_len, + "{label} {version:?}: EDK entries must immediately follow the count field" + ); + assert_eq!( + parsed.edks.len(), + count as usize, + "{label} {version:?}: parsed entries must match count" + ); + } + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_edk_section_length_fields_are_big_endian_uint16() { + let generator = aes_keyring(0).await; + let child = aes_keyring(1).await; + let mk = multi_keyring(generator, vec![child]).await; + let (expected_ns, _) = namespace_and_name(0); + let expected_pid_len = expected_ns.len() as u16; + + for version in VERSIONS { + let ct = encrypt_with_version(b"length fields uint16", version, mk.clone()).await; + let parsed = parse_edk_section(&ct, version); + let edk = &parsed.edks[0]; + + // Decode each length field directly from the wire as a big-endian UInt16. + let entries_start = parsed.edk_count_offset + 2; + let pid_len_offset = entries_start; + let pinfo_len_offset = entries_start + 2 + edk.provider_id_len as usize; + let edk_len_offset = pinfo_len_offset + 2 + edk.provider_info_len as usize; + + let count_wire = u16::from_be_bytes([ct[parsed.edk_count_offset], ct[parsed.edk_count_offset + 1]]); + let pid_len_wire = u16::from_be_bytes([ct[pid_len_offset], ct[pid_len_offset + 1]]); + let pinfo_len_wire = u16::from_be_bytes([ct[pinfo_len_offset], ct[pinfo_len_offset + 1]]); + let edk_len_wire = u16::from_be_bytes([ct[edk_len_offset], ct[edk_len_offset + 1]]); + + // EDK count: 2 keyrings → UInt16 value 2 ([0x00, 0x02]). + //= spec/data-format/message-header.md#encrypted-data-key-count + //= type=test + //# The length of the serialized encrypted data key count MUST be 2 bytes. + // + //= spec/data-format/message-header.md#encrypted-data-key-count + //= type=test + //# The encrypted data key count MUST be interpreted as a UInt16. + assert_eq!(count_wire, 2, "{version:?}: EDK count UInt16 value"); + assert_eq!(ct[parsed.edk_count_offset], 0x00, "{version:?}: EDK count high byte"); + assert_eq!(ct[parsed.edk_count_offset + 1], 0x02, "{version:?}: EDK count low byte"); + + // Key provider ID length: the UInt16 at this offset equals the known keyring namespace byte length. + //= spec/data-format/message-header.md#key-provider-id-length + //= type=test + //# The length of the serialized key provider ID length field MUST be 2 bytes. + // + //= spec/data-format/message-header.md#key-provider-id-length + //= type=test + //# The key provider ID length MUST be interpreted as a UInt16. + assert_eq!(pid_len_wire, expected_pid_len, "{version:?}: provider ID length UInt16 value"); + + // Key provider information length: the UInt16 at this offset must be positive for a raw AES keyring + // (which packs key name + bit length + IV length + IV into provider info). + //= spec/data-format/message-header.md#key-provider-information-length + //= type=test + //# The length of the serialized key provider information length field MUST be 2 bytes. + // + //= spec/data-format/message-header.md#key-provider-information-length + //= type=test + //# The key provider information length MUST be interpreted as a UInt16. + assert!(pinfo_len_wire > 0, "{version:?}: provider info length UInt16 must be positive"); + + // Encrypted data key length: raw AES keyring stores IV in provider_info; the ciphertext field is + // wrapped data key (32 bytes) + GCM tag (16 bytes) = 48. + //= spec/data-format/message-header.md#encrypted-data-key-length + //= type=test + //# The length of the serialized encrypted data key length field MUST be 2 bytes. + // + //= spec/data-format/message-header.md#encrypted-data-key-length + //= type=test + //# The encrypted data key length MUST be interpreted as a UInt16. + assert_eq!(edk_len_wire, 48, "{version:?}: EDK ciphertext length UInt16 value (wrapped 32B key + 16B tag)"); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_edk_count_zero_rejected_on_decrypt() { + let keyring = aes_keyring(0).await; + + for version in VERSIONS { + let mut ct = encrypt_with_version(b"zero count", version, keyring.clone()).await; + let offset = skip_to_edk_section(&ct, version); + // Tamper: set count to 0. + ct[offset] = 0x00; + ct[offset + 1] = 0x00; + let mut dec = + DecryptInput::with_legacy_keyring(&ct, EncryptionContext::new(), keyring.clone()); + if let Version::V1 = version { + dec.commitment_policy = EsdkCommitmentPolicy::ForbidEncryptAllowDecrypt; + } + + //= spec/data-format/message-header.md#encrypted-data-key-count + //= type=test + //= reason=Tampering the count to 0 and verifying decrypt rejects it proves the >0 constraint is enforced on the deserialization path. + //# This value MUST be greater than 0. + let err = decrypt(&dec) + .await + .expect_err(&format!("{version:?}: decrypt must reject EDK count of 0")); + assert!( + matches!(err.kind, aws_esdk::ErrorKind::SerializationError), + "{version:?}: expected SerializationError, got: {} ({:?})", + err.message, err.kind + ); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_edk_count_max_enforcement() { + let generator = aes_keyring(0).await; + let child = aes_keyring(1).await; + let mk = multi_keyring(generator, vec![child]).await; + + let expect_exceed = |err: &aws_esdk::Error| { + assert!( + err.message.contains("exceed") && err.message.contains("maximum"), + "error must indicate EDK count exceeds maximum, got: {} ({:?})", + err.message, err.kind + ); + }; + + //= spec/data-format/message-header.md#encrypted-data-key-count + //= type=test + //# This value MUST be less than or equal to the [maximum number of encrypted data keys](../client-apis/client.md#maximum-number-of-encrypted-data-keys) if the maximum number is configured. + + // Encrypt with 2 EDKs and max=1 → error. + let mut enc_over = EncryptInput::with_legacy_keyring( + b"max edk encrypt", + EncryptionContext::new(), + mk.clone(), + ); + enc_over.max_encrypted_data_keys = Some(std::num::NonZeroUsize::new(1).unwrap()); + expect_exceed(&encrypt(&enc_over).await.expect_err("encrypt must fail when EDK count exceeds max")); + + // Decrypt a 2-EDK message with max=1 → error. + let ct = encrypt_with_version(b"max edk decrypt", Version::V2, mk.clone()).await; + let mut dec_over = DecryptInput::with_legacy_keyring(&ct, EncryptionContext::new(), mk.clone()); + dec_over.max_encrypted_data_keys = Some(std::num::NonZeroUsize::new(1).unwrap()); + expect_exceed(&decrypt(&dec_over).await.expect_err("decrypt must fail when EDK count exceeds max")); + + // Encrypt with 2 EDKs and max=2 → ok (the "equal to" side of ≤). + let mut enc_at = EncryptInput::with_legacy_keyring(b"at max", EncryptionContext::new(), mk); + enc_at.max_encrypted_data_keys = Some(std::num::NonZeroUsize::new(2).unwrap()); + assert!( + encrypt(&enc_at).await.is_ok(), + "encrypt must succeed when EDK count equals max" + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_edk_entry_field_order() { + let keyring = aes_keyring(0).await; + + for version in VERSIONS { + let ct = encrypt_with_version(b"entry order", version, keyring.clone()).await; + let edk_start = skip_to_edk_section(&ct, version) + 2; // skip count + let mut pos = edk_start; + + //= spec/data-format/message-header.md#encrypted-data-key-entries + //= type=test + //# Each Encrypted Data Key Entry MUST consist of, in order, + //# Key Provider ID Length, + //# Key Provider ID, + //# Key Provider Information Length, + //# Key Provider Information, + //# Encrypted Data Key Length, + //# and Encrypted Data Key. + + // 1. Key Provider ID Length (2 bytes) + let pid_len = u16::from_be_bytes([ct[pos], ct[pos + 1]]); + pos += 2; + assert!(pid_len > 0, "{version:?}: provider ID length must be positive"); + + // 2. Key Provider ID (pid_len bytes) + let pid = &ct[pos..pos + pid_len as usize]; + let pid_str = std::str::from_utf8(pid).expect("provider ID must be valid UTF-8"); + let (expected_ns, _) = namespace_and_name(0); + assert_eq!( + pid_str, expected_ns, + "{version:?}: provider ID must match keyring namespace" + ); + pos += pid_len as usize; + + // 3. Key Provider Information Length (2 bytes) + let pinfo_len = u16::from_be_bytes([ct[pos], ct[pos + 1]]); + pos += 2; + + // 4. Key Provider Information (pinfo_len bytes) + let _pinfo = &ct[pos..pos + pinfo_len as usize]; + pos += pinfo_len as usize; + + // 5. Encrypted Data Key Length (2 bytes) + let edk_len = u16::from_be_bytes([ct[pos], ct[pos + 1]]); + pos += 2; + assert!( + edk_len > 0, + "{version:?}: encrypted data key length must be positive" + ); + + // 6. Encrypted Data Key (edk_len bytes) + let _edk = &ct[pos..pos + edk_len as usize]; + pos += edk_len as usize; + + // Verify we consumed exactly one entry and the position matches the parser. + let parsed = parse_edk_section(&ct, version); + assert_eq!( + pos, parsed.end_offset, + "{version:?}: manual walk must match parser end offset" + ); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_edk_entries_preserve_keyring_order() { + let generator = aes_keyring(0).await; + let c1 = aes_keyring(1).await; + let c2 = aes_keyring(2).await; + let mk = multi_keyring(generator, vec![c1, c2]).await; + + for version in VERSIONS { + let ct = encrypt_with_version(b"order check", version, mk.clone()).await; + let parsed = parse_edk_section(&ct, version); + + //= spec/data-format/message-header.md#encrypted-data-keys + //= type=test + //= reason=Verifying that EDK provider IDs appear in generator-then-children order proves entries are serialized in the order they appear in the encryption materials, exercising the "Entries" component of the Count+Entries structure. + //# The Encrypted Data Keys MUST consist of, in order, + //# Encrypted Data Key Count, + //# and Encrypted Data Key Entries. + for (i, edk) in parsed.edks.iter().enumerate() { + let pid_str = std::str::from_utf8(&edk.provider_id).unwrap(); + let (expected_ns, _) = namespace_and_name(i as u8); + assert_eq!( + pid_str, expected_ns, + "{version:?}: EDK {i} provider ID must match keyring {i} namespace" + ); + } + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_edk_entry_lengths_match_fields() { + // Multi-keyring so multiple entries are checked per run. + let generator = aes_keyring(0).await; + let child = aes_keyring(1).await; + let mk = multi_keyring(generator, vec![child]).await; + + for version in VERSIONS { + let ct = encrypt_with_version(b"entry lengths match", version, mk.clone()).await; + let parsed = parse_edk_section(&ct, version); + + for (i, edk) in parsed.edks.iter().enumerate() { + //= spec/data-format/message-header.md#key-provider-id + //= type=test + //# The length of the serialized key provider ID MUST be equal to the value of the [Key Provider ID Length](#key-provider-id-length) field. + assert_eq!( + edk.provider_id.len(), edk.provider_id_len as usize, + "{version:?}: EDK {i}: provider ID byte length must equal the provider ID length field" + ); + + //= spec/data-format/message-header.md#key-provider-information + //= type=test + //# The length of the serialized key provider information MUST be equal to the value of the [Key Provider Information Length](#key-provider-information-length) field. + assert_eq!( + edk.provider_info.len(), edk.provider_info_len as usize, + "{version:?}: EDK {i}: provider info byte length must equal the provider info length field" + ); + + //= spec/data-format/message-header.md#encrypted-data-key + //= type=test + //# The length of the serialized encrypted data key MUST be equal to the value of the [Encrypted Data Key Length](#encrypted-data-key-length) field. + assert_eq!( + edk.edk.len(), edk.edk_len as usize, + "{version:?}: EDK {i}: encrypted data key byte length must equal the EDK length field" + ); + } + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_key_provider_id_is_utf8() { + let generator = aes_keyring(0).await; + let child = aes_keyring(1).await; + let mk = multi_keyring(generator, vec![child]).await; + + for version in VERSIONS { + let ct = encrypt_with_version(b"pid utf8", version, mk.clone()).await; + let parsed = parse_edk_section(&ct, version); + + //= spec/data-format/message-header.md#key-provider-id + //= type=test + //# The key provider ID MUST be interpreted as UTF-8 encoded bytes. + for (i, edk) in parsed.edks.iter().enumerate() { + let pid_str = + std::str::from_utf8(&edk.provider_id).expect("provider ID must be valid UTF-8"); + let (expected_ns, _) = namespace_and_name(i as u8); + assert_eq!( + pid_str, expected_ns, + "{version:?}: provider ID must be the keyring namespace as UTF-8" + ); + } + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_key_provider_info_interpreted_as_bytes() { + let keyring = aes_keyring(0).await; + + for version in VERSIONS { + let ct = encrypt_with_version(b"pinfo bytes", version, keyring.clone()).await; + let parsed = parse_edk_section(&ct, version); + let edk = &parsed.edks[0]; + // Provider info for raw AES keyring starts with the key name. + let (_, expected_name) = namespace_and_name(0); + + //= spec/data-format/message-header.md#key-provider-information + //= type=test + //# The key provider information MUST be interpreted as bytes. + assert!( + edk.provider_info.starts_with(expected_name.as_bytes()), + "{version:?}: provider info must start with the known key name" + ); + } +} diff --git a/esdk/tests/test_encryption_context_aad.rs b/esdk/tests/test_encryption_context_aad.rs new file mode 100644 index 000000000..bede4d50d --- /dev/null +++ b/esdk/tests/test_encryption_context_aad.rs @@ -0,0 +1,234 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Tests for spec/data-format/message-header.md#aad, +//! #key-value-pairs-length, and #key-value-pairs + +mod fixtures; +mod test_helpers; + +use aws_esdk::*; +use fixtures::*; +use test_helpers::*; + +/// V1 header AAD offset: Version(1) + Type(1) + AlgSuiteID(2) + MessageID(16) = 20. +const V1_AAD_OFFSET: usize = 20; +/// V2 header AAD offset: Version(1) + AlgSuiteID(2) + MessageID(32) = 35. +const V2_AAD_OFFSET: usize = 35; + +fn aad_offset(version: Version) -> usize { + match version { + Version::V1 => V1_AAD_OFFSET, + Version::V2 => V2_AAD_OFFSET, + } +} + +/// Assert that every (key, value) pair in `expected` is present in `actual`. +/// Used to verify the encryption context survives the round trip intact, +/// while ignoring any keys the SDK may add (e.g. `aws-crypto-public-key`) +fn assert_ec_contains(actual: &EncryptionContext, expected: &EncryptionContext, version: Version) { + for (k, v) in expected { + assert_eq!( + actual.get(k), + Some(v), + "{version:?}: decrypted EC missing or mismatched for key {k:?}" + ); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_aad_serialization_order() { + for version in VERSIONS { + let ec = small_encryption_context(SmallEncryptionContextVariation::AB); + let pt = b"aad serialization order"; + let ct = encrypt_no_sign_with_ec(pt, ec.clone(), version).await; + let off = aad_offset(version); + + //= spec/data-format/message-header.md#aad + //= type=test + //# The AAD MUST consist of, in order, + //# Key Value Pairs Length, + //# and Key Value Pairs. + + // 1. Key Value Pairs Length (2 bytes at the AAD offset) + let kvp_len = u16::from_be_bytes([ct[off], ct[off + 1]]) as usize; + assert!(kvp_len > 0, "{version:?}: non-empty EC must have non-zero KVP length"); + + // 2. Key Value Pairs (immediately follow the length field) + let kvp_count_offset = off + 2; + let kvp_count = + u16::from_be_bytes([ct[kvp_count_offset], ct[kvp_count_offset + 1]]) as usize; + assert_eq!(kvp_count, 2, "{version:?}: AB encryption context has 2 key-value pairs"); + + // Cross-check: the decrypt path recovers the same encryption context, which is + // only possible if the on-wire ordering agreed with the spec on both sides. + let dec = decrypt_with_version(&ct, version).await; + assert_ec_contains(&dec.encryption_context, &ec, version); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_aad_key_value_pairs_length_field_size() { + for version in VERSIONS { + let ec = small_encryption_context(SmallEncryptionContextVariation::A); + let pt = b"kvp length field size"; + let ct = encrypt_no_sign_with_ec(pt, ec.clone(), version).await; + let off = aad_offset(version); + + //= spec/data-format/message-header.md#key-value-pairs-length + //= type=test + //# The length of the serialized key value pairs length field MUST be 2 bytes. + + // The KVP length field occupies exactly 2 bytes at [off..off+2]. + let kvp_len = u16::from_be_bytes([ct[off], ct[off + 1]]) as usize; + // For "A" (keyA=valA) the Key Value Pairs field is: + // count(2) + key_len(2) + key(4) + val_len(2) + val(4) = 14 bytes. + // The length field covers the entire Key Value Pairs structure, including + // the Key Value Pair Count. + assert_eq!(kvp_len, 14, "{version:?}: KVP length for single pair keyA=valA must be 14"); + + // Cross-check: the decrypted EC matches what we encrypted, confirming the 2-byte + // length field was parsed correctly on the decrypt side too. + let dec = decrypt_with_version(&ct, version).await; + assert_ec_contains(&dec.encryption_context, &ec, version); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_aad_key_value_pairs_length_uint16() { + for version in VERSIONS { + let ec = small_encryption_context(SmallEncryptionContextVariation::A); + let pt = b"kvp length uint16"; + let ct = encrypt_no_sign_with_ec(pt, ec.clone(), version).await; + let off = aad_offset(version); + + //= spec/data-format/message-header.md#key-value-pairs-length + //= type=test + //# The key value pairs length MUST be interpreted as a UInt16. + + // Read the 2 bytes as big-endian u16 and verify the value. + let kvp_len = u16::from_be_bytes([ct[off], ct[off + 1]]); + // keyA=valA: count(2) + key_len(2) + key(4) + val_len(2) + val(4) = 14. + assert_eq!(kvp_len, 14, "{version:?}: UInt16 KVP length for keyA=valA must be 14"); + + // Cross-check: the decrypted EC round-trips, confirming both sides agree that + // the field is a big-endian UInt16. + let dec = decrypt_with_version(&ct, version).await; + assert_ec_contains(&dec.encryption_context, &ec, version); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_aad_empty_encryption_context_length_zero() { + for version in VERSIONS { + let ec = small_encryption_context(SmallEncryptionContextVariation::Empty); + let pt = b"empty ec length zero"; + let ct = encrypt_no_sign_with_ec(pt, ec.clone(), version).await; + let off = aad_offset(version); + + //= spec/data-format/message-header.md#key-value-pairs-length + //= type=test + //# When the [encryption context](../framework/structures.md#encryption-context) is empty, the value of this field MUST be 0. + + // The 2 bytes at the AAD offset must be [0x00, 0x00]. + assert_eq!(ct[off], 0x00, "{version:?}: empty EC KVP length high byte must be 0"); + assert_eq!( + ct[off + 1], + 0x00, + "{version:?}: empty EC KVP length low byte must be 0" + ); + + // Cross-check: decrypt returns an empty encryption context (non-signing suite + // means the SDK added no entries of its own). + let dec = decrypt_with_version(&ct, version).await; + assert!( + dec.encryption_context.is_empty(), + "{version:?}: decrypted EC must be empty, got {:?}", + dec.encryption_context + ); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_aad_key_value_pairs_serialization() { + for version in VERSIONS { + let ec = small_encryption_context(SmallEncryptionContextVariation::AB); + let pt = b"kvp serialization"; + let ct = encrypt_no_sign_with_ec(pt, ec.clone(), version).await; + let off = aad_offset(version); + + //= spec/data-format/message-header.md#key-value-pairs + //= type=test + //# The encryption context key-value pairs MUST be serialized according to its [specification for serialization](../framework/structures.md#serialization). + + // Parse the KVP section: after 2-byte length, 2-byte count, then pairs. + let kvp_len = u16::from_be_bytes([ct[off], ct[off + 1]]) as usize; + assert!(kvp_len > 0, "{version:?}: non-empty KVP length"); + let mut pos = off + 2; + let count = u16::from_be_bytes([ct[pos], ct[pos + 1]]) as usize; + assert_eq!(count, 2, "{version:?}: AB has 2 pairs"); + pos += 2; + + // Pairs must be sorted by key: keyA < keyB. + let key1_len = u16::from_be_bytes([ct[pos], ct[pos + 1]]) as usize; + pos += 2; + let key1 = std::str::from_utf8(&ct[pos..pos + key1_len]).unwrap(); + pos += key1_len; + let val1_len = u16::from_be_bytes([ct[pos], ct[pos + 1]]) as usize; + pos += 2; + let val1 = std::str::from_utf8(&ct[pos..pos + val1_len]).unwrap(); + pos += val1_len; + + let key2_len = u16::from_be_bytes([ct[pos], ct[pos + 1]]) as usize; + pos += 2; + let key2 = std::str::from_utf8(&ct[pos..pos + key2_len]).unwrap(); + pos += key2_len; + let val2_len = u16::from_be_bytes([ct[pos], ct[pos + 1]]) as usize; + pos += 2; + let val2 = std::str::from_utf8(&ct[pos..pos + val2_len]).unwrap(); + + assert_eq!(key1, "keyA", "{version:?}: first key in sorted order"); + assert_eq!(val1, "valA", "{version:?}: first value"); + assert_eq!(key2, "keyB", "{version:?}: second key in sorted order"); + assert_eq!(val2, "valB", "{version:?}: second value"); + + // Cross-check: decrypted EC contains the same key/value pairs we encrypted. + let dec = decrypt_with_version(&ct, version).await; + assert_ec_contains(&dec.encryption_context, &ec, version); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_aad_empty_encryption_context_no_kvp_field() { + for version in VERSIONS { + let ec = small_encryption_context(SmallEncryptionContextVariation::Empty); + let pt = b"empty ec no kvp"; + let ct = encrypt_no_sign_with_ec(pt, ec.clone(), version).await; + let off = aad_offset(version); + + //= spec/data-format/message-header.md#key-value-pairs + //= type=test + //# When the [encryption context](../framework/structures.md#encryption-context) is empty, + //# this field MUST NOT be included in the [AAD](#aad). + + // KVP Length is 0, and the next field (EDK count) starts immediately after. + let kvp_len = u16::from_be_bytes([ct[off], ct[off + 1]]); + assert_eq!(kvp_len, 0, "{version:?}: empty EC must have KVP length 0"); + // The bytes right after the 2-byte KVP Length field are the EDK count (not KVP data). + let edk_count_offset = off + 2; + let edk_count = + u16::from_be_bytes([ct[edk_count_offset], ct[edk_count_offset + 1]]); + assert!( + edk_count >= 1, + "{version:?}: EDK count must be at least 1, proving no KVP field between AAD length and EDKs" + ); + + // Cross-check: decrypt recovers an empty encryption context. + let dec = decrypt_with_version(&ct, version).await; + assert!( + dec.encryption_context.is_empty(), + "{version:?}: decrypted EC must be empty, got {:?}", + dec.encryption_context + ); + } +}