2021-12-21

Why is matrix multiplication row x row 4-5 times slower than row x column on Mali's GPU?

Recently, I encountered a problem when using computer shader to develop matrix multiplication. A common matrix multiplication C = AB. in order to make the memory continuous, I transposed the B matrix. I think this can speed up the running speed. However, when I measured the speed, I found that the form of line X was several times slower than that of line X. I explored it for a long time and couldn't understand it, so I wrote down the problem for help!!!

  • My environment Mali G77 (MediaTek Tianji 1200)
  • A matrix dimension: 4x2048x2048
  • B matrix dimension: 4x2048x2048

Time comparison:

  • Row x row: About 9s
  • Row x column: about 1.6s
  • Column x column: about 3.3s

question demo:https://github.com/yikox/ProfilerDemo

//computer shader
#version 310 es

#define XLOCAL 8
#define YLOCAL 8
#define ZLOCAL 1

layout(binding = 0) writeonly buffer soutput{
    vec4 data[];
} uOutput;
layout(binding = 1) readonly buffer sinput0{
    vec4 data[];
} uInput0;
layout(binding = 2) readonly buffer sinput1{
    vec4 data[];
} uInput1;

layout(location=3) uniform ivec4 uInputSize0;
layout(location=4) uniform ivec4 uInputSize1;
layout(location=5) uniform ivec4 uOutputSize;

layout (local_size_x = XLOCAL, local_size_y = YLOCAL, local_size_z = ZLOCAL) in;

//矩阵A和矩阵B相乘的某一列的第I个元素
vec4 PixelMul(int i, ivec3 pos)
{
    // 行x行
    // vec4 data0 = uInput0.data[i + pos.y * uInputSize0.x + pos.z * uInputSize0.x * uInputSize0.y];
    // vec4 data1 = uInput1.data[i + pos.x * uInputSize1.y + pos.z * uInputSize1.x * uInputSize1.y];

    // 行x列
    // vec4 data0 = uInput0.data[i + pos.y * uInputSize0.x + pos.z * uInputSize0.x * uInputSize0.y];
    // vec4 data1 = uInput1.data[pos.x + i * uInputSize1.y + pos.z * uInputSize1.x * uInputSize1.y];

    // 列x列
    vec4 data0 = uInput0.data[pos.y + i * uInputSize0.x + pos.z * uInputSize0.x * uInputSize0.y];
    vec4 data1 = uInput1.data[pos.x + i * uInputSize1.y + pos.z * uInputSize1.x * uInputSize1.y];
    return data0 * data1;
}
void main()
{
    ivec3 pos = ivec3(gl_GlobalInvocationID) * ivec3(2, 2, 1);
    if(all(lessThan(pos, uOutputSize.xyz)))
    {
        vec4 outData00 = vec4(0);
        vec4 outData01 = vec4(0);
        vec4 outData10 = vec4(0);
        vec4 outData11 = vec4(0);

        for(int i = 0; i < uInputSize0.x; i++)
        {
            outData00 += PixelMul(i, pos + ivec3(0, 0, 0));
            outData01 += PixelMul(i, pos + ivec3(1, 0, 0));
            outData10 += PixelMul(i, pos + ivec3(0, 1, 0));
            outData11 += PixelMul(i, pos + ivec3(1, 1, 0));
        }

        uOutput.data[pos.x + 0 + (pos.y + 0) * uOutputSize.x + pos.z * uOutputSize.x * uOutputSize.y] = outData00;
        uOutput.data[pos.x + 1 + (pos.y + 0) * uOutputSize.x + pos.z * uOutputSize.x * uOutputSize.y] = outData01;
        uOutput.data[pos.x + 0 + (pos.y + 1) * uOutputSize.x + pos.z * uOutputSize.x * uOutputSize.y] = outData10;
        uOutput.data[pos.x + 1 + (pos.y + 1) * uOutputSize.x + pos.z * uOutputSize.x * uOutputSize.y] = outData11;
    }
}

shader code:

//computer shader
#version 310 es

#define XLOCAL 8
#define YLOCAL 8
#define ZLOCAL 1

layout(binding = 0) writeonly buffer soutput{
    vec4 data[];
} uOutput;
layout(binding = 1) readonly buffer sinput0{
    vec4 data[];
} uInput0;
layout(binding = 2) readonly buffer sinput1{
    vec4 data[];
} uInput1;

layout(location=3) uniform ivec4 uInputSize0;
layout(location=4) uniform ivec4 uInputSize1;
layout(location=5) uniform ivec4 uOutputSize;

layout (local_size_x = XLOCAL, local_size_y = YLOCAL, local_size_z = ZLOCAL) in;

//矩阵A和矩阵B相乘的某一列的第I个元素
vec4 PixelMul(int i, ivec3 pos)
{
    // 行x行
    // vec4 data0 = uInput0.data[i + pos.y * uInputSize0.x + pos.z * uInputSize0.x * uInputSize0.y];
    // vec4 data1 = uInput1.data[i + pos.x * uInputSize1.y + pos.z * uInputSize1.x * uInputSize1.y];

    // 行x列
    // vec4 data0 = uInput0.data[i + pos.y * uInputSize0.x + pos.z * uInputSize0.x * uInputSize0.y];
    // vec4 data1 = uInput1.data[pos.x + i * uInputSize1.y + pos.z * uInputSize1.x * uInputSize1.y];

    // 列x列
    vec4 data0 = uInput0.data[pos.y + i * uInputSize0.x + pos.z * uInputSize0.x * uInputSize0.y];
    vec4 data1 = uInput1.data[pos.x + i * uInputSize1.y + pos.z * uInputSize1.x * uInputSize1.y];
    return data0 * data1;
}
void main()
{
    ivec3 pos = ivec3(gl_GlobalInvocationID) * ivec3(2, 2, 1);
    if(all(lessThan(pos, uOutputSize.xyz)))
    {
        vec4 outData00 = vec4(0);
        vec4 outData01 = vec4(0);
        vec4 outData10 = vec4(0);
        vec4 outData11 = vec4(0);

        for(int i = 0; i < uInputSize0.x; i++)
        {
            outData00 += PixelMul(i, pos + ivec3(0, 0, 0));
            outData01 += PixelMul(i, pos + ivec3(1, 0, 0));
            outData10 += PixelMul(i, pos + ivec3(0, 1, 0));
            outData11 += PixelMul(i, pos + ivec3(1, 1, 0));
        }

        uOutput.data[pos.x + 0 + (pos.y + 0) * uOutputSize.x + pos.z * uOutputSize.x * uOutputSize.y] = outData00;
        uOutput.data[pos.x + 1 + (pos.y + 0) * uOutputSize.x + pos.z * uOutputSize.x * uOutputSize.y] = outData01;
        uOutput.data[pos.x + 0 + (pos.y + 1) * uOutputSize.x + pos.z * uOutputSize.x * uOutputSize.y] = outData10;
        uOutput.data[pos.x + 1 + (pos.y + 1) * uOutputSize.x + pos.z * uOutputSize.x * uOutputSize.y] = outData11;
    }
}


from Recent Questions - Stack Overflow https://ift.tt/3sn9E4x
https://ift.tt/eA8V8J

No comments:

Post a Comment