Custom operator cause stuck in multicard train

i write a custom operator for 8bit quandzation , and i get stuck when i train my network with 4 gpu. I find my program only consume gpu memory and only have 0 % gpu utils in nvidia-smi ,obviously it doens’t run the train loop . After that i change my settings set only use 1 gpu , and it finnally work well .

here is the op definition

from __future__ import absolute_import
import mxnet as mx
class Quantize(mx.operator.CustomOp):
    

    
#    def _max_fwd(self, is_train, req, in_data, out_data, aux):
#        
#        absmax_val = mx.nd.max(mx.nd.abs(in_data[0]))
#        scale = absmax_val/(2.0**(self._quanti_bits-1)-1)
#        quantized = mx.nd.cast(in_data[0]/scale,dtype='int32')
#        recover_val = mx.nd.cast(quantized,dtype='float32')*scale
#        self.assign(out_data[0],req[0],recover_val)     
        
    
    def forward(self, is_train, req, in_data, out_data, aux):
        
        aphal = 3
        quanti_bits = 8
        mean = mx.nd.mean(in_data[0])
        std_v = aphal*mx.nd.sqrt(mx.nd.mean(mx.nd.square(in_data[0]-mean)))
        mask = mx.nd.cast(mx.nd.sign(in_data[0])*(mx.nd.abs(in_data[0]-mean)>std_v),dtype='int32')
        scale = std_v/(2.0**(quanti_bits-1)-1)
        quantized = mx.nd.cast(in_data[0]/scale,dtype='int32')
        quantized = (1-mx.nd.abs(mask))*quantized + mask*(2**(quanti_bits-1)-1)
        recover_val = mx.nd.cast(quantized,dtype='float32')*scale

        self.assign(out_data[0],req[0],recover_val)
        

    
    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):


        self.assign(in_grad[0],req[0],out_grad[0])
        
@mx.operator.register("quantize")  # register with name "quantize"
class QuantizeProp(mx.operator.CustomOpProp):
    def __init__(self):
        super(QuantizeProp, self).__init__(True)

    def list_arguments(self):
        #  this can be omitted if you only have 1 input.
        return ['data']

    def list_outputs(self):
        #  this can be omitted if you only have 1 output.
        return ['output']

    def infer_shape(self, in_shapes):
        """Calculate output shapes from input shapes. This can be
        omited if all your inputs and outputs have the same shape.

        in_shapes : list of shape. Shape is described by a tuple of int.
        """
        data_shape = in_shapes[0]
        output_shape = data_shape
        # return 3 lists representing inputs shapes, outputs shapes, and aux data shapes.
        return (data_shape,), (output_shape,), ()

    def create_operator(self, ctx, in_shapes, in_dtypes):
        #  create and return the CustomOp class.
        return Quantize()