Cuda Reduce(规约)学习笔记
本文是记录阅读https://zhuanlan.zhihu.com/p/2039848582770058820后自己的理解。
什么是Reduce
规约就是把一组很多个数据,通过某种运算,合并成更少的数据,通常最后合并成一个结果。
Reduce是将多个输入变成一个输出的操作,例如:
$$ \text{Sum Reduction:} [1,2,3,4]\rightarrow10\\ \text{Min Reduction:} [1,2,3,4]\rightarrow1\\ \text{Max Reduction:} [1,2,3,4]\rightarrow4\\ \text{Product Reduction:} [1,2,3,4]\rightarrow24 $$
因为GPU中有很多线程,适合对上述操作进行并行处理,所以Reduce是Cuda中一个很重要的东西。
简单版本的规约
下面以一个数组相加的问题为例子,使用规约进行计算。
__global__ void reduce_v0(float* input, float* output, int n){
extern __shared__ float smem[];
int tid = threadIdx.x;
int gid = blockDim.x * blockIdx.x + threadIdx.x;
smem[tid] = (gid < n)?input[gid] : 0.0f;
__syncthreads();
for (int step = 1; step < blockDim.x; step*= 2){
if (tid % (2 * step) == 0){
smem[tid] += smem[tid + step];
}
__syncthreads();
}
if(tid==0){
output[blockIdx.x] = smem[0];
}
}
首先将input拷贝至当前block的share memory,使用tid定位当前的thread,用gid来访问input数组。值得注意的是,在数据拷贝后有一句__syncthreads()来同步线程。如果不同步的话会出现 race condition。
然后从step=1开始,当前thread处理当前位置的值以及当前位置加step位置的值,所以是smem[tid]+=smem[tid+step]。
(tid%(2*step)==0)用于判断当前thread是否为要使用的thread。
例:当前tid为0,step为1,则当前thread激活。tid为1,step为1,则当前thread闲置。当tid为2时,step=1,当前thread激活。
所以这里要跨2倍的step来选择thread,因为每一个激活的thread都处理tid位置与tid+step位置的值。下一个线程则位于当前线程位置再跨2倍的step。
在每次step后还要使用__syncthreads()来等待所以线程完成,再进行下一次循环。否则会出现race condition。
Race condition (竞争条件)。
**意思是:**多个线程同时读写同一个数据,最终结果取决于它们执行的先后顺序。如果这个顺序不确定,结果就可能不稳定、错误。
例如:thread 0: sum += 1,thread 1: sum += 1。理想结果是sum=2。但是实际可能发生:
thread 0 读取 sum = 0 thread 1 读取 sum = 0
thread 0 计算 0 + 1 = 1 thread 1 计算 0 + 1 = 1
thread 0 写回 sum = 1 thread 1 写回 sum = 1
最后的结果sum=1
完成所有规约操作之后,结果存在smem[0]中。当前block中的第一个thread即tid=0激活,将结果写入output数组中当前block id的位置。至此就完成了当前block的计算。
Warp Divergence
GPU以Warp(32线程)为最小调度单位,Warp中所有线程都执行同样的指令。当****同一个 warp 里的线程走了不同的分支路径,导致这些路径不能真正同时执行,只能分批执行。在上面的简单规约中,if (tid % (2 * step) == 0)就会造成严重的Warp Divergence,
在step=1时,只有一半的threads是真正工作的。step=2时,只有25%的线程工作。在最后的step里只有一个线程工作。会造成很差的线程利用率。
版本V1 解决Warp Divergence
在上个版本中我们使用了取模来选择活跃的线程。为了解决Warp Divergence,这里使用 Strided Index来选择线程。
如果使用普通的连续index:tid 0 -> index 0,tid 1 -> index 1,tid 2 -> index 2,tid 3 -> index 3 …
如果使用****Strided index:tid 0 -> index 0,tid 1 -> index 2,tid 2 -> index 4,tid 3 -> index 6 …
这里的线程是连续的,但是访问的数据下标是跳跃的。
__global__ void reduce_v1(float* input, float* output, int n){
extern __shared__ float smem[];
int tid = threadIdx.x;
int gid = blockIdx.x * blockDim.x + threadIdx.x;
smem[tid] = (gid < n) ? input[gid] : 0.0f;
__syncthreads();
for(int step = 1; s < blockDim.x; s *= 2){
int index = threadIdx.x * 2 * s;
if(index < blockDim.x){
smem[index] = smem[index + s];
}
__syncthreads();
}
if (tid == 0){
output[blockIdx.x] = smem[0];
}
}
在代码实现中与前一个版本的区别只有选择活跃线程的语句。
以blockDim=256为例子:
s=1时,index = 2*tid,活跃的线程范围是0~127,正好4个Warp,没有Warp Divergence。s=2,4时同理。
当s=8时,线程0~15活跃,只有当前warp有divergence。s=16,32,64,128时同理。