@@ -11,7 +11,7 @@ void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
11
11
int tid = threadIdx .x + blockIdx .x * blockDim .x ;
12
12
13
13
if (tid < size) {
14
-
14
+ // global memory, hence atomics
15
15
if (count_if_op (arr[tid])) {
16
16
atomicAdd (count, 1 );
17
17
}
@@ -32,6 +32,7 @@ void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
32
32
__shared__ int local_count_array[TPB];
33
33
34
34
if (tid < size) {
35
+ // set if op in shared memory
35
36
if (count_if_op (arr[tid])) {
36
37
local_count_array[threadIdx .x ] = 1 ;
37
38
}
@@ -41,6 +42,9 @@ void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
41
42
42
43
__syncthreads ();
43
44
45
+ // manual reduction within block between first half
46
+ // and second half
47
+ // note format of reduction to enable SIMD
44
48
for (int offset = blockDim .x / 2 ; offset > 0 ; offset >>=1 ) {
45
49
if (threadIdx .x < offset && tid + offset < size) {
46
50
local_count_array[threadIdx .x ] += local_count_array[threadIdx .x + offset];
@@ -67,6 +71,7 @@ void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
67
71
68
72
bool predicate = tid < size && count_if_op (arr[tid]);
69
73
74
+ // block level primitive
70
75
int block_count = __syncthreads_count (predicate);
71
76
72
77
if (threadIdx .x == 0 ) {
@@ -90,26 +95,30 @@ void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
90
95
91
96
bool predicate = tid < size && count_if_op (arr[tid]);
92
97
98
+ // find participating warps in predicate
99
+ // FULL_MASK is 32 bit set
93
100
unsigned ballot_mask = __ballot_sync (FULL_MASK, predicate);
94
- int warp_count = __popc (ballot_mask);
101
+ int warp_count = __popc (ballot_mask); // counts set bits
95
102
96
- // global atomics
97
- // if(threadIdx.x == 0) {
103
+ // global memory atomics
104
+ // if (threadIdx.x % 32 == 0) {
98
105
// atomicAdd(count, warp_count);
99
106
// }
100
107
101
-
108
+ // shared memory of size number of warps per block
102
109
// optimization for block reduction
103
110
__shared__ int block_counts[TPB / 32 ];
104
111
105
112
int warp_id = threadIdx .x / 32 ;
106
113
int lane_id = threadIdx .x % 32 ;
107
114
if (lane_id == 0 ) {
115
+ // set local warp count in smem
108
116
block_counts[warp_id] = warp_count;
109
117
}
110
118
111
119
__syncthreads ();
112
120
121
+ // recall reduction from earlier
113
122
for (int offset = (TPB / 32 ) / 2 ; offset > 0 ; offset >>= 1 ) {
114
123
if (lane_id == 0 && warp_id < offset) {
115
124
block_counts[warp_id] += block_counts[warp_id + offset];
0 commit comments