diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index c3b2761d25..3159e08a4e 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -13,7 +13,7 @@ on: branches: [ "JDK17/Springboot3" ] jobs: - test: + JDK17-Test: runs-on: ubuntu-latest steps: - name: Checkout codes @@ -35,3 +35,19 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} flags: pull-request name: PR-Coverage + JDK21-Test: + runs-on: ubuntu-latest + steps: + - name: Checkout codes + uses: actions/checkout@v3 + with: + ref: ${{ github.event.pull_request.head.ref }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + - name: Set up JDK 21 + uses: actions/setup-java@v3 + with: + java-version: '21' + distribution: 'temurin' + cache: maven + - name: Test with Maven + run: mvn clean install -B -U --file pom.xml diff --git a/trpc-core/src/main/java/com/tencent/trpc/core/common/config/ClientConfig.java b/trpc-core/src/main/java/com/tencent/trpc/core/common/config/ClientConfig.java index 2e5cc283fa..7cc1c0e8a5 100644 --- a/trpc-core/src/main/java/com/tencent/trpc/core/common/config/ClientConfig.java +++ b/trpc-core/src/main/java/com/tencent/trpc/core/common/config/ClientConfig.java @@ -1,7 +1,7 @@ /* * Tencent is pleased to support the open source community by making tRPC available. * - * Copyright (C) 2023 THL A29 Limited, a Tencent company. + * Copyright (C) 2023 THL A29 Limited, a Tencent company. * All rights reserved. * * If you have downloaded a copy of the tRPC source code from Tencent, @@ -32,7 +32,7 @@ public class ClientConfig extends BaseProtocolConfig { private static final Logger logger = LoggerFactory.getLogger(ClientConfig.class); - + @ConfigProperty protected String namespace; @@ -75,7 +75,7 @@ public class ClientConfig extends BaseProtocolConfig { /** * BackendConfig mapping. */ - protected Map backendConfigMap = Maps.newHashMap(); + protected Map backendConfigMap = Maps.newConcurrentMap(); /** * Whether the service is registered. diff --git a/trpc-core/src/test/java/com/tencent/trpc/core/common/config/ClientConfigTest.java b/trpc-core/src/test/java/com/tencent/trpc/core/common/config/ClientConfigTest.java index 8e51941cb3..c29327aa24 100644 --- a/trpc-core/src/test/java/com/tencent/trpc/core/common/config/ClientConfigTest.java +++ b/trpc-core/src/test/java/com/tencent/trpc/core/common/config/ClientConfigTest.java @@ -227,4 +227,54 @@ public void testClientConfigNotEmptryBackendConfig() { assertEquals(false, backendConfig.isIoThreadGroupShare()); assertEquals(1000, backendConfig.getIoThreads()); } -} + + @Test + public void testInitAndStop() { + ClientConfig config = new ClientConfig(); + config.init(); + assertTrue(config.isInitialized()); + config.init(); + assertTrue(config.isInitialized()); + config.stop(); + config.stop(); + } + + @Test + public void testGetBackendConfig() { + ClientConfig config = new ClientConfig(); + BackendConfig backendConfig = new BackendConfig(); + backendConfig.setName("svc"); + backendConfig.setNamingUrl("a://b"); + config.addBackendConfig(backendConfig); + assertEquals(backendConfig, config.getBackendConfig("svc")); + } + + @Test + public void testSetterAfterInitThrows() { + ClientConfig config = new ClientConfig(); + config.init(); + Exception ex = null; + try { + config.setNamespace("ns"); + } catch (Exception e) { + ex = e; + } + assertTrue(ex instanceof IllegalArgumentException); + } + + @Test + public void testSetBackendConfigMap() { + ClientConfig config = new ClientConfig(); + java.util.Map map = new java.util.HashMap<>(); + config.setBackendConfigMap(map); + assertEquals(map, config.getBackendConfigMap()); + } + + @Test + public void testSetDefaultIdempotent() { + ClientConfig config = new ClientConfig(); + config.setDefault(); + config.setDefault(); + assertTrue(config.isSetDefault()); + } +} \ No newline at end of file diff --git a/trpc-proto/trpc-proto-http/src/main/java/com/tencent/trpc/proto/http/client/HttpConsumerInvoker.java b/trpc-proto/trpc-proto-http/src/main/java/com/tencent/trpc/proto/http/client/HttpConsumerInvoker.java index 8d1468f854..91ec9463ef 100644 --- a/trpc-proto/trpc-proto-http/src/main/java/com/tencent/trpc/proto/http/client/HttpConsumerInvoker.java +++ b/trpc-proto/trpc-proto-http/src/main/java/com/tencent/trpc/proto/http/client/HttpConsumerInvoker.java @@ -93,10 +93,8 @@ private Response handleResponse(Request request, CloseableHttpResponse httpRespo Map respAttachments = new HashMap<>(); for (Header header : httpResponse.getAllHeaders()) { String name = header.getName(); - for (HeaderElement element : header.getElements()) { - String value = element.getName(); - respAttachments.put(name, value.getBytes(StandardCharsets.UTF_8)); - } + String value = header.getValue(); + respAttachments.put(name, value.getBytes(StandardCharsets.UTF_8)); } Header contentLengthHdr = httpResponse.getFirstHeader(HttpHeaders.CONTENT_LENGTH); diff --git a/trpc-proto/trpc-proto-http/src/main/java/com/tencent/trpc/proto/http/server/AbstractHttpExecutor.java b/trpc-proto/trpc-proto-http/src/main/java/com/tencent/trpc/proto/http/server/AbstractHttpExecutor.java index e12345b9ca..ab09329119 100644 --- a/trpc-proto/trpc-proto-http/src/main/java/com/tencent/trpc/proto/http/server/AbstractHttpExecutor.java +++ b/trpc-proto/trpc-proto-http/src/main/java/com/tencent/trpc/proto/http/server/AbstractHttpExecutor.java @@ -1,7 +1,7 @@ /* * Tencent is pleased to support the open source community by making tRPC available. * - * Copyright (C) 2023 THL A29 Limited, a Tencent company. + * Copyright (C) 2023 Tencent. * All rights reserved. * * If you have downloaded a copy of the tRPC source code from Tencent, @@ -19,11 +19,11 @@ import com.tencent.trpc.core.exception.TRpcException; import com.tencent.trpc.core.logger.Logger; import com.tencent.trpc.core.logger.LoggerFactory; +import com.tencent.trpc.core.rpc.RpcContext; import com.tencent.trpc.core.rpc.CallInfo; import com.tencent.trpc.core.rpc.ProviderInvoker; import com.tencent.trpc.core.rpc.RequestMeta; import com.tencent.trpc.core.rpc.Response; -import com.tencent.trpc.core.rpc.RpcContext; import com.tencent.trpc.core.rpc.RpcInvocation; import com.tencent.trpc.core.rpc.RpcServerContext; import com.tencent.trpc.core.rpc.common.RpcMethodInfo; @@ -38,17 +38,19 @@ import com.tencent.trpc.proto.http.common.RpcServerContextWithHttp; import com.tencent.trpc.proto.http.common.TrpcServletRequestWrapper; import com.tencent.trpc.proto.http.common.TrpcServletResponseWrapper; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; import java.util.Enumeration; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.CompletionStage; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import org.apache.commons.lang3.StringUtils; import org.apache.http.HttpStatus; @@ -66,15 +68,15 @@ public abstract class AbstractHttpExecutor { protected void execute(HttpServletRequest request, HttpServletResponse response, RpcMethodInfoAndInvoker methodInfoAndInvoker) { - + AtomicBoolean responded = new AtomicBoolean(false); try { DefRequest rpcRequest = buildDefRequest(request, response, methodInfoAndInvoker); - CountDownLatch countDownLatch = new CountDownLatch(1); + CompletableFuture completionFuture = new CompletableFuture<>(); // use a thread pool for asynchronous processing - invokeRpcRequest(methodInfoAndInvoker.getInvoker(), rpcRequest, countDownLatch); + invokeRpcRequest(methodInfoAndInvoker.getInvoker(), rpcRequest, completionFuture, responded); // If the request carries a timeout, use this timeout to wait for the request to be processed. // If not carried, use the default timeout. @@ -82,18 +84,25 @@ protected void execute(HttpServletRequest request, HttpServletResponse response, if (requestTimeout <= 0) { requestTimeout = methodInfoAndInvoker.getInvoker().getConfig().getRequestTimeout(); } - if (requestTimeout > 0 && !countDownLatch.await(requestTimeout, TimeUnit.MILLISECONDS)) { - throw TRpcException.newFrameException(ErrorCode.TRPC_SERVER_TIMEOUT_ERR, - "wait http request execute timeout"); + if (requestTimeout > 0) { + try { + completionFuture.get(requestTimeout, TimeUnit.MILLISECONDS); + } catch (TimeoutException ex) { + if (responded.compareAndSet(false, true)) { + doErrorReply(request, response, + TRpcException.newFrameException(ErrorCode.TRPC_SERVER_TIMEOUT_ERR, + "wait http request execute timeout")); + } + } } else { - countDownLatch.await(); + completionFuture.get(); } - } catch (Exception ex) { logger.error("dispatch request [{}] error", request, ex); - doErrorReply(request, response, ex); + if (responded.compareAndSet(false, true)) { + doErrorReply(request, response, ex); + } } - } /** @@ -108,55 +117,83 @@ protected void execute(HttpServletRequest request, HttpServletResponse response, /** * Request processing * - * @param countDownLatch latch used to wait for the request processing + * @param invoker the invoker + * @param rpcRequest the rpc request + * @param completionFuture the completion future + * @param responded the responded flag */ - private void invokeRpcRequest(ProviderInvoker invoker, DefRequest rpcRequest, CountDownLatch countDownLatch) { + private void invokeRpcRequest(ProviderInvoker invoker, DefRequest rpcRequest, + CompletableFuture completionFuture, + AtomicBoolean responded) { WorkerPool workerPool = invoker.getConfig().getWorkerPoolObj(); if (null == workerPool) { logger.error("dispatch rpcRequest [{}] error, workerPool is empty", rpcRequest); - throw TRpcException.newFrameException(ErrorCode.TRPC_SERVER_NOSERVICE_ERR, - "not found service, workerPool is empty"); + completionFuture.completeExceptionally(TRpcException.newFrameException(ErrorCode.TRPC_SERVER_NOSERVICE_ERR, + "not found service, workerPool is empty")); + return; } workerPool.execute(() -> { - - // Get the original http response - HttpServletResponse response = getOriginalResponse(rpcRequest); - - // Invoke the routing implementation method to handle the request. - CompletionStage future = invoker.invoke(rpcRequest); - future.whenComplete((result, t) -> { - try { - // Throw the call exception, which will be handled uniformly by the exception handling program. - if (t != null) { - throw t; - } - - // Throw a business logic exception, which will be handled uniformly - // by the exception handling program. - Throwable ex = result.getException(); - if (ex != null) { - throw ex; + try { + // Get the original http response + HttpServletResponse response = getOriginalResponse(rpcRequest); + // Invoke the routing implementation method to handle the request. + CompletionStage rpcFuture = invoker.invoke(rpcRequest); + + rpcFuture.whenComplete((result, throwable) -> { + try { + if (responded.get()) { + return; + } + + // Throw the call exception, which will be handled uniformly by the exception handling program. + if (throwable != null) { + throw throwable; + } + + // Throw a business logic exception, which will be handled uniformly + // by the exception handling program. + if (result.getException() != null) { + throw result.getException(); + } + + // normal response + if (responded.compareAndSet(false, true)) { + response.setStatus(HttpStatus.SC_OK); + httpCodec.writeHttpResponse(response, result); + response.flushBuffer(); + } + + completionFuture.complete(null); + } catch (Throwable t) { + handleError(t, rpcRequest, response, responded, completionFuture); } + }); - // normal response - response.setStatus(HttpStatus.SC_OK); - httpCodec.writeHttpResponse(response, result); - response.flushBuffer(); - } catch (Throwable e) { - HttpServletRequest request = getOriginalRequest(rpcRequest); - logger.warn("reply message error, channel: [{}], msg:[{}]", request.getRemoteAddr(), request, e); - httpErrorReply(request, response, - ErrorResponse.create(request, HttpStatus.SC_SERVICE_UNAVAILABLE, e)); - } finally { - countDownLatch.countDown(); - } - }); + } catch (Exception e) { + handleError(e, rpcRequest, getOriginalResponse(rpcRequest), responded, completionFuture); + } }); } + /** + * Handle error + */ + private void handleError(Throwable t, DefRequest rpcRequest, HttpServletResponse response, + AtomicBoolean responded, CompletableFuture completionFuture) { + try { + if (responded.compareAndSet(false, true)) { + HttpServletRequest request = getOriginalRequest(rpcRequest); + logger.warn("reply message error, channel: [{}], msg:[{}]", request.getRemoteAddr(), request, t); + httpErrorReply(request, response, ErrorResponse.create(request, HttpStatus.SC_SERVICE_UNAVAILABLE, t)); + } + } finally { + completionFuture.completeExceptionally(t); + } + } + /** * Build the context request. * @@ -392,7 +429,6 @@ private void setRpcServerContext(HttpServletRequest request, HttpServletResponse // to maintain consistency. rpcRequest.getAttachments().put(header, value.getBytes(StandardCharsets.UTF_8)); } - logger.debug("request attachment: {}", JsonUtils.toJson(rpcRequest.getAttachments())); } /** @@ -488,4 +524,4 @@ private String getString(String[] callInfos, int length, int cursor) { return callInfos.length < length ? StringUtils.EMPTY : callInfos[cursor]; } -} +} \ No newline at end of file diff --git a/trpc-proto/trpc-proto-http/src/test/java/com/tencent/trpc/proto/http/client/HttpConsumerInvokerTest.java b/trpc-proto/trpc-proto-http/src/test/java/com/tencent/trpc/proto/http/client/HttpConsumerInvokerTest.java new file mode 100644 index 0000000000..cff454a950 --- /dev/null +++ b/trpc-proto/trpc-proto-http/src/test/java/com/tencent/trpc/proto/http/client/HttpConsumerInvokerTest.java @@ -0,0 +1,404 @@ +/* + * Tencent is pleased to support the open source community by making tRPC available. + * + * Copyright (C) 2023 Tencent. + * All rights reserved. + * + * If you have downloaded a copy of the tRPC source code from Tencent, + * please note that tRPC source code is licensed under the Apache 2.0 License, + * A copy of the Apache 2.0 License can be found in the LICENSE file. + */ + +package com.tencent.trpc.proto.http.client; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.tencent.trpc.core.common.ConfigManager; +import com.tencent.trpc.core.common.config.BackendConfig; +import com.tencent.trpc.core.common.config.ConsumerConfig; +import com.tencent.trpc.core.common.config.ProtocolConfig; +import com.tencent.trpc.core.rpc.CallInfo; +import com.tencent.trpc.core.rpc.Request; +import com.tencent.trpc.core.rpc.RequestMeta; +import com.tencent.trpc.core.rpc.Response; +import com.tencent.trpc.core.rpc.RpcInvocation; +import com.tencent.trpc.core.rpc.common.RpcMethodInfo; +import com.tencent.trpc.core.worker.spi.WorkerPool; +import java.io.ByteArrayInputStream; +import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import org.apache.http.Header; +import org.apache.http.HttpHeaders; +import org.apache.http.HttpStatus; +import org.apache.http.HttpVersion; +import org.apache.http.StatusLine; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.entity.BasicHttpEntity; +import org.apache.http.message.BasicHeader; +import org.apache.http.message.BasicStatusLine; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests the response header parsing logic in {@link HttpConsumerInvoker}. + * Verifies that {@code header.getValue()} returns the complete header value instead of + * the truncated result previously produced by {@code HeaderElement.getName()}. + */ +public class HttpConsumerInvokerTest { + + private HttpRpcClient mockHttpRpcClient; + private ConsumerConfig mockConsumerConfig; + private ProtocolConfig mockProtocolConfig; + private BackendConfig mockBackendConfig; + private WorkerPool mockWorkerPool; + + private HttpConsumerInvoker invoker; + + @BeforeEach + public void setUp() { + ConfigManager.stopTest(); + ConfigManager.startTest(); + + mockHttpRpcClient = mock(HttpRpcClient.class); + mockConsumerConfig = mock(ConsumerConfig.class); + mockProtocolConfig = mock(ProtocolConfig.class); + mockBackendConfig = mock(BackendConfig.class); + mockWorkerPool = mock(WorkerPool.class); + + when(mockConsumerConfig.getBackendConfig()).thenReturn(mockBackendConfig); + when(mockBackendConfig.getWorkerPoolObj()).thenReturn(mockWorkerPool); + when(mockProtocolConfig.getIp()).thenReturn("127.0.0.1"); + when(mockProtocolConfig.getPort()).thenReturn(8080); + when(mockProtocolConfig.getExtMap()).thenReturn(new HashMap<>()); + + invoker = new HttpConsumerInvoker<>(mockHttpRpcClient, mockConsumerConfig, mockProtocolConfig); + } + + @AfterEach + public void tearDown() { + ConfigManager.stopTest(); + } + + /** + * Verifies that a simple header value (no delimiters) is parsed correctly. + * e.g. X-Custom-Header: simple-value + */ + @Test + public void testSimpleHeaderValueParsedCorrectly() throws Exception { + CloseableHttpResponse mockResponse = buildMockResponse( + HttpStatus.SC_OK, + new Header[]{ + new BasicHeader("X-Custom-Header", "simple-value"), + new BasicHeader(HttpHeaders.CONTENT_LENGTH, "0") + }, + null + ); + + Request mockRequest = buildMockRequest(); + + Response response = invokeHandleResponse(mockRequest, mockResponse); + + assertNotNull(response); + assertNotNull(response.getAttachments()); + byte[] headerValue = (byte[]) response.getAttachments().get("X-Custom-Header"); + assertNotNull(headerValue); + assertEquals("simple-value", new String(headerValue, StandardCharsets.UTF_8)); + } + + /** + * Verifies that a composite header value containing semicolons is parsed completely + * (this is the core fix scenario). The old {@code HeaderElement.getName()} only returned + * the token before the first {@code ;} (e.g. {@code "application/json"}), whereas + * {@code header.getValue()} returns the full value + * (e.g. {@code "application/json; charset=utf-8"}). + */ + @Test + public void testComplexHeaderWithSemicolonParsedCompletely() throws Exception { + String fullContentType = "application/json; charset=utf-8"; + CloseableHttpResponse mockResponse = buildMockResponse( + HttpStatus.SC_OK, + new Header[]{ + new BasicHeader(HttpHeaders.CONTENT_TYPE, fullContentType), + new BasicHeader(HttpHeaders.CONTENT_LENGTH, "0") + }, + null + ); + + Request mockRequest = buildMockRequest(); + + Response response = invokeHandleResponse(mockRequest, mockResponse); + + assertNotNull(response); + byte[] contentTypeValue = (byte[]) response.getAttachments().get(HttpHeaders.CONTENT_TYPE); + assertNotNull(contentTypeValue); + // After the fix, the full value should be returned, not just "application/json" + assertEquals(fullContentType, new String(contentTypeValue, StandardCharsets.UTF_8)); + } + + /** + * Verifies that a header value containing an equals sign is parsed completely. + * The old {@code HeaderElement.getName()} also truncated values with {@code =}, + * whereas {@code header.getValue()} returns the full value. + * e.g. X-Token: key=abc123 + */ + @Test + public void testHeaderWithEqualSignParsedCompletely() throws Exception { + String tokenValue = "key=abc123"; + CloseableHttpResponse mockResponse = buildMockResponse( + HttpStatus.SC_OK, + new Header[]{ + new BasicHeader("X-Token", tokenValue), + new BasicHeader(HttpHeaders.CONTENT_LENGTH, "0") + }, + null + ); + + Request mockRequest = buildMockRequest(); + + Response response = invokeHandleResponse(mockRequest, mockResponse); + + assertNotNull(response); + byte[] tokenBytes = (byte[]) response.getAttachments().get("X-Token"); + assertNotNull(tokenBytes); + assertEquals(tokenValue, new String(tokenBytes, StandardCharsets.UTF_8)); + } + + /** + * Verifies that multiple response headers are all parsed correctly and stored in attachments. + */ + @Test + public void testMultipleHeadersAllParsedCorrectly() throws Exception { + CloseableHttpResponse mockResponse = buildMockResponse( + HttpStatus.SC_OK, + new Header[]{ + new BasicHeader("X-Trace-Id", "trace-abc-123"), + new BasicHeader("X-Caller", "service-a"), + new BasicHeader("X-Callee", "service-b"), + new BasicHeader(HttpHeaders.CONTENT_LENGTH, "0") + }, + null + ); + + Request mockRequest = buildMockRequest(); + + Response response = invokeHandleResponse(mockRequest, mockResponse); + + assertNotNull(response); + assertEquals("trace-abc-123", + new String((byte[]) response.getAttachments().get("X-Trace-Id"), StandardCharsets.UTF_8)); + assertEquals("service-a", + new String((byte[]) response.getAttachments().get("X-Caller"), StandardCharsets.UTF_8)); + assertEquals("service-b", + new String((byte[]) response.getAttachments().get("X-Callee"), StandardCharsets.UTF_8)); + } + + /** + * Verifies that header values are stored as {@code byte[]} to maintain consistency + * with the tRPC protocol. + */ + @Test + public void testHeaderValueStoredAsByteArray() throws Exception { + String expectedValue = "test-value"; + CloseableHttpResponse mockResponse = buildMockResponse( + HttpStatus.SC_OK, + new Header[]{ + new BasicHeader("X-Test", expectedValue), + new BasicHeader(HttpHeaders.CONTENT_LENGTH, "0") + }, + null + ); + + Request mockRequest = buildMockRequest(); + + Response response = invokeHandleResponse(mockRequest, mockResponse); + + assertNotNull(response); + Object storedValue = response.getAttachments().get("X-Test"); + // Verify that the stored type is byte[] + assertNotNull(storedValue); + assertEquals(byte[].class, storedValue.getClass()); + assertArrayEquals(expectedValue.getBytes(StandardCharsets.UTF_8), (byte[]) storedValue); + } + + /** + * Verifies that a non-200 HTTP status code causes a {@link com.tencent.trpc.core.exception.TRpcException} + * to be thrown. + */ + @Test + public void testNon200StatusCodeThrowsException() throws Exception { + CloseableHttpResponse mockResponse = buildMockResponse( + HttpStatus.SC_NOT_FOUND, + new Header[]{}, + null + ); + + Request mockRequest = buildMockRequest(); + assertThrows(com.tencent.trpc.core.exception.TRpcException.class, + () -> invokeHandleResponse(mockRequest, mockResponse)); + } + + /** + * Verifies that a {@code Content-Length: 0} response returns an empty response body. + */ + @Test + public void testZeroContentLengthReturnsEmptyResponse() throws Exception { + CloseableHttpResponse mockResponse = buildMockResponse( + HttpStatus.SC_OK, + new Header[]{ + new BasicHeader(HttpHeaders.CONTENT_LENGTH, "0") + }, + null + ); + + Request mockRequest = buildMockRequest(); + + Response response = invokeHandleResponse(mockRequest, mockResponse); + + assertNotNull(response); + assertNull(response.getValue()); + } + + /** + * Verifies that a complex header value with multiple semicolons and equals signs + * is parsed completely. + * e.g. Set-Cookie: sessionId=abc; Path=/; HttpOnly + */ + @Test + public void testComplexCookieHeaderParsedCompletely() throws Exception { + String cookieValue = "sessionId=abc; Path=/; HttpOnly"; + CloseableHttpResponse mockResponse = buildMockResponse( + HttpStatus.SC_OK, + new Header[]{ + new BasicHeader("Set-Cookie", cookieValue), + new BasicHeader(HttpHeaders.CONTENT_LENGTH, "0") + }, + null + ); + + Request mockRequest = buildMockRequest(); + + Response response = invokeHandleResponse(mockRequest, mockResponse); + + assertNotNull(response); + byte[] cookieBytes = (byte[]) response.getAttachments().get("Set-Cookie"); + assertNotNull(cookieBytes); + // After the fix, the full cookie value should be returned, not just "sessionId" + assertEquals(cookieValue, new String(cookieBytes, StandardCharsets.UTF_8)); + } + + /** + * Verifies that the response body is decoded correctly when Content-Length is non-zero. + */ + @Test + public void testResponseBodyParsedWhenContentLengthNonZero() throws Exception { + String jsonBody = "\"hello\""; + CloseableHttpResponse mockResponse = buildMockResponse( + HttpStatus.SC_OK, + new Header[]{ + new BasicHeader("X-Custom", "custom-value"), + new BasicHeader(HttpHeaders.CONTENT_LENGTH, String.valueOf(jsonBody.length())) + }, + jsonBody + ); + + Request mockRequest = buildMockRequest(); + + Response response = invokeHandleResponse(mockRequest, mockResponse); + + assertNotNull(response); + // Verify that response headers are also parsed correctly + byte[] customValue = (byte[]) response.getAttachments().get("X-Custom"); + assertNotNull(customValue); + assertEquals("custom-value", new String(customValue, StandardCharsets.UTF_8)); + // Verify that the response body is decoded correctly + assertEquals("hello", response.getValue()); + } + + // ==================== Helper methods ==================== + + /** + * Builds a mock HTTP response with the given status code, headers, and optional body. + */ + private CloseableHttpResponse buildMockResponse(int statusCode, Header[] headers, String body) { + CloseableHttpResponse mockResponse = mock(CloseableHttpResponse.class); + StatusLine statusLine = new BasicStatusLine(HttpVersion.HTTP_1_1, statusCode, + statusCode == HttpStatus.SC_OK ? "OK" : "Not Found"); + when(mockResponse.getStatusLine()).thenReturn(statusLine); + when(mockResponse.getAllHeaders()).thenReturn(headers); + + // Wire up Content-Length header lookup + for (Header header : headers) { + if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(header.getName())) { + when(mockResponse.getFirstHeader(HttpHeaders.CONTENT_LENGTH)).thenReturn(header); + } + } + + if (body != null) { + BasicHttpEntity entity = new BasicHttpEntity(); + entity.setContent(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); + when(mockResponse.getEntity()).thenReturn(entity); + } + + return mockResponse; + } + + /** + * Builds a mock {@link Request} backed by a real {@link RpcMethodInfo} whose return type is + * {@link String}. + */ + private Request buildMockRequest() throws Exception { + Request mockRequest = mock(Request.class); + RequestMeta mockMeta = mock(RequestMeta.class); + CallInfo mockCallInfo = mock(CallInfo.class); + + // Build real RpcInvocation and RpcMethodInfo instances + RpcInvocation invocation = new RpcInvocation(); + Method method = TestService.class.getMethod("testMethod", String.class); + RpcMethodInfo methodInfo = new RpcMethodInfo(TestService.class, method); + invocation.setRpcMethodInfo(methodInfo); + + when(mockRequest.getInvocation()).thenReturn(invocation); + when(mockRequest.getMeta()).thenReturn(mockMeta); + when(mockMeta.getCallInfo()).thenReturn(mockCallInfo); + when(mockCallInfo.getCaller()).thenReturn("test-caller"); + when(mockCallInfo.getCallee()).thenReturn("test-callee"); + when(mockRequest.getAttachments()).thenReturn(new HashMap<>()); + + return mockRequest; + } + + /** + * Invokes the private {@code handleResponse} method via reflection. + */ + private Response invokeHandleResponse(Request request, CloseableHttpResponse httpResponse) + throws Exception { + Method handleResponseMethod = HttpConsumerInvoker.class + .getDeclaredMethod("handleResponse", Request.class, CloseableHttpResponse.class); + handleResponseMethod.setAccessible(true); + try { + return (Response) handleResponseMethod.invoke(invoker, request, httpResponse); + } catch (java.lang.reflect.InvocationTargetException e) { + // Unwrap and rethrow so that @Test(expected=...) can catch the original exception + if (e.getCause() instanceof Exception) { + throw (Exception) e.getCause(); + } + throw e; + } + } + + /** + * Stub service interface used only for constructing {@link RpcMethodInfo} in tests. + */ + private interface TestService { + + String testMethod(String input); + } +} \ No newline at end of file diff --git a/trpc-proto/trpc-proto-http/src/test/java/com/tencent/trpc/proto/http/server/AbstractHttpExecutorTest.java b/trpc-proto/trpc-proto-http/src/test/java/com/tencent/trpc/proto/http/server/AbstractHttpExecutorTest.java index a46044d30c..464b8c4041 100644 --- a/trpc-proto/trpc-proto-http/src/test/java/com/tencent/trpc/proto/http/server/AbstractHttpExecutorTest.java +++ b/trpc-proto/trpc-proto-http/src/test/java/com/tencent/trpc/proto/http/server/AbstractHttpExecutorTest.java @@ -11,8 +11,49 @@ package com.tencent.trpc.proto.http.server; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import com.google.protobuf.Message; +import com.tencent.trpc.core.common.ConfigManager; +import com.tencent.trpc.core.common.config.ProviderConfig; +import com.tencent.trpc.core.exception.ErrorCode; +import com.tencent.trpc.core.exception.TRpcException; +import com.tencent.trpc.core.rpc.ProviderInvoker; +import com.tencent.trpc.core.rpc.RpcContext; +import com.tencent.trpc.core.rpc.RpcInvocation; +import com.tencent.trpc.core.rpc.Response; +import com.tencent.trpc.core.rpc.common.RpcMethodInfo; +import com.tencent.trpc.core.rpc.common.RpcMethodInfoAndInvoker; +import com.tencent.trpc.core.rpc.def.DefRequest; +import com.tencent.trpc.core.rpc.def.DefResponse; +import com.tencent.trpc.core.worker.spi.WorkerPool; +import com.tencent.trpc.core.worker.spi.WorkerPool.Task; +import com.tencent.trpc.proto.http.common.HttpCodec; +import com.tencent.trpc.proto.http.common.HttpConstants; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.http.HttpStatus; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; @@ -23,9 +64,720 @@ @MockitoSettings(strictness = Strictness.LENIENT) public class AbstractHttpExecutorTest { + private static final String TEST_SERVICE = "trpc.demo.server"; + private static final String TEST_METHOD = "hello"; + private static final String TEST_IP = "127.0.0.1"; + private static final int TEST_PORT = 8080; + + @BeforeEach + public void setUp() { + ConfigManager.stopTest(); + ConfigManager.startTest(); + } + + @AfterEach + public void tearDown() { + ConfigManager.stopTest(); + } + + private HttpServletRequest mockRequest() { + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getAttribute(HttpConstants.REQUEST_ATTRIBUTE_TRPC_SERVICE)).thenReturn(TEST_SERVICE); + when(request.getAttribute(HttpConstants.REQUEST_ATTRIBUTE_TRPC_METHOD)).thenReturn(TEST_METHOD); + when(request.getRemoteAddr()).thenReturn(TEST_IP); + when(request.getRemotePort()).thenReturn(TEST_PORT); + when(request.getMethod()).thenReturn("POST"); + when(request.getRequestURI()).thenReturn("/api/test"); + when(request.getHeaderNames()).thenReturn(Collections.emptyEnumeration()); + return request; + } + + private WorkerPool mockSyncWorkerPool() { + WorkerPool workerPool = mock(WorkerPool.class); + doAnswer(invocation -> { + Object arg = invocation.getArguments()[0]; + if (arg instanceof Runnable) { + ((Runnable) arg).run(); + } else if (arg instanceof Task) { + ((Task) arg).run(); + } + return null; + }).when(workerPool).execute(any()); + return workerPool; + } + + private ProviderConfig mockProviderConfig(int timeout) { + ProviderConfig config = mock(ProviderConfig.class); + when(config.getRequestTimeout()).thenReturn(timeout); + WorkerPool workerPool = mockSyncWorkerPool(); + when(config.getWorkerPoolObj()).thenReturn(workerPool); + return config; + } + + private DefRequest mockDefRequest(HttpServletRequest request, HttpServletResponse response) { + DefRequest defRequest = new DefRequest(); + defRequest.getAttachments().put(HttpConstants.TRPC_ATTACH_SERVLET_RESPONSE, response); + defRequest.getAttachments().put(HttpConstants.TRPC_ATTACH_SERVLET_REQUEST, request); + return defRequest; + } + + private AbstractHttpExecutor createExecutorWithCodec() { + AbstractHttpExecutor executor = new AbstractHttpExecutor() { + @Override + protected RpcMethodInfoAndInvoker getRpcMethodInfoAndInvoker(Object object) { + return null; + } + }; + HttpCodec httpCodec = mock(HttpCodec.class); + setField(executor, "httpCodec", httpCodec); + return executor; + } + + private AbstractHttpExecutor createExecutorWithInvoker(RpcMethodInfoAndInvoker methodInfoAndInvoker) { + AbstractHttpExecutor executor = new AbstractHttpExecutor() { + @Override + protected RpcMethodInfoAndInvoker getRpcMethodInfoAndInvoker(Object object) { + return methodInfoAndInvoker; + } + }; + HttpCodec httpCodec = mock(HttpCodec.class); + setField(executor, "httpCodec", httpCodec); + return executor; + } + + private void setField(Object target, String fieldName, Object value) { + try { + Field field = AbstractHttpExecutor.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(target, value); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private Object getField(Object target, String fieldName) { + try { + Field field = AbstractHttpExecutor.class.getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(target); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private Object invokePrivate(Object target, String methodName, Class[] paramTypes, Object... args) + throws Exception { + Method method = AbstractHttpExecutor.class.getDeclaredMethod(methodName, paramTypes); + method.setAccessible(true); + try { + return method.invoke(target, args); + } catch (java.lang.reflect.InvocationTargetException e) { + if (e.getCause() instanceof Exception) { + throw (Exception) e.getCause(); + } + throw e; + } + } + + private RpcMethodInfoAndInvoker buildMethodInfoAndInvoker(ProviderInvoker invoker) throws Exception { + Method method = TestService.class.getMethod("hello", RpcContext.class, String.class); + RpcMethodInfo methodInfo = new RpcMethodInfo(TestService.class, method); + RpcMethodInfoAndInvoker methodInfoAndInvoker = new RpcMethodInfoAndInvoker(); + methodInfoAndInvoker.setMethodInfo(methodInfo); + methodInfoAndInvoker.setInvoker(invoker); + return methodInfoAndInvoker; + } + + // ==================== buildRpcInvocation ==================== + + @Test + public void testBuildRpcInvocation() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getAttribute(HttpConstants.REQUEST_ATTRIBUTE_TRPC_SERVICE)).thenReturn(TEST_SERVICE); + when(request.getAttribute(HttpConstants.REQUEST_ATTRIBUTE_TRPC_METHOD)).thenReturn(TEST_METHOD); + + RpcMethodInfo methodInfo = mock(RpcMethodInfo.class); + when(methodInfo.getParamsTypes()).thenReturn(new Type[]{RpcContext.class, String.class}); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + when(httpCodec.convertToJavaBean(any(), any())).thenReturn("param"); + + RpcInvocation invocation = (RpcInvocation) invokePrivate(executor, "buildRpcInvocation", + new Class[]{HttpServletRequest.class, RpcMethodInfo.class}, request, methodInfo); + + assertEquals("/" + TEST_SERVICE + "/" + TEST_METHOD, invocation.getFunc()); + } + + // ==================== parseRpcParams ==================== + + @Test + public void testParseRpcParamsUnsupported() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + RpcMethodInfo methodInfo = mock(RpcMethodInfo.class); + when(methodInfo.getParamsTypes()).thenReturn(new Type[]{String.class}); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + try { + invokePrivate(executor, "parseRpcParams", + new Class[]{HttpServletRequest.class, RpcMethodInfo.class}, request, methodInfo); + } catch (UnsupportedOperationException e) { + assertNotNull(e); + } + } + + @Test + public void testParseRpcParamsMap() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + RpcMethodInfo methodInfo = mock(RpcMethodInfo.class); + when(methodInfo.getParamsTypes()).thenReturn(new Type[]{RpcContext.class, Map.class}); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + Map mockMap = new HashMap<>(); + when(httpCodec.convertToJsonParam(any())).thenReturn(mockMap); + + Object[] result = (Object[]) invokePrivate(executor, "parseRpcParams", + new Class[]{HttpServletRequest.class, RpcMethodInfo.class}, request, methodInfo); + + assertNotNull(result); + assertEquals(mockMap, result[0]); + } + + @Test + public void testParseRpcParamsParameterized() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + RpcMethodInfo methodInfo = mock(RpcMethodInfo.class); + ParameterizedType paramType = mock(ParameterizedType.class); + when(methodInfo.getParamsTypes()).thenReturn(new Type[]{RpcContext.class, paramType}); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + when(httpCodec.convertToParameterizedBean(any(), any())).thenReturn("paramResult"); + + Object[] result = (Object[]) invokePrivate(executor, "parseRpcParams", + new Class[]{HttpServletRequest.class, RpcMethodInfo.class}, request, methodInfo); + + assertNotNull(result); + assertEquals("paramResult", result[0]); + } + + // ==================== invokeRpcRequest ==================== + + @Test + public void testInvokeRpcSuccess() throws Exception { + ProviderInvoker invoker = mock(ProviderInvoker.class); + DefResponse successResponse = new DefResponse(); + successResponse.setValue("success"); + CompletableFuture successFuture = CompletableFuture.completedFuture(successResponse); + when(invoker.invoke(any())).thenReturn(successFuture); + ProviderConfig config = mockProviderConfig(0); + when(invoker.getConfig()).thenReturn(config); + + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + DefRequest defRequest = mockDefRequest(request, response); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + AtomicBoolean responded = new AtomicBoolean(false); + CompletableFuture completionFuture = new CompletableFuture<>(); + + invokePrivate(executor, "invokeRpcRequest", + new Class[]{ProviderInvoker.class, DefRequest.class, CompletableFuture.class, AtomicBoolean.class}, + invoker, defRequest, completionFuture, responded); + + completionFuture.get(); + verify(response).setStatus(HttpStatus.SC_OK); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + verify(httpCodec).writeHttpResponse(response, successResponse); + verify(response).flushBuffer(); + } + + @Test + public void testInvokeRpcWorkerPoolNull() throws Exception { + ProviderConfig config = mock(ProviderConfig.class); + when(config.getWorkerPoolObj()).thenReturn(null); + ProviderInvoker invoker = mock(ProviderInvoker.class); + when(invoker.getConfig()).thenReturn(config); + + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + DefRequest defRequest = mockDefRequest(request, response); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + AtomicBoolean responded = new AtomicBoolean(false); + CompletableFuture completionFuture = new CompletableFuture<>(); + + invokePrivate(executor, "invokeRpcRequest", + new Class[]{ProviderInvoker.class, DefRequest.class, CompletableFuture.class, AtomicBoolean.class}, + invoker, defRequest, completionFuture, responded); + + assertTrue(completionFuture.isCompletedExceptionally()); + } + + @Test + public void testInvokeRpcAlreadyResponded() throws Exception { + ProviderInvoker invoker = mock(ProviderInvoker.class); + DefResponse successResponse = new DefResponse(); + successResponse.setValue("success"); + CompletableFuture successFuture = CompletableFuture.completedFuture(successResponse); + when(invoker.invoke(any())).thenReturn(successFuture); + ProviderConfig config = mockProviderConfig(0); + when(invoker.getConfig()).thenReturn(config); + + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + DefRequest defRequest = mockDefRequest(request, response); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + AtomicBoolean responded = new AtomicBoolean(true); + CompletableFuture completionFuture = new CompletableFuture<>(); + + invokePrivate(executor, "invokeRpcRequest", + new Class[]{ProviderInvoker.class, DefRequest.class, CompletableFuture.class, AtomicBoolean.class}, + invoker, defRequest, completionFuture, responded); + + // when responded=true, completionFuture won't complete, verify it's not done + assertTrue(responded.get()); + assertTrue(!completionFuture.isDone()); + } + + @Test + public void testInvokeRpcBusinessException() throws Exception { + ProviderInvoker invoker = mock(ProviderInvoker.class); + DefResponse responseWithEx = new DefResponse(); + responseWithEx.setException(new RuntimeException("business error")); + CompletableFuture future = CompletableFuture.completedFuture(responseWithEx); + when(invoker.invoke(any())).thenReturn(future); + ProviderConfig config = mockProviderConfig(0); + when(invoker.getConfig()).thenReturn(config); + + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getRemoteAddr()).thenReturn(TEST_IP); + when(request.getMethod()).thenReturn("POST"); + when(request.getRequestURI()).thenReturn("/api/test"); + HttpServletResponse response = mock(HttpServletResponse.class); + DefRequest defRequest = mockDefRequest(request, response); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + AtomicBoolean responded = new AtomicBoolean(false); + CompletableFuture completionFuture = new CompletableFuture<>(); + + invokePrivate(executor, "invokeRpcRequest", + new Class[]{ProviderInvoker.class, DefRequest.class, CompletableFuture.class, AtomicBoolean.class}, + invoker, defRequest, completionFuture, responded); + + assertTrue(responded.get()); + assertTrue(completionFuture.isCompletedExceptionally()); + verify(response).setStatus(HttpStatus.SC_SERVICE_UNAVAILABLE); + } + + @Test + public void testInvokeRpcWithException() throws Exception { + ProviderConfig config = mockProviderConfig(0); + ProviderInvoker invoker = mock(ProviderInvoker.class); + when(invoker.getConfig()).thenReturn(config); + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new RuntimeException("boom")); + when(invoker.invoke(any())).thenReturn(failedFuture); + + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getRemoteAddr()).thenReturn(TEST_IP); + when(request.getMethod()).thenReturn("POST"); + when(request.getRequestURI()).thenReturn("/api/test"); + HttpServletResponse response = mock(HttpServletResponse.class); + DefRequest defRequest = mockDefRequest(request, response); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + AtomicBoolean responded = new AtomicBoolean(false); + CompletableFuture completionFuture = new CompletableFuture<>(); + + invokePrivate(executor, "invokeRpcRequest", + new Class[]{ProviderInvoker.class, DefRequest.class, CompletableFuture.class, AtomicBoolean.class}, + invoker, defRequest, completionFuture, responded); + + assertTrue(responded.get()); + assertTrue(completionFuture.isCompletedExceptionally()); + verify(response).setStatus(HttpStatus.SC_SERVICE_UNAVAILABLE); + } + + @Test + public void testInvokeRpcThrowsDirectly() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getRemoteAddr()).thenReturn(TEST_IP); + when(request.getMethod()).thenReturn("POST"); + when(request.getRequestURI()).thenReturn("/api/test"); + HttpServletResponse response = mock(HttpServletResponse.class); + DefRequest defRequest = mockDefRequest(request, response); + + ProviderConfig config = mockProviderConfig(0); + ProviderInvoker invoker = mock(ProviderInvoker.class); + when(invoker.getConfig()).thenReturn(config); + when(invoker.invoke(any())).thenThrow(new RuntimeException("boom-direct")); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + AtomicBoolean responded = new AtomicBoolean(false); + CompletableFuture completionFuture = new CompletableFuture<>(); + + invokePrivate(executor, "invokeRpcRequest", + new Class[]{ProviderInvoker.class, DefRequest.class, CompletableFuture.class, AtomicBoolean.class}, + invoker, defRequest, completionFuture, responded); + + assertTrue(responded.get()); + assertTrue(completionFuture.isCompletedExceptionally()); + verify(response).setStatus(HttpStatus.SC_SERVICE_UNAVAILABLE); + } + + // ==================== handleError ==================== + + @Test + public void testHandleError() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getRemoteAddr()).thenReturn(TEST_IP); + when(request.getMethod()).thenReturn("POST"); + when(request.getRequestURI()).thenReturn("/api/test"); + when(request.getQueryString()).thenReturn("param=value"); + HttpServletResponse response = mock(HttpServletResponse.class); + DefRequest defRequest = mockDefRequest(request, response); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + AtomicBoolean responded = new AtomicBoolean(false); + CompletableFuture completionFuture = new CompletableFuture<>(); + Throwable testException = new RuntimeException("Test error"); + + invokePrivate(executor, "handleError", + new Class[]{Throwable.class, DefRequest.class, HttpServletResponse.class, + AtomicBoolean.class, CompletableFuture.class}, + testException, defRequest, response, responded, completionFuture); + + assertTrue(responded.get()); + assertTrue(completionFuture.isCompletedExceptionally()); + verify(response).setStatus(HttpStatus.SC_SERVICE_UNAVAILABLE); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + verify(httpCodec).writeHttpResponse(any(HttpServletResponse.class), any()); + } + + // ==================== doErrorReply ==================== + + @Test + public void testDoErrorReplyTimeout() throws Exception { + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + TRpcException ex = TRpcException.newFrameException(ErrorCode.TRPC_SERVER_TIMEOUT_ERR, "timeout"); + + executor.doErrorReply(request, response, ex); + + verify(response).setStatus(HttpStatus.SC_REQUEST_TIMEOUT); + } + + @Test + public void testDoErrorReplyNotFound() throws Exception { + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + TRpcException ex = TRpcException.newFrameException(ErrorCode.TRPC_SERVER_NOSERVICE_ERR, "not found"); + + executor.doErrorReply(request, response, ex); + + verify(response).setStatus(HttpStatus.SC_NOT_FOUND); + } + + @Test + public void testDoErrorReplyNoFunc() throws Exception { + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + TRpcException ex = TRpcException.newFrameException(ErrorCode.TRPC_SERVER_NOFUNC_ERR, "no func"); + + executor.doErrorReply(request, response, ex); + + verify(response).setStatus(HttpStatus.SC_NOT_FOUND); + } + + @Test + public void testDoErrorReplyValidate() throws Exception { + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + TRpcException ex = TRpcException.newFrameException(ErrorCode.TRPC_SERVER_VALIDATE_ERR, "validate error"); + + executor.doErrorReply(request, response, ex); + + verify(response).setStatus(HttpStatus.SC_BAD_REQUEST); + } + + @Test + public void testDoErrorReplyAuth() throws Exception { + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + TRpcException ex = TRpcException.newFrameException(ErrorCode.TRPC_SERVER_AUTH_ERR, "no auth"); + + executor.doErrorReply(request, response, ex); + + verify(response).setStatus(HttpStatus.SC_UNAUTHORIZED); + } + + @Test + public void testDoErrorReplyOverload() throws Exception { + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + TRpcException ex = TRpcException.newFrameException(ErrorCode.TRPC_SERVER_OVERLOAD_ERR, "overload"); + + executor.doErrorReply(request, response, ex); + + verify(response).setStatus(HttpStatus.SC_INTERNAL_SERVER_ERROR); + } + + @Test + public void testDoErrorReplyEncode() throws Exception { + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + TRpcException ex = TRpcException.newFrameException(ErrorCode.TRPC_SERVER_ENCODE_ERR, "encode error"); + + executor.doErrorReply(request, response, ex); + + verify(response).setStatus(HttpStatus.SC_INTERNAL_SERVER_ERROR); + } + + @Test + public void testDoErrorReplySystem() throws Exception { + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + TRpcException ex = TRpcException.newFrameException(ErrorCode.TRPC_SERVER_SYSTEM_ERR, "system error"); + + executor.doErrorReply(request, response, ex); + + verify(response).setStatus(HttpStatus.SC_INTERNAL_SERVER_ERROR); + } + + @Test + public void testDoErrorReplyDefault() throws Exception { + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + + executor.doErrorReply(request, response, new RuntimeException("unknown")); + + verify(response).setStatus(HttpStatus.SC_SERVICE_UNAVAILABLE); + } + + @Test + public void testHttpErrorReplyFlushException() throws Exception { + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + doThrow(new IOException("flush error")).when(response).flushBuffer(); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + + executor.doErrorReply(request, response, new RuntimeException("err")); + + verify(response).setStatus(HttpStatus.SC_SERVICE_UNAVAILABLE); + } + @Test - public void testExecutor() { - DefaultHttpExecutor executor = new DefaultHttpExecutor(null); - assertNotNull(executor); + public void testExecuteSuccess() throws Exception { + ProviderInvoker invoker = mock(ProviderInvoker.class); + ProviderConfig config = mockProviderConfig(0); + when(invoker.getConfig()).thenReturn(config); + + DefResponse successResponse = new DefResponse(); + successResponse.setValue("ok"); + when(invoker.invoke(any())).thenReturn(CompletableFuture.completedFuture(successResponse)); + + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + RpcMethodInfoAndInvoker methodInfoAndInvoker = buildMethodInfoAndInvoker(invoker); + AbstractHttpExecutor executor = createExecutorWithInvoker(methodInfoAndInvoker); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + when(httpCodec.convertToJavaBean(any(), any())).thenReturn("param"); + + executor.execute(request, response, methodInfoAndInvoker); + + verify(response).setStatus(HttpStatus.SC_OK); + } + + @Test + public void testExecuteWithTimeout() throws Exception { + ProviderInvoker invoker = mock(ProviderInvoker.class); + ProviderConfig config = mock(ProviderConfig.class); + when(config.getRequestTimeout()).thenReturn(50); + WorkerPool workerPool = mock(WorkerPool.class); + when(config.getWorkerPoolObj()).thenReturn(workerPool); + when(invoker.getConfig()).thenReturn(config); + + CompletableFuture neverFuture = new CompletableFuture<>(); + when(invoker.invoke(any())).thenReturn(neverFuture); + + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + RpcMethodInfoAndInvoker methodInfoAndInvoker = buildMethodInfoAndInvoker(invoker); + AbstractHttpExecutor executor = createExecutorWithInvoker(methodInfoAndInvoker); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + when(httpCodec.convertToJavaBean(any(), any())).thenReturn("param"); + + executor.execute(request, response, methodInfoAndInvoker); + + verify(response).setStatus(HttpStatus.SC_REQUEST_TIMEOUT); + } + + @Test + public void testExecuteWithCallerCallee() throws Exception { + ProviderInvoker invoker = mock(ProviderInvoker.class); + ProviderConfig config = mockProviderConfig(0); + when(invoker.getConfig()).thenReturn(config); + + DefResponse successResponse = new DefResponse(); + successResponse.setValue("ok"); + when(invoker.invoke(any())).thenReturn(CompletableFuture.completedFuture(successResponse)); + + HttpServletRequest request = mockRequest(); + when(request.getHeader(HttpConstants.HTTP_HEADER_TRPC_CALLER)).thenReturn("trpc.app.server.service"); + when(request.getHeader(HttpConstants.HTTP_HEADER_TRPC_CALLEE)).thenReturn("trpc.app.server.service.method"); + when(request.getHeader(HttpConstants.HTTP_HEADER_TRPC_MESSAGE_TYPE)).thenReturn("1"); + when(request.getHeader(HttpConstants.HTTP_HEADER_TRPC_REQUEST_ID)).thenReturn("12345"); + when(request.getHeader(HttpConstants.HTTP_HEADER_TRPC_TIMEOUT)).thenReturn("3000"); + HttpServletResponse response = mock(HttpServletResponse.class); + + RpcMethodInfoAndInvoker methodInfoAndInvoker = buildMethodInfoAndInvoker(invoker); + AbstractHttpExecutor executor = createExecutorWithInvoker(methodInfoAndInvoker); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + when(httpCodec.convertToJavaBean(any(), any())).thenReturn("param"); + + executor.execute(request, response, methodInfoAndInvoker); + + verify(response).setStatus(HttpStatus.SC_OK); + } + + @Test + public void testExecuteWithTransInfo() throws Exception { + ProviderInvoker invoker = mock(ProviderInvoker.class); + ProviderConfig config = mockProviderConfig(0); + when(invoker.getConfig()).thenReturn(config); + + DefResponse successResponse = new DefResponse(); + successResponse.setValue("ok"); + when(invoker.invoke(any())).thenReturn(CompletableFuture.completedFuture(successResponse)); + + HttpServletRequest request = mockRequest(); + String transInfo = "{\"key\":\"dmFsdWU=\"}"; + when(request.getHeader(HttpConstants.HTTP_HEADER_TRPC_TRANS_INFO)).thenReturn(transInfo); + HttpServletResponse response = mock(HttpServletResponse.class); + + RpcMethodInfoAndInvoker methodInfoAndInvoker = buildMethodInfoAndInvoker(invoker); + AbstractHttpExecutor executor = createExecutorWithInvoker(methodInfoAndInvoker); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + when(httpCodec.convertToJavaBean(any(), any())).thenReturn("param"); + + executor.execute(request, response, methodInfoAndInvoker); + + verify(response).setStatus(HttpStatus.SC_OK); + } + + // ==================== execute catch branch ==================== + + @Test + public void testExecuteCatchBranch() throws Exception { + ProviderInvoker invoker = mock(ProviderInvoker.class); + ProviderConfig config = mockProviderConfig(0); + when(invoker.getConfig()).thenReturn(config); + + HttpServletRequest request = mockRequest(); + HttpServletResponse response = mock(HttpServletResponse.class); + + Method method = BadService.class.getMethod("hello", String.class); + RpcMethodInfo methodInfo = new RpcMethodInfo(BadService.class, method); + RpcMethodInfoAndInvoker methodInfoAndInvoker = new RpcMethodInfoAndInvoker(); + methodInfoAndInvoker.setMethodInfo(methodInfo); + methodInfoAndInvoker.setInvoker(invoker); + + AbstractHttpExecutor executor = createExecutorWithInvoker(methodInfoAndInvoker); + + executor.execute(request, response, methodInfoAndInvoker); + + verify(response).setStatus(HttpStatus.SC_SERVICE_UNAVAILABLE); + } + + // ==================== parseRpcParams Protobuf ==================== + + @Test + public void testParseRpcParamsProtobuf() throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + RpcMethodInfo methodInfo = mock(RpcMethodInfo.class); + when(methodInfo.getParamsTypes()).thenReturn( + new Type[]{RpcContext.class, tests.service.HelloRequestProtocol.HelloRequest.class}); + + AbstractHttpExecutor executor = createExecutorWithCodec(); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + tests.service.HelloRequestProtocol.HelloRequest mockMsg = + tests.service.HelloRequestProtocol.HelloRequest.getDefaultInstance(); + when(httpCodec.convertToPBParam(any(), any())).thenReturn(mockMsg); + + Object[] result = (Object[]) invokePrivate(executor, "parseRpcParams", + new Class[]{HttpServletRequest.class, RpcMethodInfo.class}, request, methodInfo); + + assertNotNull(result); + assertEquals(mockMsg, result[0]); + } + + // ==================== setRpcServerContext header loop ==================== + + @Test + public void testExecuteWithHeaders() throws Exception { + ProviderInvoker invoker = mock(ProviderInvoker.class); + ProviderConfig config = mockProviderConfig(0); + when(invoker.getConfig()).thenReturn(config); + + DefResponse successResponse = new DefResponse(); + successResponse.setValue("ok"); + when(invoker.invoke(any())).thenReturn(CompletableFuture.completedFuture(successResponse)); + + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getAttribute(HttpConstants.REQUEST_ATTRIBUTE_TRPC_SERVICE)).thenReturn(TEST_SERVICE); + when(request.getAttribute(HttpConstants.REQUEST_ATTRIBUTE_TRPC_METHOD)).thenReturn(TEST_METHOD); + when(request.getRemoteAddr()).thenReturn(TEST_IP); + when(request.getRemotePort()).thenReturn(TEST_PORT); + when(request.getMethod()).thenReturn("POST"); + when(request.getRequestURI()).thenReturn("/api/test"); + Enumeration headerNames = Collections.enumeration(Collections.singletonList("X-Custom")); + when(request.getHeaderNames()).thenReturn(headerNames); + when(request.getHeader("X-Custom")).thenReturn("custom-value"); + HttpServletResponse response = mock(HttpServletResponse.class); + + RpcMethodInfoAndInvoker methodInfoAndInvoker = buildMethodInfoAndInvoker(invoker); + AbstractHttpExecutor executor = createExecutorWithInvoker(methodInfoAndInvoker); + HttpCodec httpCodec = (HttpCodec) getField(executor, "httpCodec"); + when(httpCodec.convertToJavaBean(any(), any())).thenReturn("param"); + + executor.execute(request, response, methodInfoAndInvoker); + + verify(response).setStatus(HttpStatus.SC_OK); + } + + // ==================== TestService ==================== + + private interface TestService { + + String hello(RpcContext ctx, String req); + } + + private interface BadService { + + String hello(String req); } -} +} \ No newline at end of file diff --git a/trpc-proto/trpc-proto-http/src/test/resources/log4j2-test.xml b/trpc-proto/trpc-proto-http/src/test/resources/log4j2-test.xml new file mode 100644 index 0000000000..9b9061b327 --- /dev/null +++ b/trpc-proto/trpc-proto-http/src/test/resources/log4j2-test.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + +