Skip to content

Commit 78fc58d

Browse files
committed
three working kernels
1 parent 4b75b61 commit 78fc58d

File tree

3 files changed

+94
-7
lines changed

3 files changed

+94
-7
lines changed

cpp/include/learning_cuda/count/count.cuh

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,40 @@ int count_if(T* arr, int size,
5252
return count_h;
5353
}
5454

55-
} // namespace manual_reduction
55+
} // namespace manual_reduction
56+
57+
namespace syncthreads_count_reduction {
58+
59+
template <typename T, class F>
60+
int count_if(T* arr, int size,
61+
F count_if_op) {
62+
int *count_d, count_h, N_BLK;
63+
64+
common::_prep_count_if(&count_d, N_BLK, size);
65+
66+
count_kernel<<<N_BLK, TPB>>> (arr, size, count_d, count_if_op);
67+
68+
common::_finish_count_if(&count_d, count_h);
69+
70+
return count_h;
71+
}
72+
73+
} // namespace syncthreads_count_reduction
74+
75+
namespace ballot_sync_reduction {
76+
77+
template <typename T, class F>
78+
int count_if(T* arr, int size,
79+
F count_if_op) {
80+
int *count_d, count_h, N_BLK;
81+
82+
common::_prep_count_if(&count_d, N_BLK, size);
83+
84+
count_kernel<<<N_BLK, TPB>>> (arr, size, count_d, count_if_op);
85+
86+
common::_finish_count_if(&count_d, count_h);
87+
88+
return count_h;
89+
}
90+
91+
} // namespace ballot_count_reduction

cpp/include/learning_cuda/count/count_kernels.cuh

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <cuda_runtime.h>
1+
// #include <cuda_runtime.h>
22

33
namespace naive {
44

@@ -38,7 +38,7 @@ void count_kernel(T* arr, int size, int *count, F count_if_op) {
3838

3939
if (tid < size) {
4040
for (int offset = blockDim.x / 2; offset > 0; offset /=2 ) {
41-
local_count_array[threadIdx.x] = local_count_array[threadIdx.x + offset];
41+
local_count_array[threadIdx.x] += local_count_array[threadIdx.x + offset];
4242
__syncthreads();
4343
}
4444
}
@@ -48,4 +48,45 @@ void count_kernel(T* arr, int size, int *count, F count_if_op) {
4848
}
4949
}
5050

51-
} // namespace manual_reduction
51+
} // namespace manual_reduction
52+
53+
namespace syncthreads_count_reduction {
54+
55+
template <typename T, class F>
56+
__global__
57+
void count_kernel(T* arr, int size, int *count, F count_if_op) {
58+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
59+
60+
bool predicate = tid < size && count_if_op(arr[tid]);
61+
62+
int block_count = __syncthreads_count(predicate);
63+
64+
if(threadIdx.x == 0) {
65+
atomicAdd(count, block_count);
66+
}
67+
68+
}
69+
70+
} // namespace syncthreads_count
71+
72+
#define FULL_MASK 0xffffffff
73+
74+
namespace ballot_sync_reduction {
75+
76+
template <typename T, class F>
77+
__global__
78+
void count_kernel(T* arr, int size, int *count, F count_if_op) {
79+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
80+
81+
bool predicate = tid < size && count_if_op(arr[tid]);
82+
83+
unsigned ballot_mask = __ballot_sync(FULL_MASK, predicate);
84+
int warp_count = __popc(ballot_mask);
85+
86+
if(threadIdx.x % 32 == 0) {
87+
atomicAdd(count, warp_count);
88+
}
89+
90+
}
91+
92+
} // ballot syncthreads_count

cpp/src/count/count.cu

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
int main() {
15-
int n_elems = 100000;
15+
int n_elems = 1000;
1616

1717
thrust::device_vector<int> rand_d(n_elems, 0);
1818

@@ -38,10 +38,20 @@ int main() {
3838

3939
std::cout << "\nNaive Count: " << naive_count << std::endl;
4040

41-
int man_red_count = naive::count_if(thrust::raw_pointer_cast(rand_d.data()),
41+
int man_count = manual_reduction::count_if(thrust::raw_pointer_cast(rand_d.data()),
4242
n_elems, is_greater_than_10);
4343

44-
std::cout << "\nManual Reduction Count: " << man_red_count << std::endl;
44+
std::cout << "\nManual Reduction Count: " << man_count << std::endl;
45+
46+
int syn_count = syncthreads_count_reduction::count_if(thrust::raw_pointer_cast(rand_d.data()),
47+
n_elems, is_greater_than_10);
48+
49+
std::cout << "\nSyncthread Reduction Count: " << syn_count << std::endl;
50+
51+
int bal_count = ballot_sync_reduction::count_if(thrust::raw_pointer_cast(rand_d.data()),
52+
n_elems, is_greater_than_10);
53+
54+
std::cout << "\nBallot Reduction Count: " << bal_count << std::endl;
4555

4656
return 0;
4757
}

0 commit comments

Comments
 (0)