.. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_net_builder.py: Run model by using NetBuilder API ========================================================= In this tutorial, we will introduce the ways to build and run a model using NetBuilder APIs. .. code-block:: python import cinn from cinn import frontend from cinn import common import numpy as np # sphinx_gallery_thumbnail_path = './paddlepaddle.png' Define the NetBuilder. ----------------------------- Using NetBuilder is a convenient way to build a model in CINN. You can build and run a model by invoking NetBuilder's API as following. :code:`name`: the ID of NetBuilder Generally, the API in `NetBuilder` is coarse-grained operator, in other words, the DL framework like Paddle's operator. .. code-block:: python builder = frontend.NetBuilder(name="batchnorm_conv2d") Define the input variable of the model. --------------------------------------------- The input variable should be created by create_input API. Note that the variable here is just a placeholder, does not need the actual data. :code:`type`: the data type of input variable, now support `Void`, `Int`, `UInt`, `Float`, `Bool` and `String`, the parameter is the type's bit-widths, here the data type is `float32`. :code:`shape`: The shape of the input variable, note that here does not support dynamic shape, so the dimension value should be greater than 0 now. :code:`id_hint`: the name of variable, the defaule value is `""` .. code-block:: python a = builder.create_input( type=common.Float(32), shape=(8, 3, 224, 224), id_hint="x") scale = builder.create_input(type=common.Float(32), shape=[3], id_hint="scale") bias = builder.create_input(type=common.Float(32), shape=[3], id_hint="bias") mean = builder.create_input(type=common.Float(32), shape=[3], id_hint="mean") variance = builder.create_input( type=common.Float(32), shape=[3], id_hint="variance") weight = builder.create_input( type=common.Float(32), shape=(3, 3, 7, 7), id_hint="weight") Build the model by using NetBuilder API --------------------------------------------- For convenience, here we build a simple model that only consists of batchnorm and conv2d operators. Note that you can find the operator's detailed introduction in another document, we won't go into detail here. .. code-block:: python y = builder.batchnorm(a, scale, bias, mean, variance, is_test=True) res = builder.conv2d(y[0], weight) Set target --------------------- The target identified where the model should run, now we support two targets: :code:`DefaultHostTarget`: the model will running at cpu. :code:`DefaultNVGPUTarget`: the model will running at nv gpu. .. code-block:: python if common.is_compiled_with_cuda(): target = common.DefaultNVGPUTarget() else: target = common.DefaultHostTarget() print("Model running at ", target.arch) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Model running at Arch.X86 Generate the program --------------------- After the model building, the `Computation` will generate a CINN execution program, and you can get it like: .. code-block:: python computation = frontend.Computation.build_and_compile(target, builder) Random fake input data ----------------------------- Before running, you should read or generate some data to feed the model's input. :code:`get_tensor`: Get the tensor with specific name in computation. :code:`from_numpy`: Fill the tensor with numpy data. .. code-block:: python tensor_data = [ np.random.random([8, 3, 224, 224]).astype("float32"), # a np.random.random([3]).astype("float32"), # scale np.random.random([3]).astype("float32"), # bias np.random.random([3]).astype("float32"), # mean np.random.random([3]).astype("float32"), # variance np.random.random([3, 3, 7, 7]).astype("float32") # weight ] computation.get_tensor("x").from_numpy(tensor_data[0], target) computation.get_tensor("scale").from_numpy(tensor_data[1], target) computation.get_tensor("bias").from_numpy(tensor_data[2], target) computation.get_tensor("mean").from_numpy(tensor_data[3], target) computation.get_tensor("variance").from_numpy(tensor_data[4], target) computation.get_tensor("weight").from_numpy(tensor_data[5], target) Run program and print result ----------------------------- Finally, you can run the model by invoking function `execute()`. After that, you can get the tensor you want by `get_tensor` with tensor's name. .. code-block:: python computation.execute() res_tensor = computation.get_tensor(str(res)) res_data = res_tensor.numpy(target) # print result print(res_data) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [[[[41.300034 41.32889 40.87694 ... 41.017094 40.977764 41.43262 ] [40.66744 40.374546 41.068977 ... 41.64359 40.59516 39.62327 ] [41.18457 40.07553 40.306675 ... 42.416054 41.4802 40.468147] ... [42.274876 41.293266 41.46016 ... 41.1733 41.638245 41.78466 ] [41.524933 42.01216 40.807613 ... 41.243847 42.163734 41.821606] [41.252533 42.235886 41.395493 ... 42.59662 42.104828 41.440254]] [[43.64468 43.11713 43.218945 ... 44.419003 43.60452 43.440025] [42.999943 42.594635 43.46457 ... 44.319992 43.11883 42.671257] [42.989483 42.126328 43.062424 ... 44.44373 44.230293 43.123837] ... [45.079826 44.20662 44.09692 ... 44.75235 44.563763 44.032536] [44.17878 44.37436 43.422424 ... 43.29678 45.046303 44.439877] [43.977108 44.29691 43.525467 ... 44.779503 44.789032 43.94192 ]] [[37.902706 37.147522 38.064827 ... 37.633286 37.112053 38.035236] [37.937916 37.733925 37.223503 ... 37.785645 37.849224 37.682865] [37.63615 36.88459 37.599262 ... 38.829453 37.80579 37.403847] ... [38.898945 37.51382 37.86469 ... 38.29599 38.365997 37.932575] [39.033062 38.40738 37.687614 ... 38.914192 38.727806 38.252518] [37.778 38.507298 37.885326 ... 38.903698 38.25751 38.395638]]] [[[42.743874 43.090458 42.705624 ... 42.29254 42.29984 42.315804] [41.652885 42.61314 43.07195 ... 42.349762 41.72451 41.791553] [42.272617 42.885338 42.491203 ... 40.942463 41.132015 41.276176] ... [41.428234 40.786144 41.092304 ... 42.25208 42.249966 41.074356] [41.12266 41.098907 41.339474 ... 42.21494 41.98615 42.16423 ] [41.093136 40.736393 40.612026 ... 42.172665 41.888657 40.606525]] [[44.885036 45.419273 46.009888 ... 44.758255 44.279404 44.31435 ] [44.446453 44.900543 45.541084 ... 44.830772 44.37794 44.25062 ] [44.989704 44.842106 45.5936 ... 44.16348 43.43664 43.622528] ... [43.886513 44.112514 43.552845 ... 44.617413 44.29766 43.1779 ] [42.834915 43.53686 43.725327 ... 44.810688 44.49254 43.86616 ] [43.13151 42.920265 43.617294 ... 43.78895 44.461014 44.139328]] [[39.576283 39.76104 38.940693 ... 39.581345 39.08043 38.17172 ] [38.40676 39.181484 39.586117 ... 39.260128 39.443054 38.862144] [39.67811 39.61034 39.58296 ... 37.678688 37.63858 37.570877] ... [38.896503 38.52317 37.95039 ... 38.271748 39.490025 38.525524] [37.65303 37.681347 37.494225 ... 38.75704 38.511597 38.41539 ] [37.697704 38.209988 37.305393 ... 38.317337 38.559345 37.90143 ]]] [[[41.32929 41.74829 41.50349 ... 41.72447 42.585907 42.441044] [40.862965 40.75343 41.075466 ... 42.22953 42.453655 42.17996 ] [40.458736 40.56273 40.036434 ... 40.891075 42.229977 41.814816] ... [41.324177 41.208374 40.753242 ... 40.90548 40.210472 40.166634] [41.38516 41.701263 41.078594 ... 41.429382 40.10512 40.95803 ] [41.10963 41.760563 41.107037 ... 41.583523 41.279423 40.511284]] [[42.697742 43.3048 43.416958 ... 44.315994 44.881016 44.524723] [43.65681 43.609573 43.47862 ... 44.12984 45.195675 44.301723] [42.66219 42.837914 42.660976 ... 43.12605 44.239704 44.185734] ... [43.35991 44.273983 43.207153 ... 43.064762 42.679085 42.847244] [43.215584 44.4136 43.950047 ... 44.113617 43.480686 43.274685] [43.044224 44.134 43.45668 ... 44.51243 43.642487 44.029 ]] [[37.82997 38.30262 37.787506 ... 38.73762 39.145596 39.214157] [37.7688 38.26052 38.676003 ... 39.158955 38.59617 38.96273 ] [37.560474 37.328846 37.617752 ... 37.94291 38.034767 38.219986] ... [38.091568 37.453407 37.60966 ... 37.83764 37.110413 38.025 ] [37.61462 37.436928 37.175743 ... 38.37969 37.9572 37.72554 ] [37.817383 38.15954 37.50487 ... 37.808636 37.927593 38.064747]]] ... [[[42.500374 42.63498 43.143276 ... 41.59989 41.403538 41.439278] [41.934628 41.667717 41.721813 ... 41.25938 41.717907 41.12168 ] [42.086777 41.656315 42.209263 ... 40.612267 41.285275 41.18172 ] ... [41.601242 40.837345 40.93395 ... 41.57721 40.890324 41.58985 ] [41.779762 41.488 41.85838 ... 41.402683 40.187996 41.255142] [41.348114 40.954266 40.663338 ... 41.30279 40.796883 41.137814]] [[44.657394 45.53907 45.630577 ... 44.103313 44.40315 44.38892 ] [44.509567 44.630383 44.980797 ... 43.858788 44.71564 43.71002 ] [44.681923 44.48463 44.330616 ... 43.469017 43.979595 43.557213] ... [43.882473 43.21293 43.204563 ... 44.308258 43.200928 43.94489 ] [44.27936 43.209034 43.82105 ... 44.078785 43.275066 43.91738 ] [43.55389 43.600197 43.039677 ... 43.96242 43.249626 43.746513]] [[39.58318 39.320465 39.307163 ... 38.91711 38.4188 37.95429 ] [38.552593 38.997684 38.26397 ... 38.576607 38.561592 38.705746] [38.561188 39.047134 39.052013 ... 38.048313 38.5195 37.435566] ... [38.281456 37.773933 37.552616 ... 38.11439 38.358353 38.466003] [38.46599 37.60399 38.231476 ... 38.224827 37.64155 37.612072] [38.35948 38.304157 38.339863 ... 38.57184 38.04912 38.03847 ]]] [[[41.33109 42.337395 42.226433 ... 41.845016 41.81689 41.838345] [42.269596 42.40946 41.885414 ... 41.318443 41.249157 43.030975] [41.79686 41.850822 41.24382 ... 40.36381 41.83436 41.478294] ... [41.865494 41.216534 41.465626 ... 41.557777 40.824944 41.928368] [40.651833 40.52206 40.488293 ... 41.33708 41.15255 40.85743 ] [40.93058 40.23158 41.188267 ... 41.2085 40.418922 40.98327 ]] [[43.59166 44.904793 44.882133 ... 44.375744 44.469074 45.49534 ] [43.99026 44.610493 44.02713 ... 44.04601 44.40529 44.809177] [43.20949 44.403145 44.120373 ... 43.53783 43.44808 43.863758] ... [44.147076 43.202297 43.602135 ... 43.992283 43.225285 44.347965] [43.081 42.78432 43.434864 ... 43.61393 43.01665 43.658325] [43.17987 42.65024 43.901566 ... 43.732456 42.9811 42.32024 ]] [[38.399788 38.593864 39.02284 ... 38.261436 38.453922 38.70734 ] [38.867073 38.48216 39.371414 ... 38.220577 38.467358 38.481518] [38.20534 38.17316 37.907887 ... 38.19695 38.686535 38.449596] ... [37.68741 37.41581 38.02491 ... 37.38249 38.079636 38.38591 ] [37.626015 37.3454 37.000896 ... 37.047356 37.364502 38.381653] [37.92605 37.451626 38.096653 ... 37.539745 37.250698 38.186066]]] [[[41.68087 42.280346 42.89318 ... 41.69584 41.929256 42.50261 ] [41.522465 41.427807 42.00095 ... 41.31566 41.97961 41.62702 ] [42.431652 42.685707 41.267784 ... 41.95677 42.240917 41.612564] ... [40.838047 41.18262 40.94151 ... 42.09618 42.608597 42.508175] [41.33259 41.53562 41.837173 ... 41.67547 42.501366 42.1523 ] [41.493443 41.757122 41.915096 ... 41.18798 41.721752 42.40112 ]] [[44.750164 44.57851 45.26397 ... 44.9957 44.70977 45.24735 ] [44.808926 44.77709 44.230633 ... 44.154804 44.47225 44.445686] [43.908375 44.38397 44.62626 ... 44.741566 44.519047 44.55006 ] ... [43.68286 43.816692 43.5502 ... 44.405636 45.3298 44.822346] [43.891857 44.411987 43.604332 ... 43.910305 44.84048 44.914944] [43.73606 44.097893 44.660606 ... 43.84984 44.46986 44.77598 ]] [[38.98857 38.616226 39.361988 ... 38.474712 38.66995 38.37242 ] [38.150547 38.697502 38.99803 ... 38.67756 38.64789 38.150562] [39.18136 38.32171 38.53559 ... 38.609066 39.11044 38.908623] ... [37.175915 37.93462 38.557713 ... 39.040977 39.305103 39.279564] [38.43907 38.719715 38.484554 ... 37.88855 37.97843 39.52871 ] [38.393093 38.361702 38.514805 ... 38.277405 38.47148 39.083828]]]] .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 2.201 seconds) .. _sphx_glr_download_tutorials_net_builder.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: net_builder.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: net_builder.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_