Run model by using NetBuilder API

In this tutorial, we will introduce the ways to build and run a model using NetBuilder APIs.

import cinn
from cinn import frontend
from cinn import common
import numpy as np
# sphinx_gallery_thumbnail_path = './paddlepaddle.png'

Define the NetBuilder.

Using NetBuilder is a convenient way to build a model in CINN. You can build and run a model by invoking NetBuilder’s API as following.

name: the ID of NetBuilder

Generally, the API in NetBuilder is coarse-grained operator, in other words, the DL framework like Paddle’s operator.

builder = frontend.NetBuilder(name="batchnorm_conv2d")

Define the input variable of the model.

The input variable should be created by create_input API. Note that the variable here is just a placeholder, does not need the actual data.

type: the data type of input variable, now support Void, Int, UInt, Float, Bool and String, the parameter is the type’s bit-widths, here the data type is float32.

shape: The shape of the input variable, note that here does not support dynamic shape, so the dimension value should be greater than 0 now.

id_hint: the name of variable, the defaule value is “”

a = builder.create_input(
    type=common.Float(32), shape=(8, 3, 224, 224), id_hint="x")
scale = builder.create_input(type=common.Float(32), shape=[3], id_hint="scale")
bias = builder.create_input(type=common.Float(32), shape=[3], id_hint="bias")
mean = builder.create_input(type=common.Float(32), shape=[3], id_hint="mean")
variance = builder.create_input(
    type=common.Float(32), shape=[3], id_hint="variance")
weight = builder.create_input(
    type=common.Float(32), shape=(3, 3, 7, 7), id_hint="weight")

Build the model by using NetBuilder API

For convenience, here we build a simple model that only consists of batchnorm and conv2d operators. Note that you can find the operator’s detailed introduction in another document, we won’t go into detail here.

y = builder.batchnorm(a, scale, bias, mean, variance, is_test=True)
res = builder.conv2d(y[0], weight)

Set target

The target identified where the model should run, now we support two targets:

DefaultHostTarget: the model will running at cpu.

DefaultNVGPUTarget: the model will running at nv gpu.

if common.is_compiled_with_cuda():
    target = common.DefaultNVGPUTarget()
else:
    target = common.DefaultHostTarget()

print("Model running at ", target.arch)

Out:

Model running at  Arch.X86

Generate the program

After the model building, the Computation will generate a CINN execution program, and you can get it like:

computation = frontend.Computation.build_and_compile(target, builder)

Random fake input data

Before running, you should read or generate some data to feed the model’s input. get_tensor: Get the tensor with specific name in computation. from_numpy: Fill the tensor with numpy data.

tensor_data = [
    np.random.random([8, 3, 224, 224]).astype("float32"),  # a
    np.random.random([3]).astype("float32"),  # scale
    np.random.random([3]).astype("float32"),  # bias
    np.random.random([3]).astype("float32"),  # mean
    np.random.random([3]).astype("float32"),  # variance
    np.random.random([3, 3, 7, 7]).astype("float32")  # weight
]

computation.get_tensor("x").from_numpy(tensor_data[0], target)
computation.get_tensor("scale").from_numpy(tensor_data[1], target)
computation.get_tensor("bias").from_numpy(tensor_data[2], target)
computation.get_tensor("mean").from_numpy(tensor_data[3], target)
computation.get_tensor("variance").from_numpy(tensor_data[4], target)
computation.get_tensor("weight").from_numpy(tensor_data[5], target)

Run program and print result

Finally, you can run the model by invoking function execute(). After that, you can get the tensor you want by get_tensor with tensor’s name.

computation.execute()
res_tensor = computation.get_tensor(str(res))
res_data = res_tensor.numpy(target)

# print result
print(res_data)

Out:

[[[[41.300034 41.32889  40.87694  ... 41.017094 40.977764 41.43262 ]
   [40.66744  40.374546 41.068977 ... 41.64359  40.59516  39.62327 ]
   [41.18457  40.07553  40.306675 ... 42.416054 41.4802   40.468147]
   ...
   [42.274876 41.293266 41.46016  ... 41.1733   41.638245 41.78466 ]
   [41.524933 42.01216  40.807613 ... 41.243847 42.163734 41.821606]
   [41.252533 42.235886 41.395493 ... 42.59662  42.104828 41.440254]]

  [[43.64468  43.11713  43.218945 ... 44.419003 43.60452  43.440025]
   [42.999943 42.594635 43.46457  ... 44.319992 43.11883  42.671257]
   [42.989483 42.126328 43.062424 ... 44.44373  44.230293 43.123837]
   ...
   [45.079826 44.20662  44.09692  ... 44.75235  44.563763 44.032536]
   [44.17878  44.37436  43.422424 ... 43.29678  45.046303 44.439877]
   [43.977108 44.29691  43.525467 ... 44.779503 44.789032 43.94192 ]]

  [[37.902706 37.147522 38.064827 ... 37.633286 37.112053 38.035236]
   [37.937916 37.733925 37.223503 ... 37.785645 37.849224 37.682865]
   [37.63615  36.88459  37.599262 ... 38.829453 37.80579  37.403847]
   ...
   [38.898945 37.51382  37.86469  ... 38.29599  38.365997 37.932575]
   [39.033062 38.40738  37.687614 ... 38.914192 38.727806 38.252518]
   [37.778    38.507298 37.885326 ... 38.903698 38.25751  38.395638]]]


 [[[42.743874 43.090458 42.705624 ... 42.29254  42.29984  42.315804]
   [41.652885 42.61314  43.07195  ... 42.349762 41.72451  41.791553]
   [42.272617 42.885338 42.491203 ... 40.942463 41.132015 41.276176]
   ...
   [41.428234 40.786144 41.092304 ... 42.25208  42.249966 41.074356]
   [41.12266  41.098907 41.339474 ... 42.21494  41.98615  42.16423 ]
   [41.093136 40.736393 40.612026 ... 42.172665 41.888657 40.606525]]

  [[44.885036 45.419273 46.009888 ... 44.758255 44.279404 44.31435 ]
   [44.446453 44.900543 45.541084 ... 44.830772 44.37794  44.25062 ]
   [44.989704 44.842106 45.5936   ... 44.16348  43.43664  43.622528]
   ...
   [43.886513 44.112514 43.552845 ... 44.617413 44.29766  43.1779  ]
   [42.834915 43.53686  43.725327 ... 44.810688 44.49254  43.86616 ]
   [43.13151  42.920265 43.617294 ... 43.78895  44.461014 44.139328]]

  [[39.576283 39.76104  38.940693 ... 39.581345 39.08043  38.17172 ]
   [38.40676  39.181484 39.586117 ... 39.260128 39.443054 38.862144]
   [39.67811  39.61034  39.58296  ... 37.678688 37.63858  37.570877]
   ...
   [38.896503 38.52317  37.95039  ... 38.271748 39.490025 38.525524]
   [37.65303  37.681347 37.494225 ... 38.75704  38.511597 38.41539 ]
   [37.697704 38.209988 37.305393 ... 38.317337 38.559345 37.90143 ]]]


 [[[41.32929  41.74829  41.50349  ... 41.72447  42.585907 42.441044]
   [40.862965 40.75343  41.075466 ... 42.22953  42.453655 42.17996 ]
   [40.458736 40.56273  40.036434 ... 40.891075 42.229977 41.814816]
   ...
   [41.324177 41.208374 40.753242 ... 40.90548  40.210472 40.166634]
   [41.38516  41.701263 41.078594 ... 41.429382 40.10512  40.95803 ]
   [41.10963  41.760563 41.107037 ... 41.583523 41.279423 40.511284]]

  [[42.697742 43.3048   43.416958 ... 44.315994 44.881016 44.524723]
   [43.65681  43.609573 43.47862  ... 44.12984  45.195675 44.301723]
   [42.66219  42.837914 42.660976 ... 43.12605  44.239704 44.185734]
   ...
   [43.35991  44.273983 43.207153 ... 43.064762 42.679085 42.847244]
   [43.215584 44.4136   43.950047 ... 44.113617 43.480686 43.274685]
   [43.044224 44.134    43.45668  ... 44.51243  43.642487 44.029   ]]

  [[37.82997  38.30262  37.787506 ... 38.73762  39.145596 39.214157]
   [37.7688   38.26052  38.676003 ... 39.158955 38.59617  38.96273 ]
   [37.560474 37.328846 37.617752 ... 37.94291  38.034767 38.219986]
   ...
   [38.091568 37.453407 37.60966  ... 37.83764  37.110413 38.025   ]
   [37.61462  37.436928 37.175743 ... 38.37969  37.9572   37.72554 ]
   [37.817383 38.15954  37.50487  ... 37.808636 37.927593 38.064747]]]


 ...


 [[[42.500374 42.63498  43.143276 ... 41.59989  41.403538 41.439278]
   [41.934628 41.667717 41.721813 ... 41.25938  41.717907 41.12168 ]
   [42.086777 41.656315 42.209263 ... 40.612267 41.285275 41.18172 ]
   ...
   [41.601242 40.837345 40.93395  ... 41.57721  40.890324 41.58985 ]
   [41.779762 41.488    41.85838  ... 41.402683 40.187996 41.255142]
   [41.348114 40.954266 40.663338 ... 41.30279  40.796883 41.137814]]

  [[44.657394 45.53907  45.630577 ... 44.103313 44.40315  44.38892 ]
   [44.509567 44.630383 44.980797 ... 43.858788 44.71564  43.71002 ]
   [44.681923 44.48463  44.330616 ... 43.469017 43.979595 43.557213]
   ...
   [43.882473 43.21293  43.204563 ... 44.308258 43.200928 43.94489 ]
   [44.27936  43.209034 43.82105  ... 44.078785 43.275066 43.91738 ]
   [43.55389  43.600197 43.039677 ... 43.96242  43.249626 43.746513]]

  [[39.58318  39.320465 39.307163 ... 38.91711  38.4188   37.95429 ]
   [38.552593 38.997684 38.26397  ... 38.576607 38.561592 38.705746]
   [38.561188 39.047134 39.052013 ... 38.048313 38.5195   37.435566]
   ...
   [38.281456 37.773933 37.552616 ... 38.11439  38.358353 38.466003]
   [38.46599  37.60399  38.231476 ... 38.224827 37.64155  37.612072]
   [38.35948  38.304157 38.339863 ... 38.57184  38.04912  38.03847 ]]]


 [[[41.33109  42.337395 42.226433 ... 41.845016 41.81689  41.838345]
   [42.269596 42.40946  41.885414 ... 41.318443 41.249157 43.030975]
   [41.79686  41.850822 41.24382  ... 40.36381  41.83436  41.478294]
   ...
   [41.865494 41.216534 41.465626 ... 41.557777 40.824944 41.928368]
   [40.651833 40.52206  40.488293 ... 41.33708  41.15255  40.85743 ]
   [40.93058  40.23158  41.188267 ... 41.2085   40.418922 40.98327 ]]

  [[43.59166  44.904793 44.882133 ... 44.375744 44.469074 45.49534 ]
   [43.99026  44.610493 44.02713  ... 44.04601  44.40529  44.809177]
   [43.20949  44.403145 44.120373 ... 43.53783  43.44808  43.863758]
   ...
   [44.147076 43.202297 43.602135 ... 43.992283 43.225285 44.347965]
   [43.081    42.78432  43.434864 ... 43.61393  43.01665  43.658325]
   [43.17987  42.65024  43.901566 ... 43.732456 42.9811   42.32024 ]]

  [[38.399788 38.593864 39.02284  ... 38.261436 38.453922 38.70734 ]
   [38.867073 38.48216  39.371414 ... 38.220577 38.467358 38.481518]
   [38.20534  38.17316  37.907887 ... 38.19695  38.686535 38.449596]
   ...
   [37.68741  37.41581  38.02491  ... 37.38249  38.079636 38.38591 ]
   [37.626015 37.3454   37.000896 ... 37.047356 37.364502 38.381653]
   [37.92605  37.451626 38.096653 ... 37.539745 37.250698 38.186066]]]


 [[[41.68087  42.280346 42.89318  ... 41.69584  41.929256 42.50261 ]
   [41.522465 41.427807 42.00095  ... 41.31566  41.97961  41.62702 ]
   [42.431652 42.685707 41.267784 ... 41.95677  42.240917 41.612564]
   ...
   [40.838047 41.18262  40.94151  ... 42.09618  42.608597 42.508175]
   [41.33259  41.53562  41.837173 ... 41.67547  42.501366 42.1523  ]
   [41.493443 41.757122 41.915096 ... 41.18798  41.721752 42.40112 ]]

  [[44.750164 44.57851  45.26397  ... 44.9957   44.70977  45.24735 ]
   [44.808926 44.77709  44.230633 ... 44.154804 44.47225  44.445686]
   [43.908375 44.38397  44.62626  ... 44.741566 44.519047 44.55006 ]
   ...
   [43.68286  43.816692 43.5502   ... 44.405636 45.3298   44.822346]
   [43.891857 44.411987 43.604332 ... 43.910305 44.84048  44.914944]
   [43.73606  44.097893 44.660606 ... 43.84984  44.46986  44.77598 ]]

  [[38.98857  38.616226 39.361988 ... 38.474712 38.66995  38.37242 ]
   [38.150547 38.697502 38.99803  ... 38.67756  38.64789  38.150562]
   [39.18136  38.32171  38.53559  ... 38.609066 39.11044  38.908623]
   ...
   [37.175915 37.93462  38.557713 ... 39.040977 39.305103 39.279564]
   [38.43907  38.719715 38.484554 ... 37.88855  37.97843  39.52871 ]
   [38.393093 38.361702 38.514805 ... 38.277405 38.47148  39.083828]]]]

Total running time of the script: ( 0 minutes 2.201 seconds)

Gallery generated by Sphinx-Gallery