Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Dynamic custom operator GPU support #17270

Merged
merged 16 commits into from
Jan 31, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions example/extensions/lib_custom_op/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
# specific language governing permissions and limitations
# under the License.

all: gemm_lib
all: gemm_lib relu_lib

gemm_lib:
g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet

relu_lib:
nvcc -shared -std=c++11 -Xcompiler -fPIC relu_lib.cu -o librelu_lib.so -I ../../../include/mxnet

clean:
rm -rf libgemm_lib.so
rm -rf libgemm_lib.so librelu_lib.so
25 changes: 16 additions & 9 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ MXReturnValue backward(std::map<std::string, std::string> attrs,
unsigned m = inputs[2].shape[1];
// allocate temporary workspace memory through resource manager
// for multiple arrays better to request a big memory pool
void *workspace = res.alloc((k*n + m*k) * sizeof(float));
void *workspace = res.alloc_cpu((k*n + m*k) * sizeof(float));
float *At = static_cast<float*>(workspace);
float *Bt = static_cast<float*>(workspace) + (k*n);

Expand Down Expand Up @@ -167,8 +167,8 @@ MXReturnValue inferShape(std::map<std::string, std::string> attrs,
}

REGISTER_OP(my_gemm)
.setForward(forward)
.setBackward(backward)
.setForward(forward, "cpu")
.setBackward(backward, "cpu")
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape);
Expand All @@ -182,16 +182,23 @@ class MyStatefulGemm : public CustomStatefulOp {
MXReturnValue Forward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
++count;
std::cout << "Info: keyword + number of forward: " << count << std::endl;
std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
std::map<std::string, std::string> attrs;
if (inputs[0].ctx.dev_type != "cpu") {
std::cout << "Forward is not implemented for " << inputs[0].ctx.dev_type << std::endl;
return MX_FAIL;
}
return forward(attrs, inputs, outputs, op_res);
}

MXReturnValue Backward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
if (inputs[0].ctx.dev_type != "cpu") {
std::cout << "Backward is not implemented for " << inputs[0].ctx.dev_type << std::endl;
return MX_FAIL;
}
return backward(attrs, inputs, outputs, op_res);
}

Expand All @@ -203,9 +210,9 @@ class MyStatefulGemm : public CustomStatefulOp {

MXReturnValue createOpState(std::map<std::string, std::string> attrs,
CustomStatefulOp** op_inst) {
int count = 0;
if (attrs.count("test_kw") > 0)
count = std::stoi(attrs["test_kw"]);
// testing passing of keyward arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0;
// creating stateful operator instance
*op_inst = new MyStatefulGemm(count);
std::cout << "Info: stateful operator created" << std::endl;
return MX_SUCCESS;
Expand All @@ -222,7 +229,7 @@ REGISTER_OP(state_gemm)
.setInferType(inferType)
.setInferShape(inferShape)
.setMutateInputs(mutateInputs)
.setCreateOpState(createOpState);
.setCreateOpState(createOpState, "cpu");

MXReturnValue initialize(int version) {
if (version >= 10400) {
Expand Down
193 changes: 193 additions & 0 deletions example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2020 by Contributors
* \file relu_lib.cu
* \brief simple custom relu operator implemented using CUDA function
*/

#include <iostream>
#include "lib_api.h"

__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N)
out[tid] = in[tid] > 0 ? in[tid] : 0;
}

__global__ void relu_gpu_backward(float *out, float *in, int64_t N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N)
out[tid] = in[tid] > 0 ? 1 : 0;
}

MXReturnValue forwardCPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();
for (int i=0; i<inputs[0].size(); i++) {
out_data[i] = in_data[i] > 0 ? in_data[i] : 0;
}
return MX_SUCCESS;
}

MXReturnValue backwardCPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();
for (int i=0; i<inputs[0].size(); i++) {
out_data[i] = in_data[i] > 0 ? 1 : 0;
}
return MX_SUCCESS;
}

MXReturnValue forwardGPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

// test on memory resource allocation
void *workspace_cpu = res.alloc_cpu(8 * sizeof(float));
void *workspace_gpu = res.alloc_gpu(8 * sizeof(float));

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();
int block = 256;
int grid = (N + (block - 1)) / block;
relu_gpu_forward<<<grid,block,0,cuda_stream>>>(out_data, in_data, N);

return MX_SUCCESS;
}

MXReturnValue backwardGPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();
int block = 256;
int grid = (N + (block - 1)) / block;
relu_gpu_backward<<<grid,block,0,cuda_stream>>>(out_data, in_data, N);

return MX_SUCCESS;
}

MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* num_in, int* num_out) {
*num_in = 1;
*num_out = 1;
return MX_SUCCESS;
}

MXReturnValue inferType(std::map<std::string, std::string> attrs,
std::vector<int> &intypes,
std::vector<int> &outtypes) {
outtypes[0] = intypes[0];
return MX_SUCCESS;
}

MXReturnValue inferShape(std::map<std::string, std::string> attrs,
std::vector<std::vector<unsigned int>> &inshapes,
std::vector<std::vector<unsigned int>> &outshapes) {
outshapes[0] = inshapes[0];
return MX_SUCCESS;
}

REGISTER_OP(my_relu)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setForward(forwardCPU, "cpu")
.setForward(forwardGPU, "gpu")
.setBackward(backwardCPU, "cpu")
.setBackward(backwardGPU, "gpu");

class MyStatefulReluCPU : public CustomStatefulOp {
public:
explicit MyStatefulReluCPU() {}
MXReturnValue Forward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
return forwardCPU(attrs, inputs, outputs, op_res);
}
MXReturnValue Backward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
return backwardCPU(attrs, inputs, outputs, op_res);
}
~MyStatefulReluCPU() {}
};

class MyStatefulReluGPU : public CustomStatefulOp {
public:
explicit MyStatefulReluGPU() {}
MXReturnValue Forward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
return forwardGPU(attrs, inputs, outputs, op_res);
}
MXReturnValue Backward(std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource op_res) {
std::map<std::string, std::string> attrs;
return backwardGPU(attrs, inputs, outputs, op_res);
}
~MyStatefulReluGPU() {}
};

MXReturnValue createOpStateCPU(std::map<std::string, std::string> attrs,
CustomStatefulOp** op_inst) {
*op_inst = new MyStatefulReluCPU();
return MX_SUCCESS;
}

MXReturnValue createOpStateGPU(std::map<std::string, std::string> attrs,
CustomStatefulOp** op_inst) {
*op_inst = new MyStatefulReluGPU();
return MX_SUCCESS;
}

REGISTER_OP(my_state_relu)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setCreateOpState(createOpStateCPU, "cpu")
.setCreateOpState(createOpStateGPU, "gpu");

MXReturnValue initialize(int version) {
if (version >= 10400) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
return MX_FAIL;
}
}
69 changes: 69 additions & 0 deletions example/extensions/lib_custom_op/test_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=arguments-differ

# This test checks dynamic loading of custom library into MXNet
# and checks end to end compute of a simple 2D gemm custom op

import mxnet as mx
import os
import time

#load library
if (os.name=='posix'):
path = os.path.abspath('librelu_lib.so')
mx.library.load(path)

a = mx.nd.array([[-2,-1],[1,2]], ctx=mx.cpu())
b = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu())

print("--------start ndarray compute---------")
print(mx.nd.my_relu(a))
print(mx.nd.my_relu(b))
print(mx.nd.my_state_relu(a))
print(mx.nd.my_state_relu(b))

print("--------start symbolic compute--------")
c = mx.sym.Variable('c')
d = mx.sym.my_relu(c)
in_grad = [mx.nd.empty((2,2), ctx=mx.gpu())]
exe = d.bind(ctx=mx.gpu(), args={'c':b}, args_grad=in_grad)
out = exe.forward()
print(out)

print("--------start backward compute--------")
out_grad = mx.nd.ones((2,2), ctx=mx.gpu())
exe.backward([out_grad])
print(in_grad)

print("--------start stress test---------")
a = mx.nd.uniform(shape=(1000,1000,100), ctx=mx.cpu())
b = mx.nd.uniform(shape=(1000,1000,100), ctx=mx.gpu())
t1 = time.time()
r1 = mx.nd.my_relu(a)
t2 = time.time()
r2 = mx.nd.my_relu(b)
t3 = time.time()
print("CPU running time:")
print(t2 - t1)
print("GPU running time:")
print(t3 - t2)

6 changes: 3 additions & 3 deletions example/extensions/lib_subgraph/subgraph_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ MXReturnValue myExecutor(std::vector<MXTensor> inputs,
// get input tensor based on node ID inputs from data storage
MXTensor &input = data[node_inputs.list[0].list[0].num];
// create temporary storage
MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0);
MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0});
// save allocated ptr to free later
to_free.push_back(tmp.data_ptr);
// execute log operator
Expand All @@ -95,7 +95,7 @@ MXReturnValue myExecutor(std::vector<MXTensor> inputs,
// get input tensor based on node ID inputs from data storage
MXTensor &input = data[node_inputs.list[0].list[0].num];
// create temporary storage
MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0);
MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0});
// save allocated ptr to free later
to_free.push_back(tmp.data_ptr);
// execute exp operator
Expand Down Expand Up @@ -172,7 +172,7 @@ MXReturnValue createOpState(std::map<std::string, std::string> attrs,

REGISTER_OP(_custom_subgraph_op)
.setIsSubgraphOp()
.setCreateOpState(createOpState);
.setCreateOpState(createOpState, "cpu");

const std::vector<std::string> op_names({"exp","log"});

Expand Down
Loading