Skip to content

Commit d677d1a

Browse files
authoredJul 13, 2019
FP16 Support for C Predict API (apache#15245)
* Upgrade archive utility and add back FC improvement This reverts commit 6543488. * Change permissions for Ubuntu AR * Extract and cd into binutils dir * Allow AR path to be chosen by user * Add AR path to build * Fix AR paths * Revert AR flag in makefile * Build from source doc updated * Commit for C Predict API * Add FP16 predict support * Add Test Predictor fixes * Add test for predictor * Cleanup fixes * Fixes * Add support for forward pass only for gpu * Fix Reshape, move test to gpu * Add monitor callback for C Predict API * Add tests, default dtype and set_monitor_callback * Improve error * Fix c_str_array
1 parent cbb6f7f commit d677d1a

File tree

4 files changed

+454
-10
lines changed

4 files changed

+454
-10
lines changed
 

‎amalgamation/python/mxnet_predict.py

+126-5
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,77 @@
2525

2626
import os
2727
import sys
28+
from array import array
2829
import ctypes
2930
import logging
3031
import numpy as np
3132

33+
# pylint: disable= no-member
34+
_DTYPE_NP_TO_MX = {
35+
None: -1,
36+
np.float32: 0,
37+
np.float64: 1,
38+
np.float16: 2,
39+
np.uint8: 3,
40+
np.int32: 4,
41+
np.int8: 5,
42+
np.int64: 6,
43+
}
44+
45+
_DTYPE_MX_TO_NP = {
46+
-1: None,
47+
0: np.float32,
48+
1: np.float64,
49+
2: np.float16,
50+
3: np.uint8,
51+
4: np.int32,
52+
5: np.int8,
53+
6: np.int64,
54+
}
55+
3256
__all__ = ["Predictor", "load_ndarray_file"]
3357

3458
if sys.version_info[0] == 3:
3559
py_str = lambda x: x.decode('utf-8')
60+
61+
def c_str_array(strings):
62+
"""Create ctypes const char ** from a list of Python strings.
63+
64+
Parameters
65+
----------
66+
strings : list of string
67+
Python strings.
68+
69+
Returns
70+
-------
71+
(ctypes.c_char_p * len(strings))
72+
A const char ** pointer that can be passed to C API.
73+
"""
74+
arr = (ctypes.c_char_p * len(strings))()
75+
arr[:] = [s.encode('utf-8') for s in strings]
76+
return arr
77+
78+
3679
else:
3780
py_str = lambda x: x
3881

82+
def c_str_array(strings):
83+
"""Create ctypes const char ** from a list of Python strings.
84+
85+
Parameters
86+
----------
87+
strings : list of strings
88+
Python strings.
89+
90+
Returns
91+
-------
92+
(ctypes.c_char_p * len(strings))
93+
A const char ** pointer that can be passed to C API.
94+
"""
95+
arr = (ctypes.c_char_p * len(strings))()
96+
arr[:] = strings
97+
return arr
98+
3999

40100
def c_str(string):
41101
""""Convert a python string to C string."""
@@ -48,6 +108,11 @@ def c_array(ctype, values):
48108
"""Create ctypes array from a python array."""
49109
return (ctype * len(values))(*values)
50110

111+
def c_array_buf(ctype, buf):
112+
"""Create ctypes array from a Python buffer."""
113+
return (ctype * len(buf)).from_buffer(buf)
114+
115+
51116

52117
def _find_lib_path():
53118
"""Find mxnet library."""
@@ -87,9 +152,18 @@ def _check_call(ret):
87152
if ret != 0:
88153
raise RuntimeError(py_str(_LIB.MXGetLastError()))
89154

155+
156+
def _monitor_callback_wrapper(callback):
157+
"""A wrapper for the user-defined handle."""
158+
def callback_handle(name, array, _):
159+
""" ctypes function """
160+
callback(name, array)
161+
return callback_handle
162+
90163
_LIB = _load_lib()
91164
# type definitions
92165
mx_uint = ctypes.c_uint
166+
mx_int = ctypes.c_int
93167
mx_float = ctypes.c_float
94168
mx_float_p = ctypes.POINTER(mx_float)
95169
PredictorHandle = ctypes.c_void_p
@@ -116,10 +190,13 @@ class Predictor(object):
116190
117191
dev_id : int, optional
118192
The device id of the predictor.
193+
194+
type_dict : Dict of str->numpy.dtype
195+
Input type dictionary, name->dtype
119196
"""
120197
def __init__(self, symbol_file,
121198
param_raw_bytes, input_shapes,
122-
dev_type="cpu", dev_id=0):
199+
dev_type="cpu", dev_id=0, type_dict=None):
123200
dev_type = devstr2type[dev_type]
124201
indptr = [0]
125202
sdata = []
@@ -133,15 +210,38 @@ def __init__(self, symbol_file,
133210
handle = PredictorHandle()
134211
param_raw_bytes = bytearray(param_raw_bytes)
135212
ptr = (ctypes.c_char * len(param_raw_bytes)).from_buffer(param_raw_bytes)
136-
_check_call(_LIB.MXPredCreate(
213+
214+
# data types
215+
num_provided_arg_types = 0
216+
# provided type argument names
217+
provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)()
218+
# provided types
219+
provided_arg_type_data = ctypes.POINTER(mx_uint)()
220+
if type_dict is not None:
221+
provided_arg_type_names = []
222+
provided_arg_type_data = []
223+
for k, v in type_dict.items():
224+
v = np.dtype(v).type
225+
if v in _DTYPE_NP_TO_MX:
226+
provided_arg_type_names.append(k)
227+
provided_arg_type_data.append(_DTYPE_NP_TO_MX[v])
228+
num_provided_arg_types = mx_uint(len(provided_arg_type_names))
229+
provided_arg_type_names = c_str_array(provided_arg_type_names)
230+
provided_arg_type_data = c_array_buf(ctypes.c_int, array('i', provided_arg_type_data))
231+
232+
_check_call(_LIB.MXPredCreateEx(
137233
c_str(symbol_file),
138234
ptr, len(param_raw_bytes),
139235
ctypes.c_int(dev_type), ctypes.c_int(dev_id),
140236
mx_uint(len(indptr) - 1),
141237
c_array(ctypes.c_char_p, keys),
142238
c_array(mx_uint, indptr),
143239
c_array(mx_uint, sdata),
240+
num_provided_arg_types,
241+
provided_arg_type_names,
242+
provided_arg_type_data,
144243
ctypes.byref(handle)))
244+
self.type_dict = type_dict
145245
self.handle = handle
146246

147247
def __del__(self):
@@ -160,10 +260,18 @@ def forward(self, **kwargs):
160260
>>> predictor.forward(data=mydata)
161261
>>> out = predictor.get_output(0)
162262
"""
263+
if self.type_dict and len(self.type_dict) != len(kwargs.items()):
264+
raise ValueError("number of kwargs should be same as len of type_dict" \
265+
"Please check your forward pass inputs" \
266+
"or type_dict passed to Predictor instantiation")
267+
163268
for k, v in kwargs.items():
164269
if not isinstance(v, np.ndarray):
165270
raise ValueError("Expect numpy ndarray as input")
166-
v = np.asarray(v, dtype=np.float32, order='C')
271+
if self.type_dict and k in self.type_dict:
272+
v = np.asarray(v, dtype=self.type_dict[k], order='C')
273+
else:
274+
v = np.asarray(v, dtype=np.float32, order='C')
167275
_check_call(_LIB.MXPredSetInput(
168276
self.handle, c_str(k),
169277
v.ctypes.data_as(mx_float_p),
@@ -218,18 +326,30 @@ def get_output(self, index):
218326
"""
219327
pdata = ctypes.POINTER(mx_uint)()
220328
ndim = mx_uint()
329+
out_type = mx_int()
221330
_check_call(_LIB.MXPredGetOutputShape(
222331
self.handle, index,
223332
ctypes.byref(pdata),
224333
ctypes.byref(ndim)))
334+
_check_call(_LIB.MXPredGetOutputType(
335+
self.handle, index,
336+
ctypes.byref(out_type)))
225337
shape = tuple(pdata[:ndim.value])
226-
data = np.empty(shape, dtype=np.float32)
338+
data = np.empty(shape, dtype=_DTYPE_MX_TO_NP[out_type.value])
227339
_check_call(_LIB.MXPredGetOutput(
228340
self.handle, mx_uint(index),
229341
data.ctypes.data_as(mx_float_p),
230342
mx_uint(data.size)))
231343
return data
232344

345+
def set_monitor_callback(self, callback, monitor_all=False):
346+
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_void_p)
347+
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
348+
_check_call(_LIB.MXPredSetMonitorCallback(self.handle,
349+
self._monitor_callback,
350+
None,
351+
ctypes.c_int(monitor_all)))
352+
233353

234354
def load_ndarray_file(nd_bytes):
235355
"""Load ndarray file and return as list of numpy array.
@@ -273,4 +393,5 @@ def load_ndarray_file(nd_bytes):
273393
if len(keys) == 0 or len(keys[0]) == 0:
274394
return arrs
275395
else:
276-
return {keys[i] : arrs[i] for i in range(len(keys))}
396+
return {keys[i] : arrs[i] for i in range(len(keys))
397+
}

‎include/mxnet/c_predict_api.h

+65
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ typedef float mx_float;
4949
typedef void *PredictorHandle;
5050
/*! \brief handle to NDArray list */
5151
typedef void *NDListHandle;
52+
/*! \brief handle to NDArray */
53+
typedef void *NDArrayHandle;
54+
/*! \brief callback used for add monitoring to nodes in the graph */
55+
typedef void (*PredMonitorCallback)(const char*,
56+
NDArrayHandle,
57+
void*);
5258

5359
/*!
5460
* \brief Get the last error happeneed.
@@ -85,6 +91,44 @@ MXNET_DLL int MXPredCreate(const char* symbol_json_str,
8591
const mx_uint* input_shape_data,
8692
PredictorHandle* out);
8793

94+
/*!
95+
* \brief create a predictor
96+
* \param symbol_json_str The JSON string of the symbol.
97+
* \param param_bytes The in-memory raw bytes of parameter ndarray file.
98+
* \param param_size The size of parameter ndarray file.
99+
* \param dev_type The device type, 1: cpu, 2: gpu
100+
* \param dev_id The device id of the predictor.
101+
* \param num_input_nodes Number of input nodes to the net.
102+
* For feedforward net, this is 1.
103+
* \param input_keys The name of the input argument.
104+
* For feedforward net, this is {"data"}
105+
* \param input_shape_indptr Index pointer of shapes of each input node.
106+
* The length of this array = num_input_nodes + 1.
107+
* For feedforward net that takes 4 dimensional input, this is {0, 4}.
108+
* \param input_shape_data A flattened data of shapes of each input node.
109+
* For feedforward net that takes 4 dimensional input, this is the shape data.
110+
* \param num_provided_arg_dtypes
111+
* The length of provided_arg_dtypes.
112+
* \param provided_arg_dtype_names
113+
* The provided_arg_dtype_names the names of args for which dtypes are provided.
114+
* \param provided_arg_dtypes
115+
* The provided_arg_dtypes the dtype provided
116+
* \param out The created predictor handle.
117+
* \return 0 when success, -1 when failure.
118+
*/
119+
MXNET_DLL int MXPredCreateEx(const char* symbol_json_str,
120+
const void* param_bytes,
121+
int param_size,
122+
int dev_type, int dev_id,
123+
const mx_uint num_input_nodes,
124+
const char** input_keys,
125+
const mx_uint* input_shape_indptr,
126+
const mx_uint* input_shape_data,
127+
const mx_uint num_provided_arg_dtypes,
128+
const char** provided_arg_dtype_names,
129+
const int* provided_arg_dtypes,
130+
PredictorHandle* out);
131+
88132
/*!
89133
* \brief create a predictor wich customized outputs
90134
* \param symbol_json_str The JSON string of the symbol.
@@ -186,6 +230,18 @@ MXNET_DLL int MXPredGetOutputShape(PredictorHandle handle,
186230
mx_uint index,
187231
mx_uint** shape_data,
188232
mx_uint* shape_ndim);
233+
234+
/*!
235+
* \brief Get the dtype of output node.
236+
* The returned data type is only valid before next call to MXPred function.
237+
* \param handle The handle of the predictor.
238+
* \param out_index The index of the output node, set to 0 if there is only one output.
239+
* \param out_dtype The dtype of the output node
240+
*/
241+
MXNET_DLL int MXPredGetOutputType(PredictorHandle handle,
242+
mx_uint out_index,
243+
int* out_dtype);
244+
189245
/*!
190246
* \brief Set the input data of predictor.
191247
* \param handle The predictor handle.
@@ -269,6 +325,15 @@ MXNET_DLL int MXNDListGet(NDListHandle handle,
269325
const mx_float** out_data,
270326
const mx_uint** out_shape,
271327
mx_uint* out_ndim);
328+
329+
/*!
330+
* \brief set a call back to notify the completion of operation and allow for
331+
* additional monitoring
332+
*/
333+
MXNET_DLL int MXPredSetMonitorCallback(PredictorHandle handle,
334+
PredMonitorCallback callback,
335+
void* callback_handle,
336+
bool monitor_all);
272337
/*!
273338
* \brief Free a MXAPINDList
274339
* \param handle The handle of the MXAPINDList.

‎src/c_api/c_predict_api.cc

+135-5
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct MXAPIPredictor {
4747
std::vector<NDArray> aux_arrays;
4848
// output shapes
4949
mxnet::ShapeVector out_shapes;
50+
// output types
51+
nnvm::DTypeVector out_dtypes;
52+
5053
// uint32_t buffer for output shapes
5154
std::vector<uint32_t> out_shapes_buffer;
5255
// key to arguments
@@ -88,7 +91,7 @@ int _CreatePartialOut(const char* symbol_json_str,
8891
const void* param_bytes,
8992
int param_size,
9093
int dev_type, int dev_id,
91-
mx_uint num_input_nodes,
94+
const mx_uint num_input_nodes,
9295
const char** input_keys,
9396
const mx_uint* input_shape_indptr,
9497
const mx_uint* input_shape_data,
@@ -97,6 +100,9 @@ int _CreatePartialOut(const char* symbol_json_str,
97100
// This is used for parallel inference.
98101
int num_threads,
99102
bool lazy,
103+
const mx_uint num_provided_arg_dtypes,
104+
const char** provided_arg_dtype_names,
105+
const int* provided_arg_dtypes,
100106
PredictorHandle* out) {
101107
using nnvm::Symbol;
102108

@@ -135,6 +141,7 @@ int _CreatePartialOut(const char* symbol_json_str,
135141

136142
// load the parameters
137143
std::unordered_map<std::string, NDArray> arg_params, aux_params;
144+
std::unordered_map<std::string, int> arg_types, aux_types;
138145
{
139146
std::unordered_set<std::string> arg_names, aux_names;
140147
std::vector<std::string> arg_names_vec = sym.ListInputNames(Symbol::kReadOnlyArgs);
@@ -156,12 +163,23 @@ int _CreatePartialOut(const char* symbol_json_str,
156163
std::string name(names[i].c_str() + 4);
157164
if (aux_names.count(name) != 0) {
158165
aux_params[name] = data[i];
166+
aux_types[name] = data[i].dtype();
159167
}
160168
}
161169
if (!strncmp(names[i].c_str(), "arg:", 4)) {
162170
std::string name(names[i].c_str() + 4);
163171
if (arg_names.count(name) != 0) {
164172
arg_params[name] = data[i];
173+
arg_types[name] = data[i].dtype();
174+
}
175+
}
176+
}
177+
178+
if (num_provided_arg_dtypes > 0) {
179+
for (mx_uint i = 0; i < num_provided_arg_dtypes; ++i) {
180+
if (aux_types.count(provided_arg_dtype_names[i]) == 0 &&
181+
arg_types.count(provided_arg_dtype_names[i]) == 0) {
182+
arg_types[provided_arg_dtype_names[i]] = provided_arg_dtypes[i];
165183
}
166184
}
167185
}
@@ -179,6 +197,7 @@ int _CreatePartialOut(const char* symbol_json_str,
179197
mxnet::ShapeVector out_shapes(sym.ListOutputNames().size());
180198
mxnet::ShapeVector aux_shapes(aux_names.size());
181199
mxnet::ShapeVector arg_shapes;
200+
nnvm::DTypeVector result_arg_types, result_out_types, result_aux_types;
182201
std::unordered_map<std::string, size_t> key2arg;
183202
for (size_t i = 0; i < arg_names.size(); ++i) {
184203
std::string key = arg_names[i];
@@ -187,21 +206,46 @@ int _CreatePartialOut(const char* symbol_json_str,
187206

188207
try {
189208
mxnet::ShapeVector in_shapes;
209+
nnvm::DTypeVector in_types;
190210
for (std::string key : sym.ListInputNames(Symbol::kAll)) {
191211
if (known_shape.count(key) != 0) {
192212
in_shapes.push_back(known_shape[key]);
193213
} else {
194214
in_shapes.emplace_back();
195215
}
196216
}
217+
218+
for (std::string key : sym.ListInputNames(Symbol::kAll)) {
219+
if (arg_types.count(key) != 0) {
220+
in_types.push_back(arg_types[key]);
221+
} else if (aux_types.count(key) != 0) {
222+
in_types.push_back(aux_types[key]);
223+
} else {
224+
// if key not in arg_types or aux_types set to FP32
225+
in_types.push_back(0);
226+
}
227+
}
197228
nnvm::Graph g; g.outputs = sym.outputs;
198229
g = mxnet::exec::InferShape(std::move(g), std::move(in_shapes), "__shape__");
230+
g = mxnet::exec::InferType(std::move(g), std::move(in_types), "__dtype__");
199231
bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
232+
// This is tricky for AMP Use case, for example, with only weights input types
233+
// cannot be inferred in AMP. Thus for AMP converted model type_dict will be
234+
// required
235+
bool infer_type_complete = (g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0);
200236
CHECK(infer_complete)
201237
<< "The shape information of is not enough to get the shapes";
238+
CHECK(infer_type_complete)
239+
<< "The type information is not enough, please provide input arg_types "
240+
"with provided_arg_dtype_names and provided_arg_dtypes."
241+
"If using amalgamation python frontend you can use type_dict in Predictor API"
242+
"to provide this information";
202243
CopyAttr(g.indexed_graph(),
203244
g.GetAttr<mxnet::ShapeVector>("shape"),
204245
&arg_shapes, &out_shapes, &aux_shapes);
246+
CopyAttr(g.indexed_graph(),
247+
g.GetAttr<nnvm::DTypeVector>("dtype"),
248+
&result_arg_types, &result_out_types, &result_aux_types);
205249
} catch (const mxnet::op::InferShapeError &err) {
206250
throw dmlc::Error(err.msg);
207251
}
@@ -210,19 +254,31 @@ int _CreatePartialOut(const char* symbol_json_str,
210254

211255
std::vector<NDArray> arg_arrays, aux_arrays;
212256
for (size_t i = 0; i < arg_shapes.size(); ++i) {
213-
NDArray nd = NDArray(arg_shapes[i], ctx);
257+
NDArray nd;
258+
if (result_arg_types[i] != -1) {
259+
nd = NDArray(arg_shapes[i], ctx, false, result_arg_types[i]);
260+
} else {
261+
nd = NDArray(arg_shapes[i], ctx);
262+
}
214263
if (arg_params.count(arg_names[i]) != 0) {
215264
CopyFromTo(arg_params[arg_names[i]], &nd);
216265
}
217266
arg_arrays.push_back(nd);
218267
}
268+
219269
for (size_t i = 0; i < aux_shapes.size(); ++i) {
220-
NDArray nd = NDArray(aux_shapes[i], ctx);
270+
NDArray nd;
271+
if (result_aux_types[i] != -1) {
272+
nd = NDArray(aux_shapes[i], ctx, false, result_aux_types[i]);
273+
} else {
274+
nd = NDArray(aux_shapes[i], ctx);
275+
}
221276
if (aux_params.count(aux_names[i]) != 0) {
222277
CopyFromTo(aux_params[aux_names[i]], &nd);
223278
}
224279
aux_arrays.push_back(nd);
225280
}
281+
226282
// bind
227283
for (int i = 0; i < num_threads; i++) {
228284
std::unique_ptr<MXAPIPredictor> ret(new MXAPIPredictor());
@@ -232,6 +288,7 @@ int _CreatePartialOut(const char* symbol_json_str,
232288
ret->arg_arrays = arg_arrays;
233289
ret->aux_arrays = aux_arrays;
234290
ret->out_shapes = out_shapes;
291+
ret->out_dtypes = result_out_types;
235292

236293
if (!lazy) {
237294
std::map<std::string, Context> ctx_map;
@@ -272,6 +329,9 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
272329
output_keys,
273330
1,
274331
false,
332+
0,
333+
nullptr,
334+
nullptr,
275335
out);
276336
}
277337

@@ -295,9 +355,44 @@ int MXPredCreate(const char* symbol_json_str,
295355
input_shape_indptr,
296356
input_shape_data,
297357
0,
298-
NULL,
358+
nullptr,
359+
1,
360+
false,
361+
0,
362+
nullptr,
363+
nullptr,
364+
out);
365+
}
366+
367+
int MXPredCreateEx(const char* symbol_json_str,
368+
const void* param_bytes,
369+
int param_size,
370+
int dev_type, int dev_id,
371+
mx_uint num_input_nodes,
372+
const char** input_keys,
373+
const mx_uint* input_shape_indptr,
374+
const mx_uint* input_shape_data,
375+
const mx_uint num_provided_arg_dtypes,
376+
const char** provided_arg_dtype_names,
377+
const int* provided_arg_dtypes,
378+
PredictorHandle* out) {
379+
return _CreatePartialOut(
380+
symbol_json_str,
381+
param_bytes,
382+
param_size,
383+
dev_type,
384+
dev_id,
385+
num_input_nodes,
386+
input_keys,
387+
input_shape_indptr,
388+
input_shape_data,
389+
0,
390+
nullptr,
299391
1,
300392
false,
393+
num_provided_arg_dtypes,
394+
provided_arg_dtype_names,
395+
provided_arg_dtypes,
301396
out);
302397
}
303398

@@ -330,9 +425,12 @@ int MXPredCreateMultiThread(const char* symbol_json_str,
330425
input_shape_indptr,
331426
input_shape_data,
332427
0,
333-
NULL,
428+
nullptr,
334429
num_threads,
335430
true,
431+
0,
432+
nullptr,
433+
nullptr,
336434
out);
337435
}
338436

@@ -421,6 +519,7 @@ int MXPredReshape(mx_uint num_input_nodes,
421519
p->exec.get()));
422520
ret->out_shapes = out_shapes;
423521
ret->out_arrays = ret->exec->outputs();
522+
ret->out_dtypes = p->out_dtypes;
424523
}
425524
*out = ret.release();
426525
API_END();
@@ -444,6 +543,21 @@ int MXPredGetOutputShape(PredictorHandle handle,
444543
API_END();
445544
}
446545

546+
int MXPredGetOutputType(PredictorHandle handle,
547+
mx_uint out_index,
548+
int* out_dtype) {
549+
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
550+
API_BEGIN();
551+
CHECK_LT(out_index, p->out_arrays.size())
552+
<< "Index exceed number of outputs, provided out_index should be less than "
553+
<< p->out_arrays.size();
554+
555+
const int s = p->out_dtypes[out_index];
556+
CHECK_GE(s, 0);
557+
out_dtype[out_index] = s;
558+
API_END();
559+
}
560+
447561
int MXPredSetInput(PredictorHandle handle,
448562
const char* key,
449563
const mx_float* data,
@@ -543,6 +657,22 @@ int MXNDListGet(NDListHandle handle,
543657
API_END();
544658
}
545659

660+
int MXPredSetMonitorCallback(PredictorHandle handle,
661+
PredMonitorCallback callback,
662+
void* callback_handle,
663+
bool monitor_all) {
664+
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
665+
API_BEGIN();
666+
PredMonitorCallback callback_temp = callback;
667+
void* callback_handle_temp = callback_handle;
668+
std::function<void(const char*, void*)> clbk
669+
= [callback_temp, callback_handle_temp](const char* name, void* handle) {
670+
callback_temp(name, handle, callback_handle_temp);
671+
};
672+
p->exec->SetMonitorCallback(clbk, monitor_all);
673+
API_END();
674+
}
675+
546676
int MXNDListFree(NDListHandle handle) {
547677
API_BEGIN();
548678
delete static_cast<MXAPINDList*>(handle);

‎tests/python/gpu/test_predictor.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import print_function
19+
import sys, os
20+
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
21+
sys.path.append(os.path.join(curr_path, "../../../amalgamation/python/"))
22+
from mxnet_predict import Predictor, load_ndarray_file
23+
24+
import ctypes
25+
import numpy as np
26+
import mxnet as mx
27+
import mxnet.ndarray as nd
28+
from mxnet.ndarray import NDArray
29+
from mxnet import gluon
30+
from mxnet.test_utils import assert_almost_equal, download_model
31+
from mxnet.contrib.amp import amp
32+
from mxnet.base import NDArrayHandle, py_str
33+
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
34+
from common import setup_module, with_seed, teardown
35+
36+
@with_seed()
37+
def test_predictor_with_dtype():
38+
prefix = 'test_predictor_simple_dense'
39+
symbol_file = "%s-symbol.json" % prefix
40+
param_file = "%s-0000.params" % prefix
41+
42+
input1 = np.random.uniform(size=(1, 3))
43+
input1 = input1.astype(np.float16)
44+
45+
block = mx.gluon.nn.HybridSequential()
46+
block.add(mx.gluon.nn.Dense(7))
47+
block.add(mx.gluon.nn.Dense(3))
48+
block.cast(np.float16)
49+
block.hybridize()
50+
block.initialize(ctx=mx.gpu(0))
51+
tmp = mx.nd.array(input1, dtype=np.float16, ctx=mx.gpu(0))
52+
out1 = block.forward(tmp)
53+
block.export(prefix)
54+
55+
predictor = Predictor(open(symbol_file, "r").read(),
56+
open(param_file, "rb").read(),
57+
{"data": input1.shape},
58+
dev_type="gpu",
59+
dev_id=0,
60+
type_dict={"data": input1.dtype})
61+
predictor.forward(data=input1)
62+
predictor_out1 = predictor.get_output(0)
63+
64+
assert_almost_equal(out1.asnumpy(), predictor_out1, rtol=1e-5, atol=1e-6)
65+
66+
def compare_module_cpredict(result_sym, result_arg_params, result_aux_params, monitor_callback=False):
67+
# Dummmy inputs
68+
input1 = np.ones((1, 3, 224, 224))
69+
input1 = input1.astype(np.float32)
70+
nd_dict = {}
71+
def pred_mon_callback(name, arr):
72+
nd_dict[name] = arr
73+
mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.gpu())
74+
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]], for_training=False)
75+
mod.set_params(result_arg_params, result_aux_params)
76+
mod.forward(mx.io.DataBatch(data=[mx.nd.array(input1, ctx=mx.gpu())],
77+
label=[mx.nd.ones((1,), ctx=mx.gpu())]))
78+
prefix = "test_predictor_amp"
79+
mod.save_checkpoint(prefix, 0, remove_amp_cast=False)
80+
sym_file = "{}-symbol.json".format(prefix)
81+
params_file = "{}-0000.params".format(prefix)
82+
predictor = Predictor(open(sym_file, "r").read(),
83+
open(params_file, "rb").read(),
84+
{'data': (1, 3, 224, 224),
85+
'softmax_label': (1,)},
86+
dev_type="gpu",
87+
dev_id=0)
88+
if monitor_callback:
89+
predictor.set_monitor_callback(pred_mon_callback, monitor_all=True)
90+
predictor.forward(data=input1, softmax_label=mx.nd.ones((1,)).asnumpy())
91+
predictor_out1 = predictor.get_output(0)
92+
if monitor_callback:
93+
assert len(nd_dict) > 0, "Callback not called"
94+
assert_almost_equal(mod.get_outputs()[0].asnumpy(), predictor_out1, atol=1e-1, rtol=1e-1)
95+
96+
97+
@with_seed()
98+
def test_predictor_amp():
99+
dir_path = os.path.dirname(os.path.realpath(__file__))
100+
model_path = os.path.join(dir_path, 'model')
101+
if not os.path.isdir(model_path):
102+
os.mkdir(model_path)
103+
prefix, epoch = download_model("imagenet1k-resnet-18", dst_dir=model_path)
104+
105+
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
106+
107+
108+
# Convert model to mixed precision model, params in FP32
109+
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym,
110+
arg_params,
111+
aux_params,
112+
target_dtype="float16",
113+
target_dtype_ops=["Convolution"])
114+
compare_module_cpredict(result_sym, result_arg_params, result_aux_params)
115+
116+
# Convert model to mixed precision model, params in FP16
117+
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym,
118+
arg_params,
119+
aux_params,
120+
target_dtype="float16",
121+
target_dtype_ops=["Convolution"],
122+
cast_optional_params=True)
123+
compare_module_cpredict(result_sym, result_arg_params, result_aux_params, monitor_callback=True)
124+
125+
126+
if __name__ == '__main__':
127+
import nose
128+
nose.runmodule()

0 commit comments

Comments
 (0)
Please sign in to comment.