@@ -50,14 +50,18 @@ class InstallDeps implements IRepairStep {
50
50
private string $ tfjsPath ;
51
51
private IClientService $ clientService ;
52
52
private LoggerInterface $ logger ;
53
+ private string $ runTfjsGpuInstall ;
54
+ private string $ tfjsGPUPath ;
53
55
54
56
public function __construct (IConfig $ config , IClientService $ clientService , LoggerInterface $ logger ) {
55
57
$ this ->config = $ config ;
56
58
$ this ->binaryDir = dirname (__DIR__ , 2 ) . '/bin/ ' ;
57
59
$ this ->preGypBinaryDir = dirname (__DIR__ , 2 ) . '/node_modules/@mapbox/node-pre-gyp/bin/ ' ;
58
60
$ this ->ffmpegDir = dirname (__DIR__ , 2 ) . '/node_modules/ffmpeg-static/ ' ;
59
61
$ this ->tfjsInstallScript = dirname (__DIR__ , 2 ) . '/node_modules/@tensorflow/tfjs-node/scripts/install.js ' ;
62
+ $ this ->runTfjsGpuInstall = dirname (__DIR__ , 2 ) . '/node_modules/@tensorflow/tfjs-node-gpu/scripts/install.js ' ;
60
63
$ this ->tfjsPath = dirname (__DIR__ , 2 ) . '/node_modules/@tensorflow/tfjs-node/ ' ;
64
+ $ this ->tfjsGPUPath = dirname (__DIR__ , 2 ) . '/node_modules/@tensorflow/tfjs-node-gpu/ ' ;
61
65
$ this ->clientService = $ clientService ;
62
66
$ this ->logger = $ logger ;
63
67
}
@@ -116,6 +120,7 @@ public function run(IOutput $output): void {
116
120
$ this ->setBinariesPermissions ();
117
121
118
122
$ this ->runTfjsInstall ($ binaryPath );
123
+ $ this ->runTfjsGpuInstall ($ binaryPath );
119
124
}
120
125
121
126
protected function testBinary (string $ binaryPath ): ?string {
@@ -147,10 +152,28 @@ protected function runTfjsInstall(string $nodeBinary) : void {
147
152
}
148
153
chdir ($ oriCwd );
149
154
if ($ returnCode !== 0 ) {
155
+ $ this ->logger ->error ('Failed to install Tensorflow.js: ' .trim (implode ("\n" , $ output )));
150
156
throw new \Exception ('Failed to install Tensorflow.js: ' .trim (implode ("\n" , $ output )));
151
157
}
152
158
}
153
159
160
+ protected function runTfjsGpuInstall (string $ nodeBinary ) : void {
161
+ $ oriCwd = getcwd ();
162
+ chdir ($ this ->tfjsGPUPath );
163
+ $ cmd = 'PATH= ' .escapeshellcmd ($ this ->preGypBinaryDir ).': ' .escapeshellcmd ($ this ->binaryDir ).':$PATH ' . escapeshellcmd ($ nodeBinary ) . ' ' . escapeshellarg ($ this ->tfjsGpuInstallScript ) . ' gpu ' . escapeshellarg ('download ' );
164
+ try {
165
+ exec ($ cmd . ' 2>&1 ' , $ output , $ returnCode ); // Appending 2>&1 to avoid leaking sterr
166
+ } catch (\Throwable $ e ) {
167
+ $ this ->logger ->error ('Failed to install Tensorflow.js for GPU: ' .$ e ->getMessage (), ['exception ' => $ e ]);
168
+ throw new \Exception ('Failed to install Tensorflow.js for GPU: ' .$ e ->getMessage ());
169
+ }
170
+ chdir ($ oriCwd );
171
+ if ($ returnCode !== 0 ) {
172
+ $ this ->logger ->error ('Failed to install Tensorflow.js for GPU: ' .trim (implode ("\n" , $ output )));
173
+ throw new \Exception ('Failed to install Tensorflow.js for GPU: ' .trim (implode ("\n" , $ output )));
174
+ }
175
+ }
176
+
154
177
protected function downloadNodeBinary (string $ server , string $ version , string $ arch , string $ flavor = '' ) : string {
155
178
$ name = 'node- ' .$ version .'-linux- ' .$ arch ;
156
179
if ($ flavor !== '' ) {
0 commit comments