Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 34d03fd

Browse files
committed
Refactor around lib loading
1 parent 51c4091 commit 34d03fd

File tree

5 files changed

+114
-173
lines changed

5 files changed

+114
-173
lines changed

src/c_api/c_api.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@
4646
#include "mxnet/libinfo.h"
4747
#include "mxnet/imperative.h"
4848
#include "mxnet/lib_api.h"
49+
#include "../initialize.h"
4950
#include "./c_api_common.h"
5051
#include "../operator/custom/custom-inl.h"
5152
#include "../operator/tensor/matrix_op-inl.h"
5253
#include "../operator/tvmop/op_module.h"
5354
#include "../common/utils.h"
54-
#include "../common/library.h"
5555

5656
using namespace mxnet;
5757

@@ -95,7 +95,7 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
9595
// Loads library and initializes it
9696
int MXLoadLib(const char *path) {
9797
API_BEGIN();
98-
void *lib = load_lib(path);
98+
void *lib = LibraryInitializer::Get()->lib_load(path);
9999
if (!lib)
100100
LOG(FATAL) << "Unable to load library";
101101

src/common/library.cc

Lines changed: 0 additions & 98 deletions
This file was deleted.

src/common/library.h

Lines changed: 0 additions & 40 deletions
This file was deleted.

src/initialize.cc

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,15 @@
2828
#include <mxnet/engine.h>
2929
#include "./engine/openmp.h"
3030
#include "./operator/custom/custom-inl.h"
31-
#include "./common/library.h"
3231
#if MXNET_USE_OPENCV
3332
#include <opencv2/opencv.hpp>
3433
#endif // MXNET_USE_OPENCV
3534
#include "common/utils.h"
3635
#include "engine/openmp.h"
3736

38-
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
39-
#include <windows.h>
40-
#else
41-
#include <dlfcn.h>
42-
#endif
4337

4438
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
39+
#include <windows.h>
4540
/*!
4641
* \brief Retrieve the system error message for the last-error code
4742
* \param err string that gets the error message
@@ -58,9 +53,10 @@ void win_err(char **err) {
5853
reinterpret_cast<char*>(err),
5954
0, NULL);
6055
}
56+
#else
57+
#include <dlfcn.h>
6158
#endif
6259

63-
6460
namespace mxnet {
6561

6662
#if MXNET_USE_SIGNAL_HANDLER && DMLC_LOG_STACK_TRACE
@@ -106,6 +102,91 @@ LibraryInitializer::~LibraryInitializer() {
106102
close_open_libs();
107103
}
108104

105+
bool LibraryInitializer::lib_is_loaded(const std::string& path) const {
106+
return loaded_libs.count(path) > 0;
107+
}
108+
109+
/*!
110+
* \brief Loads the dynamic shared library file
111+
* \param path library file location
112+
* \return handle a pointer for the loaded library, throws dmlc::error if library can't be loaded
113+
*/
114+
void* LibraryInitializer::lib_load(const char* path) {
115+
void *handle = nullptr;
116+
// check if library was already loaded
117+
if (!lib_is_loaded(path)) {
118+
// if not, load it
119+
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
120+
handle = LoadLibrary(path);
121+
if (!handle) {
122+
char *err_msg = nullptr;
123+
win_err(&err_msg);
124+
LOG(FATAL) << "Error loading library: '" << path << "'\n" << err_msg;
125+
LocalFree(err_msg);
126+
return nullptr;
127+
}
128+
#else
129+
handle = dlopen(path, RTLD_LAZY);
130+
if (!handle) {
131+
LOG(FATAL) << "Error loading library: '" << path << "'\n" << dlerror();
132+
return nullptr;
133+
}
134+
#endif // _WIN32 or _WIN64 or __WINDOWS__
135+
// then store the pointer to the library
136+
loaded_libs[path] = handle;
137+
} else {
138+
loaded_libs.at(path);
139+
}
140+
return handle;
141+
}
142+
143+
/*!
144+
* \brief Closes the loaded dynamic shared library file
145+
* \param handle library file handle
146+
*/
147+
void LibraryInitializer::lib_close(void* handle) {
148+
std::string libpath;
149+
for (const auto& l: loaded_libs) {
150+
if (l.second == handle) {
151+
libpath = l.first;
152+
break;
153+
}
154+
}
155+
CHECK(!libpath.empty());
156+
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
157+
FreeLibrary((HMODULE)handle);
158+
#else
159+
if (dlclose(handle)) {
160+
LOG(WARNING) << "LibraryInitializer::lib_close: couldn't close library at address: " << handle
161+
<< " loaded from: '" << libpath << "': " << dlerror();
162+
}
163+
#endif // _WIN32 or _WIN64 or __WINDOWS__
164+
loaded_libs.erase(libpath);
165+
}
166+
167+
/*!
168+
* \brief Obtains address of given function in the loaded library
169+
* \param handle pointer for the loaded library
170+
* \param func function pointer that gets output address
171+
* \param name function name to be fetched
172+
*/
173+
void LibraryInitializer::get_sym(void* handle, void** func, char* name) {
174+
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
175+
*func = GetProcAddress((HMODULE)handle, name);
176+
if (!(*func)) {
177+
char *err_msg = nullptr;
178+
win_err(&err_msg);
179+
LOG(FATAL) << "Error getting function '" << name << "' from library\n" << err_msg;
180+
LocalFree(err_msg);
181+
}
182+
#else
183+
*func = dlsym(handle, name);
184+
if (!(*func)) {
185+
LOG(FATAL) << "Error getting function '" << name << "' from library\n" << dlerror();
186+
}
187+
#endif // _WIN32 or _WIN64 or __WINDOWS__
188+
}
189+
109190
bool LibraryInitializer::was_forked() const {
110191
return common::current_process_id() != original_pid_;
111192
}
@@ -153,15 +234,11 @@ void LibraryInitializer::install_signal_handlers() {
153234
}
154235

155236
void LibraryInitializer::close_open_libs() {
156-
for (auto const& lib : loaded_libs) {
157-
close_lib(lib.second);
237+
for (const auto& l: loaded_libs) {
238+
lib_close(l.second);
158239
}
159240
}
160241

161-
void LibraryInitializer::dynlib_defer_close(const std::string &path, void *handle) {
162-
loaded_libs.emplace(path, handle);
163-
}
164-
165242
/**
166243
* Perform static initialization
167244
*/

src/initialize.h

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,14 @@
2626
#include <cstdlib>
2727
#include <string>
2828
#include <map>
29+
#include "dmlc/io.h"
30+
2931

3032
#ifndef MXNET_INITIALIZE_H_
3133
#define MXNET_INITIALIZE_H_
3234

3335
namespace mxnet {
3436

35-
/*!
36-
* \brief fetches from the library a function pointer of any given datatype and name
37-
* \param T a template parameter for data type of function pointer
38-
* \param lib library handle
39-
* \param func_name function name to search for in the library
40-
* \return func a function pointer
41-
*/
42-
template<typename T>
43-
T get_func(void *lib, char *func_name) {
44-
T func;
45-
get_sym(lib, reinterpret_cast<void**>(&func), func_name);
46-
if (!func)
47-
LOG(FATAL) << "Unable to get function '" << func_name << "' from library";
48-
return func;
49-
}
5037

5138

5239
void pthread_atfork_prepare();
@@ -58,7 +45,7 @@ void pthread_atfork_child();
5845
*/
5946
class LibraryInitializer {
6047
public:
61-
typedef static std::map<std::string, void*> loaded_libs_t;
48+
typedef std::map<std::string, void*> loaded_libs_t;
6249
static LibraryInitializer* Get() {
6350
static LibraryInitializer inst;
6451
return &inst;
@@ -79,8 +66,10 @@ class LibraryInitializer {
7966

8067

8168
// Library loading
82-
void lib_defer_close(const std::string& path, void* handle);
83-
void lib_is_loaded()
69+
bool lib_is_loaded(const std::string& path) const;
70+
void* lib_load(const char* path);
71+
void lib_close(void* handle);
72+
static void get_sym(void* handle, void** func, char* name);
8473

8574
/**
8675
* Original pid of the process which first loaded and initialized the library
@@ -92,7 +81,6 @@ class LibraryInitializer {
9281
size_t mp_cv_num_threads_;
9382

9483
// Actual code for the atfork handlers as member functions.
95-
9684
void atfork_prepare();
9785
void atfork_parent();
9886
void atfork_child();
@@ -115,10 +103,24 @@ class LibraryInitializer {
115103

116104
void close_open_libs();
117105

118-
119106
loaded_libs_t loaded_libs;
120107
};
121108

109+
/*!
110+
* \brief fetches from the library a function pointer of any given datatype and name
111+
* \param T a template parameter for data type of function pointer
112+
* \param lib library handle
113+
* \param func_name function name to search for in the library
114+
* \return func a function pointer
115+
*/
116+
template<typename T>
117+
T get_func(void *lib, char *func_name) {
118+
T func;
119+
LibraryInitializer::Get()->get_sym(lib, reinterpret_cast<void**>(&func), func_name);
120+
if (!func)
121+
LOG(FATAL) << "Unable to get function '" << func_name << "' from library";
122+
return func;
123+
}
122124

123125
} // namespace mxnet
124126
#endif // MXNET_INITIALIZE_H_

0 commit comments

Comments
 (0)