How to store weight and bias as dtype int8 for hybrid block

Hi,
I want to use quantized version of nn.Dense for forward pass only. Below is the implemenation i have come up with

class QuantizedDense(nn.HybridBlock):
    def __init__(self, units, activation=None, use_bias=True, flatten=True,
                 dtype='float32', weight_initializer=None, bias_initializer='zeros',
                 in_units=0, weight_min_initializer=None, weight_max_initializer=None,
                 bias_min_initializer=None, bias_max_initializer=None, **kwargs):
        super(QuantizedDense, self).__init__(**kwargs)
        with self.name_scope():
            self._units = units
            self._in_units = in_units
            self.no_bias = not use_bias
            self.flatten = flatten
            self.weight_quantize = self.params.get('weight_quantize', shape=(units, in_units),
                                          init=weight_initializer, 
                                          allow_deferred_init=True)
            self.weight_min = self.params.get('weight_min', shape=(1,),
                                         init=weight_min_initializer,
                                         allow_deferred_init=True, dtype=dtype)
            self.weight_max = self.params.get('weight_max', shape=(1,),
                                         init=weight_max_initializer,
                                         allow_deferred_init=True, dtype=dtype)
            if use_bias:
                self.bias_quantize = self.params.get('bias_quantize', shape=(units,),
                                            init=bias_initializer, dtype=dtype,
                                            allow_deferred_init=True)
                self.bias_min = self.params.get('bias_min', shape=(1,),
                                            init=bias_min_initializer,
                                            allow_deferred_init=True, dtype=dtype)
                self.bias_max = self.params.get('bias_max', shape=(1,),
                                            init=bias_max_initializer,
                                            allow_deferred_init=True, dtype=dtype)
            else:
                self.bias_quantize = None
                self.bias_min = None
                self.bias_max = None
            if activation is not None:
                self.act = Activation(activation, prefix=activation+'_')
            else:
                self.act = None
                
    def hybrid_forward(self, F, x, weight_quantize, weight_min, weight_max, bias_quantize=None, 
                       bias_min=None, bias_max=None):
        q_inputs, q_inputs_min, q_inputs_max = F.contrib.quantize(
            data=x, 
            min_range=x.min(), 
            max_range=x.max(),
            out_type='uint8')
        q_output, q_output_min, q_output_max = F.contrib.quantized_fully_connected(
            data=q_inputs,
            weight=weight_quantize.astype(dtype ='int8'), 
            bias=bias_quantize.astype(dtype='int8'),
            min_data=q_inputs_min,
            max_data=q_inputs_max,
            min_weight=weight_min,
            max_weight=weight_max,
            min_bias=bias_min,
            max_bias=bias_max,
            num_hidden=self._units,
            no_bias=self.no_bias,
            flatten=self.flatten,
            name='quant_dense_fwd'
        )
        q_8_out, q_8_out_min, q_8_out_max = F.contrib.requantize(
            data=q_output.astype('int32'), 
            min_range=q_output_min, 
            max_range=q_output_max,
            name='requant_fwd', 
        )
        act = F.contrib.dequantize(
            data=q_8_out, 
            min_range=q_8_out_min, 
            max_range=q_8_out_max,
            name='dequant_fwd')
        if self.act is not None:
            act = self.act(act)
        return act

This seems to be working. However I would like to store the weight and bias of the layers as int8 and not cast it in the hybrid_forward -> weight_quantize.astype(dtype ='int8') and bias_quantize.astype(dtype='int8'). But if i specify dtype as int8 in the __init__, I get the following error

.....

~/code/mxnet/incubator-mxnet/python/mxnet/ndarray/random.py in _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs)
     46                 "Distribution parameters must all have the same type, but got " \
     47                 "both %s and %s."%(type(params[0]), type(i))
---> 48         return random(*params, shape=shape, dtype=dtype, ctx=ctx, out=out, **kwargs)
     49 
     50     raise ValueError("Distribution parameters must be either NDArray or numbers, "

........
MXNetError: [11:18:15] src/operator/random/./sample_op.h:754: Check failed: dtype_ok: Output type must be float16, float32, float64: dtype is 5 vs 2 or 0 or 1

What is the correct way of setting the weight and bias in following code as int8 and not fp32

self.weight_quantize = self.params.get('weight_quantize', shape=(units, in_units),
                                          init=weight_initializer, 
                                          allow_deferred_init=True)
self.bias_quantize = self.params.get('bias_quantize', shape=(units,),
                                            init=bias_initializer,
                                            allow_deferred_init=True)

You’re already setting the dtype of the parameters correctly in this snippet:

self.bias_quantize = self.params.get('bias_quantize', shape=(units,),
                                     init=bias_initializer, dtype=dtype,
                                     allow_deferred_init=True)

But you don’t have dtype in:

self.weight_quantize = self.params.get('weight_quantize', shape=(units, in_units),
                                       init=weight_initializer, 
                                       allow_deferred_init=True)

Although from the error it looks like int8 type isn’t supported for random operator.