@@ -257,49 +257,80 @@ TEST(Engine, PushFunc) {
257
257
258
258
TEST (Engine, PushFuncND) {
259
259
auto ctx = mxnet::Context{};
260
- mxnet::NDArray nd (ctx);
261
-
262
- // Test #1
263
- LOG (INFO) << " ===== Test #1: PushAsyncND param and deleter =====" ;
264
- int * a = new int (100 );
265
- int res = MXEnginePushAsyncND (FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1 , nullptr , 0 );
266
- EXPECT_EQ (res, 0 );
267
-
268
- // Test #2
269
- LOG (INFO) << " ===== Test #2: PushAsyncND NULL param and NULL deleter =====" ;
270
- res = MXEnginePushAsyncND (FooAsyncFunc, nullptr , nullptr , &ctx, nullptr , 0 , &nd, 0 );
271
- EXPECT_EQ (res, 0 );
272
-
273
- // Test #3
274
- LOG (INFO) << " ===== Test #3: PushAsyncND invalid number of const nds =====" ;
275
- res = MXEnginePushAsyncND (FooAsyncFunc, nullptr , nullptr , &ctx, &nd, -1 , nullptr , 0 );
276
- EXPECT_EQ (res, -1 );
277
-
278
- // Test #4
279
- LOG (INFO) << " ===== Test #4: PushAsyncND invalid number of mutable nds =====" ;
280
- res = MXEnginePushAsyncND (FooAsyncFunc, nullptr , nullptr , &ctx, nullptr , 0 , &nd, -1 );
281
- EXPECT_EQ (res, -1 );
282
-
283
- // Test #5
284
- LOG (INFO) << " ===== Test #5: PushSyncND param and deleter =====" ;
285
- int * b = new int (101 );
286
- res = MXEnginePushSyncND (FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1 , nullptr , 0 );
287
- EXPECT_EQ (res, 0 );
288
-
289
- // Test #6
290
- LOG (INFO) << " ===== Test #6: PushSyncND NULL param and NULL deleter =====" ;
291
- res = MXEnginePushSyncND (FooSyncFunc, nullptr , nullptr , &ctx, nullptr , 0 , &nd, 1 );
292
- EXPECT_EQ (res, 0 );
293
-
294
- // Test #7
295
- LOG (INFO) << " ===== Test #7: PushSyncND invalid number of const nds =====" ;
296
- res = MXEnginePushSyncND (FooSyncFunc, nullptr , nullptr , &ctx, &nd, -1 , nullptr , 0 );
297
- EXPECT_EQ (res, -1 );
298
-
299
- // Test #8
300
- LOG (INFO) << " ===== Test #8: PushSyncND invalid number of mutable nds =====" ;
301
- res = MXEnginePushSyncND (FooSyncFunc, nullptr , nullptr , &ctx, nullptr , 0 , &nd, -1 );
302
- EXPECT_EQ (res, -1 );
260
+ std::vector<mxnet::NDArray*> nds;
261
+ const int num_nds = 5 ;
262
+ for (int i = 0 ; i < num_nds; ++i) {
263
+ mxnet::NDArray *pnd = new mxnet::NDArray (ctx);
264
+ nds.push_back (pnd);
265
+ }
266
+ for (int num_const_nds = 0 ; num_const_nds <= num_nds; ++num_const_nds) {
267
+ int num_mutable_nds = num_nds - num_const_nds;
268
+ void ** const_nds_handle = num_const_nds > 0 ?
269
+ reinterpret_cast <void **>(nds.data ()) : nullptr ;
270
+ void ** mutable_nds_handle = num_mutable_nds > 0 ?
271
+ reinterpret_cast <void **>(nds.data () + num_const_nds) : nullptr ;
272
+
273
+ // Test #1
274
+ LOG (INFO) << " ===== Test #1: PushAsyncND param and deleter =====" ;
275
+ int * a = new int (100 );
276
+ int res = MXEnginePushAsyncND (FooAsyncFunc, a, FooFuncDeleter, &ctx,
277
+ const_nds_handle, num_const_nds,
278
+ mutable_nds_handle, num_mutable_nds);
279
+ EXPECT_EQ (res, 0 );
280
+
281
+ // Test #2
282
+ LOG (INFO) << " ===== Test #2: PushAsyncND NULL param and NULL deleter =====" ;
283
+ res = MXEnginePushAsyncND (FooAsyncFunc, nullptr , nullptr , &ctx,
284
+ const_nds_handle, num_const_nds,
285
+ mutable_nds_handle, num_mutable_nds);
286
+ EXPECT_EQ (res, 0 );
287
+
288
+ // Test #3
289
+ LOG (INFO) << " ===== Test #3: PushAsyncND invalid number of const nds =====" ;
290
+ res = MXEnginePushAsyncND (FooAsyncFunc, nullptr , nullptr , &ctx,
291
+ const_nds_handle, -1 ,
292
+ mutable_nds_handle, num_mutable_nds);
293
+ EXPECT_EQ (res, -1 );
294
+
295
+ // Test #4
296
+ LOG (INFO) << " ===== Test #4: PushAsyncND invalid number of mutable nds =====" ;
297
+ res = MXEnginePushAsyncND (FooAsyncFunc, nullptr , nullptr , &ctx,
298
+ const_nds_handle, num_const_nds,
299
+ mutable_nds_handle, -1 );
300
+ EXPECT_EQ (res, -1 );
301
+
302
+ // Test #5
303
+ LOG (INFO) << " ===== Test #5: PushSyncND param and deleter =====" ;
304
+ int * b = new int (101 );
305
+ res = MXEnginePushSyncND (FooSyncFunc, b, FooFuncDeleter, &ctx,
306
+ const_nds_handle, num_const_nds,
307
+ mutable_nds_handle, num_mutable_nds);
308
+ EXPECT_EQ (res, 0 );
309
+
310
+ // Test #6
311
+ LOG (INFO) << " ===== Test #6: PushSyncND NULL param and NULL deleter =====" ;
312
+ res = MXEnginePushSyncND (FooSyncFunc, nullptr , nullptr , &ctx,
313
+ const_nds_handle, num_const_nds,
314
+ mutable_nds_handle, num_mutable_nds);
315
+ EXPECT_EQ (res, 0 );
316
+
317
+ // Test #7
318
+ LOG (INFO) << " ===== Test #7: PushSyncND invalid number of const nds =====" ;
319
+ res = MXEnginePushSyncND (FooSyncFunc, nullptr , nullptr , &ctx,
320
+ const_nds_handle, -1 ,
321
+ mutable_nds_handle, num_mutable_nds);
322
+ EXPECT_EQ (res, -1 );
323
+
324
+ // Test #8
325
+ LOG (INFO) << " ===== Test #8: PushSyncND invalid number of mutable nds =====" ;
326
+ res = MXEnginePushSyncND (FooSyncFunc, nullptr , nullptr , &ctx,
327
+ const_nds_handle, num_const_nds,
328
+ mutable_nds_handle, -1 );
329
+ EXPECT_EQ (res, -1 );
330
+ }
331
+ for (mxnet::NDArray* pnd : nds) {
332
+ delete pnd;
333
+ }
303
334
}
304
335
305
336
TEST (Engine, basics) {
0 commit comments