Note
Click here to download the full example code
Run model by using NetBuilder API
In this tutorial, we will introduce the ways to build and run a model using NetBuilder APIs.
import cinn
from cinn import frontend
from cinn import common
import numpy as np
# sphinx_gallery_thumbnail_path = './paddlepaddle.png'
Define the NetBuilder.
Using NetBuilder is a convenient way to build a model in CINN. You can build and run a model by invoking NetBuilder’s API as following.
name
: the ID of NetBuilder
Generally, the API in NetBuilder is coarse-grained operator, in other words, the DL framework like Paddle’s operator.
builder = frontend.NetBuilder(name="batchnorm_conv2d")
Define the input variable of the model.
The input variable should be created by create_input API. Note that the variable here is just a placeholder, does not need the actual data.
type
: the data type of input variable, now support Void, Int, UInt,
Float, Bool and String, the parameter is the type’s bit-widths, here the
data type is float32.
shape
: The shape of the input variable, note that here does not support
dynamic shape, so the dimension value should be greater than 0 now.
id_hint
: the name of variable, the defaule value is “”
a = builder.create_input(
type=common.Float(32), shape=(8, 3, 224, 224), id_hint="x")
scale = builder.create_input(type=common.Float(32), shape=[3], id_hint="scale")
bias = builder.create_input(type=common.Float(32), shape=[3], id_hint="bias")
mean = builder.create_input(type=common.Float(32), shape=[3], id_hint="mean")
variance = builder.create_input(
type=common.Float(32), shape=[3], id_hint="variance")
weight = builder.create_input(
type=common.Float(32), shape=(3, 3, 7, 7), id_hint="weight")
Build the model by using NetBuilder API
For convenience, here we build a simple model that only consists of batchnorm and conv2d operators. Note that you can find the operator’s detailed introduction in another document, we won’t go into detail here.
y = builder.batchnorm(a, scale, bias, mean, variance, is_test=True)
res = builder.conv2d(y[0], weight)
Set target
The target identified where the model should run, now we support two targets:
DefaultHostTarget
: the model will running at cpu.
DefaultNVGPUTarget
: the model will running at nv gpu.
if common.is_compiled_with_cuda():
target = common.DefaultNVGPUTarget()
else:
target = common.DefaultHostTarget()
print("Model running at ", target.arch)
Out:
Model running at Arch.X86
Generate the program
After the model building, the Computation will generate a CINN execution program, and you can get it like:
computation = frontend.Computation.build_and_compile(target, builder)
Random fake input data
Before running, you should read or generate some data to feed the model’s input.
get_tensor
: Get the tensor with specific name in computation.
from_numpy
: Fill the tensor with numpy data.
tensor_data = [
np.random.random([8, 3, 224, 224]).astype("float32"), # a
np.random.random([3]).astype("float32"), # scale
np.random.random([3]).astype("float32"), # bias
np.random.random([3]).astype("float32"), # mean
np.random.random([3]).astype("float32"), # variance
np.random.random([3, 3, 7, 7]).astype("float32") # weight
]
computation.get_tensor("x").from_numpy(tensor_data[0], target)
computation.get_tensor("scale").from_numpy(tensor_data[1], target)
computation.get_tensor("bias").from_numpy(tensor_data[2], target)
computation.get_tensor("mean").from_numpy(tensor_data[3], target)
computation.get_tensor("variance").from_numpy(tensor_data[4], target)
computation.get_tensor("weight").from_numpy(tensor_data[5], target)
Run program and print result
Finally, you can run the model by invoking function execute(). After that, you can get the tensor you want by get_tensor with tensor’s name.
Out:
[[[[41.300034 41.32889 40.87694 ... 41.017094 40.977764 41.43262 ]
[40.66744 40.374546 41.068977 ... 41.64359 40.59516 39.62327 ]
[41.18457 40.07553 40.306675 ... 42.416054 41.4802 40.468147]
...
[42.274876 41.293266 41.46016 ... 41.1733 41.638245 41.78466 ]
[41.524933 42.01216 40.807613 ... 41.243847 42.163734 41.821606]
[41.252533 42.235886 41.395493 ... 42.59662 42.104828 41.440254]]
[[43.64468 43.11713 43.218945 ... 44.419003 43.60452 43.440025]
[42.999943 42.594635 43.46457 ... 44.319992 43.11883 42.671257]
[42.989483 42.126328 43.062424 ... 44.44373 44.230293 43.123837]
...
[45.079826 44.20662 44.09692 ... 44.75235 44.563763 44.032536]
[44.17878 44.37436 43.422424 ... 43.29678 45.046303 44.439877]
[43.977108 44.29691 43.525467 ... 44.779503 44.789032 43.94192 ]]
[[37.902706 37.147522 38.064827 ... 37.633286 37.112053 38.035236]
[37.937916 37.733925 37.223503 ... 37.785645 37.849224 37.682865]
[37.63615 36.88459 37.599262 ... 38.829453 37.80579 37.403847]
...
[38.898945 37.51382 37.86469 ... 38.29599 38.365997 37.932575]
[39.033062 38.40738 37.687614 ... 38.914192 38.727806 38.252518]
[37.778 38.507298 37.885326 ... 38.903698 38.25751 38.395638]]]
[[[42.743874 43.090458 42.705624 ... 42.29254 42.29984 42.315804]
[41.652885 42.61314 43.07195 ... 42.349762 41.72451 41.791553]
[42.272617 42.885338 42.491203 ... 40.942463 41.132015 41.276176]
...
[41.428234 40.786144 41.092304 ... 42.25208 42.249966 41.074356]
[41.12266 41.098907 41.339474 ... 42.21494 41.98615 42.16423 ]
[41.093136 40.736393 40.612026 ... 42.172665 41.888657 40.606525]]
[[44.885036 45.419273 46.009888 ... 44.758255 44.279404 44.31435 ]
[44.446453 44.900543 45.541084 ... 44.830772 44.37794 44.25062 ]
[44.989704 44.842106 45.5936 ... 44.16348 43.43664 43.622528]
...
[43.886513 44.112514 43.552845 ... 44.617413 44.29766 43.1779 ]
[42.834915 43.53686 43.725327 ... 44.810688 44.49254 43.86616 ]
[43.13151 42.920265 43.617294 ... 43.78895 44.461014 44.139328]]
[[39.576283 39.76104 38.940693 ... 39.581345 39.08043 38.17172 ]
[38.40676 39.181484 39.586117 ... 39.260128 39.443054 38.862144]
[39.67811 39.61034 39.58296 ... 37.678688 37.63858 37.570877]
...
[38.896503 38.52317 37.95039 ... 38.271748 39.490025 38.525524]
[37.65303 37.681347 37.494225 ... 38.75704 38.511597 38.41539 ]
[37.697704 38.209988 37.305393 ... 38.317337 38.559345 37.90143 ]]]
[[[41.32929 41.74829 41.50349 ... 41.72447 42.585907 42.441044]
[40.862965 40.75343 41.075466 ... 42.22953 42.453655 42.17996 ]
[40.458736 40.56273 40.036434 ... 40.891075 42.229977 41.814816]
...
[41.324177 41.208374 40.753242 ... 40.90548 40.210472 40.166634]
[41.38516 41.701263 41.078594 ... 41.429382 40.10512 40.95803 ]
[41.10963 41.760563 41.107037 ... 41.583523 41.279423 40.511284]]
[[42.697742 43.3048 43.416958 ... 44.315994 44.881016 44.524723]
[43.65681 43.609573 43.47862 ... 44.12984 45.195675 44.301723]
[42.66219 42.837914 42.660976 ... 43.12605 44.239704 44.185734]
...
[43.35991 44.273983 43.207153 ... 43.064762 42.679085 42.847244]
[43.215584 44.4136 43.950047 ... 44.113617 43.480686 43.274685]
[43.044224 44.134 43.45668 ... 44.51243 43.642487 44.029 ]]
[[37.82997 38.30262 37.787506 ... 38.73762 39.145596 39.214157]
[37.7688 38.26052 38.676003 ... 39.158955 38.59617 38.96273 ]
[37.560474 37.328846 37.617752 ... 37.94291 38.034767 38.219986]
...
[38.091568 37.453407 37.60966 ... 37.83764 37.110413 38.025 ]
[37.61462 37.436928 37.175743 ... 38.37969 37.9572 37.72554 ]
[37.817383 38.15954 37.50487 ... 37.808636 37.927593 38.064747]]]
...
[[[42.500374 42.63498 43.143276 ... 41.59989 41.403538 41.439278]
[41.934628 41.667717 41.721813 ... 41.25938 41.717907 41.12168 ]
[42.086777 41.656315 42.209263 ... 40.612267 41.285275 41.18172 ]
...
[41.601242 40.837345 40.93395 ... 41.57721 40.890324 41.58985 ]
[41.779762 41.488 41.85838 ... 41.402683 40.187996 41.255142]
[41.348114 40.954266 40.663338 ... 41.30279 40.796883 41.137814]]
[[44.657394 45.53907 45.630577 ... 44.103313 44.40315 44.38892 ]
[44.509567 44.630383 44.980797 ... 43.858788 44.71564 43.71002 ]
[44.681923 44.48463 44.330616 ... 43.469017 43.979595 43.557213]
...
[43.882473 43.21293 43.204563 ... 44.308258 43.200928 43.94489 ]
[44.27936 43.209034 43.82105 ... 44.078785 43.275066 43.91738 ]
[43.55389 43.600197 43.039677 ... 43.96242 43.249626 43.746513]]
[[39.58318 39.320465 39.307163 ... 38.91711 38.4188 37.95429 ]
[38.552593 38.997684 38.26397 ... 38.576607 38.561592 38.705746]
[38.561188 39.047134 39.052013 ... 38.048313 38.5195 37.435566]
...
[38.281456 37.773933 37.552616 ... 38.11439 38.358353 38.466003]
[38.46599 37.60399 38.231476 ... 38.224827 37.64155 37.612072]
[38.35948 38.304157 38.339863 ... 38.57184 38.04912 38.03847 ]]]
[[[41.33109 42.337395 42.226433 ... 41.845016 41.81689 41.838345]
[42.269596 42.40946 41.885414 ... 41.318443 41.249157 43.030975]
[41.79686 41.850822 41.24382 ... 40.36381 41.83436 41.478294]
...
[41.865494 41.216534 41.465626 ... 41.557777 40.824944 41.928368]
[40.651833 40.52206 40.488293 ... 41.33708 41.15255 40.85743 ]
[40.93058 40.23158 41.188267 ... 41.2085 40.418922 40.98327 ]]
[[43.59166 44.904793 44.882133 ... 44.375744 44.469074 45.49534 ]
[43.99026 44.610493 44.02713 ... 44.04601 44.40529 44.809177]
[43.20949 44.403145 44.120373 ... 43.53783 43.44808 43.863758]
...
[44.147076 43.202297 43.602135 ... 43.992283 43.225285 44.347965]
[43.081 42.78432 43.434864 ... 43.61393 43.01665 43.658325]
[43.17987 42.65024 43.901566 ... 43.732456 42.9811 42.32024 ]]
[[38.399788 38.593864 39.02284 ... 38.261436 38.453922 38.70734 ]
[38.867073 38.48216 39.371414 ... 38.220577 38.467358 38.481518]
[38.20534 38.17316 37.907887 ... 38.19695 38.686535 38.449596]
...
[37.68741 37.41581 38.02491 ... 37.38249 38.079636 38.38591 ]
[37.626015 37.3454 37.000896 ... 37.047356 37.364502 38.381653]
[37.92605 37.451626 38.096653 ... 37.539745 37.250698 38.186066]]]
[[[41.68087 42.280346 42.89318 ... 41.69584 41.929256 42.50261 ]
[41.522465 41.427807 42.00095 ... 41.31566 41.97961 41.62702 ]
[42.431652 42.685707 41.267784 ... 41.95677 42.240917 41.612564]
...
[40.838047 41.18262 40.94151 ... 42.09618 42.608597 42.508175]
[41.33259 41.53562 41.837173 ... 41.67547 42.501366 42.1523 ]
[41.493443 41.757122 41.915096 ... 41.18798 41.721752 42.40112 ]]
[[44.750164 44.57851 45.26397 ... 44.9957 44.70977 45.24735 ]
[44.808926 44.77709 44.230633 ... 44.154804 44.47225 44.445686]
[43.908375 44.38397 44.62626 ... 44.741566 44.519047 44.55006 ]
...
[43.68286 43.816692 43.5502 ... 44.405636 45.3298 44.822346]
[43.891857 44.411987 43.604332 ... 43.910305 44.84048 44.914944]
[43.73606 44.097893 44.660606 ... 43.84984 44.46986 44.77598 ]]
[[38.98857 38.616226 39.361988 ... 38.474712 38.66995 38.37242 ]
[38.150547 38.697502 38.99803 ... 38.67756 38.64789 38.150562]
[39.18136 38.32171 38.53559 ... 38.609066 39.11044 38.908623]
...
[37.175915 37.93462 38.557713 ... 39.040977 39.305103 39.279564]
[38.43907 38.719715 38.484554 ... 37.88855 37.97843 39.52871 ]
[38.393093 38.361702 38.514805 ... 38.277405 38.47148 39.083828]]]]
Total running time of the script: ( 0 minutes 2.201 seconds)