From c1f783b7588843a26d0e6631166d66c97f84f979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E4=B8=9A=E6=98=8C?= Date: Fri, 29 May 2026 16:35:05 +0800 Subject: [PATCH] Fix UB shm allocation cleanup crashes --- src/brpc/ubshm/shm/shm_ipc.cpp | 58 +++++++++++++ src/brpc/ubshm/shm/shm_mgr.cpp | 7 +- src/brpc/ubshm/ub_endpoint.cpp | 15 +++- src/brpc/ubshm/ub_ring.cpp | 30 ++++--- test/brpc_ubshm_unittest.cpp | 153 +++++++++++++++++++++++++++++++++ 5 files changed, 247 insertions(+), 16 deletions(-) create mode 100644 test/brpc_ubshm_unittest.cpp diff --git a/src/brpc/ubshm/shm/shm_ipc.cpp b/src/brpc/ubshm/shm/shm_ipc.cpp index 7e934c7568..a63e9cdd7c 100644 --- a/src/brpc/ubshm/shm/shm_ipc.cpp +++ b/src/brpc/ubshm/shm/shm_ipc.cpp @@ -29,6 +29,41 @@ namespace brpc { namespace ubring { +namespace { + +RETURN_CODE ReserveIpcShm(int fd, const SHM *shm) +{ +#if defined(__linux__) + const int rc = posix_fallocate(fd, 0, (off_t)shm->len); + if (rc != 0) { + LOG(ERROR) << "IPC reserve shm=" << shm->name << " length=" << shm->len + << " failed, ret(" << rc << ")."; + return SHM_ERR; + } +#else + UNREFERENCE_PARAM(fd); + UNREFERENCE_PARAM(shm); +#endif + return UBRING_OK; +} + +RETURN_CODE CheckIpcShmSize(int fd, const SHM *shm) +{ + struct stat st; + if (fstat(fd, &st) != 0) { + LOG(ERROR) << "IPC stat shm=" << shm->name << " failed, ret(" << errno << ")."; + return SHM_ERR; + } + if ((uint64_t)st.st_size < (uint64_t)shm->len) { + LOG(ERROR) << "IPC shm=" << shm->name << " actual length=" << st.st_size + << " is shorter than requested length=" << shm->len << "."; + return SHM_ERR; + } + return UBRING_OK; +} + +} // namespace + RETURN_CODE IpcShmLocalMalloc(SHM *shm) { int fd = shm_open(shm->name, O_CREAT | O_EXCL | O_RDWR, SHM_IPC_MODE); @@ -50,9 +85,16 @@ RETURN_CODE IpcShmLocalMalloc(SHM *shm) return SHM_ERR; } + if (ReserveIpcShm(fd, shm) != UBRING_OK) { + close(fd); + shm_unlink(shm->name); + return SHM_ERR; + } + shm->addr = (uint8_t*)mmap(NULL, shm->len, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); if (shm->addr == (uint8_t*)MAP_FAILED) { LOG(ERROR) << "IPC map shm=" << shm->name << " length=" << shm->len << " failed, ret(" << errno << ")."; + shm->addr = NULL; close(fd); shm_unlink(shm->name); return SHM_ERR; @@ -75,6 +117,7 @@ RETURN_CODE IpcShmMunmap(SHM *shm) return SHM_ERR; } + shm->addr = NULL; LOG(INFO) << "IPC unmap shm=" << shm->name << " length=" << shm->len << " success."; return UBRING_OK; } @@ -109,6 +152,8 @@ RETURN_CODE IpcShmLocalFree(SHM *shm) int ret = munmap(shm->addr, shm->len); if (ret != UBRING_OK) { LOG(WARNING) << "IPC unmap shm=" << shm->name << " failed, ret=" << ret; + } else { + shm->addr = NULL; } ret = shm_unlink(shm->name); @@ -138,9 +183,15 @@ RETURN_CODE IpcShmRemoteMalloc(SHM *shm) return SHM_ERR; } + if (CheckIpcShmSize(fd, shm) != UBRING_OK) { + close(fd); + return SHM_ERR; + } + shm->addr = (uint8_t*)mmap(NULL, shm->len, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); if (shm->addr == (uint8_t*)MAP_FAILED) { LOG(ERROR) << "IPC map shm=" << shm->name << " failed, ret=" << errno; + shm->addr = NULL; close(fd); return SHM_ERR; } @@ -157,9 +208,15 @@ RETURN_CODE IpcShmLocalMmap(SHM *shm, int prot) return SHM_ERR; } + if (CheckIpcShmSize(fd, shm) != UBRING_OK) { + close(fd); + return SHM_ERR; + } + shm->addr = (uint8_t*)mmap(NULL, shm->len, prot, MAP_SHARED, fd, 0); if (shm->addr == (uint8_t*)MAP_FAILED) { LOG(ERROR) << "IPC map shm=" << shm->name << " failed, ret=" << errno; + shm->addr = NULL; close(fd); return SHM_ERR; } @@ -182,6 +239,7 @@ RETURN_CODE IpcShmRemoteFree(SHM *shm) return SHM_ERR; } + shm->addr = NULL; LOG(INFO) << "IPC free remote shm=" << shm->name << " success."; return UBRING_OK; } diff --git a/src/brpc/ubshm/shm/shm_mgr.cpp b/src/brpc/ubshm/shm/shm_mgr.cpp index 3f819857b2..39a54dd1df 100644 --- a/src/brpc/ubshm/shm/shm_mgr.cpp +++ b/src/brpc/ubshm/shm/shm_mgr.cpp @@ -114,6 +114,11 @@ RETURN_CODE ShmLocalCalloc(SHM *shm) { LOG(ERROR) << "Failed to alloc local shm."; return rc; } + if (UNLIKELY(shm->addr == NULL)) { + LOG(ERROR) << "Local shm=" << shm->name << " allocated with NULL address."; + ShmFree(shm); + return SHM_ERR; + } memset(shm->addr, 0, shm->len); return UBRING_OK; } @@ -244,4 +249,4 @@ RETURN_CODE ShmFree(SHM *shm) { return rc; } } -} \ No newline at end of file +} diff --git a/src/brpc/ubshm/ub_endpoint.cpp b/src/brpc/ubshm/ub_endpoint.cpp index b4c728c057..7b4868209a 100644 --- a/src/brpc/ubshm/ub_endpoint.cpp +++ b/src/brpc/ubshm/ub_endpoint.cpp @@ -17,6 +17,8 @@ #if BRPC_WITH_UBRING +#include + #include #include #include "butil/fd_utility.h" @@ -526,12 +528,12 @@ void* UBShmEndpoint::ProcessHandshakeAtServer(void* arg) { ub_transport->_ub_state = UBShmTransport::UB_OFF; } else { ep->_state = S_ALLOC_SHM; - ubring::SHM remote_trx_shm = {NULL, remote_msg.len, 0, {0}, (uint8_t)ep->_socket->fd()}; + ubring::SHM remote_trx_shm = {NULL, remote_msg.len, 0, {0}, (uint32_t)ep->_socket->fd()}; strncpy(remote_trx_shm.name, remote_msg.shm_name, SHM_MAX_NAME_BUFF_LEN); size_t local_shm_len = (size_t)(FLAGS_data_queue_size) * MB_TO_BYTE; // server端共享内存名称 - ubring::SHM local_trx_shm = {NULL, local_shm_len, 0, {0}, (uint8_t)ep->_socket->fd()}; + ubring::SHM local_trx_shm = {NULL, local_shm_len, 0, {0}, (uint32_t)ep->_socket->fd()}; char clientName[SHM_MAX_NAME_BUFF_LEN]; strncpy(clientName, remote_msg.shm_name, SHM_MAX_NAME_BUFF_LEN); @@ -646,10 +648,15 @@ ssize_t UBShmEndpoint::CutFromIOBufList(butil::IOBuf** from, size_t ndata) { } ssize_t nw = 0; + errno = 0; nw = _ub_ring->UbrTrxWritev(vec, nvec); if (UNLIKELY(nw == -1)) { - LOG(ERROR) << "Non-blocking send msg in failed, connection has been closed."; - errno = EPIPE; + if (errno == EMSGSIZE) { + LOG(ERROR) << "Non-blocking send msg failed, message is larger than ubring capacity."; + } else { + LOG(ERROR) << "Non-blocking send msg in failed, connection has been closed."; + errno = EPIPE; + } } else if (UNLIKELY(nw == UBRING_RETRY)) { errno = EAGAIN; nw = -1; diff --git a/src/brpc/ubshm/ub_ring.cpp b/src/brpc/ubshm/ub_ring.cpp index 0ea64f07c1..11a5d9b311 100644 --- a/src/brpc/ubshm/ub_ring.cpp +++ b/src/brpc/ubshm/ub_ring.cpp @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include #include @@ -91,9 +92,6 @@ RETURN_CODE UBRing::UbrTrxClose() { if (_trx->ubrTx.remoteRxEventQ.addr != nullptr) { ((UbrEventQMsg *)_trx->ubrTx.remoteRxEventQ.addr)->flag = UBR_STATE_CLOSED; } - if (UNLIKELY(ShmRemoteFree(&_trx->remoteShm) != UBRING_OK)) { - LOG(WARNING) << "Force close, remote shm " << _trx->remoteShm.name << " free failed."; - } if (UNLIKELY(UbrTrxFreeShm(_trx) != UBRING_OK)) { LOG(WARNING) << "Force close, local shm " << _trx->localShm.name << " free failed."; } @@ -321,10 +319,6 @@ void *UBRing::UbrAsynClearCallback(void *args) return NULL; } - if (UNLIKELY(ShmRemoteFree(&trx->remoteShm) != UBRING_OK)) { - LOG(ERROR) << "Trx close, remote shm " << trx->remoteShm.name << " free failed."; - } - if (UNLIKELY(UbrTrxFreeShm(trx) != UBRING_OK)) { LOG(ERROR) << "Trx close, wait for local shm " << trx->localShm.name << " free fail."; } @@ -348,6 +342,12 @@ int UBRing::UbrTrxSend(const void *buf, uint32_t bufLen) uint32_t remainChunkNum = (_trx->ubrTx.writePos > tail) ? (tail + cap - _trx->ubrTx.writePos) : (tail - _trx->ubrTx.writePos); uint32_t needMsgChunkNum = CalcUbrMsgChunkCnt(bufLen); + if (needMsgChunkNum >= cap) { + LOG(ERROR) << "Ubr send failed, payload length=" << bufLen + << " needs " << needMsgChunkNum << " chunks, capacity=" << cap << "."; + errno = EMSGSIZE; + return UBRING_ERR; + } if (remainChunkNum < needMsgChunkNum) { return UBRING_RETRY; } @@ -653,7 +653,7 @@ RETURN_CODE UBRing::UbrTrxFreeShm(UbrTrx *trx) RETURN_CODE remoteRc = UBRING_OK; if (trx->remoteShm.addr != NULL) { - remoteRc = IpcShmRemoteFree(&trx->remoteShm); + remoteRc = ShmRemoteFree(&trx->remoteShm); } if (remoteRc != UBRING_OK) { LOG(WARNING) << "Free remote shm " << trx->remoteShm.name << " failed, rc=" << remoteRc; @@ -795,6 +795,7 @@ int UBRing::UbrAllocateServerShm(SHM* remote_trx_shm, SHM* local_trx_shm) { if (UNLIKELY((ShmLocalCalloc(local_trx_shm)) != UBRING_OK)) { LOG(ERROR) << "Trx apply local shared memory failed."; + ShmRemoteFree(remote_trx_shm); return -1; } @@ -808,9 +809,9 @@ int UBRing::UbrAllocateServerShm(SHM* remote_trx_shm, SHM* local_trx_shm) { _trx->type = TCP_TRX; if (UNLIKELY((UbrServerTrxInit(local_trx_shm, remote_trx_shm)) != UBRING_OK)) { LOG(ERROR) << "Server trx init failed."; - ShmRemoteFree(remote_trx_shm); UbrTrxFreeShm(_trx); UBRingManager::ReleaseUbrTrxFromMgr(_trx); + _trx = nullptr; return -1; } return 0; @@ -826,6 +827,7 @@ int UBRing::UbrAllocateLocalShm(SHM *local_trx_shm, const char *shm_name) _trx->type = TCP_TRX; if (UNLIKELY((ApplyAndMapLocalShm(local_trx_shm, shm_name)) != UBRING_OK)) { LOG(ERROR) << "Trx apply or map local shared memory failed, localName=" << shm_name; + _trx = nullptr; return -1; } return 0; @@ -873,7 +875,7 @@ RETURN_CODE UBRing::UbrMapRemoteShmAddTimer(SHM *localTrxShm, const char *localN if (UNLIKELY(UbrAddTimer() != UBRING_OK)) { LOG(ERROR) << "Ubr add timer failed, localName=" << localName; - ShmRemoteFree(&remoteTrxShm); + ShmRemoteFree(&_trx->remoteShm); return UBRING_ERR; } @@ -884,7 +886,7 @@ RETURN_CODE UBRing::UbrMapRemoteShmAddTimer(SHM *localTrxShm, const char *localN LOG(ERROR) << "Local shm " << localTrxShm->name << " wait for connect remote map timeout."; DeleteTimerSafe((uint32_t)_trx->hbTimerFd); DeleteTimerSafe((uint32_t)_trx->timerFd); - ShmRemoteFree(&remoteTrxShm); + ShmRemoteFree(&_trx->remoteShm); return UBRING_ERR_TIMEOUT; } @@ -961,6 +963,12 @@ RETURN_CODE UBRing::WritevHasEnoughSpace(size_t bufLen) uint32_t remainChunkNum = (_trx->ubrTx.writePos > tail) ? (tail + cap - _trx->ubrTx.writePos) : (tail - _trx->ubrTx.writePos); uint32_t needMsgChunkNum = CalcUbrMsgChunkCnt((uint32_t)bufLen); + if (needMsgChunkNum >= cap) { + LOG(ERROR) << "Ubr write failed, payload length=" << bufLen + << " needs " << needMsgChunkNum << " chunks, capacity=" << cap << "."; + errno = EMSGSIZE; + return UBRING_ERR; + } if (remainChunkNum < needMsgChunkNum) { return UBRING_RETRY; } diff --git a/test/brpc_ubshm_unittest.cpp b/test/brpc_ubshm_unittest.cpp new file mode 100644 index 0000000000..84b1d814ec --- /dev/null +++ b/test/brpc_ubshm_unittest.cpp @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include + +#include + +#include "brpc/ubshm/common/common.h" +#include "brpc/ubshm/shm/shm_ipc.h" +#include "brpc/ubshm/shm/shm_mgr.h" +#include "brpc/ubshm/ub_ring.h" + +namespace { + +constexpr size_t kShmLen = SHM_ALLOC_UNIT_SIZE; + +brpc::ubring::SHM MakeShm(const char* suffix, size_t len = kShmLen, bool unlink_existing = true) { + brpc::ubring::SHM shm = {NULL, len, 0, {0}, 0}; + snprintf(shm.name, sizeof(shm.name), "/brpc_ut_%d_%s", (int)getpid(), suffix); + if (unlink_existing) { + brpc::ubring::IpcShmFree(&shm); + } + return shm; +} + +void CleanupShm(brpc::ubring::SHM* shm) { + if (shm == NULL) { + return; + } + if (shm->addr != NULL) { + brpc::ubring::IpcShmLocalFree(shm); + } else { + brpc::ubring::IpcShmFree(shm); + } +} + +} // namespace + +TEST(UBShmTest, IpcRemoteMallocRejectsShortObject) { + brpc::ubring::SHM owner = MakeShm("short_owner", kShmLen); + errno = 0; + RETURN_CODE rc = brpc::ubring::IpcShmLocalMalloc(&owner); + if (rc != UBRING_OK && errno == EPERM) { + GTEST_SKIP() << "POSIX shm is not permitted in this environment."; + } + ASSERT_EQ(UBRING_OK, rc); + + brpc::ubring::SHM remote = MakeShm("short_owner", kShmLen * 2, false); + rc = brpc::ubring::IpcShmRemoteMalloc(&remote); + if (rc == UBRING_OK) { + brpc::ubring::IpcShmRemoteFree(&remote); + } + EXPECT_NE(UBRING_OK, rc); + EXPECT_EQ(NULL, remote.addr); + + CleanupShm(&owner); +} + +TEST(UBShmTest, IpcMunmapClearsAddress) { + brpc::ubring::SHM shm = MakeShm("munmap"); + errno = 0; + RETURN_CODE rc = brpc::ubring::IpcShmLocalMalloc(&shm); + if (rc != UBRING_OK && errno == EPERM) { + GTEST_SKIP() << "POSIX shm is not permitted in this environment."; + } + ASSERT_EQ(UBRING_OK, rc); + + EXPECT_EQ(UBRING_OK, brpc::ubring::IpcShmMunmap(&shm)); + EXPECT_EQ(NULL, shm.addr); + + CleanupShm(&shm); +} + +TEST(UBShmTest, IpcRemoteFreeClearsAddress) { + brpc::ubring::SHM shm = MakeShm("remote_free"); + errno = 0; + RETURN_CODE rc = brpc::ubring::IpcShmLocalMalloc(&shm); + if (rc != UBRING_OK && errno == EPERM) { + GTEST_SKIP() << "POSIX shm is not permitted in this environment."; + } + ASSERT_EQ(UBRING_OK, rc); + + EXPECT_EQ(UBRING_OK, brpc::ubring::IpcShmRemoteFree(&shm)); + EXPECT_EQ(NULL, shm.addr); + + CleanupShm(&shm); +} + +TEST(UBShmTest, ServerAllocFailureReleasesMappedRemoteShm) { + brpc::ubring::SetShmType(brpc::ubring::SHM_TYPE_IPC); + + brpc::ubring::SHM remote_owner = MakeShm("server_remote"); + errno = 0; + RETURN_CODE rc = brpc::ubring::IpcShmLocalMalloc(&remote_owner); + if (rc != UBRING_OK && errno == EPERM) { + GTEST_SKIP() << "POSIX shm is not permitted in this environment."; + } + ASSERT_EQ(UBRING_OK, rc); + brpc::ubring::SHM local_conflict = MakeShm("server_local"); + errno = 0; + rc = brpc::ubring::IpcShmLocalMalloc(&local_conflict); + if (rc != UBRING_OK && errno == EPERM) { + CleanupShm(&remote_owner); + GTEST_SKIP() << "POSIX shm is not permitted in this environment."; + } + ASSERT_EQ(UBRING_OK, rc); + + brpc::ubring::SHM remote_param = MakeShm("server_remote", kShmLen, false); + brpc::ubring::SHM local_param = MakeShm("server_local", kShmLen, false); + + brpc::ubring::UBRing ring; + EXPECT_EQ(-1, ring.UbrAllocateServerShm(&remote_param, &local_param)); + EXPECT_EQ(NULL, remote_param.addr); + + if (remote_param.addr != NULL) { + brpc::ubring::IpcShmRemoteFree(&remote_param); + } + CleanupShm(&remote_owner); + CleanupShm(&local_conflict); +} + +TEST(UBShmTest, WritevRejectsPayloadLargerThanRingCapacity) { + brpc::ubring::UbrDataStatusQMsg data_status = {}; + data_status.tail = 3; + + brpc::ubring::UbrTrx trx = {}; + trx.ubrTx.localDataStatusQ.addr = reinterpret_cast(&data_status); + trx.ubrTx.capacity = 4; + trx.ubrTx.writePos = 0; + + brpc::ubring::UBRing ring; + ring._trx = &trx; + + const size_t too_large_payload = UBR_MSG_PAYLOAD_LEN * trx.ubrTx.capacity; + EXPECT_EQ(UBRING_ERR, ring.WritevHasEnoughSpace(too_large_payload)); +}