diff --git a/CMakeLists.txt b/CMakeLists.txt index db45cd5..0fb8190 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,14 +20,6 @@ message(STATUS "TINYTORCH_BUILD_TEST ${TINYTORCH_BUILD_TEST}") message(STATUS "TINYTORCH_USE_CUDA ${TINYTORCH_USE_CUDA}") message(STATUS "TINYTORCH_USE_NCCL ${TINYTORCH_USE_NCCL}") -if (TINYTORCH_USE_CUDA) - add_definitions(-DUSE_CUDA) -endif () - -if (TINYTORCH_USE_NCCL) - add_definitions(-DUSE_NCCL) -endif () - add_subdirectory(src) if (TINYTORCH_BUILD_EXAMPLES) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index aec31e2..7d26186 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -87,10 +87,12 @@ if (TINYTORCH_USE_CUDA) set_property(TARGET ${PROJECT_NAME} PROPERTY CUDA_ARCHITECTURES native) target_include_directories(${PROJECT_NAME} PUBLIC ${CUDA_INCLUDE_DIRS}) target_link_libraries(${PROJECT_NAME} PUBLIC CUDA::cudart CUDA::curand CUDA::cublas) + target_compile_definitions(${PROJECT_NAME} PUBLIC USE_CUDA) endif () # NCCL if (TINYTORCH_USE_NCCL) target_include_directories(${PROJECT_NAME} PUBLIC ${NCCL_INCLUDE_DIRS}) target_link_libraries(${PROJECT_NAME} PUBLIC ${NCCL_LIBRARY}) + target_compile_definitions(${PROJECT_NAME} PUBLIC USE_NCCL) endif () diff --git a/src/Distributed/BackendNCCL.cpp b/src/Distributed/BackendNCCL.cpp index 4d578ee..67d3161 100644 --- a/src/Distributed/BackendNCCL.cpp +++ b/src/Distributed/BackendNCCL.cpp @@ -51,14 +51,29 @@ std::shared_ptr BackendNCCL::allGather(std::vector>& o ASSERT(outputTensors.size() == 1); ASSERT(inputTensors.size() == 1); auto& input = inputTensors.back(); - auto output = op::stack(ArrayView(outputTensors.back()), 0); - return collective( - input, output, + auto& outputs = outputTensors.back(); + + int64_t numel = input.numel(); + auto flatOutput = + Tensor::zeros({numel * static_cast(outputs.size())}, Options(input.device(), input.dtype())); + + auto work = collective( + input, flatOutput, [&](const std::shared_ptr& comm, cudaStream_t stream) { - return ncclAllGather(input.dataPtr<>(), output.dataPtr<>(), input.numel(), getNcclDataType(input.dtype()), + return ncclAllGather(input.dataPtr<>(), flatOutput.dataPtr<>(), numel, getNcclDataType(input.dtype()), comm->getNcclComm(), stream); }, OpType::ALL_GATHER); + + if (work) { + auto workNCCL = std::static_pointer_cast(work); + workNCCL->setPostCompletionFn([flatOutput, outputs, numel]() mutable { + for (size_t i = 0; i < outputs.size(); i++) { + outputs[i].copy_(flatOutput.narrow(0, static_cast(i) * numel, numel).view(outputs[i].sizes())); + } + }); + } + return work; } std::shared_ptr BackendNCCL::reduceScatter(std::vector& outputTensors, @@ -128,12 +143,17 @@ std::shared_ptr BackendNCCL::getNCCLComm(const std::string& deviceKey, ncclUniqueId commId; std::string key = NCCL_COMM_ID_PREFIX + std::to_string(ncclCommCounter_++); if (getRank() == 0) { - NCCL_CALL(ncclGetUniqueId(&commId)); + ncclResult_t ret = ncclGetUniqueId(&commId); + if (ret != ncclSuccess && ret != ncclInProgress) { + NCCL_ERROR(ret); + setErrorLocked(BackendErrorType::COMM_ERROR); + return nullptr; + } store_->set(key, std::string(reinterpret_cast(&commId), sizeof(commId))); } else { auto idData = store_->get(key); if (idData.size() != sizeof(commId)) { - setError(BackendErrorType::COMM_ERROR); + setErrorLocked(BackendErrorType::COMM_ERROR); return nullptr; } std::memcpy(&commId, idData.data(), sizeof(commId)); @@ -141,7 +161,7 @@ std::shared_ptr BackendNCCL::getNCCLComm(const std::string& deviceKey, auto comm = NCCLComm::create(getSize(), getRank(), commId, device.index); if (!comm) { - setError(BackendErrorType::COMM_ERROR); + setErrorLocked(BackendErrorType::COMM_ERROR); return nullptr; } diff --git a/src/Distributed/BackendNCCL.h b/src/Distributed/BackendNCCL.h index b131ab7..70d59c9 100644 --- a/src/Distributed/BackendNCCL.h +++ b/src/Distributed/BackendNCCL.h @@ -38,6 +38,9 @@ class BackendNCCL : public Backend { void setTimeout(std::chrono::milliseconds timeout) override { options_->timeout = timeout; } + void setUseComputeStream(bool flag) { useComputeStream_ = flag; } + bool useComputeStream() const { return useComputeStream_; } + std::string getBackendName() const override { return NCCL_BACKEND_NAME; } std::shared_ptr getBackendOptions() override { return options_; } @@ -90,6 +93,10 @@ class BackendNCCL : public Backend { protected: void setError(BackendErrorType error) { std::lock_guard lock(mutex_); + setErrorLocked(error); + } + + void setErrorLocked(BackendErrorType error) { if (error_ == BackendErrorType::SUCCESS) { error_ = error; } @@ -144,6 +151,8 @@ class BackendNCCL : public Backend { std::list> activeWorks_; + bool useComputeStream_ = false; + mutable std::mutex mutex_; BackendErrorType error_ = BackendErrorType::SUCCESS; }; @@ -164,28 +173,35 @@ std::shared_ptr BackendNCCL::collective(Tensor& input, Tensor& output, Fun return nullptr; } - auto& ncclStream = getNCCLStream(deviceKey, device.index); auto& computeSteam = cuda::getCurrentCUDAStream(device.index); - ncclStream.waitStream(computeSteam); + cuda::CUDAStream& opStream = useComputeStream_ ? computeSteam : getNCCLStream(deviceKey, device.index); + if (!useComputeStream_) { + opStream.waitStream(computeSteam); + } auto work = std::make_shared(std::to_string(getID()), getGroupDesc(), device, getRank(), opType); work->setOpTimeout(options_->timeout); work->setNCCLComm(comm); work->setStore(store_); work->setIsBarrierOp(opType == OpType::BARRIER); - work->setCudaEvent(cuda::createCUDAEvent(device.index)); + work->setUseComputeStream(useComputeStream_); + if (!useComputeStream_) { + work->setCudaEvent(cuda::createCUDAEvent(device.index)); + } work->setOutputs({output}); cuda::CudaDeviceGuard guard(device.index); NCCL_CALL(ncclGroupStart()); - NCCL_CALL(ncclFunc(comm, ncclStream.stream())); + NCCL_CALL(ncclFunc(comm, opStream.stream())); ncclResult_t state = ncclGroupEnd(); if (!waitNcclCommandResult(comm, state)) { return nullptr; } - work->getCudaEvent().record(ncclStream); + if (!useComputeStream_) { + work->getCudaEvent().record(opStream); + } workEnqueue(work); return work; } diff --git a/src/Distributed/DistributedProcessGroup.cpp b/src/Distributed/DistributedProcessGroup.cpp index 9c90ca1..8af2076 100644 --- a/src/Distributed/DistributedProcessGroup.cpp +++ b/src/Distributed/DistributedProcessGroup.cpp @@ -138,13 +138,12 @@ std::shared_ptr DistributedProcessGroup::createStore(const InitConfig& co if (config.method == InitMethod::ENV) { bool isServer = (config.rank == 0); int numWorkers = config.worldSize; - bool waitWorkers = true; - return std::make_shared(config.masterAddr, // host - config.masterPort, // port - isServer, // isServer - waitWorkers, // waitWorkers - numWorkers, // numWorkers - timeout // timeout + return std::make_shared(config.masterAddr, // host + config.masterPort, // port + isServer, // isServer + config.waitWorkers, // waitWorkers + numWorkers, // numWorkers + timeout // timeout ); } @@ -152,11 +151,10 @@ std::shared_ptr DistributedProcessGroup::createStore(const InitConfig& co if (config.method == InitMethod::TCP) { bool isServer = (config.rank == 0); int numWorkers = config.worldSize; - bool waitWorkers = true; return std::make_shared(config.masterAddr, // host static_cast(config.masterPort), // port isServer, // isServer - waitWorkers, // waitWorkers + config.waitWorkers, // waitWorkers numWorkers, // numWorkers timeout // timeout ); diff --git a/src/Distributed/DistributedSampler.cpp b/src/Distributed/DistributedSampler.cpp index 8963521..655c28e 100644 --- a/src/Distributed/DistributedSampler.cpp +++ b/src/Distributed/DistributedSampler.cpp @@ -30,7 +30,7 @@ DistributedSampler::DistributedSampler(size_t datasetSize, std::optional if (dropLast_) { numSamples_ = datasetSize_ / numReplicas_; } else { - numSamples_ = std::ceil(static_cast(datasetSize_) * 1.0f / static_cast(numReplicas_)); + numSamples_ = (datasetSize_ + numReplicas_ - 1) / numReplicas_; } totalSize_ = numSamples_ * numReplicas_; generateIndices(); diff --git a/src/Distributed/FileStore.cpp b/src/Distributed/FileStore.cpp index 1e2388b..82326d2 100644 --- a/src/Distributed/FileStore.cpp +++ b/src/Distributed/FileStore.cpp @@ -30,7 +30,7 @@ class FileLock { return; } - struct flock fl{}; + struct flock fl {}; fl.l_type = F_WRLCK; fl.l_whence = SEEK_SET; fl.l_start = 0; @@ -44,7 +44,7 @@ class FileLock { ~FileLock() { if (fd_ != -1) { - struct flock fl{}; + struct flock fl {}; fl.l_type = F_UNLCK; fl.l_whence = SEEK_SET; fl.l_start = 0; diff --git a/src/Distributed/Reducer.cpp b/src/Distributed/Reducer.cpp index 367a69f..9202f0e 100644 --- a/src/Distributed/Reducer.cpp +++ b/src/Distributed/Reducer.cpp @@ -124,7 +124,7 @@ void Reducer::reduceBucket(int64_t bucketIdx) { ASSERT(!bucket.params.empty()); std::vector tensors = {bucket.flatBuffer}; - bucket.work = processGroup_->allReduce(tensors); + bucket.work = processGroup_->allReduce(tensors, AVG); ASSERT(bucket.work); bucketReadyCnt_++; @@ -169,8 +169,6 @@ void Reducer::checkAllBucketsReady() { ASSERT(success); cuda::CudaDeviceGuard guard(bucket.flatBuffer.device().index); - auto worldSize = processGroup_->getWorldSize(); - bucket.flatBuffer /= static_cast(worldSize); copyFlattenedBufferToGrads(bIdx); bucket.work = nullptr; diff --git a/src/Distributed/TCPStore.cpp b/src/Distributed/TCPStore.cpp index f5f017e..d372ef0 100644 --- a/src/Distributed/TCPStore.cpp +++ b/src/Distributed/TCPStore.cpp @@ -6,11 +6,11 @@ #include "TCPStore.h" -#include #include #include #include -#include +#include +#include #include #include @@ -20,15 +20,22 @@ #include #include #include +#include #include #include #include #include #include +#include #include "Utils/Logger.h" #include "Utils/Macros.h" +#ifdef __APPLE__ +#include +#define MSG_NOSIGNAL 0 +#endif + typedef int socket_t; #define INVALID_SOCKET (-1) #define SOCKET_ERROR (-1) @@ -50,10 +57,7 @@ enum class TCPStoreCommandType : uint8_t { CMD_WORKER_UNREGISTER }; -static void setupSignalHandlers() { - // ignore SIGPIPE - signal(SIGPIPE, SIG_IGN); -} +static void setupSignalHandlers() { signal(SIGPIPE, SIG_IGN); } static bool sendBytes(socket_t socket, const void* data, size_t size) { const char* ptr = static_cast(data); @@ -61,23 +65,16 @@ static bool sendBytes(socket_t socket, const void* data, size_t size) { while (remaining > 0) { ssize_t sent = send(socket, ptr, remaining, MSG_NOSIGNAL); - if (sent < 0) { int error = errno; - if (error == EPIPE || error == ECONNRESET || error == ECONNABORTED) { - LOGE("TCPStore: Client disconnected (EPIPE/ECONNRESET/ECONNABORTED)"); - return false; - } else if (error == EAGAIN) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + if (error == EINTR) { continue; - } else { - LOGE("TCPStore: Send error: %d", error); - return false; } + LOGE("TCPStore: Send error: %d", error); + return false; } else if (sent == 0) { return false; } - ptr += sent; remaining -= sent; } @@ -90,38 +87,56 @@ static bool recvBytes(socket_t socket, void* data, size_t size) { while (remaining > 0) { ssize_t received = recv(socket, ptr, remaining, 0); - if (received < 0) { int error = errno; - if (error == ECONNRESET || error == EPIPE || error == ECONNABORTED) { - LOGE("TCPStore: Client disconnected during recv"); - return false; - } else if (error == EAGAIN) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + if (error == EINTR) { continue; - } else if (error == ETIMEDOUT) { - LOGE("TCPStore: Recv timeout"); - return false; - } else { + } + if (error != EAGAIN && error != EWOULDBLOCK && error != ETIMEDOUT) { LOGE("TCPStore: Recv error: %d", error); - return false; } + return false; } else if (received == 0) { return false; } - ptr += received; remaining -= received; } return true; } -bool configureSocket(socket_t socket) { +static bool sendString(socket_t sock, const std::string& s) { + uint64_t size = s.size(); + return sendBytes(sock, &size, sizeof(size)) && (size == 0 || sendBytes(sock, s.data(), size)); +} + +static bool recvString(socket_t sock, std::string& s) { + uint64_t size; + if (!recvBytes(sock, &size, sizeof(size))) return false; + s.resize(size); + return size == 0 || recvBytes(sock, &s[0], size); +} + +static bool sendBlob(socket_t sock, const std::vector& v) { + uint64_t size = v.size(); + return sendBytes(sock, &size, sizeof(size)) && (size == 0 || sendBytes(sock, v.data(), size)); +} + +static bool recvBlob(socket_t sock, std::vector& v) { + uint64_t size; + if (!recvBytes(sock, &size, sizeof(size))) return false; + v.resize(size); + return size == 0 || recvBytes(sock, v.data(), size); +} + +static bool configureSocket(socket_t socket, bool isServer) { int optval = 1; - if (setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)) < 0) { - LOGE("TCPStore: Failed to set SO_REUSEADDR"); - return false; + if (isServer) { + if (setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)) < 0) { + LOGE("TCPStore: Failed to set SO_REUSEADDR"); + return false; + } } if (setsockopt(socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) < 0) { @@ -129,44 +144,26 @@ bool configureSocket(socket_t socket) { return false; } - timeval timeout{}; - timeout.tv_sec = 30; // 30s - timeout.tv_usec = 0; - if (setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) { - LOGE("TCPStore: Failed to set SO_RCVTIMEO"); - return false; - } - - if (setsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) { - LOGE("TCPStore: Failed to set SO_SNDTIMEO"); +#ifdef __APPLE__ + if (setsockopt(socket, SOL_SOCKET, SO_NOSIGPIPE, &optval, sizeof(optval)) < 0) { + LOGE("TCPStore: Failed to set SO_NOSIGPIPE"); return false; } +#endif return true; } -static bool isSocketConnected(socket_t socket) { - int error = 0; - socklen_t len = sizeof(error); - int retval = getsockopt(socket, SOL_SOCKET, SO_ERROR, &error, &len); - - if (retval != 0) { +static bool setSocketTimeout(socket_t socket, int timeoutSec) { + timeval timeout{}; + timeout.tv_sec = timeoutSec; + timeout.tv_usec = 0; + if (setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) { return false; } - - if (error != 0) { + if (setsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) { return false; } - - char dummy; - ssize_t result = send(socket, &dummy, 0, MSG_NOSIGNAL); - if (result < 0) { - int sendError = errno; - if (sendError == EPIPE || sendError == ECONNRESET || sendError == ENOTCONN) { - return false; - } - } - return true; } @@ -177,6 +174,9 @@ class TCPStore::ClientImpl { ClientImpl(std::string host, uint16_t port, const std::chrono::milliseconds& timeout); ~ClientImpl(); + ClientImpl(const ClientImpl&) = delete; + ClientImpl& operator=(const ClientImpl&) = delete; + bool connect(); bool set(const std::string& key, const std::vector& value) const; bool get(const std::string& key, std::vector& value) const; @@ -192,14 +192,14 @@ class TCPStore::ClientImpl { std::uint16_t getPort() const noexcept { return port_; } private: + bool sendCommand(detail::TCPStoreCommandType cmd) const; + bool sendKeys(const std::vector& keys) const; + std::string host_; uint16_t port_; - std::chrono::milliseconds timeout_; socket_t clientSock_; - - int maxRetries = 10; - int retryDelayMs = 1000; - + int maxRetries_ = 10; + int retryDelayMs_ = 1000; mutable std::atomic isRegistered_{false}; }; @@ -210,7 +210,6 @@ class TCPStore::ServerImpl { bool start(); bool waitForWorkers(); - int getRegisteredWorkers() const; private: @@ -227,11 +226,9 @@ class TCPStore::ServerImpl { bool handleWorkerRegister(socket_t socket); bool handleWorkerUnregister(socket_t socket); - void cleanupFinishedThreads(); - void shutdownAllThreads(); - void cleanupDisconnectedClients(); - void startClientMonitoring(); - void stopClientMonitoring(); + void notifyKeyWaiters(const std::string& key); + bool waitForCondition(std::unique_lock& lock, std::shared_ptr& cv, + const std::function& pred); uint16_t port_; int numWorkers_; @@ -249,19 +246,19 @@ class TCPStore::ServerImpl { socket_t serverSock_; std::thread serverThread_; - std::vector> clientThreads_; std::mutex clientThreadsMutex_; - std::atomic activeThreadCount_{0}; - - std::unordered_set activeClients_; - std::mutex activeClientsMutex_; - std::thread clientMonitorThread_; - std::atomic monitoringRunning_{false}; + std::vector clientThreads_; }; +// ============================================================ +// TCPStore +// ============================================================ + TCPStore::TCPStore(const std::string& host, uint16_t port, bool isServer, bool waitWorkers, int numWorkers, const std::chrono::milliseconds& timeout) : Store(timeout), host_(host), port_(port), isServer_(isServer) { + detail::setupSignalHandlers(); + if (isServer_) { server_ = std::make_unique(port, numWorkers, timeout); if (!server_->start()) { @@ -282,7 +279,7 @@ TCPStore::TCPStore(const std::string& host, uint16_t port, bool isServer, bool w if (isServer_ && waitWorkers && numWorkers > 0) { if (!server_->waitForWorkers()) { - LOGE("TCPStore: Timeout while waiting for workers (expected: %d, registered: %d)", numWorkers, + LOGE("TCPStore: Timeout waiting for workers (expected: %d, registered: %d)", numWorkers, server_->getRegisteredWorkers()); return; } @@ -291,10 +288,7 @@ TCPStore::TCPStore(const std::string& host, uint16_t port, bool isServer, bool w TCPStore::~TCPStore() { if (client_) { - bool ret = client_->unregisterWorker(); - if (!ret) { - LOGE("TCPStore: unregisterWorker failed"); - } + client_->unregisterWorker(); } client_.reset(); server_.reset(); @@ -302,14 +296,14 @@ TCPStore::~TCPStore() { void TCPStore::set(const std::string& key, const std::vector& value) { if (!client_->set(key, value)) { - LOGE("TCPStore: set operation failed for key: %s", key.c_str()); + LOGE("TCPStore: set failed for key: %s", key.c_str()); } } std::vector TCPStore::get(const std::string& key) { std::vector value; if (!client_->get(key, value)) { - LOGE("TCPStore: get operation failed for key: %s", key.c_str()); + LOGE("TCPStore: get failed for key: %s", key.c_str()); } return value; } @@ -317,7 +311,7 @@ std::vector TCPStore::get(const std::string& key) { int64_t TCPStore::add(const std::string& key, int64_t value) { int64_t newValue = -1; if (!client_->add(key, value, newValue)) { - LOGE("TCPStore: add operation failed for key: %s", key.c_str()); + LOGE("TCPStore: add failed for key: %s", key.c_str()); } return newValue; } @@ -329,19 +323,23 @@ bool TCPStore::check(const std::vector& keys) { return client_->che int64_t TCPStore::getNumKeys() { int64_t numKeys = -1; if (!client_->getNumKeys(numKeys)) { - LOGE("TCPStore: getNumKeys operation failed"); + LOGE("TCPStore: getNumKeys failed"); } return numKeys; } void TCPStore::wait(const std::vector& keys) { if (!client_->wait(keys)) { - LOGE("TCPStore: wait operation failed or timed out"); + LOGE("TCPStore: wait failed or timed out"); } } +// ============================================================ +// ClientImpl +// ============================================================ + TCPStore::ClientImpl::ClientImpl(std::string host, uint16_t port, const std::chrono::milliseconds& timeout) - : host_(std::move(host)), port_(port), timeout_(timeout), clientSock_(INVALID_SOCKET) {} + : host_(std::move(host)), port_(port), clientSock_(INVALID_SOCKET) {} TCPStore::ClientImpl::~ClientImpl() { if (clientSock_ != INVALID_SOCKET) { @@ -357,19 +355,19 @@ bool TCPStore::ClientImpl::connect() { return false; } - if (!detail::configureSocket(clientSock_)) { + if (!detail::configureSocket(clientSock_, false)) { closesocket(clientSock_); clientSock_ = INVALID_SOCKET; return false; } + detail::setSocketTimeout(clientSock_, 30); + sockaddr_in addr{}; - std::memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_port = htons(port_); addrinfo hints{}, *result; - std::memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; @@ -383,22 +381,20 @@ bool TCPStore::ClientImpl::connect() { std::memcpy(&addr.sin_addr, &reinterpret_cast(result->ai_addr)->sin_addr, sizeof(in_addr)); freeaddrinfo(result); - // connect & retry bool connected = false; - for (int i = 0; i < maxRetries && !connected; i++) { + for (int i = 0; i < maxRetries_; i++) { if (::connect(clientSock_, reinterpret_cast(&addr), sizeof(addr)) == 0) { connected = true; break; } - - if (i < maxRetries - 1) { - LOGE("TCPStore: Connection attempt failed, retrying in %d ms", retryDelayMs); - std::this_thread::sleep_for(std::chrono::milliseconds(retryDelayMs)); + if (i < maxRetries_ - 1) { + LOGE("TCPStore: Connection attempt %d failed, retrying...", i + 1); + std::this_thread::sleep_for(std::chrono::milliseconds(retryDelayMs_)); } } if (!connected) { - LOGE("TCPStore: Failed to connect to server after %d attempts", maxRetries); + LOGE("TCPStore: Failed to connect after %d attempts", maxRetries_); closesocket(clientSock_); clientSock_ = INVALID_SOCKET; return false; @@ -407,244 +403,102 @@ bool TCPStore::ClientImpl::connect() { return true; } -bool TCPStore::ClientImpl::set(const std::string& key, const std::vector& value) const { - detail::TCPStoreCommandType cmd = detail::TCPStoreCommandType::CMD_SET; - - if (!detail::sendBytes(clientSock_, &cmd, sizeof(cmd))) { - LOGE("TCPStore: Failed to send SET command"); - return false; - } +bool TCPStore::ClientImpl::sendCommand(detail::TCPStoreCommandType cmd) const { + return detail::sendBytes(clientSock_, &cmd, sizeof(cmd)); +} - uint64_t keySize = key.size(); - if (!detail::sendBytes(clientSock_, &keySize, sizeof(keySize)) || - !detail::sendBytes(clientSock_, key.data(), keySize)) { - LOGE("TCPStore: Failed to send key for SET command"); - return false; +bool TCPStore::ClientImpl::sendKeys(const std::vector& keys) const { + uint64_t numKeys = keys.size(); + if (!detail::sendBytes(clientSock_, &numKeys, sizeof(numKeys))) return false; + for (const auto& key : keys) { + if (!detail::sendString(clientSock_, key)) return false; } + return true; +} - uint64_t valueSize = value.size(); - if (!detail::sendBytes(clientSock_, &valueSize, sizeof(valueSize)) || - !detail::sendBytes(clientSock_, value.data(), valueSize)) { - LOGE("TCPStore: Failed to send value for SET command"); - return false; - } +bool TCPStore::ClientImpl::set(const std::string& key, const std::vector& value) const { + if (!sendCommand(detail::TCPStoreCommandType::CMD_SET)) return false; + if (!detail::sendString(clientSock_, key)) return false; + if (!detail::sendBlob(clientSock_, value)) return false; uint8_t confirmation; - if (!detail::recvBytes(clientSock_, &confirmation, sizeof(confirmation)) || confirmation != 1) { - LOGE("TCPStore: Failed to receive confirmation for SET command"); - return false; - } - - return true; + return detail::recvBytes(clientSock_, &confirmation, sizeof(confirmation)) && confirmation == 1; } bool TCPStore::ClientImpl::get(const std::string& key, std::vector& value) const { - detail::TCPStoreCommandType cmd = detail::TCPStoreCommandType::CMD_GET; - - if (!detail::sendBytes(clientSock_, &cmd, sizeof(cmd))) { - LOGE("TCPStore: Failed to send GET command"); - return false; - } - - uint64_t keySize = key.size(); - if (!detail::sendBytes(clientSock_, &keySize, sizeof(keySize)) || - !detail::sendBytes(clientSock_, key.data(), keySize)) { - LOGE("TCPStore: Failed to send key for GET command"); - return false; - } - - uint64_t valueSize; - if (!detail::recvBytes(clientSock_, &valueSize, sizeof(valueSize))) { - LOGE("TCPStore: Failed to receive value size for GET command"); - return false; - } - - value.resize(valueSize); - if (valueSize > 0 && !detail::recvBytes(clientSock_, value.data(), valueSize)) { - LOGE("TCPStore: Failed to receive value for GET command"); - return false; - } - - return true; + if (!sendCommand(detail::TCPStoreCommandType::CMD_GET)) return false; + if (!detail::sendString(clientSock_, key)) return false; + return detail::recvBlob(clientSock_, value); } bool TCPStore::ClientImpl::add(const std::string& key, int64_t addValue, int64_t& newValue) const { - detail::TCPStoreCommandType cmd = detail::TCPStoreCommandType::CMD_ADD; - - if (!detail::sendBytes(clientSock_, &cmd, sizeof(cmd))) { - LOGE("TCPStore: Failed to send ADD command"); - return false; - } - - uint64_t keySize = key.size(); - if (!detail::sendBytes(clientSock_, &keySize, sizeof(keySize)) || - !detail::sendBytes(clientSock_, key.data(), keySize)) { - LOGE("TCPStore: Failed to send key for ADD command"); - return false; - } - - if (!detail::sendBytes(clientSock_, &addValue, sizeof(addValue))) { - LOGE("TCPStore: Failed to send value for ADD command"); - return false; - } - - if (!detail::recvBytes(clientSock_, &newValue, sizeof(newValue))) { - LOGE("TCPStore: Failed to receive new value for ADD command"); - return false; - } - - return true; + if (!sendCommand(detail::TCPStoreCommandType::CMD_ADD)) return false; + if (!detail::sendString(clientSock_, key)) return false; + if (!detail::sendBytes(clientSock_, &addValue, sizeof(addValue))) return false; + return detail::recvBytes(clientSock_, &newValue, sizeof(newValue)); } bool TCPStore::ClientImpl::deleteKey(const std::string& key) const { - detail::TCPStoreCommandType cmd = detail::TCPStoreCommandType::CMD_DELETE_KEY; - - if (!detail::sendBytes(clientSock_, &cmd, sizeof(cmd))) { - LOGE("TCPStore: Failed to send DELETE command"); - return false; - } - - uint64_t keySize = key.size(); - if (!detail::sendBytes(clientSock_, &keySize, sizeof(keySize)) || - !detail::sendBytes(clientSock_, key.data(), keySize)) { - LOGE("TCPStore: Failed to send key for DELETE command"); - return false; - } + if (!sendCommand(detail::TCPStoreCommandType::CMD_DELETE_KEY)) return false; + if (!detail::sendString(clientSock_, key)) return false; uint8_t result; - if (!detail::recvBytes(clientSock_, &result, sizeof(result))) { - LOGE("TCPStore: Failed to receive result for DELETE command"); - return false; - } - - return result == 1; + return detail::recvBytes(clientSock_, &result, sizeof(result)) && result == 1; } bool TCPStore::ClientImpl::check(const std::vector& keys) const { - detail::TCPStoreCommandType cmd = detail::TCPStoreCommandType::CMD_CHECK; - - if (!detail::sendBytes(clientSock_, &cmd, sizeof(cmd))) { - LOGE("TCPStore: Failed to send CHECK command"); - return false; - } - - uint64_t numKeys = keys.size(); - if (!detail::sendBytes(clientSock_, &numKeys, sizeof(numKeys))) { - LOGE("TCPStore: Failed to send number of keys for CHECK command"); - return false; - } - - for (const auto& key : keys) { - uint64_t keySize = key.size(); - if (!detail::sendBytes(clientSock_, &keySize, sizeof(keySize)) || - !detail::sendBytes(clientSock_, key.data(), keySize)) { - LOGE("TCPStore: Failed to send key for CHECK command"); - return false; - } - } + if (!sendCommand(detail::TCPStoreCommandType::CMD_CHECK)) return false; + if (!sendKeys(keys)) return false; uint8_t result; - if (!detail::recvBytes(clientSock_, &result, sizeof(result))) { - LOGE("TCPStore: Failed to receive result for CHECK command"); - return false; - } - - return result == 1; + return detail::recvBytes(clientSock_, &result, sizeof(result)) && result == 1; } bool TCPStore::ClientImpl::getNumKeys(int64_t& numKeys) const { - detail::TCPStoreCommandType cmd = detail::TCPStoreCommandType::CMD_NUM_KEYS; - - if (!detail::sendBytes(clientSock_, &cmd, sizeof(cmd))) { - LOGE("TCPStore: Failed to send NUM_KEYS command"); - return false; - } - - if (!detail::recvBytes(clientSock_, &numKeys, sizeof(numKeys))) { - LOGE("TCPStore: Failed to receive number of keys"); - return false; - } - - return true; + if (!sendCommand(detail::TCPStoreCommandType::CMD_NUM_KEYS)) return false; + return detail::recvBytes(clientSock_, &numKeys, sizeof(numKeys)); } bool TCPStore::ClientImpl::wait(const std::vector& keys) const { - detail::TCPStoreCommandType cmd = detail::TCPStoreCommandType::CMD_WAIT; - - if (!detail::sendBytes(clientSock_, &cmd, sizeof(cmd))) { - LOGE("TCPStore: Failed to send WAIT command"); - return false; - } - - uint64_t numKeys = keys.size(); - if (!detail::sendBytes(clientSock_, &numKeys, sizeof(numKeys))) { - LOGE("TCPStore: Failed to send number of keys for WAIT command"); - return false; - } - - for (const auto& key : keys) { - uint64_t keySize = key.size(); - if (!detail::sendBytes(clientSock_, &keySize, sizeof(keySize)) || - !detail::sendBytes(clientSock_, key.data(), keySize)) { - LOGE("TCPStore: Failed to send key for WAIT command"); - return false; - } - } + if (!sendCommand(detail::TCPStoreCommandType::CMD_WAIT)) return false; + if (!sendKeys(keys)) return false; uint8_t result; - if (!detail::recvBytes(clientSock_, &result, sizeof(result))) { - LOGE("TCPStore: Failed to receive result for WAIT command"); - return false; - } - - return result == 1; + return detail::recvBytes(clientSock_, &result, sizeof(result)) && result == 1; } bool TCPStore::ClientImpl::registerWorker() const { - detail::TCPStoreCommandType cmd = detail::TCPStoreCommandType::CMD_WORKER_REGISTER; - if (!detail::sendBytes(clientSock_, &cmd, sizeof(cmd))) { - LOGE("TCPStore: Failed to send worker registration command"); - return false; - } + if (!sendCommand(detail::TCPStoreCommandType::CMD_WORKER_REGISTER)) return false; uint8_t confirmation; if (!detail::recvBytes(clientSock_, &confirmation, sizeof(confirmation)) || confirmation != 1) { - LOGE("TCPStore: Failed to receive worker registration confirmation"); return false; } - isRegistered_.store(true); return true; } bool TCPStore::ClientImpl::unregisterWorker() const { - if (!isRegistered_.load()) { - return true; - } - - detail::TCPStoreCommandType cmd = detail::TCPStoreCommandType::CMD_WORKER_UNREGISTER; - if (!detail::sendBytes(clientSock_, &cmd, sizeof(cmd))) { - LOGE("TCPStore: Failed to send worker unregistration command"); - return false; - } - + if (!isRegistered_.load()) return true; + if (!sendCommand(detail::TCPStoreCommandType::CMD_WORKER_UNREGISTER)) return false; isRegistered_.store(false); return true; } +// ============================================================ +// ServerImpl +// ============================================================ + TCPStore::ServerImpl::ServerImpl(uint16_t port, int numWorkers, const std::chrono::milliseconds& timeout) : port_(port), numWorkers_(numWorkers), timeout_(timeout), shutdownServer_(false), registeredWorkers_(0), - serverSock_(INVALID_SOCKET) { - detail::setupSignalHandlers(); -} + serverSock_(INVALID_SOCKET) {} TCPStore::ServerImpl::~ServerImpl() { shutdownServer_.store(true); - stopClientMonitoring(); if (serverSock_ != INVALID_SOCKET) { closesocket(serverSock_); @@ -662,8 +516,19 @@ TCPStore::ServerImpl::~ServerImpl() { workerCV_.notify_all(); } - serverThread_.detach(); - shutdownAllThreads(); + if (serverThread_.joinable()) { + serverThread_.join(); + } + + { + std::lock_guard lock(clientThreadsMutex_); + for (auto& t : clientThreads_) { + if (t.joinable()) { + t.join(); + } + } + clientThreads_.clear(); + } } bool TCPStore::ServerImpl::start() { @@ -673,20 +538,19 @@ bool TCPStore::ServerImpl::start() { return false; } - if (!detail::configureSocket(serverSock_)) { + if (!detail::configureSocket(serverSock_, true)) { closesocket(serverSock_); serverSock_ = INVALID_SOCKET; return false; } sockaddr_in addr{}; - std::memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_port = htons(port_); addr.sin_addr.s_addr = INADDR_ANY; if (bind(serverSock_, reinterpret_cast(&addr), sizeof(addr)) != 0) { - LOGE("TCPStore: Failed to bind server socket"); + LOGE("TCPStore: Failed to bind server socket on port %d", port_); closesocket(serverSock_); serverSock_ = INVALID_SOCKET; return false; @@ -700,14 +564,11 @@ bool TCPStore::ServerImpl::start() { } serverThread_ = std::thread(&ServerImpl::serverLoop, this); - startClientMonitoring(); return true; } bool TCPStore::ServerImpl::waitForWorkers() { - if (numWorkers_ <= 0) { - return true; - } + if (numWorkers_ <= 0) return true; std::unique_lock lock(mutex_); bool success = workerCV_.wait_for( @@ -715,154 +576,74 @@ bool TCPStore::ServerImpl::waitForWorkers() { if (!success) { LOGE("TCPStore: Timeout waiting for workers. Expected: %d, Registered: %d", numWorkers_, registeredWorkers_.load()); - } else if (!shutdownServer_.load()) { - LOGI("TCPStore: All %d workers registered successfully", numWorkers_); } return success && !shutdownServer_.load(); } int TCPStore::ServerImpl::getRegisteredWorkers() const { return registeredWorkers_.load(); } -void TCPStore::ServerImpl::startClientMonitoring() { - monitoringRunning_.store(true); - clientMonitorThread_ = std::thread([this]() { - while (monitoringRunning_.load()) { - cleanupDisconnectedClients(); - std::this_thread::sleep_for(std::chrono::milliseconds(500)); - } - }); -} - -void TCPStore::ServerImpl::stopClientMonitoring() { - monitoringRunning_.store(false); - clientMonitorThread_.detach(); -} - -void TCPStore::ServerImpl::cleanupDisconnectedClients() { - std::lock_guard lock(activeClientsMutex_); - - auto it = activeClients_.begin(); - while (it != activeClients_.end()) { - if (!detail::isSocketConnected(*it)) { - LOGI("TCPStore: Detected disconnected client, cleaning up"); - closesocket(*it); - it = activeClients_.erase(it); - } else { - ++it; - } - } -} - -void TCPStore::ServerImpl::cleanupFinishedThreads() { - std::lock_guard lock(clientThreadsMutex_); - - auto it = clientThreads_.begin(); - while (it != clientThreads_.end()) { - if ((*it)->joinable()) { - ++it; - } else { - it = clientThreads_.erase(it); - activeThreadCount_.fetch_sub(1); - } - } -} - -void TCPStore::ServerImpl::shutdownAllThreads() { - std::vector> threads; - { - std::lock_guard lock(clientThreadsMutex_); - threads = std::move(clientThreads_); - clientThreads_.clear(); - } - - for (auto& thread : threads) { - if (thread) { - thread->detach(); - } - } - - activeThreadCount_.store(0); -} - void TCPStore::ServerImpl::serverLoop() { + detail::setSocketTimeout(serverSock_, 1); + while (!shutdownServer_.load()) { sockaddr_in clientAddr{}; socklen_t addrLen = sizeof(clientAddr); socket_t clientSocket = accept(serverSock_, reinterpret_cast(&clientAddr), &addrLen); if (clientSocket == INVALID_SOCKET) { - if (!shutdownServer_.load()) { - LOGE("TCPStore: Failed to accept client connection"); - } continue; } - { - std::lock_guard lock(activeClientsMutex_); - activeClients_.insert(clientSocket); - } - - { - std::lock_guard lock(clientThreadsMutex_); - - if (clientThreads_.size() > 128) { - cleanupFinishedThreads(); - } - - auto thread = std::make_unique(&ServerImpl::handleClientInThread, this, clientSocket); - clientThreads_.push_back(std::move(thread)); - activeThreadCount_.fetch_add(1); - } + std::lock_guard lock(clientThreadsMutex_); + clientThreads_.emplace_back(&ServerImpl::handleClientInThread, this, clientSocket); } } void TCPStore::ServerImpl::handleClientInThread(socket_t socket) { - struct ThreadGuard { - std::atomic& counter; - socket_t socket_; - std::unordered_set& activeClients_; - std::mutex& activeClientsMutex_; - - explicit ThreadGuard(std::atomic& c, socket_t s, std::unordered_set& clients, std::mutex& mutex) - : counter(c), socket_(s), activeClients_(clients), activeClientsMutex_(mutex) {} - - ~ThreadGuard() { - counter.fetch_sub(1); - { - std::lock_guard lock(activeClientsMutex_); - activeClients_.erase(socket_); - } - closesocket(socket_); - } - }; - - ThreadGuard guard(activeThreadCount_, socket, activeClients_, activeClientsMutex_); - - timeval timeout{}; - timeout.tv_sec = 1; // 1s - timeout.tv_usec = 0; - setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout)); + detail::configureSocket(socket, false); + // SO_RCVTIMEO acts as a mid-message safety net only. Inter-message idleness + // is governed by poll() below — a long quiet period (e.g. while a client is + // loading a large model between registerWorker() and the first NCCL + // collective) must NOT tear down the connection. + detail::setSocketTimeout(socket, 30); while (!shutdownServer_.load()) { + // wait for either readable data or shutdown, in 1s ticks. + struct pollfd pfd {}; + pfd.fd = socket; + pfd.events = POLLIN; + int rc = poll(&pfd, 1, 1000); + if (rc < 0) { + if (errno == EINTR) continue; + break; + } + if (rc == 0) { + // idle; loop and re-check shutdownServer_. + continue; + } + if (pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) { + break; + } + if (!(pfd.revents & POLLIN)) { + continue; + } if (!handleClient(socket)) { break; } } + + closesocket(socket); } bool TCPStore::ServerImpl::handleClient(socket_t socket) { - if (shutdownServer_.load()) { - return false; - } + if (shutdownServer_.load()) return false; detail::TCPStoreCommandType cmd; if (!detail::recvBytes(socket, &cmd, sizeof(cmd))) { return false; } - if (shutdownServer_.load()) { - return false; - } + if (shutdownServer_.load()) return false; switch (cmd) { case detail::TCPStoreCommandType::CMD_SET: @@ -884,43 +665,42 @@ bool TCPStore::ServerImpl::handleClient(socket_t socket) { case detail::TCPStoreCommandType::CMD_WORKER_UNREGISTER: return handleWorkerUnregister(socket); default: - LOGE("TCPStore: Unknown command type: %d", static_cast(cmd)); + LOGE("TCPStore: Unknown command: %d", static_cast(cmd)); return false; } } -bool TCPStore::ServerImpl::handleSetCommand(socket_t socket) { - uint64_t keySize; - if (!detail::recvBytes(socket, &keySize, sizeof(keySize))) { - return false; +void TCPStore::ServerImpl::notifyKeyWaiters(const std::string& key) { + auto it = keyCVs_.find(key); + if (it != keyCVs_.end()) { + for (auto& cv : it->second) { + cv->notify_all(); + } + keyCVs_.erase(it); } +} - std::string key(keySize, '\0'); - if (!detail::recvBytes(socket, &key[0], keySize)) { - return false; - } - - uint64_t valueSize; - if (!detail::recvBytes(socket, &valueSize, sizeof(valueSize))) { - return false; +bool TCPStore::ServerImpl::waitForCondition(std::unique_lock& lock, + std::shared_ptr& cv, + const std::function& pred) { + if (timeout_ != kNoTimeout) { + return cv->wait_for(lock, timeout_, [&] { return shutdownServer_.load() || pred(); }) && !shutdownServer_.load(); } + cv->wait(lock, [&] { return shutdownServer_.load() || pred(); }); + return !shutdownServer_.load(); +} - std::vector value(valueSize); - if (!detail::recvBytes(socket, value.data(), valueSize)) { +bool TCPStore::ServerImpl::handleSetCommand(socket_t socket) { + std::string key; + std::vector value; + if (!detail::recvString(socket, key) || !detail::recvBlob(socket, value)) { return false; } { std::lock_guard lock(mutex_); keyValueStore_[key] = std::move(value); - - auto it = keyCVs_.find(key); - if (it != keyCVs_.end()) { - for (auto& cv : it->second) { - cv->notify_all(); - } - keyCVs_.erase(it); - } + notifyKeyWaiters(key); } uint8_t confirmation = 1; @@ -928,19 +708,11 @@ bool TCPStore::ServerImpl::handleSetCommand(socket_t socket) { } bool TCPStore::ServerImpl::handleGetCommand(socket_t socket) { - uint64_t keySize; - if (!detail::recvBytes(socket, &keySize, sizeof(keySize))) { - return false; - } - - std::string key(keySize, '\0'); - if (!detail::recvBytes(socket, &key[0], keySize)) { - return false; - } + std::string key; + if (!detail::recvString(socket, key)) return false; std::vector value; std::shared_ptr cv; - bool success = false; { std::unique_lock lock(mutex_); @@ -948,44 +720,19 @@ bool TCPStore::ServerImpl::handleGetCommand(socket_t socket) { auto it = keyValueStore_.find(key); if (it != keyValueStore_.end()) { value = it->second; - success = true; - } else { - cv = std::make_shared(); - keyCVs_[key].push_back(cv); + return detail::sendBlob(socket, value); } - } - if (success) { - uint64_t valueSize = value.size(); - return detail::sendBytes(socket, &valueSize, sizeof(valueSize)) && - (valueSize == 0 || detail::sendBytes(socket, value.data(), valueSize)); - } - - { - std::unique_lock lock(mutex_); - - if (timeout_ != kNoTimeout) { - auto waitUntil = std::chrono::steady_clock::now() + timeout_; - success = cv->wait_until(lock, waitUntil, [this, &key] { - return shutdownServer_.load() || keyValueStore_.find(key) != keyValueStore_.end(); - }); - } else { - cv->wait(lock, - [this, &key] { return shutdownServer_.load() || keyValueStore_.find(key) != keyValueStore_.end(); }); - success = true; - } + cv = std::make_shared(); + keyCVs_[key].push_back(cv); - if (shutdownServer_.load()) { - success = false; - } + bool success = waitForCondition(lock, cv, [&] { return keyValueStore_.find(key) != keyValueStore_.end(); }); auto cvIt = keyCVs_.find(key); if (cvIt != keyCVs_.end()) { auto& cvs = cvIt->second; cvs.erase(std::remove(cvs.begin(), cvs.end(), cv), cvs.end()); - if (cvs.empty()) { - keyCVs_.erase(cvIt); - } + if (cvs.empty()) keyCVs_.erase(cvIt); } if (success) { @@ -996,53 +743,33 @@ bool TCPStore::ServerImpl::handleGetCommand(socket_t socket) { success = false; } } - } - if (success) { - uint64_t valueSize = value.size(); - return detail::sendBytes(socket, &valueSize, sizeof(valueSize)) && - (valueSize == 0 || detail::sendBytes(socket, value.data(), valueSize)); + if (!success) { + LOGE("TCPStore: Get timeout for key: %s", key.c_str()); + return false; + } } - LOGE("TCPStore: Get timeout after %lld ms for key: %s", timeout_.count(), key.c_str()); - return false; + return detail::sendBlob(socket, value); } bool TCPStore::ServerImpl::handleAddCommand(socket_t socket) { - uint64_t keySize; - if (!detail::recvBytes(socket, &keySize, sizeof(keySize))) { - return false; - } - - std::string key(keySize, '\0'); - if (!detail::recvBytes(socket, &key[0], keySize)) { - return false; - } + std::string key; + if (!detail::recvString(socket, key)) return false; int64_t addValue; - if (!detail::recvBytes(socket, &addValue, sizeof(addValue))) { - return false; - } + if (!detail::recvBytes(socket, &addValue, sizeof(addValue))) return false; int64_t newValue = addValue; { std::lock_guard lock(mutex_); - const auto& value = keyValueStore_[key]; - if (!value.empty()) { - auto buf = reinterpret_cast(value.data()); - auto len = value.size(); - newValue += std::stoll(std::string(buf, len)); - } - auto newValStr = std::to_string(newValue); - keyValueStore_[key] = std::vector(newValStr.begin(), newValStr.end()); - - auto it = keyCVs_.find(key); - if (it != keyCVs_.end()) { - for (auto& cv : it->second) { - cv->notify_all(); - } - keyCVs_.erase(it); + const auto& existing = keyValueStore_[key]; + if (!existing.empty()) { + newValue += std::stoll(std::string(reinterpret_cast(existing.data()), existing.size())); } + auto str = std::to_string(newValue); + keyValueStore_[key] = std::vector(str.begin(), str.end()); + notifyKeyWaiters(key); } return detail::sendBytes(socket, &newValue, sizeof(newValue)); @@ -1050,112 +777,68 @@ bool TCPStore::ServerImpl::handleAddCommand(socket_t socket) { bool TCPStore::ServerImpl::handleCheckCommand(socket_t socket) { uint64_t numKeys; - if (!detail::recvBytes(socket, &numKeys, sizeof(numKeys))) { - return false; - } + if (!detail::recvBytes(socket, &numKeys, sizeof(numKeys))) return false; std::vector keys(numKeys); - for (size_t i = 0; i < numKeys; i++) { - uint64_t keySize; - if (!detail::recvBytes(socket, &keySize, sizeof(keySize))) { - return false; - } - - keys[i].resize(keySize); - if (!detail::recvBytes(socket, &keys[i][0], keySize)) { - return false; - } + if (!detail::recvString(socket, keys[i])) return false; } - bool allKeysExist = true; + bool allExist = true; { std::lock_guard lock(mutex_); for (const auto& key : keys) { if (keyValueStore_.find(key) == keyValueStore_.end()) { - allKeysExist = false; + allExist = false; break; } } } - uint8_t result = allKeysExist ? 1 : 0; + uint8_t result = allExist ? 1 : 0; return detail::sendBytes(socket, &result, sizeof(result)); } bool TCPStore::ServerImpl::handleWaitCommand(socket_t socket) { uint64_t numKeys; - if (!detail::recvBytes(socket, &numKeys, sizeof(numKeys))) { - return false; - } + if (!detail::recvBytes(socket, &numKeys, sizeof(numKeys))) return false; std::vector keys(numKeys); for (size_t i = 0; i < numKeys; i++) { - uint64_t keySize; - if (!detail::recvBytes(socket, &keySize, sizeof(keySize))) { - return false; - } - keys[i].resize(keySize); - if (!detail::recvBytes(socket, &keys[i][0], keySize)) { - return false; - } + if (!detail::recvString(socket, keys[i])) return false; } - bool success = true; - std::shared_ptr cv; - std::vector missingKeys; + auto allKeysExist = [&] { + return std::all_of(keys.begin(), keys.end(), + [&](const auto& k) { return keyValueStore_.find(k) != keyValueStore_.end(); }); + }; + bool success; { std::unique_lock lock(mutex_); - for (const auto& key : keys) { - if (keyValueStore_.find(key) == keyValueStore_.end()) { - missingKeys.push_back(key); - } - } - - if (missingKeys.empty()) { + if (allKeysExist()) { uint8_t result = 1; return detail::sendBytes(socket, &result, sizeof(result)); } - cv = std::make_shared(); - for (const auto& key : missingKeys) { - keyCVs_[key].push_back(cv); - } - } - - { - std::unique_lock lock(mutex_); - - if (timeout_ != kNoTimeout) { - auto waitUntil = std::chrono::steady_clock::now() + timeout_; - success = cv->wait_until(lock, waitUntil, [this, &keys] { - return shutdownServer_.load() || std::all_of(keys.begin(), keys.end(), [this](const auto& key) { - return keyValueStore_.find(key) != keyValueStore_.end(); - }); - }); - } else { - cv->wait(lock, [this, &keys] { - return shutdownServer_.load() || std::all_of(keys.begin(), keys.end(), [this](const auto& key) { - return keyValueStore_.find(key) != keyValueStore_.end(); - }); - }); - success = true; + auto cv = std::make_shared(); + std::vector missingKeys; + for (const auto& key : keys) { + if (keyValueStore_.find(key) == keyValueStore_.end()) { + keyCVs_[key].push_back(cv); + missingKeys.push_back(key); + } } - if (shutdownServer_.load()) { - success = false; - } + success = waitForCondition(lock, cv, allKeysExist); for (const auto& key : missingKeys) { auto it = keyCVs_.find(key); if (it != keyCVs_.end()) { auto& cvs = it->second; cvs.erase(std::remove(cvs.begin(), cvs.end(), cv), cvs.end()); - if (cvs.empty()) { - keyCVs_.erase(it); - } + if (cvs.empty()) keyCVs_.erase(it); } } } @@ -1165,27 +848,20 @@ bool TCPStore::ServerImpl::handleWaitCommand(socket_t socket) { } bool TCPStore::ServerImpl::handleDeleteCommand(socket_t socket) { - uint64_t keySize; - if (!detail::recvBytes(socket, &keySize, sizeof(keySize))) { - return false; - } - - std::string key(keySize, '\0'); - if (!detail::recvBytes(socket, &key[0], keySize)) { - return false; - } + std::string key; + if (!detail::recvString(socket, key)) return false; - bool success = false; + bool found = false; { std::lock_guard lock(mutex_); - auto kvIt = keyValueStore_.find(key); - if (kvIt != keyValueStore_.end()) { - keyValueStore_.erase(kvIt); - success = true; + auto it = keyValueStore_.find(key); + if (it != keyValueStore_.end()) { + keyValueStore_.erase(it); + found = true; } } - uint8_t result = success ? 1 : 0; + uint8_t result = found ? 1 : 0; return detail::sendBytes(socket, &result, sizeof(result)); } @@ -1195,7 +871,6 @@ bool TCPStore::ServerImpl::handleNumKeysCommand(socket_t socket) { std::lock_guard lock(mutex_); numKeys = static_cast(keyValueStore_.size()); } - return detail::sendBytes(socket, &numKeys, sizeof(numKeys)); } @@ -1204,7 +879,6 @@ bool TCPStore::ServerImpl::handleWorkerRegister(socket_t socket) { uint8_t confirmation = 1; if (!detail::sendBytes(socket, &confirmation, sizeof(confirmation))) { - LOGE("TCPStore: Failed to send worker registration confirmation"); registeredWorkers_.fetch_sub(1); return false; } @@ -1214,7 +888,7 @@ bool TCPStore::ServerImpl::handleWorkerRegister(socket_t socket) { workerCV_.notify_all(); } - LOGI("TCPStore: Worker registered. Current workers: %d/%d", currentWorkers, numWorkers_); + LOGI("TCPStore: Worker registered (%d/%d)", currentWorkers, numWorkers_); return true; } @@ -1224,4 +898,4 @@ bool TCPStore::ServerImpl::handleWorkerUnregister(socket_t socket) { return true; } -} // namespace tinytorch::distributed \ No newline at end of file +} // namespace tinytorch::distributed diff --git a/src/Distributed/Types.h b/src/Distributed/Types.h index 662c937..90ade20 100644 --- a/src/Distributed/Types.h +++ b/src/Distributed/Types.h @@ -6,6 +6,8 @@ #pragma once +#include + namespace tinytorch::distributed { enum ReduceOpType : uint8_t { diff --git a/src/Distributed/WorkNCCL.cpp b/src/Distributed/WorkNCCL.cpp index 4a4220c..f9e32b5 100644 --- a/src/Distributed/WorkNCCL.cpp +++ b/src/Distributed/WorkNCCL.cpp @@ -15,13 +15,28 @@ bool WorkNCCL::isCompleted() { return true; } + if (useComputeStream_) { + finish(""); + return true; + } + if (!cudaEvent_.query()) { return false; } + finish(""); return true; } bool WorkNCCL::wait(std::chrono::milliseconds timeout) { + if (useComputeStream_) { + if (postCompletionFn_) { + postCompletionFn_(); + postCompletionFn_ = nullptr; + } + finish(""); + return true; + } + synchronize(); if (timeout != kNoTimeout) { @@ -38,6 +53,11 @@ bool WorkNCCL::wait(std::chrono::milliseconds timeout) { stream.synchronize(); } + if (postCompletionFn_) { + postCompletionFn_(); + postCompletionFn_ = nullptr; + } + finish(""); return true; } diff --git a/src/Distributed/WorkNCCL.h b/src/Distributed/WorkNCCL.h index 1991e37..27af7ac 100644 --- a/src/Distributed/WorkNCCL.h +++ b/src/Distributed/WorkNCCL.h @@ -6,6 +6,8 @@ #pragma once +#include + #include "NCCLUtils.h" #include "Store.h" #include "Work.h" @@ -33,6 +35,8 @@ class WorkNCCL : public Work { void setIsBarrierOp(bool isBarrierOp) { isBarrierOp_ = isBarrierOp; } void setCudaEvent(cuda::CUDAEvent&& event) { cudaEvent_ = std::move(event); } void setOutputs(std::vector&& output) { outputs_ = std::move(output); } + void setPostCompletionFn(std::function fn) { postCompletionFn_ = std::move(fn); } + void setUseComputeStream(bool flag) { useComputeStream_ = flag; } const Device& getDevice() const { return device_; } const std::string& getPgUID() const { return pgUID_; } @@ -56,6 +60,8 @@ class WorkNCCL : public Work { bool isBarrierOp_{false}; cuda::CUDAEvent cudaEvent_; std::vector outputs_; + std::function postCompletionFn_; + bool useComputeStream_{false}; }; } // namespace tinytorch::distributed diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 92b2489..1722ab2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -11,6 +11,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) add_subdirectory(googletest) +include(GoogleTest) + add_executable(${PROJECT_NAME} test.cpp test_tensor.cpp @@ -29,13 +31,18 @@ target_include_directories(${PROJECT_NAME} PRIVATE ) target_link_libraries(${PROJECT_NAME} TinyTorch_lib gtest_main) -# cuda support -find_package(CUDAToolkit QUIET) -if (CUDAToolkit_FOUND) - target_compile_definitions(${PROJECT_NAME} PRIVATE USE_CUDA) -endif () - -include(GoogleTest) - # add tests gtest_discover_tests(${PROJECT_NAME} WORKING_DIRECTORY $) + +# nccl distributed test +if (TINYTORCH_USE_NCCL) + add_executable(TinyTorch_distributed_test test_distributed.cpp) + target_include_directories(TinyTorch_distributed_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../src + ${CMAKE_CURRENT_SOURCE_DIR}/../third_party + googletest/googletest/include + googletest/googlemock/include + ) + target_link_libraries(TinyTorch_distributed_test TinyTorch_lib gtest_main) + gtest_discover_tests(TinyTorch_distributed_test WORKING_DIRECTORY $) +endif () diff --git a/test/test_distributed.cpp b/test/test_distributed.cpp new file mode 100644 index 0000000..4cc3157 --- /dev/null +++ b/test/test_distributed.cpp @@ -0,0 +1,279 @@ +/* + * TinyTorch + * @author : keith@robot9.me + * + */ + +#if defined(USE_CUDA) && defined(USE_NCCL) + +#include +#include + +#include +#include +#include +#include +#include + +#include "Distributed/DistributedProcessGroup.h" +#include "TinyTorch.h" +#include "Utils/CUDAUtils.h" +#include "test.h" + +using namespace tinytorch; +using namespace tinytorch::distributed; + +namespace { + +constexpr int kBasePort = 29600; +std::atomic gPortOffset{0}; + +int getNextPort() { return kBasePort + gPortOffset.fetch_add(1); } + +bool runMultiProcess(int worldSize, const std::function& fn) { + int port = getNextPort(); + std::vector children; + children.reserve(worldSize); + + for (int rank = 0; rank < worldSize; rank++) { + pid_t pid = fork(); + if (pid < 0) { + std::cerr << "fork failed for rank " << rank << std::endl; + return false; + } + if (pid == 0) { + // child process + try { + fn(rank, worldSize, port); + } catch (const std::exception& e) { + std::cerr << "Rank " << rank << " exception: " << e.what() << std::endl; + _exit(1); + } catch (...) { + std::cerr << "Rank " << rank << " unknown exception" << std::endl; + _exit(1); + } + _exit(0); + } + children.push_back(pid); + } + + // parent: wait for all children + bool allSuccess = true; + for (pid_t pid : children) { + int status = 0; + waitpid(pid, &status, 0); + if (!WIFEXITED(status) || WEXITSTATUS(status) != 0) { + allSuccess = false; + } + } + return allSuccess; +} + +int getDeviceCountNoCudaInit() { + // try /proc/driver/nvidia/gpus - each GPU has a numbered directory + int count = 0; + for (int i = 0; i < 16; i++) { + std::string path = "/proc/driver/nvidia/gpus/" + std::to_string(i) + "/information"; + if (access(path.c_str(), F_OK) == 0) { + count++; + } else { + break; + } + } + if (count > 0) { + return count; + } + // fallback: use nvidia-smi + FILE* fp = popen("nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null | wc -l", "r"); + if (fp) { + char buf[32] = {}; + if (fgets(buf, sizeof(buf), fp)) { + count = std::atoi(buf); + } + pclose(fp); + } + return count; +} + +#define SKIP_IF_NOT_ENOUGH_GPUS() \ + do { \ + int deviceCount = getDeviceCountNoCudaInit(); \ + if (deviceCount < 2) { \ + GTEST_SKIP() << "Need >= 2 GPUs, have " << deviceCount; \ + } \ + } while (0) + +} // namespace + +TEST(TEST_Distributed, AllReduce_SUM) { + SKIP_IF_NOT_ENOUGH_GPUS(); + + int worldSize = 2; + bool ok = runMultiProcess(worldSize, [](int rank, int worldSize, int port) { + cuda::setDevice(rank); + auto dpg = std::make_shared(); + std::string initMethod = "tcp://127.0.0.1:" + std::to_string(port); + if (!dpg->initProcessGroup(NCCL, initMethod, rank, worldSize)) { + _exit(1); + } + + // each rank has tensor filled with (rank + 1) + std::vector data = {1.0f * (rank + 1), 2.0f * (rank + 1), 3.0f * (rank + 1), 4.0f * (rank + 1)}; + auto tensor = Tensor(data, Options({DeviceType::CUDA, static_cast(rank)}, DType::Float32)); + + std::vector tensors = {tensor}; + auto work = dpg->allReduce(tensors, SUM); + if (!work || !work->wait()) { + _exit(1); + } + + // expected: sum of (rank+1) for all ranks = worldSize*(worldSize+1)/2 + auto expectedScale = static_cast(worldSize * (worldSize + 1) / 2); + auto result = tensor.toList(); + for (int i = 0; i < 4; i++) { + float expected = static_cast(i + 1) * expectedScale; + if (std::fabs(result[i] - expected) > 1e-3f) { + _exit(1); + } + } + dpg->destroyProcessGroup(); + }); + EXPECT_TRUE(ok); +} + +TEST(TEST_Distributed, AllReduce_AVG) { + SKIP_IF_NOT_ENOUGH_GPUS(); + + int worldSize = 2; + bool ok = runMultiProcess(worldSize, [](int rank, int worldSize, int port) { + cuda::setDevice(rank); + auto dpg = std::make_shared(); + std::string initMethod = "tcp://127.0.0.1:" + std::to_string(port); + if (!dpg->initProcessGroup(NCCL, initMethod, rank, worldSize)) { + _exit(1); + } + + std::vector data = {10.0f * (static_cast(rank) + 1)}; + auto tensor = Tensor(data, Options({DeviceType::CUDA, static_cast(rank)}, DType::Float32)); + + std::vector tensors = {tensor}; + auto work = dpg->allReduce(tensors, AVG); + if (!work || !work->wait()) { + _exit(1); + } + + // AVG of (10, 20) = 15 + float expectedAvg = 10.0f * static_cast(worldSize + 1) / 2.0f; + auto result = tensor.toList(); + if (std::fabs(result[0] - expectedAvg) > 1e-3f) { + _exit(1); + } + dpg->destroyProcessGroup(); + }); + EXPECT_TRUE(ok); +} + +TEST(TEST_Distributed, Broadcast) { + SKIP_IF_NOT_ENOUGH_GPUS(); + + int worldSize = 2; + bool ok = runMultiProcess(worldSize, [](int rank, int worldSize, int port) { + cuda::setDevice(rank); + auto dpg = std::make_shared(); + std::string initMethod = "tcp://127.0.0.1:" + std::to_string(port); + if (!dpg->initProcessGroup(NCCL, initMethod, rank, worldSize)) { + _exit(1); + } + + // rank 0 has meaningful data, others have zeros + Tensor tensor; + if (rank == 0) { + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + tensor = Tensor(data, Options({DeviceType::CUDA, static_cast(rank)}, DType::Float32)); + } else { + tensor = Tensor::zeros({4}, Options({DeviceType::CUDA, static_cast(rank)}, DType::Float32)); + } + + std::vector tensors = {tensor}; + auto work = dpg->broadcast(tensors, 0); + if (!work || !work->wait()) { + _exit(1); + } + + // all ranks should have rank 0's data + auto result = tensor.toList(); + for (int i = 0; i < 4; i++) { + if (std::fabs(result[i] - static_cast(i + 1)) > 1e-3f) { + _exit(1); + } + } + dpg->destroyProcessGroup(); + }); + EXPECT_TRUE(ok); +} + +TEST(TEST_Distributed, AllGather) { + SKIP_IF_NOT_ENOUGH_GPUS(); + + int worldSize = 2; + bool ok = runMultiProcess(worldSize, [](int rank, int worldSize, int port) { + cuda::setDevice(rank); + auto dpg = std::make_shared(); + std::string initMethod = "tcp://127.0.0.1:" + std::to_string(port); + if (!dpg->initProcessGroup(NCCL, initMethod, rank, worldSize)) { + _exit(1); + } + + // each rank has a tensor with its own rank value + std::vector inputData = {static_cast(rank + 1) * 10.0f, static_cast(rank + 1) * 20.0f}; + auto input = Tensor(inputData, Options({DeviceType::CUDA, static_cast(rank)}, DType::Float32)); + + // prepare output tensors (one per rank) + std::vector outputs; + outputs.reserve(worldSize); + for (int i = 0; i < worldSize; i++) { + outputs.push_back( + Tensor::zeros({2}, Options({DeviceType::CUDA, static_cast(rank)}, DType::Float32))); + } + + std::vector> outputTensors = {outputs}; + std::vector inputTensors = {input}; + auto work = dpg->allGather(outputTensors, inputTensors); + if (!work || !work->wait()) { + _exit(1); + } + + // verify: outputs[i] should contain rank i's data + for (int i = 0; i < worldSize; i++) { + auto result = outputTensors[0][i].toList(); + float expected0 = static_cast(i + 1) * 10.0f; + float expected1 = static_cast(i + 1) * 20.0f; + if (std::fabs(result[0] - expected0) > 1e-3f || std::fabs(result[1] - expected1) > 1e-3f) { + _exit(1); + } + } + dpg->destroyProcessGroup(); + }); + EXPECT_TRUE(ok); +} + +TEST(TEST_Distributed, Barrier) { + SKIP_IF_NOT_ENOUGH_GPUS(); + + int worldSize = 2; + bool ok = runMultiProcess(worldSize, [](int rank, int worldSize, int port) { + cuda::setDevice(rank); + auto dpg = std::make_shared(); + std::string initMethod = "tcp://127.0.0.1:" + std::to_string(port); + if (!dpg->initProcessGroup(NCCL, initMethod, rank, worldSize)) { + _exit(1); + } + + // barrier should complete without hanging + dpg->barrier(false, {static_cast(rank)}); + dpg->destroyProcessGroup(); + }); + EXPECT_TRUE(ok); +} + +#endif // defined(USE_CUDA) && defined(USE_NCCL)