Struct BatchNormHelper

Struct Documentation

struct cinn::frontend::decomposer::BatchNormHelper

Public Functions

BatchNormHelper(CinnBuilder *cinn_builder, const std::vector<int> &arg_x_shape, const std::vector<int> &arg_param_shape, std::string data_layout, std::string bn_op_type)
~BatchNormHelper()
template<typename T>
Variable GetTensorFromScalar(T value, std::string name, const std::vector<int> &shape)
std::vector<Variable> MeanAndVariance(Variable x)
std::vector<Variable> GradBiasAndScale(Variable x, Variable x_mean, Variable y_grad)
Variable Mean(Variable x)
Variable Variance(Variable x, Variable mean)
Variable StdVarianceInv1d(Variable variance, float epsilon)
Variable StdVarianceInv4d(Variable variance, float epsilon)
Variable UpdateMeanVariance(Variable moving_value, Variable saved_value, float momentum)
Variable Reduce(Variable x)

Public Members

CinnBuilder *builder = {nullptr}
std::vector<int> x_shape
std::vector<int> param_shape
std::vector<int> reduce_dim
float element_count = {0}
int channel_dim = {0}
std::string op_type
int num_instructions = {0}