class ShiftScaleLayer(HybridBlock):
def init(self, axis=-1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
use_global_stats=True, fuse_relu=False,
beta_initializer=‘zeros’, gamma_initializer=‘ones’,
running_mean_initializer=‘zeros’, running_variance_initializer=‘ones’,
in_channels=0, **kwargs):
super(ShiftScaleLayer, self).init(**kwargs)
self._kwargs = {‘axis’: axis}
self.fuse_relu = fuse_relu
if in_channels != 0:
self.in_channels = in_channels
self.momentum = Variable('momentum', shape=in_channels, init=Constant(momentum))
self.epsilon = Variable('epsilon', shape=in_channels, init=Constant(epsilon))
self.momentum = BlockGrad(self.momentum)
self.epsilon = BlockGrad(self.epsilon)
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True,
differentiable=scale)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True,
differentiable=center)
self.running_mean = Variable('running_mean', shape=in_channels, init=initializer.Zero)
self.running_var = Variable('running_mean', shape=in_channels, init=initializer.One)
def batch_norm(self, F, X, gamma, beta, moving_mean, moving_var, eps, momentum, axis):
mean = F.mean(data=X, axis=axis)
var = F.mean(data=F.square(X - mean), axis=axis)
temp = F.broadcast_mul(momentum, moving_mean)
moving_mean = F.broadcast_add(temp, (1.0 - momentum))
moving_mean = moving_mean * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
X_hat = F.broadcast_div((X - moving_mean), F.sqrt(moving_var + eps))
Y = F.broadcast_add(F.broadcast_mul(gamma, X_hat), beta)
return Y, moving_mean, moving_var
def hybrid_forward(self, F, x, gamma, beta):
Y, self.running_mean, self.running_var = self.batch_norm(F, x, gamma, beta, self.running_mean, self.running_var,
self.epsilon, self.momentum, **self._kwargs)
return Y