Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cudax/include/cuda/experimental/__group/fwd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class group;

// mappings

template <::cuda::std::size_t _Np = ::cuda::std::dynamic_extent, bool _IsExhaustive = true>
template <::cuda::std::size_t _Count = ::cuda::std::dynamic_extent, bool _IsExhaustive = true>
class group_by;

template <class _Data, bool _IsExahustive>
Expand Down
32 changes: 21 additions & 11 deletions cudax/include/cuda/experimental/__group/group.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <cuda/experimental/__group/concepts.cuh>
#include <cuda/experimental/__group/fwd.cuh>
#include <cuda/experimental/__group/mapping/group_by.cuh>
#include <cuda/experimental/__group/mapping/mapping_result.cuh>
#include <cuda/experimental/__group/this_group.cuh>
#include <cuda/experimental/__group/traits.cuh>

Expand All @@ -57,8 +58,26 @@ class group

// todo(dabayer): static_assert that _Unit is (under) typename _ParentGroup::unit_type

[[nodiscard]] _CCCL_DEVICE_API static constexpr auto
__get_initial_mapping_result(const _ParentGroup& __parent) noexcept
{
using _ParentMappingResult = typename _ParentGroup::__mapping_result_type;
using _MappingResult =
::cuda::experimental::__mapping_result<1,
::cuda::experimental::__static_count_query_group<_Unit, _ParentGroup>(),
_ParentMappingResult::is_always_exhaustive(),
_ParentMappingResult::is_always_contiguous()>;
return _MappingResult{
1,
0,
::cuda::experimental::__count_query_group<unsigned, _Unit>(__parent),
::cuda::experimental::__rank_query_group<unsigned, _Unit>(__parent)};
}

using _ParentMappingResult = typename _ParentGroup::__mapping_result_type;
using _MappingResult = __group_mapping_result_t<_Mapping, _Unit, _ParentGroup>;
using _MappingResult = decltype(::cuda::std::declval<const _Mapping&>().map(
::cuda::std::declval<const _ParentGroup&>(),
__get_initial_mapping_result(::cuda::std::declval<const _ParentGroup&>())));
using _SynchronizerInstance =
__group_synchronizer_instance_t<_Synchronizer, _Unit, _ParentGroup, _Mapping, _MappingResult>;
static_assert(__group_mapping_result<_MappingResult>);
Expand All @@ -72,16 +91,7 @@ class group
[[nodiscard]] _CCCL_DEVICE_API static _MappingResult
__do_mapping(const _Mapping& __mapping, const _ParentGroup& __parent) noexcept
{
// Do not invoke the mapping for threads that are not part of the parent group.
if constexpr (!_ParentMappingResult::is_always_exhaustive())
{
if (!__parent.__mapping_result().is_valid())
{
return _MappingResult::invalid();
}
}

const auto __mapping_result = __mapping.map(_Unit{}, __parent);
const auto __mapping_result = __mapping.map(__parent, __get_initial_mapping_result(__parent));
if (__mapping_result.is_valid())
{
_CCCL_ASSERT(__mapping_result.group_rank() < __mapping_result.group_count(), "invalid group rank");
Expand Down
253 changes: 68 additions & 185 deletions cudax/include/cuda/experimental/__group/mapping/group_as.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <cuda/std/span>

#include <cuda/experimental/__group/fwd.cuh>
#include <cuda/experimental/__group/mapping/mapping_result.cuh>
#include <cuda/experimental/__group/queries.cuh>
#include <cuda/experimental/__group/traits.cuh>

Expand All @@ -54,83 +55,6 @@ class group_as<__group_as_static_tag<_Counts...>, _IsExhaustive>
static constexpr auto __counts_sum = (0 + ... + _Counts);

public:
template <bool _ParentIsAlwaysExhaustive, bool _ParentIsAlwaysContiguous>
struct __mapping_result
{
unsigned __group_rank_;
unsigned __count_;
unsigned __rank_;

[[nodiscard]] _CCCL_DEVICE_API static constexpr __mapping_result invalid() noexcept
{
return {__invalid_count_or_rank, __invalid_count_or_rank, __invalid_count_or_rank};
}

[[nodiscard]] _CCCL_DEVICE_API static constexpr ::cuda::std::size_t static_group_count() noexcept
{
return sizeof...(_Counts);
}

[[nodiscard]] _CCCL_DEVICE_API unsigned group_count() const noexcept
{
return static_cast<unsigned>(sizeof...(_Counts));
}

[[nodiscard]] _CCCL_DEVICE_API unsigned group_rank() const noexcept
{
if constexpr (!is_always_exhaustive())
{
_CCCL_ASSERT(is_valid(), "getting group rank of thread that is not part of the group is UB");
}
return __group_rank_;
}

[[nodiscard]] _CCCL_DEVICE_API static constexpr ::cuda::std::size_t static_count() noexcept
{
return ::cuda::std::dynamic_extent;
}

[[nodiscard]] _CCCL_DEVICE_API unsigned count() const noexcept
{
if constexpr (!is_always_exhaustive())
{
_CCCL_ASSERT(is_valid(), "getting count of thread that is not part of the group is UB");
}
return __count_;
}

[[nodiscard]] _CCCL_DEVICE_API unsigned rank() const noexcept
{
if constexpr (!is_always_exhaustive())
{
_CCCL_ASSERT(is_valid(), "getting rank of thread that is not part of the group is UB");
}
return __rank_;
}

[[nodiscard]] _CCCL_DEVICE_API bool is_valid() const noexcept
{
if constexpr (is_always_exhaustive())
{
return true;
}
else
{
return __rank_ != __invalid_count_or_rank;
}
}

[[nodiscard]] _CCCL_DEVICE_API static constexpr bool is_always_exhaustive() noexcept
{
return _ParentIsAlwaysExhaustive && _IsExhaustive;
}

[[nodiscard]] _CCCL_DEVICE_API static constexpr bool is_always_contiguous() noexcept
{
return _ParentIsAlwaysContiguous;
}
};

_CCCL_HIDE_FROM_ABI explicit group_as() = default;

_CCCL_TEMPLATE(bool _IsExhaustive2 = _IsExhaustive)
Expand Down Expand Up @@ -170,56 +94,73 @@ public:
return static_cast<unsigned>(static_count(__i));
}

template <class _Unit, class _ParentGroup>
[[nodiscard]] _CCCL_DEVICE_API auto map(const _Unit& __unit, const _ParentGroup& __parent) const noexcept
template <class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto
map(const _ParentGroup&, const _PrevMappingResult& __prev_mapping_result) const noexcept
{
using _ParentMappingResult = typename _ParentGroup::__mapping_result_type;
constexpr auto __static_prev_ngroups = _PrevMappingResult::static_group_count();
constexpr auto __static_prev_nunits = _PrevMappingResult::static_count();
constexpr auto __static_curr_ngroups = sizeof...(_Counts);
constexpr auto __static_ngroups =
(__static_prev_ngroups != ::cuda::std::dynamic_extent)
? (__static_prev_ngroups * __static_curr_ngroups)
: ::cuda::std::dynamic_extent;

using _MappingResult =
__mapping_result<_ParentMappingResult::is_always_exhaustive(), _ParentMappingResult::is_always_contiguous()>;
__mapping_result<__static_ngroups,
::cuda::std::dynamic_extent,
_PrevMappingResult::is_always_exhaustive() && _IsExhaustive,
_PrevMappingResult::is_always_contiguous()>;

if (!__prev_mapping_result.is_valid())
{
return _MappingResult::invalid();
}

constexpr auto __static_nunits = ::cuda::experimental::__static_count_query_group<_Unit, _ParentGroup>();
const auto __nunits = _Unit::template count_as<unsigned>(__parent);
const auto __unit_rank = _Unit::template rank_as<unsigned>(__parent);
const auto __ngroups = static_cast<unsigned>(sizeof...(_Counts));
const auto __prev_nunits = __prev_mapping_result.count();
const auto __prev_unit_rank = __prev_mapping_result.rank();
constexpr auto __curr_ngroups = static_cast<unsigned>(sizeof...(_Counts));
const auto __ngroups = __prev_mapping_result.group_count() * __curr_ngroups;

if constexpr (_IsExhaustive)
{
if constexpr (__static_nunits != ::cuda::std::dynamic_extent)
if constexpr (__static_prev_nunits != ::cuda::std::dynamic_extent)
{
static_assert(__static_nunits == __counts_sum, "group_as mapping _IsExhaustive precondition violation");
static_assert(__static_prev_nunits == __counts_sum, "group_as mapping _IsExhaustive precondition violation");
}
else
{
_CCCL_ASSERT(__nunits == static_cast<unsigned>(__counts_sum),
_CCCL_ASSERT(__prev_nunits == static_cast<unsigned>(__counts_sum),
"group_as mapping _IsExhaustive precondition violation");
}
}
else
{
if constexpr (__static_nunits != ::cuda::std::dynamic_extent)
if constexpr (__static_prev_nunits != ::cuda::std::dynamic_extent)
{
static_assert(__static_nunits >= __counts_sum, "group_as mapping requires more units than are available");
static_assert(__static_prev_nunits >= __counts_sum, "group_as mapping requires more units than are available");
}
else
{
_CCCL_ASSERT(__nunits >= static_cast<unsigned>(__counts_sum),
_CCCL_ASSERT(__prev_nunits >= static_cast<unsigned>(__counts_sum),
"group_as mapping requires more units than are available");
}

if (__unit_rank >= static_cast<unsigned>(__counts_sum))
if (__prev_unit_rank >= static_cast<unsigned>(__counts_sum))
{
return _MappingResult::invalid();
return _MappingResult::invalid_with_group_count(__ngroups);
}
}

unsigned __sum = 0;
_CCCL_PRAGMA_UNROLL_FULL()
for (unsigned __i = 0; __i < __ngroups; ++__i)
for (unsigned __i = 0; __i < __curr_ngroups; ++__i)
{
const auto __i_count = count(__i);
if (__unit_rank < __sum + __i_count)
if (__prev_unit_rank < __sum + __i_count)
{
return _MappingResult{__i, __i_count, __unit_rank - __sum};
return _MappingResult{
__ngroups, __prev_mapping_result.group_rank() * __curr_ngroups + __i, __i_count, __prev_unit_rank - __sum};
}
__sum += __i_count;
}
Expand All @@ -238,83 +179,6 @@ class group_as<__group_as_dynamic_tag<_GroupCount>, _IsExhaustive>
unsigned __counts_[_GroupCount];

public:
template <bool _ParentIsAlwaysExhaustive, bool _ParentIsAlwaysContiguous>
struct __mapping_result
{
unsigned __group_rank_;
unsigned __count_;
unsigned __rank_;

[[nodiscard]] _CCCL_DEVICE_API static constexpr __mapping_result invalid() noexcept
{
return {__invalid_count_or_rank, __invalid_count_or_rank, __invalid_count_or_rank};
}

[[nodiscard]] _CCCL_DEVICE_API static constexpr ::cuda::std::size_t static_group_count() noexcept
{
return _GroupCount;
}

[[nodiscard]] _CCCL_DEVICE_API unsigned group_count() const noexcept
{
return static_cast<unsigned>(_GroupCount);
}

[[nodiscard]] _CCCL_DEVICE_API unsigned group_rank() const noexcept
{
if constexpr (!is_always_exhaustive())
{
_CCCL_ASSERT(is_valid(), "getting group rank of thread that is not part of the group is UB");
}
return __group_rank_;
}

[[nodiscard]] _CCCL_DEVICE_API static constexpr ::cuda::std::size_t static_count() noexcept
{
return ::cuda::std::dynamic_extent;
}

[[nodiscard]] _CCCL_DEVICE_API unsigned count() const noexcept
{
if constexpr (!is_always_exhaustive())
{
_CCCL_ASSERT(is_valid(), "getting group rank of thread that is not part of the group is UB");
}
return __count_;
}

[[nodiscard]] _CCCL_DEVICE_API unsigned rank() const noexcept
{
if constexpr (!is_always_exhaustive())
{
_CCCL_ASSERT(is_valid(), "getting rank of thread that is not part of the group is UB");
}
return __rank_;
}

[[nodiscard]] _CCCL_DEVICE_API bool is_valid() const noexcept
{
if constexpr (is_always_exhaustive())
{
return true;
}
else
{
return __rank_ != __invalid_count_or_rank;
}
}

[[nodiscard]] _CCCL_DEVICE_API static constexpr bool is_always_exhaustive() noexcept
{
return _ParentIsAlwaysExhaustive && _IsExhaustive;
}

[[nodiscard]] _CCCL_DEVICE_API static constexpr bool is_always_contiguous() noexcept
{
return _ParentIsAlwaysContiguous;
}
};

_CCCL_TEMPLATE(bool _IsExhaustive2 = _IsExhaustive)
_CCCL_REQUIRES(_IsExhaustive2)
_CCCL_DEVICE_API explicit constexpr group_as(::cuda::std::span<const unsigned, _GroupCount> __counts) noexcept
Expand Down Expand Up @@ -368,35 +232,54 @@ public:
return __counts_[__i];
}

template <class _Unit, class _ParentGroup>
[[nodiscard]] _CCCL_DEVICE_API auto map(const _Unit& __unit, const _ParentGroup& __parent) const noexcept
template <class _ParentGroup, class _PrevMappingResult>
[[nodiscard]] _CCCL_DEVICE_API auto
map(const _ParentGroup&, const _PrevMappingResult& __prev_mapping_result) const noexcept
{
using _ParentMappingResult = typename _ParentGroup::__mapping_result_type;
constexpr auto __static_prev_ngroups = _PrevMappingResult::static_group_count();
constexpr auto __static_prev_nunits = _PrevMappingResult::static_count();
constexpr auto __static_curr_ngroups = _GroupCount;
constexpr auto __static_ngroups =
(__static_prev_ngroups != ::cuda::std::dynamic_extent)
? (__static_prev_ngroups * __static_curr_ngroups)
: ::cuda::std::dynamic_extent;

using _MappingResult =
__mapping_result<_ParentMappingResult::is_always_exhaustive(), _ParentMappingResult::is_always_contiguous()>;
__mapping_result<__static_ngroups,
::cuda::std::dynamic_extent,
_PrevMappingResult::is_always_exhaustive() && _IsExhaustive,
_PrevMappingResult::is_always_contiguous()>;

if (!__prev_mapping_result.is_valid())
{
return _MappingResult::invalid();
}

const auto __nunits = __unit.template count_as<unsigned>(__parent);
const auto __unit_rank = __unit.template rank_as<unsigned>(__parent);
const auto __prev_nunits = __prev_mapping_result.count();
const auto __prev_unit_rank = __prev_mapping_result.rank();
constexpr auto __curr_ngroups = static_cast<unsigned>(_GroupCount);
const auto __ngroups = __prev_mapping_result.group_count() * __curr_ngroups;

// If the mapping is exhaustive, check the preconditions, otherwise remove the last partial group.
if constexpr (_IsExhaustive)
{
_CCCL_ASSERT(::cuda::std::accumulate(__counts_, __counts_ + _GroupCount, 0u) == __nunits,
_CCCL_ASSERT(::cuda::std::accumulate(__counts_, __counts_ + __curr_ngroups, 0u) == __prev_nunits,
"group_as mapping _IsExhaustive precondition violation");
}
else if (__unit_rank >= ::cuda::std::accumulate(__counts_, __counts_ + _GroupCount, 0u))
else if (__prev_unit_rank >= ::cuda::std::accumulate(__counts_, __counts_ + __curr_ngroups, 0u))
{
return _MappingResult::invalid();
return _MappingResult::invalid_with_group_count(__ngroups);
}

unsigned __sum = 0;
_CCCL_PRAGMA_UNROLL_FULL()
for (unsigned __i = 0; __i < _GroupCount; ++__i)
for (unsigned __i = 0; __i < __curr_ngroups; ++__i)
{
const auto __i_count = count(__i);
if (__unit_rank < __sum + __i_count)
if (__prev_unit_rank < __sum + __i_count)
{
return _MappingResult{__i, __i_count, __unit_rank - __sum};
return _MappingResult{
__ngroups, __prev_mapping_result.group_rank() * __curr_ngroups + __i, __i_count, __prev_unit_rank - __sum};
}
__sum += __i_count;
}
Expand Down
Loading
Loading