Skip to content

Implement a sparse alltoall exchange pattern#959

Open
wence- wants to merge 5 commits intorapidsai:mainfrom
wence-:wence/fea/neighbour-alltoall
Open

Implement a sparse alltoall exchange pattern#959
wence- wants to merge 5 commits intorapidsai:mainfrom
wence-:wence/fea/neighbour-alltoall

Conversation

@wence-
Copy link
Copy Markdown
Contributor

@wence- wence- commented Apr 10, 2026

In this collective every rank advertises the destination ranks it will send
to and the source ranks it will receive from (these have to match up
collective, no error is provided). The caller can then insert messages to
particular ranks, followed by a final insert_finished() call.

On the receive side, after waiting for completion, we can extract received
messages by rank. The receive side message order is defined by the
insertion order on the send side. That is, if rank-A inserts messages in
order [A0, A1, A2] to rank-B, then when rank-B calls extract(rank-A) it
will see the same order (even if the messages were sent in a different order).

wence- added 5 commits April 10, 2026 16:03
In this collective every rank advertises the destination ranks it will send
to and the source ranks it will receive from (these have to match up
collective, no error is provided). The caller can then insert messages to
particular ranks, followed by a final insert_finished() call.

On the receive side, after waiting for completion, we can extract received
messages by rank. The receive side message order is defined by the
insertion order on the send side. That is, if rank-A inserts messages in
order [A0, A1, A2] to rank-B, then when rank-B calls `extract(rank-A)` it
will see the same order (even if the messages were sent in a different order).
@wence- wence- requested review from a team as code owners April 10, 2026 16:33
@wence-
Copy link
Copy Markdown
Contributor Author

wence- commented Apr 10, 2026

I can split these into bits for review purposes if that is useful

@wence- wence- added improvement Improves an existing functionality non-breaking Introduces a non-breaking change labels Apr 10, 2026
Comment on lines +96 to +97
* @note Concurrent insertion by multiple threads is supported, the caller must ensure
* that `insert_finished()` is called _after_ all `insert()` calls have completed.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They seem like two separate notes or am I misreading? It seems that the multi-threaded insertion comment and insert_finished() requirement are independent from each other.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relates to the note in insert_finished, they are indeed dependent. I would rephrase this slightly to make dependence clearer:

Suggested change
* @note Concurrent insertion by multiple threads is supported, the caller must ensure
* that `insert_finished()` is called _after_ all `insert()` calls have completed.
* @note Concurrent insertion by multiple threads is supported, the caller must ensure
* all `insert()` calls (concurrent or not) have completed before calling
* `insert_finished()`.

packed_data_vector_to_list)


cdef class SparseAlltoall:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this isn't common, but especially with Python free threading do we want to also add notes about multithreading support here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plc.Table(
[plc.Column.from_array(np.array([29], dtype=np.int32), stream=stream)]
),
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is ok but it only tests exactly 2 ranks, while any other ranks are all no-op. I think we could use a test that would test as many ranks are available too.

);
}

void SparseAlltoall::insert_finished() {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably add a check to prevent calling insert_finished() more than once.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can check locally_finished_==false

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few tests I think we're missing include:

  • insert() after insert_finished()
  • multi-threaded insert()
  • extract() with invalid source rank

RAPIDSMPF_EXPECTS_FATAL(
event_.is_set(),
"~SparseAlltoall: not all notification tasks complete, did you forget to await "
"this->wait() or to call this->insert_finished()?"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this->wait() a mistake? I don't think we have a wait() method.

Copy link
Copy Markdown
Contributor

@nirandaperera nirandaperera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had some comments for the normal impl. I will check the coroutine impl later today.

std::uint64_t received_count{0};
std::vector<std::unique_ptr<detail::Chunk>> chunks;

[[nodiscard]] bool ready() const noexcept {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit.

Suggested change
[[nodiscard]] bool ready() const noexcept {
[[nodiscard]] bool constexpr ready() const noexcept {

void send_ready_messages();
void receive_metadata_messages();
void receive_data_messages();
void complete_data_messages();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to add some @brief docstrings here for these methods (for future references). Either here, or in the cpp file

RAPIDSMPF_EXPECTS(br_ != nullptr, "the buffer resource pointer cannot be null");
auto const size = comm_->nranks();
auto const self = comm_->rank();
for (auto src : srcs_) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
for (auto src : srcs_) {
source_states.reserve(srcs_.size());
for (auto src : srcs_) {

);
source_states_.emplace(src, SourceState{});
}
for (auto dst : dsts_) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit.

Suggested change
for (auto dst : dsts_) {
next_ordinal_per_dst_.reserve(dsts_.size());
for (auto dst : dsts_) {

Comment on lines +63 to +67
SparseAlltoall::~SparseAlltoall() noexcept {
RAPIDSMPF_EXPECTS_FATAL(
locally_finished_.load(std::memory_order_acquire),
"Destroying SparseAlltoall without `insert_finished()`"
);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marked as noexcept but throwing.

Tag const metadata_tag{op_id_, 0};
for (auto src : srcs_) {
auto& state = source_states_.at(src);
while (!state.ready()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woouldnt this while not ready loop hog the progress thread until all the messages are received from all sources? I feel like this will be unfair for other concurrent collectives, isnt it?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to use an if instead?

);
state.expected_count = chunk->sequence();
} else {
incoming_by_src_.at(src).push_back(std::move(chunk));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
incoming_by_src_.at(src).push_back(std::move(chunk));
incoming_by_src_.[src].push_back(std::move(chunk));

src >= 0 && src < size && src != self, "SparseAlltoall invalid source rank."
);
RAPIDSMPF_EXPECTS(
incoming_by_src_.emplace(src, std::vector<std::unique_ptr<detail::Chunk>>{})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why cant incoming_by_src_ a class member of the SourceState? I feel like, it is the received metadata queue from a particular source, isnt it?

);
}
}
queue.erase(queue.begin(), queue.begin() + processed);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't a std::deque/queue better here? Unprocessed chunks will always be moved by processed elements in each progress iteration, isnt it?

}
processed++;
if (chunk->data_size() == 0) {
auto& state = source_states_.at(chunk->origin());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto& state = source_states_.at(chunk->origin());
auto& state = source_states_[chunk->origin()];

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants