def produce_generator(input_and_mask,
fix_gamma = True):
#input_and_mask = mx.sym.Variable("input_and_mask", shape=[batch_size, channel_num + 1, height, width])
#### blocks before dilate
conv1 = mx.sym.Convolution(data=input_and_mask, kernel=(5, 5), num_filter=64, stride=(1, 1),
dilate=(1, 1), name="conv1", pad=(2, 2))
norm1 = mx.sym.BatchNorm(data = conv1, fix_gamma=fix_gamma, name="norm1")
act1 = mx.gluon.nn.ELU()(norm1)
print(act1.infer_shape()[1])
conv2 = mx.sym.Convolution(data=act1, kernel=(3, 3), num_filter=128, stride=(2, 2),
dilate=(1, 1), name="conv2", pad=(1, 1))
norm2 = mx.sym.BatchNorm(data = conv2, fix_gamma=fix_gamma, name="norm2")
act2 = mx.gluon.nn.ELU()(norm2)
print(act2.infer_shape()[1])
conv3 = mx.sym.Convolution(data=act2, kernel=(3, 3), num_filter=128, stride=(1, 1),
dilate=(1, 1), name="conv3", pad = (1, 1))
norm3 = mx.sym.BatchNorm(data = conv3, fix_gamma=fix_gamma, name="norm3")
act3 = mx.gluon.nn.ELU()(norm3)
print(act3.infer_shape()[1])
conv4 = mx.sym.Convolution(data=act3, kernel=(3, 3), num_filter=256, stride=(2, 2),
dilate=(1, 1), name="conv4", pad = (1, 1))
norm4 = mx.sym.BatchNorm(data = conv4, fix_gamma=fix_gamma, name="norm4")
act4 = mx.gluon.nn.ELU()(norm4)
print(act4.infer_shape()[1])
conv5 = mx.sym.Convolution(data=act4, kernel=(3, 3), num_filter=256, stride=(1, 1),
dilate=(1, 1), name="conv5", pad = (1, 1))
norm5 = mx.sym.BatchNorm(data = conv5, fix_gamma=fix_gamma, name="norm5")
act5 = mx.gluon.nn.ELU()(norm5)
print(act5.infer_shape()[1])
conv6 = mx.sym.Convolution(data=act5, kernel=(3, 3), num_filter=256, stride=(1, 1),
dilate=(1, 1), name="conv6", pad = (1, 1))
norm6 = mx.sym.BatchNorm(data = conv6, fix_gamma=fix_gamma, name="norm6")
act6 = mx.gluon.nn.ELU()(norm6)
print(act6.infer_shape()[1])
#### dilate
print("dilate")
dilated_conv1 = mx.sym.Convolution(data=act6, kernel=(3, 3), num_filter=256, stride=(1, 1),
dilate=(2, 2), name="dilated_conv1", pad = (2, 2))
norm7 = mx.sym.BatchNorm(data = dilated_conv1, fix_gamma=fix_gamma, name="norm7")
act7 = mx.gluon.nn.ELU()(norm7)
print(act7.infer_shape()[1])
dilated_conv2 = mx.sym.Convolution(data=act7, kernel=(3, 3), num_filter=256, stride=(1, 1),
dilate=(4, 4), name="dilated_conv2", pad = (4, 4))
norm8 = mx.sym.BatchNorm(data = dilated_conv2, fix_gamma=fix_gamma, name="norm8")
act8 = mx.gluon.nn.ELU()(norm8)
print(act8.infer_shape()[1])
dilated_conv3 = mx.sym.Convolution(data=act8, kernel=(3, 3), num_filter=256, stride=(1, 1),
dilate=(8, 8), name="dilated_conv3", pad = (8, 8))
norm9 = mx.sym.BatchNorm(data = dilated_conv3, fix_gamma=fix_gamma, name="norm9")
act9 = mx.gluon.nn.ELU()(norm9)
print(act9.infer_shape()[1])
dilated_conv4 = mx.sym.Convolution(data=act9, kernel=(3, 3), num_filter=256, stride=(1, 1),
dilate=(16, 16), name="dilated_conv4", pad = (16, 16))
norm10 = mx.sym.BatchNorm(data = dilated_conv4, fix_gamma=fix_gamma, name="norm10")
act10 = mx.gluon.nn.ELU()(norm10)
print(act10.infer_shape()[1])
print("conv :")
### conv
conv7 = mx.sym.Convolution(data=act10, kernel=(3, 3), num_filter=256, stride=(1, 1),
dilate=(1, 1), name="conv7", pad = (1, 1))
norm11 = mx.sym.BatchNorm(data = conv7, fix_gamma=fix_gamma, name="norm11")
act11 = mx.gluon.nn.ELU()(norm11)
print(act11.infer_shape()[1])
conv8 = mx.sym.Convolution(data=act11, kernel=(3, 3), num_filter=256, stride=(1, 1),
dilate=(1, 1), name="conv8", pad = (1, 1))
norm12 = mx.sym.BatchNorm(data = conv8, fix_gamma=fix_gamma, name="norm12")
act12 = mx.gluon.nn.ELU()(norm12)
print(act12.infer_shape()[1])
print("deconv")
deconv1 = mx.sym.Deconvolution(data=act12, kernel=(4, 4), num_filter=128, stride=(2, 2),
dilate=(1, 1), name="deconv1", pad = (1, 1))
norm13 = mx.sym.BatchNorm(data = deconv1, fix_gamma=fix_gamma, name="norm13")
act13 = mx.gluon.nn.ELU()(norm13)
print(act13.infer_shape()[1])
conv9 = mx.sym.Convolution(data=act13, kernel=(3, 3), num_filter=128, stride=(1, 1),
dilate=(1, 1), name="conv9", pad = (1, 1))
norm14 = mx.sym.BatchNorm(data = conv9, fix_gamma=fix_gamma, name="norm14")
act14 = mx.gluon.nn.ELU()(norm14)
print(act14.infer_shape()[1])
deconv2 = mx.sym.Deconvolution(data=act14, kernel=(4, 4), num_filter=64, stride=(2, 2),
dilate=(1, 1), name="deconv2", pad = (1, 1))
norm15 = mx.sym.BatchNorm(data = deconv2, fix_gamma=fix_gamma, name="norm15")
act15 = mx.gluon.nn.ELU()(norm15)
print(act15.infer_shape()[1])
conv10 = mx.sym.Convolution(data=act15, kernel=(3, 3), num_filter=32, stride=(1, 1),
dilate=(1, 1), name="conv10", pad = (1, 1))
norm16 = mx.sym.BatchNorm(data = conv10, fix_gamma=fix_gamma, name="norm16")
act16 = mx.gluon.nn.ELU()(norm16)
print(act16.infer_shape()[1])
conv11 = mx.sym.Convolution(data=act16, kernel=(3, 3), num_filter=3, stride=(1, 1),
dilate=(1, 1), name="conv11", pad = (1, 1))
norm17 = mx.sym.BatchNorm(data = conv11, fix_gamma=fix_gamma, name="norm17")
output = mx.gluon.nn.Activation("tanh")(norm17)
print(output.infer_shape()[1])
return output