Quantization questions

@pengzhao-intel @xinyu-intel @anirudh2290

I have some questions regarding quantization.

I studied this blog post: https://medium.com/apache-mxnet/model-quantization-for-production-level-neural-network-inference-f54462ebba05 and related code.

  1. I was under the impression that int8 quantization was possible on CPU, however I am now finding out that only uint8 is possible on CPU, is that correct? Is there a plan to implement int8 with MKLDNN ?

  2. When I perform int8 quantization and do inference on GPU, I get very similar results as the fp32 version of the model. When I perform uint8 quantization, my results are completely out of whack, even when I exclude every symbol except the convolutions. This is when I use calibration='none'. How do I make sure that my models output similar results as the fp32 one when using uint8 quantization?

  3. What is the role of MXNET_SUBGRAPH_BACKEND=MKLDNN? what does this do? What does this do in relation with quantization? (edit: found this thanks to @anirudh2290 https://mxnet.incubator.apache.org/versions/master/tutorials/mkldnn/operator_list.html, I would suggest adding this in the quantization.py file to help people finding them)

  4. What is the role of sym_q.get_backend_symbol('MKLDNN_QUANTIZE') ? What does this do ? I find it a bit confusing because my understanding is that MKLDNN_QUANTIZE actually does operator fusion and not quantization?

  5. Am I correct to assume that even if some symbol appears to be split, they are actually fused? See https://cwiki.apache.org/confluence/display/MXNET/MXNet+Graph+Optimization+and+Quantization+based+on+subgraph+and+MKL-DNN

For example, here I ran the quantization step, and got the MKLDNN_QUANTIZE, but I feel this conv + batchnorm has not been fused? Is it normal?

  1. Just a suggestion, it would be great if the quantization was more Gluon friendly! Iā€™ll modify the SymbolBlock to at least be able to load a quantized model from scratch. I think the calibration step could be implemented differently, for example it might be simpler to pass in a ā€œCalibratorā€ instance to the quantization function that would hold all the calibration parameters and take care of running inference passes, and this Calibrator could be subclassed to support different types of model with different data types and iterator etc.

Thanks for your work on quantization!

symbol of resnet18 quantized:

can you share your whole symbol json after transformation.

@ThomasDelteil It could be that this is a bug. For example when I examine the graph I see that fusion is being done on the operators (sg_mkldnn_conv_bn_act_0) before the quantized part of the graph. Can you confirm that you didnot set any env variables and just called get_backend_symbol(ā€œMKLDNN_QUANTIZEā€) ?

I ran the following script from example/quantization . python imagenet_gen_qsym_mkldnn.py --model=resnet18_v1 --num-calib-batches=5 --calib-mode=naive . I was able to get fused graphs for all such patterns. Need to see what is different between what you are doing and this script.

2 Likes

@ThomasDelteil really thanks for raising the questions :slight_smile:

  1. The INT8 is supported from the PR#13697

  2. It is expected that the GPU INT8 performance is slower than FP32, see the comment.
    Are the following questions about GPU quantization?

  3. Sure, will try to add to quantization.py

  4. Previously, the INT8 will inherit the fused graph from FP32 (enabled by MXNET_SUBGRAPH_BACKEND=MKLDNN) but the fusion patterns of FP32 and INT8 are more and more divergence and complexity so itā€™s hard to handle both cases in one path.
    Thus, we separated the FP32 and INT8 fusion from PR#4819

    More details as below:

    a) Graph fusion for INT8, in here which will be different with FP32 (though not too much now)

    b) Quantization by a separated API
    So, youā€™re right that the MKLDNN_QUANTIZE doesnā€™t do quantization.

    c) Post Fusion by MKLDNN_QUANTIZE again in here

    Yes, itā€™s a little confusing and we are considering the further improvement.

  5. thanks @anirudh2290 's answers

  6. Yes, we are working on Gluon model now in GluonCV and GluonNLP but not complete all of them.
    Please try one of our internal improvement, https://github.com/xinyu-intel/gluon-cv/pull/1, and welcome your further suggestion about how to make Gluon interface friendly.

Thanks again for your great feedback.

4 Likes

@ThomasDelteil We are working on gluon quantization, you can try https://github.com/xinyu-intel/gluon-cv/pull/1 along with https://github.com/xinyu-intel/incubator-mxnet/pull/2. Suggestions are pretty welcome:)

1 Like

Thanks @pengzhao-intel for the detailed answer. Let me clarify some points:

  1. Thanks, got it working with @anirudh2290 help, indeed it requires to fuse operators before and after with MKLDNN_QUANTIZE.

  2. I have tried both GPU and CPU quantizations, most questions are about CPU quantizations. As a side note, my tests have showed that on GPU at higher batch sizes, int8 quantization is actually faster than fp32.

  3. Thanks

  4. Ok I think I got the flow now. Side question, how can one perform fp32 fusion right now? As you mentioned it is quite confusing because of the naming not matching the function and the use of environment variable which can be hard to use in managed environment (Jupyter notebook, google collab, enterprise envs etc). I will provide some suggestions as to what I think could be done to improve the UX at the bottom of this post.

  5. Thanks, indeed the double MKLDNN_QUANTIZE fixed it.

  6. Thanks, I have looked at it and will comment on the PR.

New Question

  1. Related to 2. When I performed quantization without calibration, the results are completely different than the non-quantized model. Is this expected? Is calibration necessary to get acceptable results?

  2. With calibration and logging enabled, I sometimes get nan values for min_divergence, is that expected?

INFO:root:layer=resnetv10_stage4_batchnorm3_gamma, min_val=0.110424, max_val=0.398399, min_divergence=nan, optimal_threshold=0.025445
INFO:root:layer=sg_mkldnn_conv_bn_act_18_gamma, min_val=0.110424, max_val=0.398399, min_divergence=nan, optimal_threshold=0.025445
  1. Why are the parameters stored in fp32 instead of int8 ? An advantage of using int8 quantization should be the lower memory footprint of the model. For the non-fused operators, for example GPU int8 quantization, the parameters are stored as int8. It makes the final parameters file 12MB vs 45MB for the mkldnn version.

Suggestion / Feedback

  • I would stay away from environment variable as means to control functionality such as quantization:

    • They are hard to discover, MXNET_SUBGRAPH_BACKEND=MKLDNN is not documented in the quantization API right now.
    • The current naming is confusing and does not seem to relate to quantization
    • They can be hard to set in a controlled environment like Jupyter, Google collab or restricted enterprise environements
    • They cannot be coded in scripts for easy replication of results
  • Currently quantization and fusion for MKLDNN are too intertwined but this is not clear in the documentation. For example you cannot do int8 inference in MKLDNN convolutions, only in MKLDNN fused convolutions. You need to re-fuse the graph after having fused it and then quantized it. Maybe add such information in this error message:

MXNetError: [21:24:54] src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc:41: Check failed: in_data[0].dtype() == mshadow::kUint8 (5 vs. 3) : mkldnn_quantized_conv op only supports uint8 as input type
  • I suggest instead, to use well-crafted and well-named API that hides the implementation details from the user. Instead of a generic quantize_model, maybe a quantize_model_mkldnn and quantize_model_cudnn, that encapsulate the necessary steps so that the user does not need to manually call the MKLDNN_QUANTIZE, set the backend env variable, etc.

  • Rename MKLDNN_QUANTIZE to MKLDNN_FUSE_INT8. Add some documentation to get_backend_symbol to clarify what it does and have the lists of available backend symbols.

  • As mentioned in my first post, the current calibration mechanism is too strict as to how the data should look like, a different mechanism maybe using injection of a Calibrator or something that allows more flexibility would allow more use-cases to be covered by quantization. Also it forces using label_names right now which should not be required.

  • It is quite obscure as to which symbol should be excluded and which should be quantized, documentation on that or a pre-defined list would be good.

  • Current calibration is extremely slow and single-threaded (20min+ for resnet18 and entropy), it would be great to take advantage of multi-CPU by using process pools or vectorizing the operations where possible. The calibration loop seems to have quadratic complexity in bins and requires 32M iteration per layer. See here and here

Thanks for your hard work on the quantization front!

2 Likes

@ThomasDelteil thanks for really great suggestions. I will discuss the details internally and get back to you.

2 Likes

@ThomasDelteil very great suggestions for quantization flow. Please see my comments below and the further suggestions are highly appreciated.

-4. In symbolic mode, still use ā€œMXNET_SUBGRAPH_BACKEND=MKLDNNā€ or API interface qsym.get_backedn_symbol(ā€˜MXNET_SUBGRAPH_BACKENDā€™, ā€˜MKLDNNā€™);
In gluon mode, it canā€™t fuse FP32 graph because thereā€™s no graph except reloading the static graph by the cached OP. Any suggestion for gluon interface?

-7. Yes, the pre-channel quantization is enabled by default in the calibration stage but itā€™s not available for online calibration where the min/max will be calculated in the flight with tensor-level and the performance will be poor since the extra memory access to get min/max.

-8. Yes, but those results are meaningless. Only the calibration results of inputs data are useful. Currently, the scalar inputs, like gamma, are calculated since we donā€™t know which one is the real input data. But it doesnā€™t affect the final accuracy. Next step, we will add the name attribute by NNVM for some OPs to avoid this situation as much as possible.

-9. The wights of FC is saved by INT8 now where is the most memory consuming parts in the NN. But the weights of convolution is still FP32. For int8 input of convolution, the weight is padded to save the extra offset information of int8 input for the better performance. Thus, the loader canā€™t recognize the data size bigger than the size calculated by shape information. Do you have any suggestions? @ThomasDelteil @anirudh2290

1 Like

Reply to your ā€œSuggestion / Feedbackā€

  1. I would stay away from environment variable as means to control functionality such as quantization:

Weā€™re still lack of the documentation :frowning: The blog (here) is a point for the end user. The documentation for the developer is WIP.
In general, the quantization flow is just enabled and not mature enough. With the 2nd generation of scalable processor is launched in AWS (C5.12xlarge and C5.24xlarege), we are actively improving the quality and usability of quantization flow.

  1. Currently quantization and fusion for MKLDNN are too intertwined but this is not clear in the documentation. For example you cannot do int8 inference in MKLDNN convolutions, only in MKLDNN fused convolutions. You need to re-fuse the graph after having fused it and then quantized it. Maybe add such information in this error message:

Agree. Next step, we will enhance the MKLDNN convolution and make it support INT8 inputs.
The background is the MKLDNN convolution is stateless API now and we donā€™t have a place to save the temp INT8 weights (with offset padding). Actually, we can convert the weight every time from FP32 to INT8 but it will no benefit again. We will change MKLDNN convolution to stateful API in the next version.

  1. I suggest instead, to use well-crafted and well-named API that hides the implementation details from the user. Instead of a generic quantize_model , maybe a quantize_model_mkldnn and quantize_model_cudnn , that encapsulate the necessary steps so that the user does not need to manually call the MKLDNN_QUANTIZE, set the backend env variable, etc.

Yes, good suggestion and we will provide these API

  1. ā€œAs mentioned in my first post, the current calibration mechanism is too strict as to how the data should look like, a different mechanism maybe using injection of a Calibrator or something that allows more flexibility would allow more use-cases to be covered by quantization. Also it forces using label_names right now which should not be required.ā€

@xinyu-intel is working on this part and the new API will be a little flexible for the data. And I will keep you in the loop for our development.

  1. It is quite obscure as to which symbol should be excluded and which should be quantized, documentation on that or a pre-defined list would be good.

Working on this and will improve the flow and make the script easy to use.

  1. Current calibration is extremely slow and single-threaded (20min+ for resnet18 and entropy ), it would be great to take advantage of multi-CPU by using process pools or vectorizing the operations where possible. The calibration loop seems to have quadratic complexity in bins and requires 32M iteration per layer. See here and here"

The entropy algorithm is implemented by numpy so itā€™s really slow. We plan to make a MXNet OP for this and it will be much faster.

2 Likes