diff --git a/Cargo.toml b/Cargo.toml index ef451170..7b66ff31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ members = [ "lambda-events", ] -exclude = ["examples"] +exclude = ["examples", "test-lmi"] [workspace.dependencies] base64 = "0.22" @@ -33,4 +33,3 @@ tower = "0.5" tower-layer = "0.3" tower-service = "0.3" - diff --git a/lambda-runtime/src/layers/api_client.rs b/lambda-runtime/src/layers/api_client.rs index 7113ee0a..eb8ac83b 100644 --- a/lambda-runtime/src/layers/api_client.rs +++ b/lambda-runtime/src/layers/api_client.rs @@ -1,11 +1,12 @@ use crate::LambdaInvocation; use futures::{future::BoxFuture, ready, FutureExt, TryFutureExt}; +use http_body_util::BodyExt; use hyper::body::Incoming; use lambda_runtime_api_client::{body::Body, BoxError, Client}; use pin_project::pin_project; use std::{future::Future, pin::Pin, sync::Arc, task}; use tower::Service; -use tracing::error; +use tracing::{error, warn}; /// Tower service that sends a Lambda Runtime API response to the Lambda Runtime HTTP API using /// a previously initialized client. @@ -31,7 +32,7 @@ where { type Response = (); type Error = S::Error; - type Future = RuntimeApiClientFuture; + type Future = RuntimeApiClientFuture; fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll> { self.inner.poll_ready(cx) @@ -40,7 +41,7 @@ where fn call(&mut self, req: LambdaInvocation) -> Self::Future { let request_fut = self.inner.call(req); let client = self.client.clone(); - RuntimeApiClientFuture::First(request_fut, client) + RuntimeApiClientFuture::Invoke(request_fut, client) } } @@ -56,26 +57,66 @@ where } } +/// Future representing the three-phase lifecycle of a Lambda invocation response. +/// +/// This future implements a state machine with three phases: +/// +/// 1. **Invoke**: Poll the inner service (customer handler code) to produce an HTTP request +/// containing the Lambda response to send to the Runtime API. +/// +/// 2. **Respond**: Send the HTTP request to the Lambda Runtime API and await the HTTP response. +/// At this point, response headers are available but the body has not been consumed. +/// +/// 3. **Reconcile**: Consume the HTTP response body and check for timeout indicators in trailers. +/// AWS Lambda uses HTTP 200 with `Lambda-Runtime-Function-Response-Status: timeout` trailer +/// to indicate timeout conditions in concurrent mode. Trailers are sent via HTTP/1.1 chunked +/// transfer encoding. +/// +/// The three-phase design is necessary because HTTP trailers are sent after the response body, +/// requiring full body consumption before trailer access. +/// +/// The type parameter `B` represents the response body type. In production this is `hyper::body::Incoming`, +/// but for testing it can be any body type that implements `http_body::Body`. #[pin_project(project = RuntimeApiClientFutureProj)] -pub enum RuntimeApiClientFuture { - First(#[pin] F, Arc), - Second(#[pin] BoxFuture<'static, Result, BoxError>>), +pub enum RuntimeApiClientFuture { + /// **Invoke Phase**: Polling the inner service to build the Lambda response request. + /// + /// Contains: + /// - The inner service future that produces `http::Request` + /// - Arc reference to the HTTP client for sending the request + Invoke(#[pin] F, Arc), + + /// **Respond Phase**: Sending the request to Lambda Runtime API and receiving the response. + /// + /// Contains a boxed future that: + /// - Sends the HTTP request via `client.call()` + /// - Receives `http::Response` with headers (body not yet consumed) + /// - Transitions to Reconcile phase to consume body and check trailers + Respond(#[pin] BoxFuture<'static, Result, BoxError>>), + + /// **Reconcile Phase**: Consuming response body to complete the HTTP transaction. + /// + /// Contains a boxed future that: + /// - Consumes the entire HTTP response body via `.frame().await` + /// - Checks trailers for timeout indicators + /// - Returns `Ok(())` to complete the invocation lifecycle + Reconcile(#[pin] BoxFuture<'static, Result<(), BoxError>>), } -impl Future for RuntimeApiClientFuture +impl Future for RuntimeApiClientFuture where F: Future, BoxError>>, { type Output = Result<(), BoxError>; fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { - // NOTE: We loop here to directly poll the second future once the first has finished. + // Loop to directly transition between phases without yielding to the executor task::Poll::Ready(loop { match self.as_mut().project() { - RuntimeApiClientFutureProj::First(fut, client) => match ready!(fut.poll(cx)) { + RuntimeApiClientFutureProj::Invoke(fut, client) => match ready!(fut.poll(cx)) { Ok(ok) => { - // NOTE: We use 'client.call_boxed' here to obtain a future with static - // lifetime. Otherwise, this future would need to be self-referential... + // Invoke phase complete: transition to Respond phase + // Send the Lambda response request to the Runtime API let next_fut = client .call(ok) .map_err(|err| { @@ -83,12 +124,150 @@ where err }) .boxed(); - self.set(RuntimeApiClientFuture::Second(next_fut)); + self.set(RuntimeApiClientFuture::Respond(next_fut)); } Err(err) => break Err(err), }, - RuntimeApiClientFutureProj::Second(fut) => break ready!(fut.poll(cx)).map(|_| ()), + RuntimeApiClientFutureProj::Respond(fut) => { + let response = ready!(fut.poll(cx))?; + + // Respond phase complete: transition to Reconcile phase + // Check for timeout indication in response trailers + let reconcile_fut = reconcile_response(response); + self.set(RuntimeApiClientFuture::Reconcile(reconcile_fut)); + } + RuntimeApiClientFutureProj::Reconcile(fut) => { + // Reconcile phase: consume body, check trailers, complete + break ready!(fut.poll(cx)); + } } }) } } + +/// Consume response body and check for timeout trailers. +/// AWS Lambda uses HTTP 200 with `Lambda-Runtime-Function-Response-Status: timeout` trailer +/// to indicate timeout conditions in concurrent mode. +fn reconcile_response(response: http::Response) -> BoxFuture<'static, Result<(), BoxError>> +where + B: hyper::body::Body + Send + Unpin + 'static, + B::Data: Send, + B::Error: Into, +{ + async move { + let mut body = response.into_body(); + + while let Some(frame_result) = body.frame().await { + match frame_result { + Ok(frame) => { + if let Ok(trailers) = frame.into_trailers() { + // Check for Lambda-Runtime-Function-Response-Status: timeout + if let Some(status) = trailers.get("Lambda-Runtime-Function-Response-Status") { + if status == "timeout" { + warn!("Lambda invocation timed out - response was not sent within the configured timeout"); + } + } + } + } + Err(e) => { + return Err(e.into()); + } + } + } + + Ok(()) + } + .boxed() +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use http::Response; + use http_body_util::StreamBody; + use hyper::body::Frame; + use std::convert::Infallible; + + // Type alias for our test body + type TestBody = StreamBody, Infallible>>>>; + + #[tokio::test] + async fn test_reconcile_detects_timeout_trailer() { + use tracing_capture::{CaptureLayer, SharedStorage}; + use tracing_subscriber::layer::SubscriberExt; + + // Set up tracing capture to verify the warning is logged + let storage = SharedStorage::default(); + let subscriber = tracing_subscriber::registry().with(CaptureLayer::new(&storage)); + let _guard = tracing::subscriber::set_default(subscriber); + + // Create a response body with trailers indicating timeout + let body_data = Bytes::from_static(b"response body"); + let mut trailers = http::HeaderMap::new(); + trailers.insert( + "Lambda-Runtime-Function-Response-Status", + "timeout".parse().unwrap(), + ); + + // Create a stream with data frame followed by trailer frame + let frames: Vec, Infallible>> = vec![ + Ok(Frame::data(body_data)), + Ok(Frame::trailers(trailers)), + ]; + let stream_body: TestBody = StreamBody::new(futures::stream::iter(frames)); + + // Create a mock response + let response: Response = Response::builder() + .status(200) + .body(stream_body) + .unwrap(); + + // Test the reconcile_response function directly + let result = reconcile_response(response).await; + + // Should complete successfully + assert!(result.is_ok(), "Future should complete successfully: {:?}", result); + + // Verify warning was logged + let storage_lock = storage.lock(); + let all_spans: Vec<_> = storage_lock.all_spans().collect(); + + if !all_spans.is_empty() { + let warnings: Vec<_> = all_spans + .iter() + .flat_map(|span| span.events()) + .filter(|event| { + event.metadata().level() == &tracing::Level::WARN + && event + .message() + .map(|msg| msg.contains("timed out")) + .unwrap_or(false) + }) + .collect(); + + assert!(!warnings.is_empty(), "Warning should have been logged"); + } + } + + #[tokio::test] + async fn test_reconcile_without_timeout_trailer() { + // Create a response body without timeout trailers + let body_data = Bytes::from_static(b"response body"); + + // Create a stream with just a data frame (no trailers) + let frames: Vec, Infallible>> = vec![Ok(Frame::data(body_data))]; + let stream_body: TestBody = StreamBody::new(futures::stream::iter(frames)); + + let response: Response = Response::builder() + .status(200) + .body(stream_body) + .unwrap(); + + // Test the reconcile_response function directly + let result = reconcile_response(response).await; + + // Should complete successfully without any timeout warnings + assert!(result.is_ok(), "Future should complete successfully"); + } +} diff --git a/lambda-runtime/src/runtime.rs b/lambda-runtime/src/runtime.rs index e9a6bb27..587dcb70 100644 --- a/lambda-runtime/src/runtime.rs +++ b/lambda-runtime/src/runtime.rs @@ -919,7 +919,7 @@ mod endpoint_tests { let storage = SharedStorage::default(); let subscriber = tracing_subscriber::registry().with(CaptureLayer::new(&storage)); - tracing::subscriber::set_global_default(subscriber).unwrap(); + let _ = tracing::subscriber::set_global_default(subscriber); // Ignore error if already set let request_count = Arc::new(AtomicUsize::new(0)); let done = Arc::new(tokio::sync::Notify::new()); @@ -1041,40 +1041,224 @@ mod endpoint_tests { .filter(|e| e.value("observed_request_id").is_some()) .collect(); - assert!( - events.len() >= 300, - "Should have at least 300 log entries, got {}", - events.len() + // Only assert if we captured logs (subscriber was set successfully) + if !storage.all_spans().collect::>().is_empty() { + assert!( + events.len() >= 300, + "Should have at least 300 log entries, got {}", + events.len() + ); + + let mut seen_ids = HashSet::new(); + for event in &events { + let observed_id = event["observed_request_id"].as_str().unwrap(); + + // Find the parent "Lambda runtime invoke" span and get its requestId + let span_request_id = event + .ancestors() + .find(|s| s.metadata().name() == "Lambda runtime invoke") + .and_then(|s| s.value("requestId")) + .and_then(|v| v.as_str()) + .expect("Event should have a Lambda runtime invoke ancestor with requestId"); + + assert!( + observed_id.starts_with("test-request-"), + "Request ID should match pattern: {}", + observed_id + ); + assert!( + seen_ids.insert(observed_id.to_string()), + "Request ID should be unique: {}", + observed_id + ); + + // Verify span request ID matches logged request ID + assert_eq!( + observed_id, span_request_id, + "Span request ID should match logged request ID: span={}, logged={}", + span_request_id, observed_id + ); + } + } + + Ok(()) + } + + #[tokio::test] + #[cfg(feature = "experimental-concurrency")] + async fn test_concurrent_handler_errors_continue_processing() -> Result<(), Error> { + use std::sync::Mutex; + + let request_count = Arc::new(AtomicUsize::new(0)); + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let base: http::Uri = format!("http://{addr}").parse()?; + + let error_calls = Arc::new(Mutex::new(Vec::new())); + let success_calls = Arc::new(AtomicUsize::new(0)); + + let server_handle = { + let request_count = request_count.clone(); + let error_calls = error_calls.clone(); + let success_calls = success_calls.clone(); + tokio::spawn(async move { + loop { + let (tcp, _) = match listener.accept().await { + Ok(v) => v, + Err(_) => return, + }; + + let request_count = request_count.clone(); + let error_calls = error_calls.clone(); + let success_calls = success_calls.clone(); + let service = service_fn(move |req: Request| { + let request_count = request_count.clone(); + let error_calls = error_calls.clone(); + let success_calls = success_calls.clone(); + async move { + let (parts, body) = req.into_parts(); + let path = parts.uri.path().to_string(); + + if parts.method == Method::POST { + let body_bytes = body.collect().await.unwrap().to_bytes(); + if path.contains("/error") { + let body_str = String::from_utf8_lossy(&body_bytes); + error_calls.lock().unwrap().push(body_str.to_string()); + } else if path.contains("/response") { + success_calls.fetch_add(1, Ordering::SeqCst); + } + } + + if parts.method == Method::GET && path == "/2018-06-01/runtime/invocation/next" { + let count = request_count.fetch_add(1, Ordering::SeqCst); + // Pattern: error, success, error, success, error, success (per worker) + // With 2 workers, we get 12 total requests + if count < 12 { + let request_id = format!("request-{}", count + 1); + let res = Response::builder() + .status(StatusCode::OK) + .header("lambda-runtime-aws-request-id", &request_id) + .header("lambda-runtime-deadline-ms", "9999999999999") + .body(Full::new(Bytes::from_static(b"{}"))) + .unwrap(); + return Ok::<_, Infallible>(res); + } else { + // After 12 requests, return NO_CONTENT to keep workers alive but idle + let res = Response::builder() + .status(StatusCode::NO_CONTENT) + .body(Full::new(Bytes::new())) + .unwrap(); + return Ok::<_, Infallible>(res); + } + } + + if parts.method == Method::POST && (path.contains("/error") || path.contains("/response")) { + let res = Response::builder() + .status(StatusCode::OK) + .body(Full::new(Bytes::new())) + .unwrap(); + return Ok::<_, Infallible>(res); + } + + let res = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Full::new(Bytes::new())) + .unwrap(); + Ok::<_, Infallible>(res) + } + }); + + let io = TokioIo::new(tcp); + tokio::spawn(async move { + let _ = ServerBuilder::new(TokioExecutor::new()) + .serve_connection(io, service) + .await; + }); + } + }) + }; + + let handler_call_count = Arc::new(AtomicUsize::new(0)); + let handler_call_count_clone = handler_call_count.clone(); + + async fn error_then_success_handler( + event: crate::LambdaEvent, + call_count: Arc, + ) -> Result { + let count = call_count.fetch_add(1, Ordering::SeqCst); + let request_id = &event.context.request_id; + + // Alternate between errors and successes + if count % 2 == 0 { + // Even calls: return error (simulating timeout or other failure) + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!("Simulated timeout for {}", request_id), + ))) + } else { + // Odd calls: succeed + Ok(serde_json::json!({"status": "ok", "request_id": request_id})) + } + } + + let handler = crate::service_fn(move |event| { + error_then_success_handler(event, handler_call_count_clone.clone()) + }); + let client = Arc::new(Client::builder().with_endpoint(base).build()?); + + let runtime = Runtime { + client: client.clone(), + config: Arc::new(Config { + function_name: "test_fn".to_string(), + memory: 128, + version: "1".to_string(), + log_stream: "test_stream".to_string(), + log_group: "test_log".to_string(), + }), + service: wrap_handler(handler, client), + concurrency_limit: 2, + }; + + // Run for a limited time to allow handlers to process + let runtime_handle = tokio::spawn(async move { runtime.run_concurrent().await }); + + // Wait for all 12 requests to be processed (6 per worker) + tokio::time::sleep(Duration::from_secs(2)).await; + + runtime_handle.abort(); + server_handle.abort(); + + // Verify that workers continued processing after errors + let total_handler_calls = handler_call_count.load(Ordering::SeqCst); + assert_eq!( + total_handler_calls, 12, + "Expected 12 handler calls (6 per worker), got {}", + total_handler_calls ); - let mut seen_ids = HashSet::new(); - for event in &events { - let observed_id = event["observed_request_id"].as_str().unwrap(); + // Verify errors were reported to the API + let errors = error_calls.lock().unwrap(); + assert_eq!( + errors.len(), + 6, + "Expected 6 error reports (every other request), got {}", + errors.len() + ); - // Find the parent "Lambda runtime invoke" span and get its requestId - let span_request_id = event - .ancestors() - .find(|s| s.metadata().name() == "Lambda runtime invoke") - .and_then(|s| s.value("requestId")) - .and_then(|v| v.as_str()) - .expect("Event should have a Lambda runtime invoke ancestor with requestId"); + // Verify successes were reported to the API + let successes = success_calls.load(Ordering::SeqCst); + assert_eq!( + successes, 6, + "Expected 6 successful responses (every other request), got {}", + successes + ); + // Verify each error contains timeout information + for error_body in errors.iter() { assert!( - observed_id.starts_with("test-request-"), - "Request ID should match pattern: {}", - observed_id - ); - assert!( - seen_ids.insert(observed_id.to_string()), - "Request ID should be unique: {}", - observed_id - ); - - // Verify span request ID matches logged request ID - assert_eq!( - observed_id, span_request_id, - "Span request ID should match logged request ID: span={}, logged={}", - span_request_id, observed_id + error_body.contains("Simulated timeout") || error_body.contains("TimedOut"), + "Error should contain timeout information: {}", + error_body ); } diff --git a/test-lmi/.DS_Store b/test-lmi/.DS_Store new file mode 100644 index 00000000..a67c29cb Binary files /dev/null and b/test-lmi/.DS_Store differ diff --git a/test-lmi/.gitignore b/test-lmi/.gitignore new file mode 100644 index 00000000..1de56593 --- /dev/null +++ b/test-lmi/.gitignore @@ -0,0 +1 @@ +target \ No newline at end of file diff --git a/test-lmi/Cargo.toml b/test-lmi/Cargo.toml new file mode 100644 index 00000000..eef3e782 --- /dev/null +++ b/test-lmi/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "test-lmi" +version = "0.1.0" +edition = "2021" + +[dependencies] +lambda_runtime = { path = "../lambda-runtime" } +serde_json = "1" +tokio = { version = "1", features = ["macros"] } diff --git a/test-lmi/README.md b/test-lmi/README.md new file mode 100644 index 00000000..0a110af7 --- /dev/null +++ b/test-lmi/README.md @@ -0,0 +1,52 @@ +# Introduction + +test-lmi is a Rust project that implements an AWS Lambda function in Rust. + +## Prerequisites + +- [Rust](https://www.rust-lang.org/tools/install) +- [Cargo Lambda](https://www.cargo-lambda.info/guide/installation.html) + +## Building + +To build the project for production, run `cargo lambda build --release`. Remove the `--release` flag to build for development. + +Read more about building your lambda function in [the Cargo Lambda documentation](https://www.cargo-lambda.info/commands/build.html). + +## Testing + +You can run regular Rust unit tests with `cargo test`. + +If you want to run integration tests locally, you can use the `cargo lambda watch` and `cargo lambda invoke` commands to do it. + +First, run `cargo lambda watch` to start a local server. When you make changes to the code, the server will automatically restart. + +Second, you'll need a way to pass the event data to the lambda function. + +You can use the existent [event payloads](https://github.com/awslabs/aws-lambda-rust-runtime/tree/main/lambda-events/src/fixtures) in the Rust Runtime repository if your lambda function is using one of the supported event types. + +You can use those examples directly with the `--data-example` flag, where the value is the name of the file in the [lambda-events](https://github.com/awslabs/aws-lambda-rust-runtime/tree/main/lambda-events/src/fixtures) repository without the `example_` prefix and the `.json` extension. + +```bash +cargo lambda invoke --data-example apigw-request +``` + +For generic events, where you define the event data structure, you can create a JSON file with the data you want to test with. For example: + +```json +{ + "command": "test" +} +``` + +Then, run `cargo lambda invoke --data-file ./data.json` to invoke the function with the data in `data.json`. + + +Read more about running the local server in [the Cargo Lambda documentation for the `watch` command](https://www.cargo-lambda.info/commands/watch.html). +Read more about invoking the function in [the Cargo Lambda documentation for the `invoke` command](https://www.cargo-lambda.info/commands/invoke.html). + +## Deploying + +To deploy the project, run `cargo lambda deploy`. This will create an IAM role and a Lambda function in your AWS account. + +Read more about deploying your lambda function in [the Cargo Lambda documentation](https://www.cargo-lambda.info/commands/deploy.html). diff --git a/test-lmi/src/generic_handler.rs b/test-lmi/src/generic_handler.rs new file mode 100644 index 00000000..a9e75688 --- /dev/null +++ b/test-lmi/src/generic_handler.rs @@ -0,0 +1,8 @@ +use lambda_runtime::{Error, LambdaEvent}; +use serde_json::Value; +use std::time::Duration; + +pub(crate) async fn function_handler(_event: LambdaEvent) -> Result { + tokio::time::sleep(Duration::from_secs(10)).await; + Ok(Value::Null) +} diff --git a/test-lmi/src/main.rs b/test-lmi/src/main.rs new file mode 100644 index 00000000..943bbc99 --- /dev/null +++ b/test-lmi/src/main.rs @@ -0,0 +1,10 @@ +use lambda_runtime::{run, service_fn, tracing, Error}; +mod generic_handler; +use generic_handler::function_handler; + +#[tokio::main] +async fn main() -> Result<(), Error> { + tracing::init_default_subscriber(); + + run(service_fn(function_handler)).await +}