большинство из того, что вам нужно, это здесь и здесь .
В первой ссылке определено, что AxAT включает в себя взятие внутренних произведений строк матрицы A, и аналогично ATxA будет включать в себя взятие внутренних произведений столбцов матрицы A. Также обратите внимание на утверждение симметрии , Во второй ссылке (прокрутите вниз от этой точки в руководстве по программированию) вы найдете умножение полной мозаичной матрицы. Вам просто нужно проиндексировать обе плитки по столбцу .
Вот рабочий пример с использованием кода из SO-ответа, который вы связали :
$ cat t1654.cu
#include <iostream>
#include <cstdio>
#include <cstdlib>
const int TILE_DIM = 32;
template <typename T>
__global__ void ATA(const T * __restrict__ A, T * __restrict__ C, int ARows, int ACols)
{
T CValue = 0;
int Row = blockIdx.y*TILE_DIM + threadIdx.y;
int Col = blockIdx.x*TILE_DIM + threadIdx.x;
__shared__ T As[TILE_DIM][TILE_DIM];
__shared__ T Bs[TILE_DIM][TILE_DIM];
for (int k = 0; k < (TILE_DIM + ARows - 1)/TILE_DIM; k++) {
if (k*TILE_DIM + threadIdx.y < ARows && blockIdx.y*blockDim.y+threadIdx.x < ACols)
As[threadIdx.y][threadIdx.x] = A[(k*TILE_DIM + threadIdx.y)*ACols + blockIdx.y*blockDim.y+threadIdx.x];
else
As[threadIdx.y][threadIdx.x] = 0.0;
if (k*TILE_DIM + threadIdx.y < ARows && Col < ACols)
Bs[threadIdx.y][threadIdx.x] = A[(k*TILE_DIM + threadIdx.y)*ACols + Col];
else
Bs[threadIdx.y][threadIdx.x] = 0.0;
__syncthreads();
for (int n = 0; n < TILE_DIM; ++n)
CValue += As[n][threadIdx.y] * Bs[n][threadIdx.x];
__syncthreads();
}
if (Row < ACols && Col < ACols)
C[((blockIdx.y * blockDim.y + threadIdx.y)*ACols) +
(blockIdx.x * blockDim.x)+ threadIdx.x] = CValue;
}
template <typename T>
__global__ void transpose_naive(const T * __restrict__ in, T * __restrict__ out, const int dim){
int col = threadIdx.x+blockDim.x*blockIdx.x;
int row = threadIdx.y+blockDim.y*blockIdx.y;
if ((col < dim) && (row < dim)) out[col*dim+row] = in[row*dim+col];
}
template <typename T>
__global__ void mm_naive(const T * __restrict__ A, const T * __restrict__ B, T * __restrict__ C, const int rowA, const int colA, const int colB){
int col = threadIdx.x+blockDim.x*blockIdx.x;
int row = threadIdx.y+blockDim.y*blockIdx.y;
if ((row < rowA) && (col < colB)){
T Cval = 0;
for (int i = 0; i < colA; i++) Cval += A[row*colA+i]*B[i*colB+col];
C[row*colB+col] = Cval;}
}
typedef float mt;
int main(){
mt *d_A, *d_B, *d_C, *h_A, *h_C, *h_C1;
int m = 64;
int n = 64;
h_A = new mt[m*n];
h_C = new mt[n*n];
h_C1 = new mt[n*n];
cudaMalloc(&d_A, m*n*sizeof(d_A[0]));
cudaMalloc(&d_B, m*n*sizeof(d_A[0]));
cudaMalloc(&d_C, n*n*sizeof(d_C[0]));
// test 1
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++)
h_A[i*n+j] = (i==j)?1.0f:0.0f;
cudaMemcpy(d_A, h_A, m*n*sizeof(d_A[0]), cudaMemcpyHostToDevice);
dim3 block(TILE_DIM, TILE_DIM);
dim3 grid((n+block.x-1)/block.x, (n+block.y-1)/block.y);
ATA<<<grid,block>>>(d_A, d_C, m, n);
cudaMemcpy(h_C, d_C, n*n*sizeof(d_C[0]), cudaMemcpyDeviceToHost);
#ifdef DEBUG
for (int i = 0; i < n; i++){
for (int j = 0; j < n; j++)
std::cout << h_C[i*n+j] << " ";
std::cout << std::endl;}
std::cout << std::endl;
#endif
// test 2
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++)
h_A[i*n+j] = rand()%10;
cudaMemcpy(d_A, h_A, m*n*sizeof(d_A[0]), cudaMemcpyHostToDevice);
ATA<<<grid,block>>>(d_A, d_C, m, n);
cudaMemcpy(h_C, d_C, n*n*sizeof(d_C[0]), cudaMemcpyDeviceToHost);
#ifdef DEBUG
for (int i = 0; i < n; i++){
for (int j = 0; j < n; j++)
std::cout << h_C[i*n+j] << " ";
std::cout << std::endl;}
std::cout << std::endl;
#endif
transpose_naive<<<grid,block>>>(d_A, d_B, n);
mm_naive<<<grid,block>>>(d_B, d_A, d_C, n, n, n);
cudaMemcpy(h_C1, d_C, n*n*sizeof(d_C[0]), cudaMemcpyDeviceToHost);
#ifdef DEBUG
for (int i = 0; i < n; i++){
for (int j = 0; j < n; j++)
std::cout << h_C1[i*n+j] << " ";
std::cout << std::endl;}
std::cout << std::endl;
#endif
for (int i = 0; i < n*n; i++) if (h_C[i] != h_C1[i]) {std::cout << "mismatch at: " << i << " was: " << h_C[i] << " should be: " << h_C1[i] << std::endl; return 0;}
}
$ nvcc -o t1654 t1654.cu
$ cuda-memcheck ./t1654
========= CUDA-MEMCHECK
========= ERROR SUMMARY: 0 errors
$
Обратите внимание, что загрузка плитки Bs
одинакова в обоих случаях. Основные изменения в загрузке плитки As
, а также обратите внимание на изменение индексации при вычислении Cvalue
. Эти изменения необходимы для индексации в обоих случаях по столбцу .
Могут быть ошибки. Я не проверял неквадратный случай и не тестировал случай, когда размер матрицы не кратен размеру блока. Кроме того, я не использовал симметрию в выводе. Однако это должно помочь с индексацией.