Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Dynamic custom operator GPU support #17270

Merged
merged 16 commits into from
Jan 31, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions example/extensions/lib_custom_op/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
# specific language governing permissions and limitations
# under the License.

all: gemm_lib
all: gemm_lib relu_lib

gemm_lib:
g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet

relu_lib:
nvcc -shared -std=c++11 -Xcompiler -fPIC relu_lib.cu -o librelu_lib.so -I ../../../include/mxnet

clean:
rm -rf libgemm_lib.so
rm -rf libgemm_lib.so librelu_lib.so
4 changes: 2 additions & 2 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ MXReturnValue inferShape(std::map<std::string, std::string> attrs,
}

REGISTER_OP(my_gemm)
.setForward(forward)
.setBackward(backward)
.setForward(forward, "cpu")
.setBackward(backward, "cpu")
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape);
Expand Down
111 changes: 111 additions & 0 deletions example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#include <iostream>
#include "lib_api.h"

__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N){
out[tid] = in[tid] > 0 ? in[tid] : 0;
}
}

__global__ void relu_gpu_backward(float *out, float *in, int64_t N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N){
out[tid] = in[tid] > 0 ? 1 : 0;
}
}

MXReturnValue forwardCPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();
for (int i=0; i<inputs[0].size(); i++) {
out_data[i] = in_data[i] > 0 ? in_data[i] : 0;
}
return MX_SUCCESS;
}

MXReturnValue backwardCPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();
for (int i=0; i<inputs[0].size(); i++) {
out_data[i] = in_data[i] > 0 ? 1 : 0;
}
return MX_SUCCESS;
}

MXReturnValue forwardGPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

cudaStream_t gpu_stream = reinterpret_cast<cudaStream_t>(res.get_gpu_stream());
int64_t N = inputs[0].size();
int grid = (N + 255) / 256;
int block = 256;
relu_gpu_forward<<<grid,block,0,gpu_stream>>>(out_data, in_data, N);

return MX_SUCCESS;
}

MXReturnValue backwardGPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

cudaStream_t gpu_stream = reinterpret_cast<cudaStream_t>(res.get_gpu_stream());
int64_t N = inputs[0].size();
int grid = (N + 255) / 256;
int block = 256;
relu_gpu_backward<<<grid,block,0,gpu_stream>>>(out_data, in_data, N);

return MX_SUCCESS;
}

MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* num_in, int* num_out) {
*num_in = 1;
*num_out = 1;
return MX_SUCCESS;
}

MXReturnValue inferType(std::map<std::string, std::string> attrs,
std::vector<int> &intypes,
std::vector<int> &outtypes) {
outtypes[0] = intypes[0];
return MX_SUCCESS;
}

MXReturnValue inferShape(std::map<std::string, std::string> attrs,
std::vector<std::vector<unsigned int>> &inshapes,
std::vector<std::vector<unsigned int>> &outshapes) {
outshapes[0] = inshapes[0];
return MX_SUCCESS;
}

REGISTER_OP(my_relu)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setForward(forwardCPU, "cpu")
.setForward(forwardGPU, "gpu")
.setBackward(backwardCPU, "cpu")
.setBackward(backwardGPU, "gpu");

MXReturnValue initialize(int version) {
if (version >= 10400) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
return MX_FAIL;
}
}
67 changes: 67 additions & 0 deletions example/extensions/lib_custom_op/test_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=arguments-differ

# This test checks dynamic loading of custom library into MXNet
# and checks end to end compute of a simple 2D gemm custom op

import mxnet as mx
import os
import time

#load library
if (os.name=='posix'):
path = os.path.abspath('librelu_lib.so')
mx.library.load(path)

a = mx.nd.array([[-2,-1],[1,2]], ctx=mx.cpu())
b = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu())

print("--------start ndarray compute---------")
print(mx.nd.my_relu(a))
print(mx.nd.my_relu(b))

print("--------start symbolic compute--------")
c = mx.sym.Variable('c')
d = mx.sym.my_relu(c)
in_grad = [mx.nd.empty((2,2), ctx=mx.gpu())]
exe = d.bind(ctx=mx.gpu(), args={'c':b}, args_grad=in_grad)
out = exe.forward()
print(out)

print("--------start backward compute--------")
out_grad = mx.nd.ones((2,2), ctx=mx.gpu())
exe.backward([out_grad])
print(in_grad)

print("--------start stress test---------")
a = mx.nd.uniform(shape=(1000,1000,100), ctx=mx.cpu())
b = mx.nd.uniform(shape=(1000,1000,100), ctx=mx.gpu())
t1 = time.time()
r1 = mx.nd.my_relu(a)
t2 = time.time()
r2 = mx.nd.my_relu(b)
t3 = time.time()
print("CPU running time:")
print(t2 - t1)
print("GPU running time:")
print(t3 - t2)

4 changes: 2 additions & 2 deletions example/extensions/lib_subgraph/subgraph_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ MXReturnValue myExecutor(std::vector<MXTensor> inputs,
// get input tensor based on node ID inputs from data storage
MXTensor &input = data[node_inputs.list[0].list[0].num];
// create temporary storage
MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0);
MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0});
// save allocated ptr to free later
to_free.push_back(tmp.data_ptr);
// execute log operator
Expand All @@ -95,7 +95,7 @@ MXReturnValue myExecutor(std::vector<MXTensor> inputs,
// get input tensor based on node ID inputs from data storage
MXTensor &input = data[node_inputs.list[0].list[0].num];
// create temporary storage
MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0);
MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0});
// save allocated ptr to free later
to_free.push_back(tmp.data_ptr);
// execute exp operator
Expand Down
Loading