.. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_jit.py: JIT in CINN ===================== In this tutorial, we will introduce the JIT module that execute the DSL on X86 and NV GPU. .. code-block:: python import cinn import numpy as np from cinn import runtime # sphinx_gallery_thumbnail_path = './paddlepaddle.png' declare some variables for latter use .. code-block:: python m = cinn.Expr(64) n = cinn.Expr(64) k = cinn.Expr(8) bn = cinn.Expr(32) Decleare the computation ------------------------- .. code-block:: python A = cinn.Placeholder("float32", "A", [m, k]) B = cinn.Placeholder("float32", "B", [k, n]) kr = cinn.Var(k.as_int32(), "kr") C = cinn.compute([ m, n ], lambda v: cinn.reduce_sum(A(v[0], kr.expr()) * B(kr.expr(), v[1]), [kr]), "C") stages = cinn.create_stages([C]) target = cinn.Target() builder = cinn.Module.Builder("matmul", target) func = cinn.lower("matmul", stages, [A.to_tensor(), B.to_tensor(), C]) builder.add_function(func) module = builder.build() Create a JIT engine. --------------------- .. code-block:: python jit = cinn.ExecutionEngine() jit.link(module) Execute the compiled function .. code-block:: python a = runtime.cinn_buffer_t( np.random.randn(m.int(), k.int()).astype("float32"), runtime.cinn_x86_device) b = runtime.cinn_buffer_t( np.random.randn(m.int(), k.int()).astype("float32"), runtime.cinn_x86_device) c = runtime.cinn_buffer_t( np.zeros([m.int(), n.int()]).astype("float32"), runtime.cinn_x86_device) args = [runtime.cinn_pod_value_t(_) for _ in [a, b, c]] matmul = jit.lookup("matmul") matmul(args) print(c.numpy()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [[ 2.171584 -2.2534988 0.52248585 ... 0.27057254 0.5578289 -5.015565 ] [ 1.4556594 2.7231464 5.9027996 ... -3.8217752 4.089931 -7.0368733 ] [-2.9768639 1.3423762 1.3999097 ... -2.1418748 -0.10759955 1.8420299 ] ... [-4.676755 -0.3913666 -2.065456 ... 0.02816737 -4.3039927 0.8238183 ] [ 4.750404 0.09432149 -2.3829799 ... 0.89701056 1.0036175 -2.9149737 ] [ 1.7998325 4.3764253 -0.98203987 ... 0.21227497 2.9889355 -2.214414 ]] .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.500 seconds) .. _sphx_glr_download_tutorials_jit.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: jit.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: jit.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_