Skip to content

Commit acbddf3

Browse files
committedSep 26, 2020
restructuring to accept TBP in template
1 parent 78fc58d commit acbddf3

File tree

4 files changed

+64
-44
lines changed

4 files changed

+64
-44
lines changed
 

‎conda/learn_dev_cuda10.2.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
name: learn_dev
22
dependencies:
3-
- cmake>=3.14
3+
- cmake>=3.18
44
- cudatoolkit=10.2

‎cpp/include/learning_cuda/count/count.cuh

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
#include "count_kernels.cuh"
1+
#pragma once
22

3-
#define TPB 32
3+
#include "count_kernels.cuh"
44

55
namespace common {
66

7+
template <int TPB>
78
void _prep_count_if(int **count_d, int &N_BLK, int size) {
89
cudaMalloc(count_d, sizeof(int));
910
cudaMemset(*count_d, 0, sizeof(int));
@@ -20,14 +21,16 @@ void _finish_count_if(int **count_d, int &count_h) {
2021

2122
namespace naive {
2223

23-
template <typename T, class F>
24-
int count_if(T* arr, int size,
25-
F count_if_op) {
24+
template <typename T, int TPB, typename CountIfOp>
25+
inline int count_if(T* arr, int size, CountIfOp count_if_op) {
2626
int *count_d, count_h, N_BLK;
2727

28-
common::_prep_count_if(&count_d, N_BLK, size);
28+
common::_prep_count_if<TPB> (&count_d, N_BLK, size);
2929

30-
count_kernel<<<N_BLK, TPB>>> (arr, size, count_d, count_if_op);
30+
std:: cout << "TPB: " << TPB << "\n";
31+
std:: cout << "N_BLK: " << N_BLK << "\n";
32+
std:: cout << "size: " << size << "\n";
33+
detail::count_kernel<<<N_BLK, TPB>>> (arr, size, count_d, count_if_op);
3134

3235
common::_finish_count_if(&count_d, count_h);
3336

@@ -38,14 +41,15 @@ int count_if(T* arr, int size,
3841

3942
namespace manual_reduction {
4043

41-
template <typename T, class F>
42-
int count_if(T* arr, int size,
43-
F count_if_op) {
44+
template <typename T, int TPB, typename CountIfOp>
45+
inline int count_if(T* arr, int size,
46+
CountIfOp count_if_op) {
4447
int *count_d, count_h, N_BLK;
4548

46-
common::_prep_count_if(&count_d, N_BLK, size);
49+
common::_prep_count_if<TPB> (&count_d, N_BLK, size);
4750

48-
count_kernel<<<N_BLK, TPB>>> (arr, size, count_d, count_if_op);
51+
detail::count_kernel<T, TPB> <<<N_BLK, TPB>>> (arr, size, count_d,
52+
count_if_op);
4953

5054
common::_finish_count_if(&count_d, count_h);
5155

@@ -56,14 +60,14 @@ int count_if(T* arr, int size,
5660

5761
namespace syncthreads_count_reduction {
5862

59-
template <typename T, class F>
60-
int count_if(T* arr, int size,
61-
F count_if_op) {
63+
template <typename T, int TPB, typename CountIfOp>
64+
inline int count_if(T* arr, int size,
65+
CountIfOp count_if_op) {
6266
int *count_d, count_h, N_BLK;
6367

64-
common::_prep_count_if(&count_d, N_BLK, size);
68+
common::_prep_count_if<TPB> (&count_d, N_BLK, size);
6569

66-
count_kernel<<<N_BLK, TPB>>> (arr, size, count_d, count_if_op);
70+
detail::count_kernel<<<N_BLK, TPB>>> (arr, size, count_d, count_if_op);
6771

6872
common::_finish_count_if(&count_d, count_h);
6973

@@ -74,14 +78,15 @@ int count_if(T* arr, int size,
7478

7579
namespace ballot_sync_reduction {
7680

77-
template <typename T, class F>
78-
int count_if(T* arr, int size,
79-
F count_if_op) {
81+
template <typename T, int TPB, typename CountIfOp>
82+
inline int count_if(T* arr, int size,
83+
CountIfOp count_if_op) {
8084
int *count_d, count_h, N_BLK;
8185

82-
common::_prep_count_if(&count_d, N_BLK, size);
86+
common::_prep_count_if<TPB> (&count_d, N_BLK, size);
8387

84-
count_kernel<<<N_BLK, TPB>>> (arr, size, count_d, count_if_op);
88+
detail::count_kernel<T, TPB> <<<N_BLK, TPB>>> (arr, size, count_d,
89+
count_if_op);
8590

8691
common::_finish_count_if(&count_d, count_h);
8792

‎cpp/include/learning_cuda/count/count_kernels.cuh

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
1-
// #include <cuda_runtime.h>
1+
#pragma once
2+
3+
#include <cuda_runtime.h>
24

35
namespace naive {
6+
namespace detail {
47

5-
template <typename T, class F>
8+
template <typename T, typename CountIfOp>
69
__global__
7-
void count_kernel(T* arr, int size, int *count, F count_if_op) {
10+
void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
811
int tid = threadIdx.x + blockIdx.x * blockDim.x;
912

1013
if (tid < size) {
14+
1115
if (count_if_op(arr[tid])) {
1216
atomicAdd(count, 1);
1317
}
1418
}
1519
}
1620

21+
} // namespace detail
1722
} // namespace naive
1823

1924
namespace manual_reduction {
25+
namespace detail {
2026

21-
template <typename T, class F>
27+
template <typename T, int TPB, typename CountIfOp>
2228
__global__
23-
void count_kernel(T* arr, int size, int *count, F count_if_op) {
29+
void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
2430
int tid = threadIdx.x + blockIdx.x * blockDim.x;
2531

26-
__shared__ int local_count_array[32];
32+
__shared__ int local_count_array[TPB];
2733

2834
if (tid < size) {
2935
if (count_if_op(arr[tid])) {
@@ -47,14 +53,16 @@ void count_kernel(T* arr, int size, int *count, F count_if_op) {
4753
atomicAdd(count, local_count_array[threadIdx.x]);
4854
}
4955
}
50-
56+
57+
} // namespace detail
5158
} // namespace manual_reduction
5259

5360
namespace syncthreads_count_reduction {
61+
namespace detail {
5462

55-
template <typename T, class F>
63+
template <typename T, typename CountIfOp>
5664
__global__
57-
void count_kernel(T* arr, int size, int *count, F count_if_op) {
65+
void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
5866
int tid = threadIdx.x + blockIdx.x * blockDim.x;
5967

6068
bool predicate = tid < size && count_if_op(arr[tid]);
@@ -67,15 +75,17 @@ void count_kernel(T* arr, int size, int *count, F count_if_op) {
6775

6876
}
6977

78+
} // namespace detail
7079
} // namespace syncthreads_count
7180

7281
#define FULL_MASK 0xffffffff
7382

7483
namespace ballot_sync_reduction {
84+
namespace detail {
7585

76-
template <typename T, class F>
86+
template <typename T, int TPB, typename CountIfOp>
7787
__global__
78-
void count_kernel(T* arr, int size, int *count, F count_if_op) {
88+
void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
7989
int tid = threadIdx.x + blockIdx.x * blockDim.x;
8090

8191
bool predicate = tid < size && count_if_op(arr[tid]);
@@ -89,4 +99,5 @@ void count_kernel(T* arr, int size, int *count, F count_if_op) {
8999

90100
}
91101

92-
} // ballot syncthreads_count
102+
} // namespace detail
103+
} // namespace ballot_sync_reduction

‎cpp/src/count/count.cu

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
int main() {
15-
int n_elems = 1000;
15+
int n_elems = 100;
1616

1717
thrust::device_vector<int> rand_d(n_elems, 0);
1818

@@ -33,23 +33,27 @@ int main() {
3333

3434
std::cout << "\nThrust Count: " << thrust_count << std::endl;
3535

36-
int naive_count = naive::count_if(thrust::raw_pointer_cast(rand_d.data()),
37-
n_elems, is_greater_than_10);
36+
int *rand_d_ptr = thrust::raw_pointer_cast(rand_d.data());
37+
38+
int naive_count = naive::count_if<int, 32>(rand_d_ptr, n_elems,
39+
is_greater_than_10);
3840

3941
std::cout << "\nNaive Count: " << naive_count << std::endl;
4042

41-
int man_count = manual_reduction::count_if(thrust::raw_pointer_cast(rand_d.data()),
42-
n_elems, is_greater_than_10);
43+
int man_count = manual_reduction::count_if<int, 32>(rand_d_ptr, n_elems,
44+
is_greater_than_10);
4345

4446
std::cout << "\nManual Reduction Count: " << man_count << std::endl;
4547

46-
int syn_count = syncthreads_count_reduction::count_if(thrust::raw_pointer_cast(rand_d.data()),
47-
n_elems, is_greater_than_10);
48+
int syn_count = syncthreads_count_reduction::count_if<int, 32>(rand_d_ptr,
49+
n_elems,
50+
is_greater_than_10);
4851

4952
std::cout << "\nSyncthread Reduction Count: " << syn_count << std::endl;
5053

51-
int bal_count = ballot_sync_reduction::count_if(thrust::raw_pointer_cast(rand_d.data()),
52-
n_elems, is_greater_than_10);
54+
int bal_count = ballot_sync_reduction::count_if<int, 32>(rand_d_ptr,
55+
n_elems,
56+
is_greater_than_10);
5357

5458
std::cout << "\nBallot Reduction Count: " << bal_count << std::endl;
5559

0 commit comments

Comments
 (0)
Please sign in to comment.