diff --git a/src/windows/common/HttpHeaderEndDetector.h b/src/windows/common/HttpHeaderEndDetector.h new file mode 100644 index 000000000..1a5003102 --- /dev/null +++ b/src/windows/common/HttpHeaderEndDetector.h @@ -0,0 +1,66 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + HttpHeaderEndDetector.h + +Abstract: + + This file contains a small state machine that detects the end-of-header + marker ("\r\n\r\n") in an HTTP message, one byte at a time. + +--*/ + +#pragma once + +namespace wsl::windows::common { + +// Detects the end-of-header marker ("\r\n\r\n") in an HTTP message +class HttpHeaderEndDetector +{ +public: + // Returns true once the full "\r\n\r\n" terminator has been consumed. + bool Consume(char byte) noexcept + { + switch (m_state) + { + case State::Start: + m_state = (byte == '\r') ? State::Cr : State::Start; + break; + case State::Cr: + m_state = (byte == '\n') ? State::CrLf : ((byte == '\r') ? State::Cr : State::Start); + break; + case State::CrLf: + m_state = (byte == '\r') ? State::CrLfCr : State::Start; + break; + case State::CrLfCr: + m_state = (byte == '\n') ? State::Done : ((byte == '\r') ? State::Cr : State::Start); + break; + case State::Done: + break; + } + + return m_state == State::Done; + } + + bool IsDone() const noexcept + { + return m_state == State::Done; + } + +private: + enum class State + { + Start, + Cr, + CrLf, + CrLfCr, + Done + }; + + State m_state = State::Start; +}; + +} // namespace wsl::windows::common diff --git a/src/windows/wslcsession/DockerHTTPClient.cpp b/src/windows/wslcsession/DockerHTTPClient.cpp index 9c855a6d6..1ce4728a6 100644 --- a/src/windows/wslcsession/DockerHTTPClient.cpp +++ b/src/windows/wslcsession/DockerHTTPClient.cpp @@ -663,16 +663,9 @@ void DockerHTTPClient::DockerHttpResponseHandle::OnRead(const gsl::span& C { // Otherwise keep parsing the HTTP response header. size_t i{}; - for (i = 0; i < Content.size() && LineFeeds < 2; i++) + for (i = 0; i < Content.size() && !HeaderEnd.IsDone(); i++) { - if (Content[i] == '\n') - { - LineFeeds++; - } - else if (Content[i] != '\r') - { - LineFeeds = 0; - } + HeaderEnd.Consume(Content[i]); } // Feed the parser up to the end of the header. @@ -805,7 +798,7 @@ std::pair DockerHTTPClient:: parser.eager(false); parser.skip(false); - size_t lineFeeds = 0; + HttpHeaderEndDetector headerEnd; // Consume the socket until the header end is reached while (!parser.is_header_done()) { @@ -822,16 +815,9 @@ std::pair DockerHTTPClient:: // Scan only the newly peeked bytes [Offset, Offset + bytesRead) size_t i = 0; - for (i = Offset; i < bytesRead + Offset && lineFeeds < 2; i++) + for (i = Offset; i < bytesRead + Offset && !headerEnd.IsDone(); i++) { - if (buffer[i] == '\n') - { - lineFeeds++; - } - else if (buffer[i] != '\r') - { - lineFeeds = 0; - } + headerEnd.Consume(buffer[i]); } WI_ASSERT(i >= Offset); @@ -847,7 +833,7 @@ std::pair DockerHTTPClient:: Offset += bytesRead; buffer.resize(Offset); - if (lineFeeds == 2) // Header is complete, feed it to the parser. + if (headerEnd.IsDone()) // Header is complete, feed it to the parser. { #ifdef WSLC_HTTP_DEBUG diff --git a/src/windows/wslcsession/DockerHTTPClient.h b/src/windows/wslcsession/DockerHTTPClient.h index fd24a490b..504c9c29a 100644 --- a/src/windows/wslcsession/DockerHTTPClient.h +++ b/src/windows/wslcsession/DockerHTTPClient.h @@ -20,6 +20,7 @@ Module Name: #include #include "relay.hpp" #include "docker_schema.h" +#include "HttpHeaderEndDetector.h" #define THROW_DOCKER_USER_ERROR_MSG(_Ex, _Msg, ...) \ if ((_Ex).HasErrorMessage()) \ @@ -199,7 +200,7 @@ class DockerHTTPClient std::function&)> OnResponse; std::function OnCompleted; boost::beast::http::response_parser Parser; - size_t LineFeeds = 0; + common::HttpHeaderEndDetector HeaderEnd; std::optional RemainingContentLength; std::optional ResponseParser; }; diff --git a/test/windows/WSLCTests.cpp b/test/windows/WSLCTests.cpp index 8baad20ab..0f00f536a 100644 --- a/test/windows/WSLCTests.cpp +++ b/test/windows/WSLCTests.cpp @@ -22,6 +22,7 @@ Module Name: #include "hcs.hpp" #include "ContainerNameGenerator.h" #include "wslc/e2e/WSLCE2EHelpers.h" +#include "HttpHeaderEndDetector.h" #include using namespace std::literals::chrono_literals; @@ -9220,6 +9221,39 @@ class WSLCTests VERIFY_IS_TRUE(payload == output); } + TEST_METHOD(HttpHeaderEndDetector) + { + // Returns the index of the byte of header end, or -1 if the header never ends. + const auto headerEndIndex = [](std::string_view input) { + wsl::windows::common::HttpHeaderEndDetector detector; + for (size_t i = 0; i < input.size(); i++) + { + if (detector.Consume(input[i])) + { + return static_cast(i); + } + } + + return -1; + }; + + VERIFY_ARE_EQUAL(3, headerEndIndex("\r\n\r\n")); + VERIFY_ARE_EQUAL(4, headerEndIndex("a\r\n\r\n")); + VERIFY_ARE_EQUAL(7, headerEndIndex("a\r\nb\r\n\r\n")); + VERIFY_ARE_EQUAL(4, headerEndIndex("\r\r\n\r\n")); + VERIFY_ARE_EQUAL(3, headerEndIndex("\r\n\r\nbody")); + + VERIFY_ARE_EQUAL(-1, headerEndIndex("")); + VERIFY_ARE_EQUAL(-1, headerEndIndex("Header: value\r\n")); + VERIFY_ARE_EQUAL(-1, headerEndIndex("HTTP/1.1 200 OK\r\n")); + VERIFY_ARE_EQUAL(-1, headerEndIndex("\r\n\r")); + + // Detection is strict. + VERIFY_ARE_EQUAL(-1, headerEndIndex("\n\n")); + VERIFY_ARE_EQUAL(-1, headerEndIndex("\r\n\n")); + VERIFY_ARE_EQUAL(-1, headerEndIndex("\n\r\n")); + } + WSLC_TEST_METHOD(ContainerRecoveryFromStorage) { auto restore = ResetTestSession(); // Required to access the storage folder.