@@ -332,166 +332,50 @@ class LeakyReLUOp : public Operator {
332
332
}; // class LeakyReLUOp
333
333
334
334
template <typename xpu>
335
- Operator* CreateOp (LeakyReLUParam type, int dtype);
335
+ void LeakyReLUCompute (const nnvm::NodeAttrs& attrs,
336
+ const OpContext& ctx, const std::vector<TBlob>& inputs,
337
+ const std::vector<OpReqType>& req,
338
+ const std::vector<TBlob>& outputs) {
339
+ const LeakyReLUParam ¶m = nnvm::get<LeakyReLUParam>(attrs.parsed );
340
+ const std::vector<TBlob> no_use_but_adapt_origin_api;
341
+ size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1 ;
342
+ CHECK_EQ (inputs.size (), expected);
336
343
337
- #if DMLC_USE_CXX11
338
- class LeakyReLUProp : public OperatorProperty {
339
- public:
340
- void Init (const std::vector<std::pair<std::string, std::string> >& kwargs) override {
341
- param_.Init (kwargs);
342
- }
343
-
344
- std::map<std::string, std::string> GetParams () const override {
345
- return param_.__DICT__ ();
346
- }
347
-
348
- bool InferShape (mxnet::ShapeVector *in_shape,
349
- mxnet::ShapeVector *out_shape,
350
- mxnet::ShapeVector *aux_shape) const override {
351
- using namespace mshadow ;
352
- if (param_.act_type == leakyrelu::kPReLU ) {
353
- CHECK_EQ (in_shape->size (), 2U ) << " Input:[data, gamma]" ;
354
- } else {
355
- CHECK_EQ (in_shape->size (), 1U ) << " Input:[data]" ;
356
- }
357
- const mxnet::TShape &dshape = in_shape->at (leakyrelu::kData );
358
- if (!mxnet::ndim_is_known (dshape)) return false ;
359
- if (param_.act_type == leakyrelu::kPReLU ) {
360
- const mxnet::TShape &gshape = in_shape->at (leakyrelu::kGamma );
361
- if (!mxnet::ndim_is_known (gshape)) {
362
- in_shape->at (leakyrelu::kGamma ) = mxnet::TShape (Shape1 (dshape[1 ]));
363
- }
364
- if (dshape == gshape) {
365
- SHAPE_ASSIGN_CHECK (*out_shape, 0 , dshape);
366
- }
367
- }
368
- out_shape->clear ();
369
- out_shape->push_back (dshape);
370
- if (param_.act_type == leakyrelu::kRReLU ) {
371
- out_shape->push_back (dshape);
372
- }
373
- return true ;
374
- }
375
-
376
- bool InferType (std::vector<int > *in_type,
377
- std::vector<int > *out_type,
378
- std::vector<int > *aux_type) const override {
379
- int dtype = -1 ;
380
- for (const int & type : *in_type) {
381
- type_assign (&dtype, type);
382
- }
383
- for (const int & type : *out_type) {
384
- type_assign (&dtype, type);
385
- }
386
-
387
- for (size_t i = 0 ; i < in_type->size (); ++i) {
388
- TYPE_ASSIGN_CHECK (*in_type, i, dtype);
389
- }
390
- for (size_t i = 0 ; i < out_type->size (); ++i) {
391
- TYPE_ASSIGN_CHECK (*out_type, i, dtype);
392
- }
393
- return dtype != -1 ;
394
- }
395
-
396
- OperatorProperty* Copy () const override {
397
- auto ptr = new LeakyReLUProp ();
398
- ptr->param_ = param_;
399
- return ptr;
400
- }
401
-
402
- std::string TypeString () const override {
403
- return " LeakyReLU" ;
404
- }
405
-
406
- // decalre dependency and inplace optimization options
407
- std::vector<int > DeclareBackwardDependency (
408
- const std::vector<int > &out_grad,
409
- const std::vector<int > &in_data,
410
- const std::vector<int > &out_data) const override {
411
- if (param_.act_type == leakyrelu::kPReLU ) {
412
- return {out_grad[leakyrelu::kOut ],
413
- out_data[leakyrelu::kOut ],
414
- in_data[leakyrelu::kData ],
415
- in_data[leakyrelu::kGamma ]};
416
- } else if (param_.act_type == leakyrelu::kRReLU ) {
417
- return {out_grad[leakyrelu::kOut ], out_data[leakyrelu::kMask ], out_data[leakyrelu::kOut ]};
418
- } else {
419
- return {out_grad[leakyrelu::kOut ], out_data[leakyrelu::kData ]};
420
- }
421
- }
344
+ MSHADOW_REAL_TYPE_SWITCH (inputs[leakyrelu::kData ].type_flag_ , DType, {
345
+ LeakyReLUOp<xpu, DType> op (param);
346
+ op.Forward (ctx, inputs, req, outputs, no_use_but_adapt_origin_api);
347
+ });
348
+ }
422
349
423
- std::vector<std::pair<int , void *> > BackwardInplaceOption (
424
- const std::vector<int > &out_grad,
425
- const std::vector<int > &in_data,
426
- const std::vector<int > &out_data,
427
- const std::vector<void *> &in_grad) const override {
428
- return {{out_grad[leakyrelu::kOut ], in_grad[leakyrelu::kData ]}};
429
- }
430
-
431
- std::vector<std::pair<int , void *> > ForwardInplaceOption (
432
- const std::vector<int > &in_data,
433
- const std::vector<void *> &out_data) const override {
434
- if (param_.act_type == leakyrelu::kPReLU ) {
435
- return {};
436
- } else {
437
- return {{in_data[leakyrelu::kData ], out_data[leakyrelu::kOut ]}};
438
- }
439
- }
440
-
441
- std::vector<std::string> ListArguments () const override {
442
- if (param_.act_type == leakyrelu::kPReLU ) {
443
- return {" data" , " gamma" };
444
- } else {
445
- return {" data" };
446
- }
447
- }
448
-
449
- std::vector<std::string> ListOutputs () const override {
450
- if (param_.act_type == leakyrelu::kRReLU ) {
451
- return {" output" , " mask" };
452
- } else {
453
- return {" output" };
454
- }
455
- }
456
-
457
- int NumOutputs () const override {
458
- if (param_.act_type == leakyrelu::kRReLU ) {
459
- return 2 ;
460
- } else {
461
- return 1 ;
462
- }
463
- }
464
-
465
- int NumVisibleOutputs () const override {
466
- return 1 ;
467
- }
468
-
469
- std::vector<ResourceRequest> ForwardResource (
470
- const mxnet::ShapeVector &in_shape) const override {
471
- if (param_.act_type == leakyrelu::kRReLU ) {
472
- return {ResourceRequest::kRandom };
473
- } else {
474
- return std::vector<ResourceRequest>();
475
- }
476
- }
350
+ template <typename xpu>
351
+ void LeakyReLUGradCompute (const nnvm::NodeAttrs& attrs,
352
+ const OpContext& ctx,
353
+ const std::vector<TBlob>& inputs,
354
+ const std::vector<OpReqType>& req,
355
+ const std::vector<TBlob>& outputs) {
356
+ const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed );
357
+ const std::vector<TBlob> no_use_but_adapt_origin_api;
358
+ // inputs: out_grad, input_data, input_gamma, output, output_mask
359
+ size_t expected_in = param.act_type == leakyrelu::kPReLU ? 2 : 1 ;
360
+ size_t expected_out = param.act_type == leakyrelu::kRReLU ? 2 : 1 ;
477
361
478
- std::vector<ResourceRequest> BackwardResource (
479
- const mxnet::ShapeVector &in_shape) const override {
480
- return {ResourceRequest::kTempSpace };
481
- }
362
+ CHECK_GE (inputs.size (), 1 + expected_in + expected_out);
363
+ std::vector<TBlob> out_grad{inputs[0 ]};
364
+ std::vector<TBlob> in_data (inputs.begin () + 1 ,
365
+ inputs.begin () + 1 + expected_in);
366
+ std::vector<TBlob> out_data (inputs.begin () + 1 + expected_in,
367
+ inputs.begin () + 1 + expected_in + expected_out);
482
368
483
- Operator* CreateOperator (Context ctx) const override {
484
- LOG (FATAL) << " Not Implemented." ;
485
- return NULL ;
486
- }
369
+ CHECK_EQ (req.size (), outputs.size ());
370
+ int dtype = inputs[0 ].type_flag_ ;
371
+ const std::vector<TBlob> &in_grad = outputs;
487
372
488
- Operator* CreateOperatorEx (Context ctx, mxnet::ShapeVector *in_shape,
489
- std::vector<int > *in_type) const override ;
373
+ MSHADOW_REAL_TYPE_SWITCH (dtype, DType, {
374
+ LeakyReLUOp<xpu, DType> op (param);
375
+ op.Backward (ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api);
376
+ });
377
+ }
490
378
491
- private:
492
- LeakyReLUParam param_;
493
- };
494
- #endif // DMLC_USE_CXX11
495
379
} // namespace op
496
380
} // namespace mxnet
497
381
0 commit comments