Skip to content

Commit 9173dad

Browse files
ElaineBaopengzhao-intel
authored andcommittedAug 30, 2019
[MKLDNN] fix uint8 batch norm memory misuse (apache#16034)
* fix uint8 bn reorder memomry * update resnet152-v2 int8 acc
1 parent 65928b1 commit 9173dad

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed
 

‎example/quantization/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ The following models have been tested on Linux systems. Accuracy is collected on
9797
|[MobileNet 1.0](#5)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|72.23%/90.64%|72.06%/90.53%|
9898
|[MobileNetV2 1.0](#6)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|70.27%/89.62%|69.82%/89.35%|
9999
|[Inception V3](#7)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|77.76%/93.83% |78.05%/93.91% |
100-
|[ResNet152-V2](#8)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/resnet/152-layers/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|76.65%/93.07%|76.19%/92.88%|
100+
|[ResNet152-V2](#8)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/resnet/152-layers/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|76.65%/93.07%|76.25%/92.89%|
101101
|[Inception-BN](#9)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/inception-bn/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|72.28%/90.63%|72.02%/90.53%|
102102
| [SSD-VGG16](#10) | [example/ssd](https://github.com/apache/incubator-mxnet/tree/master/example/ssd) | VOC2007/2012 | 0.8366 mAP | 0.8357 mAP |
103103
| [SSD-VGG16](#10) | [example/ssd](https://github.com/apache/incubator-mxnet/tree/master/example/ssd) | COCO2014 | 0.2552 mAP | 0.253 mAP |

‎example/quantization/imagenet_gen_qsym_mkldnn.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def save_params(fname, arg_params, aux_params, logger=None):
202202
if args.model == 'imagenet1k-resnet-152':
203203
rgb_mean = '0,0,0'
204204
rgb_std = '1,1,1'
205-
excluded_sym_names += ['flatten0']
205+
# stage1_unit1_bn1 & stage4_unit1_bn1 is excluded for the sake of accuracy
206+
excluded_sym_names += ['flatten0', 'stage1_unit1_bn1', 'stage4_unit1_bn1']
206207
if exclude_first_conv:
207208
excluded_sym_names += ['conv0']
208209
elif args.model == 'imagenet1k-inception-bn':

‎src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h

+10-3
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ class MKLDNNBNForward {
184184

185185
template<typename DType>
186186
static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
187-
const OpContext &ctx, const NDArray &in_data,
187+
const OpContext &ctx, const mkldnn::memory *data_mem,
188188
unsigned flags) {
189189
#if DMLC_CXX11_THREAD_LOCAL
190190
static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNForward, OpHash> fwds;
@@ -194,18 +194,25 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
194194
MKLDNNBNSignature key(param);
195195
key.AddSign(ctx.is_train);
196196
key.AddSign(param.use_global_stats);
197-
key.AddSign(in_data);
197+
key.AddSign(*data_mem);
198198

199199
auto it = fwds.find(key);
200200
if (it == fwds.end()) {
201-
auto fwd_pd = _GetFwd(*in_data.GetMKLDNNData(), ctx.is_train,
201+
auto fwd_pd = _GetFwd(*data_mem, ctx.is_train,
202202
(DType) param.eps, flags);
203203
MKLDNNBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats);
204204
it = AddToCache(&fwds, key, fwd);
205205
}
206206
return it->second;
207207
}
208208

209+
template<typename DType>
210+
static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
211+
const OpContext &ctx, const NDArray &in_data,
212+
unsigned flags) {
213+
return GetBNForward<DType>(param, ctx, in_data.GetMKLDNNData(), flags);
214+
}
215+
209216
template <typename DType>
210217
void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
211218
const std::vector<NDArray> &in_data,

‎src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const
8080
const float max_abs_output = std::max(std::abs(*min_output_ptr), std::abs(*max_output_ptr));
8181

8282
unsigned flags = mkldnn::use_global_stats | mkldnn::use_scale_shift;
83-
auto &fwd = GetBNForward<float>(param, ctx, data, flags);
83+
auto &fwd = GetBNForward<float>(param, ctx, data_mem, flags);
8484
const mkldnn::memory &weight_mem = fwd.GetWeight();
8585
CHECK_EQ(weight_mem.get_primitive_desc().get_size(), channel_count * sizeof(float) * 2);
8686
float *weight_buf = reinterpret_cast<float *>(weight_mem.get_data_handle());

0 commit comments

Comments
 (0)
Please sign in to comment.