Skip to main content

🍉共享家园

⚠ 题目描述

Shiver正在使用Tile技术优化基于CUDA的矩阵乘法,却发现这个算子总是调不对,你能帮他找到问题所在吗?

已知 A, B, C矩阵皆采用行主序 (Row-major ),A的形状为 M * N, B的形状为 N * K, C的形状为 M * K。且代码错误只会出现在incorrect_matrix_multiplication_kernel函数中

代码:

__global__ void incorrect_matrix_multiplication_kernel(const float* A, const float* B, float* C, int M, int N, int K) {
/*
* Row-major
* Shape of A: MxN
* Shape of B: NxK
* Shape of C: MxK
*/
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
int ltid_x = threadIdx.x;
int ltid_y = threadIdx.y;

if(tid_x >= K || tid_y >= M) {
return;
}

__shared__ float sA[TILE_SIZE][TILE_SIZE];
__shared__ float sB[TILE_SIZE][TILE_SIZE];
float acc = 0.0f;

for(int n = 0; n < N; n += TILE_SIZE) {
sA[ltid_y][ltid_x] = (tid_y < M && ltid_x + n < N) ? A[tid_y * N + ltid_x + n] : 0.0f;
sB[ltid_y][ltid_x] = (n + ltid_y < N && tid_x < K) ? B[(n + ltid_y) * K + tid_x] : 0.0f;
__syncthreads();

if(tid_y < M && tid_x < K) {
for(int t = 0; t < TILE_SIZE; t++) {
acc += sA[ltid_y][t] * sB[t][ltid_x];
}
}
}
if(tid_y < M && tid_x < K) {
C[tid_y * K + tid_x] = acc;
}
}

// A, B, C are device pointers (i.e. pointers to memory on the GPU)
extern "C" void solve(const float* A, const float* B, float* C, int M, int N, int K) {
dim3 threadsPerBlock(16, 16);
dim3 blocksPerGrid((K + threadsPerBlock.x - 1) / threadsPerBlock.x,
(M + threadsPerBlock.y - 1) / threadsPerBlock.y);

matrix_multiplication_kernel<<<blocksPerGrid, threadsPerBlock>>>(A, B, C, M, N, K);
cudaDeviceSynchronize();
}

数据限制

  • 1M,N,K8192 1 \leq M, N, K \leq 8192
  • 我们会在 M = 8192, N = 6144, K = 4096的情形测评程序

要求:

  • 不能使用外部库
  • 不允许修改solve函数
  • 程序的输出应当存储在向量Ctips:
  • 推荐自己构建数据集,在本地初步测试程序的正确性后再提交。

🥨分数分布

  • 如果你能找出一处错误的代码,获得 50% 的分数。
  • 如果你能找出另一处错误的代码,获得另外 50% 的分数。

💡 Hint

了解一下 CUDA 中的 shared_memory


            for(int t = 0; t < TILE_SIZE; t++) {
acc += sA[ltid_y][t] * sB[t][ltid_x];
}

计算后未使用 __syncthreads() 会导致不同线程读取写入冲突

    for(int n = 0; n < N; n += TILE_SIZE) {
sA[ltid_y][ltid_x] = (tid_y < M && ltid_x + n < N) ? A[tid_y * N + ltid_x + n] : 0.0f;
sB[ltid_y][ltid_x] = (n + ltid_y < N && tid_x < K) ? B[(n + ltid_y) * K + tid_x] : 0.0f;
__syncthreads();

访问错误