i write a custom operator for 8bit quandzation , and i get stuck when i train my network with 4 gpu. I find my program only consume gpu memory and only have 0 % gpu utils in nvidia-smi ,obviously it doens’t run the train loop . After that i change my settings set only use 1 gpu , and it finnally work well .
here is the op definition
from __future__ import absolute_import
import mxnet as mx
class Quantize(mx.operator.CustomOp):
# def _max_fwd(self, is_train, req, in_data, out_data, aux):
#
# absmax_val = mx.nd.max(mx.nd.abs(in_data[0]))
# scale = absmax_val/(2.0**(self._quanti_bits-1)-1)
# quantized = mx.nd.cast(in_data[0]/scale,dtype='int32')
# recover_val = mx.nd.cast(quantized,dtype='float32')*scale
# self.assign(out_data[0],req[0],recover_val)
def forward(self, is_train, req, in_data, out_data, aux):
aphal = 3
quanti_bits = 8
mean = mx.nd.mean(in_data[0])
std_v = aphal*mx.nd.sqrt(mx.nd.mean(mx.nd.square(in_data[0]-mean)))
mask = mx.nd.cast(mx.nd.sign(in_data[0])*(mx.nd.abs(in_data[0]-mean)>std_v),dtype='int32')
scale = std_v/(2.0**(quanti_bits-1)-1)
quantized = mx.nd.cast(in_data[0]/scale,dtype='int32')
quantized = (1-mx.nd.abs(mask))*quantized + mask*(2**(quanti_bits-1)-1)
recover_val = mx.nd.cast(quantized,dtype='float32')*scale
self.assign(out_data[0],req[0],recover_val)
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
self.assign(in_grad[0],req[0],out_grad[0])
@mx.operator.register("quantize") # register with name "quantize"
class QuantizeProp(mx.operator.CustomOpProp):
def __init__(self):
super(QuantizeProp, self).__init__(True)
def list_arguments(self):
# this can be omitted if you only have 1 input.
return ['data']
def list_outputs(self):
# this can be omitted if you only have 1 output.
return ['output']
def infer_shape(self, in_shapes):
"""Calculate output shapes from input shapes. This can be
omited if all your inputs and outputs have the same shape.
in_shapes : list of shape. Shape is described by a tuple of int.
"""
data_shape = in_shapes[0]
output_shape = data_shape
# return 3 lists representing inputs shapes, outputs shapes, and aux data shapes.
return (data_shape,), (output_shape,), ()
def create_operator(self, ctx, in_shapes, in_dtypes):
# create and return the CustomOp class.
return Quantize()