Note
Click here to download the full example code
JIT in CINN
In this tutorial, we will introduce the JIT module that execute the DSL on X86 and NV GPU.
import cinn
import numpy as np
from cinn import runtime
# sphinx_gallery_thumbnail_path = './paddlepaddle.png'
declare some variables for latter use
m = cinn.Expr(64)
n = cinn.Expr(64)
k = cinn.Expr(8)
bn = cinn.Expr(32)
Decleare the computation
A = cinn.Placeholder("float32", "A", [m, k])
B = cinn.Placeholder("float32", "B", [k, n])
kr = cinn.Var(k.as_int32(), "kr")
C = cinn.compute([
m, n
], lambda v: cinn.reduce_sum(A(v[0], kr.expr()) * B(kr.expr(), v[1]), [kr]),
"C")
stages = cinn.create_stages([C])
target = cinn.Target()
builder = cinn.Module.Builder("matmul", target)
func = cinn.lower("matmul", stages, [A.to_tensor(), B.to_tensor(), C])
builder.add_function(func)
module = builder.build()
Create a JIT engine.
jit = cinn.ExecutionEngine()
jit.link(module)
Execute the compiled function
a = runtime.cinn_buffer_t(
np.random.randn(m.int(), k.int()).astype("float32"),
runtime.cinn_x86_device)
b = runtime.cinn_buffer_t(
np.random.randn(m.int(), k.int()).astype("float32"),
runtime.cinn_x86_device)
c = runtime.cinn_buffer_t(
np.zeros([m.int(), n.int()]).astype("float32"), runtime.cinn_x86_device)
args = [runtime.cinn_pod_value_t(_) for _ in [a, b, c]]
matmul = jit.lookup("matmul")
matmul(args)
print(c.numpy())
Out:
[[ 2.171584 -2.2534988 0.52248585 ... 0.27057254 0.5578289
-5.015565 ]
[ 1.4556594 2.7231464 5.9027996 ... -3.8217752 4.089931
-7.0368733 ]
[-2.9768639 1.3423762 1.3999097 ... -2.1418748 -0.10759955
1.8420299 ]
...
[-4.676755 -0.3913666 -2.065456 ... 0.02816737 -4.3039927
0.8238183 ]
[ 4.750404 0.09432149 -2.3829799 ... 0.89701056 1.0036175
-2.9149737 ]
[ 1.7998325 4.3764253 -0.98203987 ... 0.21227497 2.9889355
-2.214414 ]]
Total running time of the script: ( 0 minutes 0.500 seconds)