32#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
33#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
36#pragma clang diagnostic push
37#pragma clang diagnostic ignored "-Wc++98-compat"
38#pragma clang diagnostic ignored "-Wsign-conversion"
39#pragma clang diagnostic ignored "-Wunused-parameter"
40#pragma clang diagnostic ignored "-Wreserved-macro-identifier"
41#pragma clang diagnostic ignored "-Wpadded"
45#if !defined(__HIPCC_RTC__)
49#define __hip_abort() \
52#define __hip_assert(COND)
54#define __hip_assert(COND) \
62namespace cooperative_groups {
84 __CG_QUALIFIER__ thread_group(internal::group_type type, uint32_t size =
static_cast<uint64_t
>(0),
85 uint64_t mask =
static_cast<uint64_t
>(0)) {
94 unsigned int meta_group_rank;
95 unsigned int meta_group_size;
98 struct _coalesced_info {
99 lane_mask member_mask;
101 struct _tiled_info tiled_info;
104 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
105 unsigned int tile_size);
106 friend class thread_block;
112 __CG_QUALIFIER__ uint32_t size()
const {
return _size; }
113 __CG_QUALIFIER__
unsigned int cg_type()
const {
return _type; }
115 __CG_QUALIFIER__ uint32_t thread_rank()
const;
117 __CG_QUALIFIER__
bool is_valid()
const;
119 __CG_QUALIFIER__
void sync()
const;
144class multi_grid_group :
public thread_group {
147 friend __CG_QUALIFIER__ multi_grid_group this_multi_grid();
151 explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size)
152 : thread_group(internal::cg_multi_grid, size) {}
157 __CG_QUALIFIER__ uint32_t num_grids() {
return internal::multi_grid::num_grids(); }
160 __CG_QUALIFIER__ uint32_t grid_rank() {
return internal::multi_grid::grid_rank(); }
161 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::multi_grid::thread_rank(); }
162 __CG_QUALIFIER__
bool is_valid()
const {
return internal::multi_grid::is_valid(); }
163 __CG_QUALIFIER__
void sync()
const { internal::multi_grid::sync(); }
175__CG_QUALIFIER__ multi_grid_group this_multi_grid() {
176 return multi_grid_group(internal::multi_grid::size());
187class grid_group :
public thread_group {
190 friend __CG_QUALIFIER__ grid_group this_grid();
194 explicit __CG_QUALIFIER__ grid_group(uint32_t size) : thread_group(internal::cg_grid, size) {}
197 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::grid::thread_rank(); }
198 __CG_QUALIFIER__
bool is_valid()
const {
return internal::grid::is_valid(); }
199 __CG_QUALIFIER__
void sync()
const { internal::grid::sync(); }
211__CG_QUALIFIER__ grid_group this_grid() {
return grid_group(internal::grid::size()); }
222class thread_block :
public thread_group {
225 friend __CG_QUALIFIER__ thread_block this_thread_block();
226 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
227 unsigned int tile_size);
228 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_block& parent,
229 unsigned int tile_size);
232 explicit __CG_QUALIFIER__ thread_block(uint32_t size)
233 : thread_group(internal::cg_workgroup, size) {}
235 __CG_QUALIFIER__ thread_group new_tiled_group(
unsigned int tile_size)
const {
236 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
238 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
239 __hip_assert(
false &&
"invalid tile size");
242 thread_group tiledGroup = thread_group(internal::cg_tiled_group, tile_size);
243 tiledGroup.coalesced_info.tiled_info.size = tile_size;
244 tiledGroup.coalesced_info.tiled_info.is_tiled =
true;
245 tiledGroup.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
246 tiledGroup.coalesced_info.tiled_info.meta_group_size = (size() + tile_size - 1) / tile_size;
252 __CG_STATIC_QUALIFIER__ dim3 group_index() {
return internal::workgroup::group_index(); }
254 __CG_STATIC_QUALIFIER__ dim3 thread_index() {
return internal::workgroup::thread_index(); }
255 __CG_STATIC_QUALIFIER__ uint32_t thread_rank() {
return internal::workgroup::thread_rank(); }
256 __CG_STATIC_QUALIFIER__ uint32_t size() {
return internal::workgroup::size(); }
257 __CG_STATIC_QUALIFIER__
bool is_valid() {
return internal::workgroup::is_valid(); }
258 __CG_STATIC_QUALIFIER__
void sync() { internal::workgroup::sync(); }
259 __CG_QUALIFIER__ dim3 group_dim() {
return internal::workgroup::block_dim(); }
271__CG_QUALIFIER__ thread_block this_thread_block() {
272 return thread_block(internal::workgroup::size());
283class tiled_group :
public thread_group {
285 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
286 unsigned int tile_size);
287 friend __CG_QUALIFIER__ tiled_group tiled_partition(
const tiled_group& parent,
288 unsigned int tile_size);
290 __CG_QUALIFIER__ tiled_group new_tiled_group(
unsigned int tile_size)
const {
291 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
293 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
294 __hip_assert(
false &&
"invalid tile size");
297 if (size() <= tile_size) {
301 tiled_group tiledGroup = tiled_group(tile_size);
302 tiledGroup.coalesced_info.tiled_info.is_tiled =
true;
307 explicit __CG_QUALIFIER__ tiled_group(
unsigned int tileSize)
308 : thread_group(internal::cg_tiled_group, tileSize) {
309 coalesced_info.tiled_info.size = tileSize;
310 coalesced_info.tiled_info.is_tiled =
true;
314 __CG_QUALIFIER__
unsigned int size()
const {
return (coalesced_info.tiled_info.size); }
316 __CG_QUALIFIER__
unsigned int thread_rank()
const {
317 return (internal::workgroup::thread_rank() & (coalesced_info.tiled_info.size - 1));
320 __CG_QUALIFIER__
void sync()
const {
321 internal::tiled_group::sync();
332class coalesced_group :
public thread_group {
334 friend __CG_QUALIFIER__ coalesced_group coalesced_threads();
335 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
unsigned int tile_size);
336 friend __CG_QUALIFIER__ coalesced_group tiled_partition(
const coalesced_group& parent,
unsigned int tile_size);
338 __CG_QUALIFIER__ coalesced_group new_tiled_group(
unsigned int tile_size)
const {
339 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
341 if (!tile_size || (tile_size > size()) || !pow2) {
342 return coalesced_group(0);
347 if (coalesced_info.tiled_info.is_tiled) {
348 unsigned int base_offset = (thread_rank() & (~(tile_size - 1)));
349 unsigned int masklength = min(
static_cast<unsigned int>(size()) - base_offset, tile_size);
350 lane_mask member_mask =
static_cast<lane_mask
>(-1) >> (__AMDGCN_WAVEFRONT_SIZE - masklength);
352 member_mask <<= (__lane_id() & ~(tile_size - 1));
353 coalesced_group coalesced_tile = coalesced_group(member_mask);
354 coalesced_tile.coalesced_info.tiled_info.is_tiled =
true;
355 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
356 coalesced_tile.coalesced_info.tiled_info.meta_group_size = size() / tile_size;
357 return coalesced_tile;
361 lane_mask member_mask = 0;
362 unsigned int tile_rank = 0;
363 int lanes_to_skip = ((thread_rank()) / tile_size) * tile_size;
365 for (
unsigned int i = 0; i < __AMDGCN_WAVEFRONT_SIZE; i++) {
366 lane_mask active = coalesced_info.member_mask & (1 << i);
369 if (lanes_to_skip <= 0 && tile_rank < tile_size) {
371 member_mask |= active;
377 coalesced_group coalesced_tile = coalesced_group(member_mask);
378 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
379 coalesced_tile.coalesced_info.tiled_info.meta_group_size =
380 (size() + tile_size - 1) / tile_size;
381 return coalesced_tile;
383 return coalesced_group(0);
388 explicit __CG_QUALIFIER__ coalesced_group(lane_mask member_mask)
389 : thread_group(internal::cg_coalesced_group) {
390 coalesced_info.member_mask = member_mask;
391 coalesced_info.size = __popcll(coalesced_info.member_mask);
392 coalesced_info.tiled_info.is_tiled =
false;
396 __CG_QUALIFIER__
unsigned int size()
const {
397 return coalesced_info.size;
400 __CG_QUALIFIER__
unsigned int thread_rank()
const {
401 return internal::coalesced_group::masked_bit_count(coalesced_info.member_mask);
404 __CG_QUALIFIER__
void sync()
const {
405 internal::coalesced_group::sync();
408 __CG_QUALIFIER__
unsigned int meta_group_rank()
const {
409 return coalesced_info.tiled_info.meta_group_rank;
412 __CG_QUALIFIER__
unsigned int meta_group_size()
const {
413 return coalesced_info.tiled_info.meta_group_size;
417 __CG_QUALIFIER__ T shfl(T var,
int srcRank)
const {
418 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
420 srcRank = srcRank %
static_cast<int>(size());
422 int lane = (size() == __AMDGCN_WAVEFRONT_SIZE) ? srcRank
423 : (__AMDGCN_WAVEFRONT_SIZE == 64) ? __fns64(coalesced_info.member_mask, 0, (srcRank + 1))
424 : __fns32(coalesced_info.member_mask, 0, (srcRank + 1));
426 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
430 __CG_QUALIFIER__ T shfl_down(T var,
unsigned int lane_delta)
const {
431 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
437 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
438 return __shfl_down(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
442 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
443 lane = __fns64(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
446 lane = __fns32(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
453 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
457 __CG_QUALIFIER__ T shfl_up(T var,
unsigned int lane_delta)
const {
458 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
464 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
465 return __shfl_up(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
469 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
470 lane = __fns64(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
472 else if (__AMDGCN_WAVEFRONT_SIZE == 32) {
473 lane = __fns32(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
480 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
491__CG_QUALIFIER__ coalesced_group coalesced_threads() {
492 return cooperative_groups::coalesced_group(__builtin_amdgcn_read_exec());
500__CG_QUALIFIER__ uint32_t thread_group::thread_rank()
const {
501 switch (this->_type) {
502 case internal::cg_multi_grid: {
503 return (
static_cast<const multi_grid_group*
>(
this)->thread_rank());
505 case internal::cg_grid: {
506 return (
static_cast<const grid_group*
>(
this)->thread_rank());
508 case internal::cg_workgroup: {
509 return (
static_cast<const thread_block*
>(
this)->thread_rank());
511 case internal::cg_tiled_group: {
512 return (
static_cast<const tiled_group*
>(
this)->thread_rank());
514 case internal::cg_coalesced_group: {
515 return (
static_cast<const coalesced_group*
>(
this)->thread_rank());
518 __hip_assert(
false &&
"invalid cooperative group type");
528__CG_QUALIFIER__
bool thread_group::is_valid()
const {
529 switch (this->_type) {
530 case internal::cg_multi_grid: {
531 return (
static_cast<const multi_grid_group*
>(
this)->is_valid());
533 case internal::cg_grid: {
534 return (
static_cast<const grid_group*
>(
this)->is_valid());
536 case internal::cg_workgroup: {
537 return (
static_cast<const thread_block*
>(
this)->is_valid());
539 case internal::cg_tiled_group: {
540 return (
static_cast<const tiled_group*
>(
this)->is_valid());
542 case internal::cg_coalesced_group: {
543 return (
static_cast<const coalesced_group*
>(
this)->is_valid());
546 __hip_assert(
false &&
"invalid cooperative group type");
556__CG_QUALIFIER__
void thread_group::sync()
const {
557 switch (this->_type) {
558 case internal::cg_multi_grid: {
559 static_cast<const multi_grid_group*
>(
this)->sync();
562 case internal::cg_grid: {
563 static_cast<const grid_group*
>(
this)->sync();
566 case internal::cg_workgroup: {
567 static_cast<const thread_block*
>(
this)->sync();
570 case internal::cg_tiled_group: {
571 static_cast<const tiled_group*
>(
this)->sync();
574 case internal::cg_coalesced_group: {
575 static_cast<const coalesced_group*
>(
this)->sync();
579 __hip_assert(
false &&
"invalid cooperative group type");
590template <
class CGTy> __CG_QUALIFIER__ uint32_t group_size(CGTy
const& g) {
return g.size(); }
597template <
class CGTy> __CG_QUALIFIER__ uint32_t thread_rank(CGTy
const& g) {
598 return g.thread_rank();
606template <
class CGTy> __CG_QUALIFIER__
bool is_valid(CGTy
const& g) {
return g.is_valid(); }
613template <
class CGTy> __CG_QUALIFIER__
void sync(CGTy
const& g) { g.sync(); }
619template <
unsigned int tileSize>
class tile_base {
621 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
625 _CG_STATIC_CONST_DECL_
unsigned int thread_rank() {
626 return (internal::workgroup::thread_rank() & (numThreads - 1));
630 __CG_STATIC_QUALIFIER__
unsigned int size() {
return numThreads; }
637template <
unsigned int size>
class thread_block_tile_base :
public tile_base<size> {
638 static_assert(is_valid_tile_size<size>::value,
639 "Tile size is either not a power of 2 or greater than the wavefront size");
640 using tile_base<size>::numThreads;
643 __CG_STATIC_QUALIFIER__
void sync() {
644 internal::tiled_group::sync();
647 template <
class T> __CG_QUALIFIER__ T shfl(T var,
int srcRank)
const {
648 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
649 return (__shfl(var, srcRank, numThreads));
652 template <
class T> __CG_QUALIFIER__ T shfl_down(T var,
unsigned int lane_delta)
const {
653 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
654 return (__shfl_down(var, lane_delta, numThreads));
657 template <
class T> __CG_QUALIFIER__ T shfl_up(T var,
unsigned int lane_delta)
const {
658 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
659 return (__shfl_up(var, lane_delta, numThreads));
662 template <
class T> __CG_QUALIFIER__ T shfl_xor(T var,
unsigned int laneMask)
const {
663 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
664 return (__shfl_xor(var, laneMask, numThreads));
669template <
unsigned int tileSize,
typename ParentCGTy>
670class parent_group_info {
674 __CG_STATIC_QUALIFIER__
unsigned int meta_group_rank() {
675 return ParentCGTy::thread_rank() / tileSize;
679 __CG_STATIC_QUALIFIER__
unsigned int meta_group_size() {
680 return (ParentCGTy::size() + tileSize - 1) / tileSize;
690template <
unsigned int tileSize,
class ParentCGTy>
691class thread_block_tile_type :
public thread_block_tile_base<tileSize>,
693 public parent_group_info<tileSize, ParentCGTy> {
694 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
696 __CG_QUALIFIER__ thread_block_tile_type() : tiled_group(numThreads) {
697 coalesced_info.tiled_info.size = numThreads;
698 coalesced_info.tiled_info.is_tiled =
true;
703template <
unsigned int tileSize>
704class thread_block_tile_type<tileSize, void> :
public thread_block_tile_base<tileSize>,
707 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
709 typedef thread_block_tile_base<numThreads> tbtBase;
713 __CG_QUALIFIER__ thread_block_tile_type(
unsigned int meta_group_rank,
unsigned int meta_group_size)
714 : tiled_group(numThreads) {
715 coalesced_info.tiled_info.size = numThreads;
716 coalesced_info.tiled_info.is_tiled =
true;
717 coalesced_info.tiled_info.meta_group_rank = meta_group_rank;
718 coalesced_info.tiled_info.meta_group_size = meta_group_size;
724 using tbtBase::thread_rank;
726 __CG_QUALIFIER__
unsigned int meta_group_rank()
const {
727 return coalesced_info.tiled_info.meta_group_rank;
730 __CG_QUALIFIER__
unsigned int meta_group_size()
const {
731 return coalesced_info.tiled_info.meta_group_size;
746__CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
unsigned int tile_size) {
747 if (parent.cg_type() == internal::cg_tiled_group) {
748 const tiled_group* cg =
static_cast<const tiled_group*
>(&parent);
749 return cg->new_tiled_group(tile_size);
751 else if(parent.cg_type() == internal::cg_coalesced_group) {
752 const coalesced_group* cg =
static_cast<const coalesced_group*
>(&parent);
753 return cg->new_tiled_group(tile_size);
756 const thread_block* tb =
static_cast<const thread_block*
>(&parent);
757 return tb->new_tiled_group(tile_size);
762__CG_QUALIFIER__ thread_group tiled_partition(
const thread_block& parent,
unsigned int tile_size) {
763 return (parent.new_tiled_group(tile_size));
766__CG_QUALIFIER__ tiled_group tiled_partition(
const tiled_group& parent,
unsigned int tile_size) {
767 return (parent.new_tiled_group(tile_size));
771__CG_QUALIFIER__ coalesced_group tiled_partition(
const coalesced_group& parent,
unsigned int tile_size) {
772 return (parent.new_tiled_group(tile_size));
775template <
unsigned int size,
class ParentCGTy>
class thread_block_tile;
778template <
unsigned int size,
class ParentCGTy>
class thread_block_tile_internal;
780template <
unsigned int size,
class ParentCGTy>
781class thread_block_tile_internal :
public thread_block_tile_type<size, ParentCGTy> {
783 template <
unsigned int tbtSize,
class tbtParentT>
784 __CG_QUALIFIER__ thread_block_tile_internal(
785 const thread_block_tile_internal<tbtSize, tbtParentT>& g)
786 : thread_block_tile_type<size, ParentCGTy>(g.meta_group_rank(), g.meta_group_size()) {}
788 __CG_QUALIFIER__ thread_block_tile_internal(
const thread_block& g)
789 : thread_block_tile_type<size, ParentCGTy>() {}
793template <
unsigned int size,
class ParentCGTy>
794class thread_block_tile :
public impl::thread_block_tile_internal<size, ParentCGTy> {
796 __CG_QUALIFIER__ thread_block_tile(
const ParentCGTy& g)
797 : impl::thread_block_tile_internal<size, ParentCGTy>(g) {}
800 __CG_QUALIFIER__
operator thread_block_tile<size, void>()
const {
801 return thread_block_tile<size, void>(*
this);
806template <
unsigned int size>
807class thread_block_tile<size, void> :
public impl::thread_block_tile_internal<size, void> {
808 template <
unsigned int,
class ParentCGTy>
friend class thread_block_tile;
812 template <
class ParentCGTy>
813 __CG_QUALIFIER__ thread_block_tile(
const thread_block_tile<size, ParentCGTy>& g)
814 : impl::thread_block_tile_internal<size, void>(g) {}
817template <
unsigned int size,
class ParentCGTy =
void>
class thread_block_tile;
820template <
unsigned int size,
class ParentCGTy>
struct tiled_partition_internal;
822template <
unsigned int size>
823struct tiled_partition_internal<size, thread_block> :
public thread_block_tile<size, thread_block> {
824 __CG_QUALIFIER__ tiled_partition_internal(
const thread_block& g)
825 : thread_block_tile<size, thread_block>(g) {}
835template <
unsigned int size,
class ParentCGTy>
836__CG_QUALIFIER__ thread_block_tile<size, ParentCGTy> tiled_partition(
const ParentCGTy& g) {
837 static_assert(is_valid_tile_size<size>::value,
838 "Tiled partition with size > wavefront size. Currently not supported ");
839 return impl::tiled_partition_internal<size, ParentCGTy>(g);
843#if defined(__clang__)
844#pragma clang diagnostic pop
Device side implementation of cooperative group feature.