Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 79d8d86

Browse files
authored
Fix the bug of MXEnginePushAsyncND and MXEnginePushSyncND (#15751)
* fix push sync nd api * align code * update test for syncnd * fix bug in tests/cpp/engine/threaded_engine_test * add more testcases for MXEnginePushSyncND and MXEnginePushAsyncND * fix test * fix * fix * lint * ci * retrigger CI
1 parent be49b3b commit 79d8d86

File tree

3 files changed

+105
-74
lines changed

3 files changed

+105
-74
lines changed

include/mxnet/c_api.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2940,12 +2940,12 @@ MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out);
29402940
* \param wait Whether this is a WaitForVar operation.
29412941
*/
29422942
MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
2943-
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
2944-
NDArrayHandle const_nds_handle, int num_const_nds,
2945-
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
2946-
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
2947-
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
2948-
bool wait DEFAULT(false));
2943+
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
2944+
NDArrayHandle* const_nds_handle, int num_const_nds,
2945+
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
2946+
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
2947+
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
2948+
bool wait DEFAULT(false));
29492949

29502950
/*!
29512951
* \brief Push a synchronous operation to the engine.
@@ -2963,11 +2963,11 @@ MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
29632963
* \param opr_name The operation name.
29642964
*/
29652965
MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
2966-
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
2967-
NDArrayHandle const_nds_handle, int num_const_nds,
2968-
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
2969-
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
2970-
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
2966+
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
2967+
NDArrayHandle* const_nds_handle, int num_const_nds,
2968+
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
2969+
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
2970+
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
29712971

29722972
#ifdef __cplusplus
29732973
}

src/c_api/c_api.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,18 +1559,18 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
15591559
}
15601560

15611561
int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
1562-
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
1563-
NDArrayHandle const_nds_handle, int num_const_nds,
1564-
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
1565-
EngineFnPropertyHandle prop_handle, int priority,
1566-
const char* opr_name, bool wait) {
1567-
API_BEGIN();
1568-
NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
1569-
NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
1562+
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
1563+
NDArrayHandle* const_nds_handle, int num_const_nds,
1564+
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
1565+
EngineFnPropertyHandle prop_handle, int priority,
1566+
const char* opr_name, bool wait) {
1567+
API_BEGIN();
1568+
NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle);
1569+
NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
15701570
std::vector<VarHandle> const_var_vec(num_const_nds);
1571-
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
1571+
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var();
15721572
std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
1573-
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
1573+
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var();
15741574
return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle,
15751575
const_var_vec.data(), num_const_nds,
15761576
mutable_var_vec.data(), num_mutable_nds,
@@ -1579,18 +1579,18 @@ int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
15791579
}
15801580

15811581
int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
1582-
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
1583-
NDArrayHandle const_nds_handle, int num_const_nds,
1584-
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
1585-
EngineFnPropertyHandle prop_handle, int priority,
1586-
const char* opr_name) {
1587-
API_BEGIN();
1588-
NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
1589-
NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
1582+
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
1583+
NDArrayHandle* const_nds_handle, int num_const_nds,
1584+
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
1585+
EngineFnPropertyHandle prop_handle, int priority,
1586+
const char* opr_name) {
1587+
API_BEGIN();
1588+
NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle);
1589+
NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
15901590
std::vector<VarHandle> const_var_vec(num_const_nds);
1591-
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
1591+
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var();
15921592
std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
1593-
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
1593+
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var();
15941594
return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle,
15951595
const_var_vec.data(), num_const_nds,
15961596
mutable_var_vec.data(), num_mutable_nds,

tests/cpp/engine/threaded_engine_test.cc

Lines changed: 74 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -257,49 +257,80 @@ TEST(Engine, PushFunc) {
257257

258258
TEST(Engine, PushFuncND) {
259259
auto ctx = mxnet::Context{};
260-
mxnet::NDArray nd(ctx);
261-
262-
// Test #1
263-
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
264-
int* a = new int(100);
265-
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
266-
EXPECT_EQ(res, 0);
267-
268-
// Test #2
269-
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
270-
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 0);
271-
EXPECT_EQ(res, 0);
272-
273-
// Test #3
274-
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
275-
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
276-
EXPECT_EQ(res, -1);
277-
278-
// Test #4
279-
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
280-
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
281-
EXPECT_EQ(res, -1);
282-
283-
// Test #5
284-
LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
285-
int* b = new int(101);
286-
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
287-
EXPECT_EQ(res, 0);
288-
289-
// Test #6
290-
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
291-
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 1);
292-
EXPECT_EQ(res, 0);
293-
294-
// Test #7
295-
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
296-
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
297-
EXPECT_EQ(res, -1);
298-
299-
// Test #8
300-
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
301-
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
302-
EXPECT_EQ(res, -1);
260+
std::vector<mxnet::NDArray*> nds;
261+
const int num_nds = 5;
262+
for (int i = 0; i < num_nds; ++i) {
263+
mxnet::NDArray *pnd = new mxnet::NDArray(ctx);
264+
nds.push_back(pnd);
265+
}
266+
for (int num_const_nds = 0; num_const_nds <= num_nds; ++num_const_nds) {
267+
int num_mutable_nds = num_nds - num_const_nds;
268+
void** const_nds_handle = num_const_nds > 0 ?
269+
reinterpret_cast<void**>(nds.data()) : nullptr;
270+
void** mutable_nds_handle = num_mutable_nds > 0 ?
271+
reinterpret_cast<void**>(nds.data() + num_const_nds) : nullptr;
272+
273+
// Test #1
274+
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
275+
int* a = new int(100);
276+
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx,
277+
const_nds_handle, num_const_nds,
278+
mutable_nds_handle, num_mutable_nds);
279+
EXPECT_EQ(res, 0);
280+
281+
// Test #2
282+
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
283+
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
284+
const_nds_handle, num_const_nds,
285+
mutable_nds_handle, num_mutable_nds);
286+
EXPECT_EQ(res, 0);
287+
288+
// Test #3
289+
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
290+
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
291+
const_nds_handle, -1,
292+
mutable_nds_handle, num_mutable_nds);
293+
EXPECT_EQ(res, -1);
294+
295+
// Test #4
296+
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
297+
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
298+
const_nds_handle, num_const_nds,
299+
mutable_nds_handle, -1);
300+
EXPECT_EQ(res, -1);
301+
302+
// Test #5
303+
LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
304+
int* b = new int(101);
305+
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx,
306+
const_nds_handle, num_const_nds,
307+
mutable_nds_handle, num_mutable_nds);
308+
EXPECT_EQ(res, 0);
309+
310+
// Test #6
311+
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
312+
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
313+
const_nds_handle, num_const_nds,
314+
mutable_nds_handle, num_mutable_nds);
315+
EXPECT_EQ(res, 0);
316+
317+
// Test #7
318+
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
319+
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
320+
const_nds_handle, -1,
321+
mutable_nds_handle, num_mutable_nds);
322+
EXPECT_EQ(res, -1);
323+
324+
// Test #8
325+
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
326+
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
327+
const_nds_handle, num_const_nds,
328+
mutable_nds_handle, -1);
329+
EXPECT_EQ(res, -1);
330+
}
331+
for (mxnet::NDArray* pnd : nds) {
332+
delete pnd;
333+
}
303334
}
304335

305336
TEST(Engine, basics) {

0 commit comments

Comments
 (0)