Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@ message(STATUS "TINYTORCH_BUILD_TEST ${TINYTORCH_BUILD_TEST}")
message(STATUS "TINYTORCH_USE_CUDA ${TINYTORCH_USE_CUDA}")
message(STATUS "TINYTORCH_USE_NCCL ${TINYTORCH_USE_NCCL}")

if (TINYTORCH_USE_CUDA)
add_definitions(-DUSE_CUDA)
endif ()

if (TINYTORCH_USE_NCCL)
add_definitions(-DUSE_NCCL)
endif ()

add_subdirectory(src)

if (TINYTORCH_BUILD_EXAMPLES)
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ if (TINYTORCH_USE_CUDA)
set_property(TARGET ${PROJECT_NAME} PROPERTY CUDA_ARCHITECTURES native)
target_include_directories(${PROJECT_NAME} PUBLIC ${CUDA_INCLUDE_DIRS})
target_link_libraries(${PROJECT_NAME} PUBLIC CUDA::cudart CUDA::curand CUDA::cublas)
target_compile_definitions(${PROJECT_NAME} PUBLIC USE_CUDA)
endif ()

# NCCL
if (TINYTORCH_USE_NCCL)
target_include_directories(${PROJECT_NAME} PUBLIC ${NCCL_INCLUDE_DIRS})
target_link_libraries(${PROJECT_NAME} PUBLIC ${NCCL_LIBRARY})
target_compile_definitions(${PROJECT_NAME} PUBLIC USE_NCCL)
endif ()
34 changes: 27 additions & 7 deletions src/Distributed/BackendNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,29 @@ std::shared_ptr<Work> BackendNCCL::allGather(std::vector<std::vector<Tensor>>& o
ASSERT(outputTensors.size() == 1);
ASSERT(inputTensors.size() == 1);
auto& input = inputTensors.back();
auto output = op::stack(ArrayView<Tensor>(outputTensors.back()), 0);
return collective(
input, output,
auto& outputs = outputTensors.back();

int64_t numel = input.numel();
auto flatOutput =
Tensor::zeros({numel * static_cast<int64_t>(outputs.size())}, Options(input.device(), input.dtype()));

auto work = collective(
input, flatOutput,
[&](const std::shared_ptr<NCCLComm>& comm, cudaStream_t stream) {
return ncclAllGather(input.dataPtr<>(), output.dataPtr<>(), input.numel(), getNcclDataType(input.dtype()),
return ncclAllGather(input.dataPtr<>(), flatOutput.dataPtr<>(), numel, getNcclDataType(input.dtype()),
comm->getNcclComm(), stream);
},
OpType::ALL_GATHER);

if (work) {
auto workNCCL = std::static_pointer_cast<WorkNCCL>(work);
workNCCL->setPostCompletionFn([flatOutput, outputs, numel]() mutable {
for (size_t i = 0; i < outputs.size(); i++) {
outputs[i].copy_(flatOutput.narrow(0, static_cast<int64_t>(i) * numel, numel).view(outputs[i].sizes()));
}
});
}
return work;
}

std::shared_ptr<Work> BackendNCCL::reduceScatter(std::vector<Tensor>& outputTensors,
Expand Down Expand Up @@ -128,20 +143,25 @@ std::shared_ptr<NCCLComm> BackendNCCL::getNCCLComm(const std::string& deviceKey,
ncclUniqueId commId;
std::string key = NCCL_COMM_ID_PREFIX + std::to_string(ncclCommCounter_++);
if (getRank() == 0) {
NCCL_CALL(ncclGetUniqueId(&commId));
ncclResult_t ret = ncclGetUniqueId(&commId);
if (ret != ncclSuccess && ret != ncclInProgress) {
NCCL_ERROR(ret);
setErrorLocked(BackendErrorType::COMM_ERROR);
return nullptr;
}
store_->set(key, std::string(reinterpret_cast<char*>(&commId), sizeof(commId)));
} else {
auto idData = store_->get(key);
if (idData.size() != sizeof(commId)) {
setError(BackendErrorType::COMM_ERROR);
setErrorLocked(BackendErrorType::COMM_ERROR);
return nullptr;
}
std::memcpy(&commId, idData.data(), sizeof(commId));
}

auto comm = NCCLComm::create(getSize(), getRank(), commId, device.index);
if (!comm) {
setError(BackendErrorType::COMM_ERROR);
setErrorLocked(BackendErrorType::COMM_ERROR);
return nullptr;
}

Expand Down
26 changes: 21 additions & 5 deletions src/Distributed/BackendNCCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class BackendNCCL : public Backend {

void setTimeout(std::chrono::milliseconds timeout) override { options_->timeout = timeout; }

void setUseComputeStream(bool flag) { useComputeStream_ = flag; }
bool useComputeStream() const { return useComputeStream_; }

std::string getBackendName() const override { return NCCL_BACKEND_NAME; }

std::shared_ptr<BackendOptions> getBackendOptions() override { return options_; }
Expand Down Expand Up @@ -90,6 +93,10 @@ class BackendNCCL : public Backend {
protected:
void setError(BackendErrorType error) {
std::lock_guard<std::mutex> lock(mutex_);
setErrorLocked(error);
}

void setErrorLocked(BackendErrorType error) {
if (error_ == BackendErrorType::SUCCESS) {
error_ = error;
}
Expand Down Expand Up @@ -144,6 +151,8 @@ class BackendNCCL : public Backend {

std::list<std::weak_ptr<WorkNCCL>> activeWorks_;

bool useComputeStream_ = false;

mutable std::mutex mutex_;
BackendErrorType error_ = BackendErrorType::SUCCESS;
};
Expand All @@ -164,28 +173,35 @@ std::shared_ptr<Work> BackendNCCL::collective(Tensor& input, Tensor& output, Fun
return nullptr;
}

auto& ncclStream = getNCCLStream(deviceKey, device.index);
auto& computeSteam = cuda::getCurrentCUDAStream(device.index);
ncclStream.waitStream(computeSteam);
cuda::CUDAStream& opStream = useComputeStream_ ? computeSteam : getNCCLStream(deviceKey, device.index);
if (!useComputeStream_) {
opStream.waitStream(computeSteam);
}

auto work = std::make_shared<WorkNCCL>(std::to_string(getID()), getGroupDesc(), device, getRank(), opType);
work->setOpTimeout(options_->timeout);
work->setNCCLComm(comm);
work->setStore(store_);
work->setIsBarrierOp(opType == OpType::BARRIER);
work->setCudaEvent(cuda::createCUDAEvent(device.index));
work->setUseComputeStream(useComputeStream_);
if (!useComputeStream_) {
work->setCudaEvent(cuda::createCUDAEvent(device.index));
}
work->setOutputs({output});

cuda::CudaDeviceGuard guard(device.index);
NCCL_CALL(ncclGroupStart());
NCCL_CALL(ncclFunc(comm, ncclStream.stream()));
NCCL_CALL(ncclFunc(comm, opStream.stream()));
ncclResult_t state = ncclGroupEnd();

if (!waitNcclCommandResult(comm, state)) {
return nullptr;
}

work->getCudaEvent().record(ncclStream);
if (!useComputeStream_) {
work->getCudaEvent().record(opStream);
}
workEnqueue(work);
return work;
}
Expand Down
16 changes: 7 additions & 9 deletions src/Distributed/DistributedProcessGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,25 +138,23 @@ std::shared_ptr<Store> DistributedProcessGroup::createStore(const InitConfig& co
if (config.method == InitMethod::ENV) {
bool isServer = (config.rank == 0);
int numWorkers = config.worldSize;
bool waitWorkers = true;
return std::make_shared<TCPStore>(config.masterAddr, // host
config.masterPort, // port
isServer, // isServer
waitWorkers, // waitWorkers
numWorkers, // numWorkers
timeout // timeout
return std::make_shared<TCPStore>(config.masterAddr, // host
config.masterPort, // port
isServer, // isServer
config.waitWorkers, // waitWorkers
numWorkers, // numWorkers
timeout // timeout
);
}

// tcp
if (config.method == InitMethod::TCP) {
bool isServer = (config.rank == 0);
int numWorkers = config.worldSize;
bool waitWorkers = true;
return std::make_shared<TCPStore>(config.masterAddr, // host
static_cast<uint16_t>(config.masterPort), // port
isServer, // isServer
waitWorkers, // waitWorkers
config.waitWorkers, // waitWorkers
numWorkers, // numWorkers
timeout // timeout
);
Expand Down
2 changes: 1 addition & 1 deletion src/Distributed/DistributedSampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ DistributedSampler::DistributedSampler(size_t datasetSize, std::optional<size_t>
if (dropLast_) {
numSamples_ = datasetSize_ / numReplicas_;
} else {
numSamples_ = std::ceil(static_cast<float>(datasetSize_) * 1.0f / static_cast<float>(numReplicas_));
numSamples_ = (datasetSize_ + numReplicas_ - 1) / numReplicas_;
}
totalSize_ = numSamples_ * numReplicas_;
generateIndices();
Expand Down
4 changes: 2 additions & 2 deletions src/Distributed/FileStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class FileLock {
return;
}

struct flock fl{};
struct flock fl {};
fl.l_type = F_WRLCK;
fl.l_whence = SEEK_SET;
fl.l_start = 0;
Expand All @@ -44,7 +44,7 @@ class FileLock {

~FileLock() {
if (fd_ != -1) {
struct flock fl{};
struct flock fl {};
fl.l_type = F_UNLCK;
fl.l_whence = SEEK_SET;
fl.l_start = 0;
Expand Down
4 changes: 1 addition & 3 deletions src/Distributed/Reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void Reducer::reduceBucket(int64_t bucketIdx) {
ASSERT(!bucket.params.empty());

std::vector<Tensor> tensors = {bucket.flatBuffer};
bucket.work = processGroup_->allReduce(tensors);
bucket.work = processGroup_->allReduce(tensors, AVG);
ASSERT(bucket.work);

bucketReadyCnt_++;
Expand Down Expand Up @@ -169,8 +169,6 @@ void Reducer::checkAllBucketsReady() {
ASSERT(success);

cuda::CudaDeviceGuard guard(bucket.flatBuffer.device().index);
auto worldSize = processGroup_->getWorldSize();
bucket.flatBuffer /= static_cast<float>(worldSize);
copyFlattenedBufferToGrads(bIdx);

bucket.work = nullptr;
Expand Down
Loading
Loading