Ways to optimize Matrix Multiplication on CPU

In this tutorial, we will introduce several ways to optimize the performance of the matrix multiplication on X86 CPU.

import cinn
import numpy as np
import time
from cinn import runtime
# sphinx_gallery_thumbnail_path = './paddlepaddle.png'

Declare the basic computation for a matmul

m = cinn.Expr(1024)
n = cinn.Expr(1024)
k = cinn.Expr(1024)

A = cinn.Placeholder("float32", "A", [m, k])
B = cinn.Placeholder("float32", "B", [k, n])

# k1 is a reduce axis
k1 = cinn.Var(k.as_int32(), "k1")

C = cinn.compute([
    m, n
], lambda vs: cinn.reduce_sum(A(vs[0], k1.expr()) * B(k1.expr(), vs[1]), [k1]),
                 "C")

stages = cinn.create_stages([C])

Fake input data, here we create a runtime buffer for each input of the generated function.

a = runtime.cinn_buffer_t(
    np.random.randn(m.int(), k.int()).astype("float32"),
    runtime.cinn_x86_device, 32)
b = runtime.cinn_buffer_t(
    np.random.randn(m.int(), k.int()).astype("float32"),
    runtime.cinn_x86_device, 32)
c = runtime.cinn_buffer_t(
    np.zeros([m.int(), n.int()]).astype("float32"), runtime.cinn_x86_device,
    32)

Here is a helper function to JIT compile the generated program and test the performance

def test_performance(stages,
                     fn_inputs=[A.to_tensor(), B.to_tensor(), C],
                     input_args=[a, b, c]):
    '''
    fake input data, compile and test program's performance
    '''
    target = cinn.Target()
    builder = cinn.Module.Builder("matmul", target)

    func = cinn.lower("matmul", stages, fn_inputs)
    builder.add_function(func)
    module = builder.build()

    jit = cinn.ExecutionEngine()
    jit.link(module)

    args = [runtime.cinn_pod_value_t(_) for _ in input_args]
    matmul = jit.lookup("matmul")

    repeat = 4

    tic = time.perf_counter()
    for i in range(repeat):
        matmul(args)
    toc = time.perf_counter()
    miniseconds = (toc - tic) / repeat * 1e3
    print(f"Takes {miniseconds:0.3f} ms")


# The basic computation without any schedule has a performance as follows
test_performance(stages)

Out:

Takes 3564.947 ms

Blocking

stages = cinn.create_stages([C])
bn = 32
i_outer, i_inner, j_outer, j_inner = stages[C].tile(0, 1, bn, bn)
k_outer, k_inner = stages[C].split("k1", 4)
stages[C].reorder([i_outer, j_outer, k_outer, k_inner, i_inner, j_inner])

# The performance is
test_performance(stages)

Out:

Takes 365.406 ms

Vectorization

stages = cinn.create_stages([C])
bn = 32
i_outer, i_inner, j_outer, j_inner = stages[C].tile(0, 1, bn, bn)
k_outer, k_inner = stages[C].split("k1", 4)
stages[C].reorder([i_outer, j_outer, k_outer, k_inner, i_inner, j_inner])
stages[C].vectorize(j_inner, 8)

# The performance is
test_performance(stages)

Out:

Takes 364.097 ms

Loop Permutation

stages = cinn.create_stages([C])
i_outer, i_inner, j_outer, j_inner = stages[C].tile(0, 1, bn, bn)
k_outer, k_inner = stages[C].split("k1", 4)
stages[C].reorder([i_outer, j_outer, k_outer, i_inner, k_inner, j_inner])
stages[C].vectorize(j_inner, 8)
stages[C].unroll(5)

test_performance(stages)

Out:

Takes 76.672 ms

Array Packing

packedB = cinn.compute(
    [n / bn, k, cinn.Expr(bn)], lambda x: B(x[1], x[0] * bn + x[2]), "packedB")
C = cinn.compute([m, n], lambda x: cinn.reduce_sum(
    A(x[0], k1.expr()) * packedB(x[1] / bn, k1.expr(), x[1] % bn), [k1]), "C")

stages = cinn.create_stages([C])
stages[packedB].vectorize(2, 8)

i_outer, i_inner, j_outer, j_inner = stages[C].tile(0, 1, bn, bn)
k_outer, k_inner = stages[C].split("k1", 4)
stages[C].reorder([i_outer, j_outer, k_outer, i_inner, k_inner, j_inner])
stages[C].vectorize(j_inner, 8)

# We make the packedB as another input of the generated function and allocate a runtime buffer for it.
packedB_buf = runtime.cinn_buffer_t(
    np.zeros([n.int() // bn, k.int(), bn]).astype("float32"),
    runtime.cinn_x86_device, 32)

# The final performance is
test_performance(
    stages,
    fn_inputs=[A.to_tensor(), B.to_tensor(), C, packedB],
    input_args=[a, b, c, packedB_buf])

Out:

Takes 76.697 ms

Total running time of the script: ( 0 minutes 23.092 seconds)

Gallery generated by Sphinx-Gallery