.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here ` to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_tutorials_jit.py:
JIT in CINN
=====================
In this tutorial, we will introduce the JIT module that execute the DSL on X86 and NV GPU.
.. code-block:: python
import cinn
import numpy as np
from cinn import runtime
# sphinx_gallery_thumbnail_path = './paddlepaddle.png'
declare some variables for latter use
.. code-block:: python
m = cinn.Expr(64)
n = cinn.Expr(64)
k = cinn.Expr(8)
bn = cinn.Expr(32)
Decleare the computation
-------------------------
.. code-block:: python
A = cinn.Placeholder("float32", "A", [m, k])
B = cinn.Placeholder("float32", "B", [k, n])
kr = cinn.Var(k.as_int32(), "kr")
C = cinn.compute([
m, n
], lambda v: cinn.reduce_sum(A(v[0], kr.expr()) * B(kr.expr(), v[1]), [kr]),
"C")
stages = cinn.create_stages([C])
target = cinn.Target()
builder = cinn.Module.Builder("matmul", target)
func = cinn.lower("matmul", stages, [A.to_tensor(), B.to_tensor(), C])
builder.add_function(func)
module = builder.build()
Create a JIT engine.
---------------------
.. code-block:: python
jit = cinn.ExecutionEngine()
jit.link(module)
Execute the compiled function
.. code-block:: python
a = runtime.cinn_buffer_t(
np.random.randn(m.int(), k.int()).astype("float32"),
runtime.cinn_x86_device)
b = runtime.cinn_buffer_t(
np.random.randn(m.int(), k.int()).astype("float32"),
runtime.cinn_x86_device)
c = runtime.cinn_buffer_t(
np.zeros([m.int(), n.int()]).astype("float32"), runtime.cinn_x86_device)
args = [runtime.cinn_pod_value_t(_) for _ in [a, b, c]]
matmul = jit.lookup("matmul")
matmul(args)
print(c.numpy())
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
[[ 2.171584 -2.2534988 0.52248585 ... 0.27057254 0.5578289
-5.015565 ]
[ 1.4556594 2.7231464 5.9027996 ... -3.8217752 4.089931
-7.0368733 ]
[-2.9768639 1.3423762 1.3999097 ... -2.1418748 -0.10759955
1.8420299 ]
...
[-4.676755 -0.3913666 -2.065456 ... 0.02816737 -4.3039927
0.8238183 ]
[ 4.750404 0.09432149 -2.3829799 ... 0.89701056 1.0036175
-2.9149737 ]
[ 1.7998325 4.3764253 -0.98203987 ... 0.21227497 2.9889355
-2.214414 ]]
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.500 seconds)
.. _sphx_glr_download_tutorials_jit.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: jit.py `
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: jit.ipynb `
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery `_