JIT in CINN

In this tutorial, we will introduce the JIT module that execute the DSL on X86 and NV GPU.

import cinn
import numpy as np
from cinn import runtime
# sphinx_gallery_thumbnail_path = './paddlepaddle.png'

declare some variables for latter use

m = cinn.Expr(64)
n = cinn.Expr(64)
k = cinn.Expr(8)
bn = cinn.Expr(32)

Decleare the computation

A = cinn.Placeholder("float32", "A", [m, k])
B = cinn.Placeholder("float32", "B", [k, n])

kr = cinn.Var(k.as_int32(), "kr")
C = cinn.compute([
    m, n
], lambda v: cinn.reduce_sum(A(v[0], kr.expr()) * B(kr.expr(), v[1]), [kr]),
                 "C")

stages = cinn.create_stages([C])

target = cinn.Target()
builder = cinn.Module.Builder("matmul", target)

func = cinn.lower("matmul", stages, [A.to_tensor(), B.to_tensor(), C])
builder.add_function(func)
module = builder.build()

Create a JIT engine.

jit = cinn.ExecutionEngine()
jit.link(module)

Execute the compiled function

a = runtime.cinn_buffer_t(
    np.random.randn(m.int(), k.int()).astype("float32"),
    runtime.cinn_x86_device)
b = runtime.cinn_buffer_t(
    np.random.randn(m.int(), k.int()).astype("float32"),
    runtime.cinn_x86_device)
c = runtime.cinn_buffer_t(
    np.zeros([m.int(), n.int()]).astype("float32"), runtime.cinn_x86_device)

args = [runtime.cinn_pod_value_t(_) for _ in [a, b, c]]
matmul = jit.lookup("matmul")
matmul(args)

print(c.numpy())

Out:

[[ 2.171584   -2.2534988   0.52248585 ...  0.27057254  0.5578289
  -5.015565  ]
 [ 1.4556594   2.7231464   5.9027996  ... -3.8217752   4.089931
  -7.0368733 ]
 [-2.9768639   1.3423762   1.3999097  ... -2.1418748  -0.10759955
   1.8420299 ]
 ...
 [-4.676755   -0.3913666  -2.065456   ...  0.02816737 -4.3039927
   0.8238183 ]
 [ 4.750404    0.09432149 -2.3829799  ...  0.89701056  1.0036175
  -2.9149737 ]
 [ 1.7998325   4.3764253  -0.98203987 ...  0.21227497  2.9889355
  -2.214414  ]]

Total running time of the script: ( 0 minutes 0.500 seconds)

Gallery generated by Sphinx-Gallery