From 0d3326ff89ef836de942672f395a53a0a27301e0 Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Thu, 7 May 2026 10:22:26 -0700 Subject: [PATCH] test partially replayed aggregation job init The other day we did some analysis and concluded that it should be OK for aggregation job init requests to contain a mix of replayed and new reports. This commit adds a test to make sure of that. --- .../tests/aggregation_job_init.rs | 156 +++++++++++++++++- core/src/report_id.rs | 12 +- 2 files changed, 161 insertions(+), 7 deletions(-) diff --git a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs index 9aac41299..f13979543 100644 --- a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs +++ b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs @@ -23,10 +23,10 @@ use janus_core::{ vdaf::VdafInstance, }; use janus_messages::{ - AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, Extension, ExtensionType, - HpkeCiphertext, HpkeConfigId, InputShareAad, Interval, MediaType, PartialBatchSelector, - PrepareInit, PrepareStepResult, ReportError, ReportIdChecksum, ReportMetadata, ReportShare, - Role, Time, + AggregateShareReq, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, BatchId, + BatchSelector, Extension, ExtensionType, HpkeCiphertext, HpkeConfigId, InputShareAad, Interval, + MediaType, PartialBatchSelector, PrepareInit, PrepareStepResult, ReportError, ReportIdChecksum, + ReportMetadata, ReportShare, Role, Time, batch_mode::{LeaderSelected, TimeInterval}, }; use prio::{codec::Encode, vdaf::dummy}; @@ -37,8 +37,11 @@ use tower::ServiceExt; use crate::aggregator::{ BatchAggregationsIterator, aggregation_job_init::test_util::{PrepareInitGenerator, put_aggregation_job}, - http_handlers::test_util::{ - HttpHandlerTest, decode_response_body, take_problem_details, take_response_body, + http_handlers::{ + test_util::{ + HttpHandlerTest, decode_response_body, take_problem_details, take_response_body, + }, + tests::aggregate_share::put_aggregate_share_request, }, test_util::{ BATCH_AGGREGATION_SHARD_COUNT, assert_task_aggregation_counter, @@ -1242,3 +1245,144 @@ async fn aggregate_init_duplicated_report_id() { assert_task_aggregation_counter(&datastore, *task.id(), TaskAggregationCounter::default()) .await; } + +#[tokio::test] +async fn aggregate_init_partially_replayed_aggregation_init() { + // Create 5 reports, 1-5. Send one aggregation job init request containing reports 1 and 2. It + // should succeed normally. Then send another init request containing reports 1-5. We expect: + // - the request overall succeeds (i.e. HTTP 200) + // - the PrepareResps for reports 1 and 2 indicate rejection + // - the PrepareResps for reports 3-5 indicate success + // We then send an aggregate share request for the batch ID. It should succeed and all five + // reports should be included. + let HttpHandlerTest { + clock, + ephemeral_datastore: _ephemeral_datastore, + datastore, + router, + hpke_keypair, + .. + } = HttpHandlerTest::new().await; + + let task = TaskBuilder::new( + BatchMode::LeaderSelected { + batch_time_window_size: None, + }, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_min_batch_size(1) + .build(); + + let batch_id = BatchId::from([12; 32]); + let agg_param = dummy::AggregationParam(0).get_encoded().unwrap(); + let partial_batch_selector = PartialBatchSelector::new_leader_selected(batch_id); + + let helper_task = task.helper_view().unwrap(); + let prep_init_generator = PrepareInitGenerator::new( + clock.clone(), + helper_task.clone(), + hpke_keypair.config().clone(), + dummy::Vdaf::new(1), + dummy::AggregationParam(0), + ); + + datastore.put_aggregator_task(&helper_task).await.unwrap(); + + let (prepare_init_1, _) = prep_init_generator.next(&1); + let (prepare_init_2, _) = prep_init_generator.next(&2); + let (prepare_init_3, _) = prep_init_generator.next(&3); + let (prepare_init_4, _) = prep_init_generator.next(&4); + let (prepare_init_5, _) = prep_init_generator.next(&5); + let report_ids: Vec<_> = [ + &prepare_init_1, + &prepare_init_2, + &prepare_init_3, + &prepare_init_4, + &prepare_init_5, + ] + .iter() + .map(|pi| *pi.report_share().metadata().id()) + .collect(); + + let request = AggregationJobInitializeReq::new( + agg_param.clone(), + partial_batch_selector.clone(), + Vec::from([prepare_init_1.clone(), prepare_init_2.clone()]), + ); + + let mut response = put_aggregation_job(&task, &random(), &request, &router).await; + assert_eq!(response.status(), StatusCode::CREATED); + let aggregate_resp: AggregationJobResp = decode_response_body(&mut response).await; + + // Response contains all the reports from the request + assert_eq!( + &report_ids[0..2], + request + .prepare_inits() + .iter() + .map(|init| *init.report_share().metadata().id()) + .collect::>() + .as_slice(), + ); + for resp in &aggregate_resp.prepare_resps { + assert_matches!(resp.result(), &PrepareStepResult::Continue { .. }); + } + + let request = AggregationJobInitializeReq::new( + agg_param.clone(), + partial_batch_selector, + Vec::from([ + prepare_init_1.clone(), + prepare_init_2.clone(), + prepare_init_3.clone(), + prepare_init_4.clone(), + prepare_init_5.clone(), + ]), + ); + + let mut response = put_aggregation_job(&task, &random(), &request, &router).await; + assert_eq!(response.status(), StatusCode::CREATED); + let aggregate_resp: AggregationJobResp = decode_response_body(&mut response).await; + + // Response contains all the reports from the request + assert_eq!( + report_ids, + request + .prepare_inits() + .iter() + .map(|init| *init.report_share().metadata().id()) + .collect::>(), + ); + for resp in &aggregate_resp.prepare_resps { + if report_ids[0..2].contains(resp.report_id()) { + assert_matches!( + resp.result(), + &PrepareStepResult::Reject(ReportError::ReportReplayed), + "first two reports must be rejected as replays", + ) + } + if report_ids[2..5].contains(resp.report_id()) { + assert_matches!( + resp.result(), + &PrepareStepResult::Continue { .. }, + "last three reports must be accepted", + ); + } + } + + let request = AggregateShareReq::new( + BatchSelector::new_leader_selected(batch_id), + agg_param.clone(), + 5, + ReportIdChecksum::from_report_ids(&report_ids), + ); + // If the request succeeds, then the checksum was valid and helper agrees that all 5 reports are + // included. + assert_eq!( + put_aggregate_share_request(&task, &request, &random(), &router) + .await + .status(), + StatusCode::OK, + ); +} diff --git a/core/src/report_id.rs b/core/src/report_id.rs index 99b848c35..2288b8619 100644 --- a/core/src/report_id.rs +++ b/core/src/report_id.rs @@ -4,10 +4,20 @@ use aws_lc_rs::digest::{SHA256, SHA256_OUTPUT_LEN, digest}; use janus_messages::{ReportId, ReportIdChecksum}; /// Additional methods for working with a [`ReportIdChecksum`]. -pub trait ReportIdChecksumExt { +pub trait ReportIdChecksumExt: Sized { /// Initialize a checksum from a single report ID. fn for_report_id(report_id: &ReportId) -> Self; + /// Initialize a checksum from multiple report IDs. + fn from_report_ids(report_ids: &[ReportId]) -> Self { + let mut ret = Self::for_report_id(&report_ids[0]); + for report_id in &report_ids[1..] { + ret = ret.updated_with(report_id); + } + + ret + } + /// Incorporate the provided report ID into this checksum. fn updated_with(self, report_id: &ReportId) -> Self;