Schedule Primitives in CINN

In this tutorial, we will guide you through the examples of using schedule primitives.

import cinn
import numpy as np
# sphinx_gallery_thumbnail_path = './paddlepaddle.png'

declare some variables for latter use Expr is short for expression.

m = cinn.Expr(32)
n = cinn.Expr(8)

print(m, n)
# get the integer contained in an integer expression
print(m.int())

Out:

32 8
32

A schedule can be created from a list of Tensors.

# declare an elementwise multiply
A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.Placeholder('float32', 'B', (m, n))
C = cinn.compute((m, n), lambda v: A(v[0], v[1]) * B(v[0], v[1]), name='C')

# create the stages for further schedule
stages = cinn.create_stages([C])

# lower will transform the computation to real code
fn = cinn.lower("fn", stages, [A.to_tensor(), B.to_tensor(), C])
print(fn)

Out:

function fn (_A, _B, _C)
{
  for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      C[i, j] = (A[i, j] * B[i, j])
    }
  }
}

One schedule is composed by multiple stages. We provide several methods to schedule each stage.

split

split can partition a specific axis into two axises by :code: factor.

A = cinn.Placeholder('float32', 'A', (m, ))
B = cinn.compute((m, ), lambda v: A(v[0]) * 2., name='B')

stages = cinn.create_stages([B])
i0, i1 = stages[B].split(level=0, factor=4)
fn = cinn.lower("fn", stages, [A.to_tensor(), B])
print(fn)

Out:

function fn (_A, _B)
{
  for (i_outer, 0, 8)
  {
    for (i_inner, 0, 4)
    {
      B[((4 * i_outer) + i_inner)] = (2 * A[((4 * i_outer) + i_inner)])
    }
  }
}

fuse

fuse can fuse two specific axises into a axis. It is the reverse operation of split.

A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.compute((m, n), lambda v: A(v[0], v[1]) * 2., name='B')

stages = cinn.create_stages([B])
i0 = stages[B].fuse(level0=0, level1=1)
fn = cinn.lower("fn", stages, [A.to_tensor(), B])
print(fn)

Out:

function fn (_A, _B)
{
  for (i_j_fused, 0, 256)
  {
    B[(i_j_fused / 8), (i_j_fused % 8)] = (2 * A[(i_j_fused / 8), (i_j_fused % 8)])
  }
}

tile

tile can partition two adjacent axises into blocks.

A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.Placeholder('float32', 'B', (m, n))
C = cinn.compute((m, n), lambda v: A(v[0], v[1]) * B(v[0], v[1]), name='C')

stages = cinn.create_stages([C])

i, j = stages[C].axis(0), stages[C].axis(1)
i_outer, i_inner, j_inner, j_outer = stages[C].tile(i, j, 4, 4)
fn = cinn.lower("fn", stages, [A.to_tensor(), B.to_tensor(), C])
print(fn)

Out:

function fn (_A, _B, _C)
{
  for (i_outer, 0, 8)
  {
    for (i_inner, 0, 4)
    {
      for (j_outer, 0, 2)
      {
        for (j_inner, 0, 4)
        {
          C[((4 * i_outer) + i_inner), ((4 * j_outer) + j_inner)] = (A[((4 * i_outer) + i_inner), ((4 * j_outer) + j_inner)] * B[((4 * i_outer) + i_inner), ((4 * j_outer) + j_inner)])
        }
      }
    }
  }
}

reorder

reorder can reorder the axises in the specified order.

A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.Placeholder('float32', 'B', (m, n))
C = cinn.compute((m, n), lambda v: A(v[0], v[1]) * B(v[0], v[1]), name='C')

stages = cinn.create_stages([C])
i0, i1 = stages[C].axis(0), stages[C].axis(1)
stages[C].reorder([i1, i0])

fn = cinn.lower("fn", stages, [A.to_tensor(), B.to_tensor(), C])
print(fn)

Out:

function fn (_A, _B, _C)
{
  for (j, 0, 8)
  {
    for (i, 0, 32)
    {
      C[i, j] = (A[i, j] * B[i, j])
    }
  }
}

unroll

unroll unroll a specific axis.

A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.Placeholder('float32', 'B', (m, n))
C = cinn.compute((m, n), lambda v: A(v[0], v[1]) * B(v[0], v[1]), name='C')

stages = cinn.create_stages([C])
i1 = stages[C].axis(1)
stages[C].unroll(i1)

fn = cinn.lower("fn", stages, [A.to_tensor(), B.to_tensor(), C])
print(fn)

Out:

function fn (_A, _B, _C)
{
  for (i, 0, 32)
  {
    C[i, 0] = (A[i, 0] * B[i, 0])
    C[i, 1] = (A[i, 1] * B[i, 1])
    C[i, 2] = (A[i, 2] * B[i, 2])
    C[i, 3] = (A[i, 3] * B[i, 3])
    C[i, 4] = (A[i, 4] * B[i, 4])
    C[i, 5] = (A[i, 5] * B[i, 5])
    C[i, 6] = (A[i, 6] * B[i, 6])
    C[i, 7] = (A[i, 7] * B[i, 7])
  }
}

compute_inline

compute_inline marks a stage as inline, then the computation body will be expanded and inserted at the location where the tensor is referenced.

A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.Placeholder('float32', 'B', (m, n))
C = cinn.compute((m, n), lambda v: A(v[0], v[1]) * B(v[0], v[1]), name='C')

# C1[i,j] = C[i,j] + B[i,j]
C1 = cinn.compute([m, n], lambda v: C(v[0], v[1]) + B(v[0], v[1]), "C1")
# C2[i,j] = C1[i,j] + B[i,j]
C2 = cinn.compute([m, n], lambda v: C1(v[0], v[1]) + B(v[0], v[1]), "C2")

stages = cinn.create_stages([C, C1, C2])

stages[C].compute_inline()
stages[C1].compute_inline()

fn = cinn.lower("fn", stages, [A.to_tensor(), B.to_tensor(), C2])
print(fn)

Out:

function fn (_A, _B, _C2)
{
  for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      C2[i, j] = ((2 * B[i, j]) + (A[i, j] * B[i, j]))
    }
  }
}

bind

bind can bind a specified axis with a thread axis.

A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.Placeholder('float32', 'B', (m, n))
C = cinn.compute((m, n), lambda v: A(v[0], v[1]) * B(v[0], v[1]), name='C')

stages = cinn.create_stages([C])
stages[C].bind(0, "blockIdx.x")
stages[C].bind(1, "threadIdx.x")

fn = cinn.lower("fn", stages, [A.to_tensor(), B.to_tensor(), C])
print(fn)

Out:

function fn (_A, _B, _C)
{
  for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      C[i, j] = (A[i, j] * B[i, j])
    }
  }
}

compute_at

compute_at can specify the stage to be computed at another stage’s scope. The input param other specifies the other stage. The input param level specifies the stage’s scope(which loop) to be computed at.

A = cinn.Placeholder('float32', 'A', (m, n, n))
B = cinn.Placeholder('float32', 'B', (m, n, n))
C = cinn.compute(
    (m, n), lambda v: A(v[0], v[1], v[1]) * B(v[0], v[1], v[1]), name='C')
D = cinn.compute((m, n), lambda v: C(v[0], v[1]) + 1., name='D')
stages = cinn.create_stages([C, D])

print("---------Before compute_at---------")
fn = cinn.lower("fn", stages, [A.to_tensor(), B.to_tensor(), C, D])
print(fn)

print("---------After compute_at---------")
stages[C].compute_at(other=stages[D], level=1)
fn2 = cinn.lower("fn", stages, [A.to_tensor(), B.to_tensor(), C, D])
print(fn2)

Out:

---------Before compute_at---------
function fn (_A, _B, _C, _D)
{
  for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      C[i, j] = (A[i, j, j] * B[i, j, j])
    }
  }
  for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      D[i, j] = (1 + C[i, j])
    }
  }
}
---------After compute_at---------
function fn (_A, _B, _C, _D)
{
  for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      C[i, j] = (A[i, j, j] * B[i, j, j])
      D[i, j] = (1 + C[i, j])
    }
  }
}

cache_read

cache_read can create a cache Tensor and load the origin Tensor’s data into this buffer. It will replace all the reading in the readers with the cache.

A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.compute((m, n), lambda v: A(v[0], v[1]) * 2., name='B')

stages = cinn.create_stages([B])
ACR = stages[A.to_tensor()].cache_read("local", [B], stages)
fn = cinn.lower("fn", stages, [A.to_tensor(), ACR, B])
print(fn)

Out:

function fn (_A, _A_read_cache, _B)
{
  for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      A_read_cache[i, j] = A[i, j]
    }
  }
  for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      B[i, j] = (2 * A_read_cache[i, j])
    }
  }
}

cache_write

cache_write can create a cache for writing to the original tensor. It will store the data in the cache memory first, then write to the output tensor.

A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.compute((m, n), lambda v: A(v[0], v[1]) * 2., name='B')

stages = cinn.create_stages([B])
BCR = stages[B].cache_write("local", stages, B)
fn = cinn.lower("fn", stages, [A.to_tensor(), B, BCR])
print(fn)

Out:

function fn (_A, _B, _B_write_cache)
{
  for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      B_write_cache[i, j] = (2 * A[i, j])
    }
  }
  for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      B[i, j] = B_write_cache[i, j]
    }
  }
}

Parallel

parallel will mark one loop to execute in parallel.(Only used in X86 backends)

A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.compute((m, n), lambda v: A(v[0], v[1]) * 2., name='B')

stages = cinn.create_stages([B])
stages[B].parallel(0)
fn = cinn.lower("fn", stages, [A.to_tensor(), B])
print(fn)

Out:

function fn (_A, _B)
{
  parallel for (i, 0, 32)
  {
    for (j, 0, 8)
    {
      B[i, j] = (2 * A[i, j])
    }
  }
}

Vectorize

vectorize will vectorize one loop in param level.(Only used in X86 backends)

A = cinn.Placeholder('float32', 'A', (m, n))
B = cinn.compute((m, n), lambda v: A(v[0], v[1]) * 2., name='B')

stages = cinn.create_stages([B])
stages[B].vectorize(0, 10)
fn = cinn.lower("fn", stages, [A.to_tensor(), B])
print(fn)

Out:

function fn (_A, _B)
{
  for (i, 0, 4)
  {
    for (j, 0, 8)
    {
      B[Ramp((10 * i),1,10), Broadcast(j,10)] = (Broadcast(2,10) * A[Ramp((10 * i),1,10), Broadcast(j,10)])
    }
  }
}

An example of optimizing performance in cuda backends

In this section, we will show you a practical example about optimizing performance using schedule primitives

Optimize an elementwise_add kernel using fuse, split and bind

A = cinn.Placeholder('float32', 'A', (m, m))
B = cinn.compute((m, m), lambda v: A([v[0], v[1]]) * 2., name='B')

stages = cinn.create_stages([B])
fn0 = cinn.lower("fn", stages, [A.to_tensor(), B])
print("Original kernel before optimizing:\n", fn0)
stages[B].fuse(0, 1)
stages[B].split(level=0, factor=256)
stages[B].bind(0, "blockIdx.x")
stages[B].bind(1, "threadIdx.x")
fn1 = cinn.lower("fn", stages, [A.to_tensor(), B])
print("\n======================================\nThe optimized kernel:\n", fn1)

Out:

Original kernel before optimizing:
 function fn (_A, _B)
{
  for (i, 0, 32)
  {
    for (j, 0, 32)
    {
      B[i, j] = (2 * A[i, j])
    }
  }
}

======================================
The optimized kernel:
 function fn (_A, _B)
{
  for (i_j_fused_outer, 0, 4)
  {
    for (i_j_fused_inner, 0, 256)
    {
      B[((i_j_fused_inner / 32) + (8 * i_j_fused_outer)), (i_j_fused_inner % 32)] = (2 * A[((i_j_fused_inner / 32) + (8 * i_j_fused_outer)), (i_j_fused_inner % 32)])
    }
  }
}

Thus we get an optimized kernel.

Total running time of the script: ( 0 minutes 0.441 seconds)

Gallery generated by Sphinx-Gallery