Skip to content

Commit a0d74e6

Browse files
committed
Implement GPU mode
fixes #67 Signed-off-by: Marcel Klehr <[email protected]>
1 parent d2849a3 commit a0d74e6

File tree

10 files changed

+396
-129
lines changed

10 files changed

+396
-129
lines changed

Makefile

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ remove-binaries:
3939
# make it download appropriate tf binaries
4040
rm -rf node_modules/@tensorflow/tfjs-node/deps/lib/*
4141
rm -rf node_modules/@tensorflow/tfjs-node/lib/*
42+
rm -rf node_modules/@tensorflow/tfjs-node-gpu/lib/*
43+
rm -rf node_modules/@tensorflow/tfjs-node-gpu/deps/lib/*
4244

4345
remove-devdeps:
4446
rm -rf node_modules

appinfo/routes.php

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
['name' => 'admin#nodejs', 'url' => '/admin/nodejs', 'verb' => 'GET'],
3434
['name' => 'admin#libtensorflow', 'url' => '/admin/libtensorflow', 'verb' => 'GET'],
3535
['name' => 'admin#wasmtensorflow', 'url' => '/admin/wasmtensorflow', 'verb' => 'GET'],
36+
['name' => 'admin#gputensorflow', 'url' => '/admin/gputensorflow', 'verb' => 'GET'],
3637
['name' => 'admin#cron', 'url' => '/admin/cron', 'verb' => 'GET'],
3738
['name' => 'admin#get_setting', 'url' => '/admin/settings/{setting}', 'verb' => 'GET'],
3839
['name' => 'admin#set_setting', 'url' => '/admin/settings/{setting}', 'verb' => 'PUT'],

lib/Controller/AdminController.php

+14
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,20 @@ public function wasmtensorflow(): JSONResponse {
167167
return new JSONResponse(['wasmtensorflow' => true]);
168168
}
169169

170+
public function gputensorflow(): JSONResponse {
171+
try {
172+
exec($this->settingsService->getSetting('node_binary') . ' ' . __DIR__ . '/../../src/test_gputensorflow.js' . ' 2>&1', $output, $returnCode);
173+
} catch (\Throwable $e) {
174+
return new JSONResponse(['gputensorflow' => false]);
175+
}
176+
177+
if ($returnCode !== 0) {
178+
return new JSONResponse(['gputensorflow' => false]);
179+
}
180+
181+
return new JSONResponse(['gputensorflow' => true]);
182+
}
183+
170184
public function cron(): JSONResponse {
171185
$cron = $this->config->getAppValue('core', 'backgroundjobs_mode', '');
172186
return new JSONResponse(['cron' => $cron]);

lib/Migration/InstallDeps.php

+23
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,18 @@ class InstallDeps implements IRepairStep {
5050
private string $tfjsPath;
5151
private IClientService $clientService;
5252
private LoggerInterface $logger;
53+
private string $runTfjsGpuInstall;
54+
private string $tfjsGPUPath;
5355

5456
public function __construct(IConfig $config, IClientService $clientService, LoggerInterface $logger) {
5557
$this->config = $config;
5658
$this->binaryDir = dirname(__DIR__, 2) . '/bin/';
5759
$this->preGypBinaryDir = dirname(__DIR__, 2) . '/node_modules/@mapbox/node-pre-gyp/bin/';
5860
$this->ffmpegDir = dirname(__DIR__, 2) . '/node_modules/ffmpeg-static/';
5961
$this->tfjsInstallScript = dirname(__DIR__, 2) . '/node_modules/@tensorflow/tfjs-node/scripts/install.js';
62+
$this->runTfjsGpuInstall = dirname(__DIR__, 2) . '/node_modules/@tensorflow/tfjs-node-gpu/scripts/install.js';
6063
$this->tfjsPath = dirname(__DIR__, 2) . '/node_modules/@tensorflow/tfjs-node/';
64+
$this->tfjsGPUPath = dirname(__DIR__, 2) . '/node_modules/@tensorflow/tfjs-node-gpu/';
6165
$this->clientService = $clientService;
6266
$this->logger = $logger;
6367
}
@@ -116,6 +120,7 @@ public function run(IOutput $output): void {
116120
$this->setBinariesPermissions();
117121

118122
$this->runTfjsInstall($binaryPath);
123+
$this->runTfjsGpuInstall($binaryPath);
119124
}
120125

121126
protected function testBinary(string $binaryPath): ?string {
@@ -147,10 +152,28 @@ protected function runTfjsInstall(string $nodeBinary) : void {
147152
}
148153
chdir($oriCwd);
149154
if ($returnCode !== 0) {
155+
$this->logger->error('Failed to install Tensorflow.js: '.trim(implode("\n", $output)));
150156
throw new \Exception('Failed to install Tensorflow.js: '.trim(implode("\n", $output)));
151157
}
152158
}
153159

160+
protected function runTfjsGpuInstall(string $nodeBinary) : void {
161+
$oriCwd = getcwd();
162+
chdir($this->tfjsGPUPath);
163+
$cmd = 'PATH='.escapeshellcmd($this->preGypBinaryDir).':'.escapeshellcmd($this->binaryDir).':$PATH ' . escapeshellcmd($nodeBinary) . ' ' . escapeshellarg($this->tfjsGpuInstallScript) . ' gpu ' . escapeshellarg('download');
164+
try {
165+
exec($cmd . ' 2>&1', $output, $returnCode); // Appending 2>&1 to avoid leaking sterr
166+
} catch (\Throwable $e) {
167+
$this->logger->error('Failed to install Tensorflow.js for GPU: '.$e->getMessage(), ['exception' => $e]);
168+
throw new \Exception('Failed to install Tensorflow.js for GPU: '.$e->getMessage());
169+
}
170+
chdir($oriCwd);
171+
if ($returnCode !== 0) {
172+
$this->logger->error('Failed to install Tensorflow.js for GPU: '.trim(implode("\n", $output)));
173+
throw new \Exception('Failed to install Tensorflow.js for GPU: '.trim(implode("\n", $output)));
174+
}
175+
}
176+
154177
protected function downloadNodeBinary(string $server, string $version, string $arch, string $flavor = '') : string {
155178
$name = 'node-'.$version.'-linux-'.$arch;
156179
if ($flavor !== '') {

0 commit comments

Comments
 (0)