.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here ` to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_tutorials_matmul.py:
Ways to optimize Matrix Multiplication on CPU
=========================================================
In this tutorial, we will introduce several ways to optimize the performance of the matrix multiplication on X86 CPU.
.. code-block:: python
import cinn
import numpy as np
import time
from cinn import runtime
# sphinx_gallery_thumbnail_path = './paddlepaddle.png'
Declare the basic computation for a matmul
.. code-block:: python
m = cinn.Expr(1024)
n = cinn.Expr(1024)
k = cinn.Expr(1024)
A = cinn.Placeholder("float32", "A", [m, k])
B = cinn.Placeholder("float32", "B", [k, n])
# k1 is a reduce axis
k1 = cinn.Var(k.as_int32(), "k1")
C = cinn.compute([
m, n
], lambda vs: cinn.reduce_sum(A(vs[0], k1.expr()) * B(k1.expr(), vs[1]), [k1]),
"C")
stages = cinn.create_stages([C])
Fake input data, here we create a runtime buffer for each input of the generated function.
.. code-block:: python
a = runtime.cinn_buffer_t(
np.random.randn(m.int(), k.int()).astype("float32"),
runtime.cinn_x86_device, 32)
b = runtime.cinn_buffer_t(
np.random.randn(m.int(), k.int()).astype("float32"),
runtime.cinn_x86_device, 32)
c = runtime.cinn_buffer_t(
np.zeros([m.int(), n.int()]).astype("float32"), runtime.cinn_x86_device,
32)
Here is a helper function to JIT compile the generated program and test the performance
.. code-block:: python
def test_performance(stages,
fn_inputs=[A.to_tensor(), B.to_tensor(), C],
input_args=[a, b, c]):
'''
fake input data, compile and test program's performance
'''
target = cinn.Target()
builder = cinn.Module.Builder("matmul", target)
func = cinn.lower("matmul", stages, fn_inputs)
builder.add_function(func)
module = builder.build()
jit = cinn.ExecutionEngine()
jit.link(module)
args = [runtime.cinn_pod_value_t(_) for _ in input_args]
matmul = jit.lookup("matmul")
repeat = 4
tic = time.perf_counter()
for i in range(repeat):
matmul(args)
toc = time.perf_counter()
miniseconds = (toc - tic) / repeat * 1e3
print(f"Takes {miniseconds:0.3f} ms")
# The basic computation without any schedule has a performance as follows
test_performance(stages)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
Takes 3564.947 ms
Blocking
---------------------
.. code-block:: python
stages = cinn.create_stages([C])
bn = 32
i_outer, i_inner, j_outer, j_inner = stages[C].tile(0, 1, bn, bn)
k_outer, k_inner = stages[C].split("k1", 4)
stages[C].reorder([i_outer, j_outer, k_outer, k_inner, i_inner, j_inner])
# The performance is
test_performance(stages)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
Takes 365.406 ms
Vectorization
---------------------
.. code-block:: python
stages = cinn.create_stages([C])
bn = 32
i_outer, i_inner, j_outer, j_inner = stages[C].tile(0, 1, bn, bn)
k_outer, k_inner = stages[C].split("k1", 4)
stages[C].reorder([i_outer, j_outer, k_outer, k_inner, i_inner, j_inner])
stages[C].vectorize(j_inner, 8)
# The performance is
test_performance(stages)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
Takes 364.097 ms
Loop Permutation
---------------------
.. code-block:: python
stages = cinn.create_stages([C])
i_outer, i_inner, j_outer, j_inner = stages[C].tile(0, 1, bn, bn)
k_outer, k_inner = stages[C].split("k1", 4)
stages[C].reorder([i_outer, j_outer, k_outer, i_inner, k_inner, j_inner])
stages[C].vectorize(j_inner, 8)
stages[C].unroll(5)
test_performance(stages)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
Takes 76.672 ms
Array Packing
---------------------
.. code-block:: python
packedB = cinn.compute(
[n / bn, k, cinn.Expr(bn)], lambda x: B(x[1], x[0] * bn + x[2]), "packedB")
C = cinn.compute([m, n], lambda x: cinn.reduce_sum(
A(x[0], k1.expr()) * packedB(x[1] / bn, k1.expr(), x[1] % bn), [k1]), "C")
stages = cinn.create_stages([C])
stages[packedB].vectorize(2, 8)
i_outer, i_inner, j_outer, j_inner = stages[C].tile(0, 1, bn, bn)
k_outer, k_inner = stages[C].split("k1", 4)
stages[C].reorder([i_outer, j_outer, k_outer, i_inner, k_inner, j_inner])
stages[C].vectorize(j_inner, 8)
# We make the packedB as another input of the generated function and allocate a runtime buffer for it.
packedB_buf = runtime.cinn_buffer_t(
np.zeros([n.int() // bn, k.int(), bn]).astype("float32"),
runtime.cinn_x86_device, 32)
# The final performance is
test_performance(
stages,
fn_inputs=[A.to_tensor(), B.to_tensor(), C, packedB],
input_args=[a, b, c, packedB_buf])
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
Takes 76.697 ms
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 23.092 seconds)
.. _sphx_glr_download_tutorials_matmul.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: matmul.py `
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: matmul.ipynb `
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery `_