Skip to content

Commit 2644bd0

Browse files
[WebGPU] Make dataToGPU upload to GPU if data is on CPU (#8483)
1 parent cb6206c commit 2644bd0

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

tfjs-backend-webgpu/src/backend_webgpu.ts

+6-3
Original file line numberDiff line numberDiff line change
@@ -594,16 +594,19 @@ export class WebGPUBackend extends KernelBackend {
594594
* @param dataId The source tensor.
595595
*/
596596
override readToGPU(dataId: DataId): GPUData {
597-
const srcTensorData = this.tensorMap.get(dataId);
598-
const {values, dtype, shape, resource} = srcTensorData;
597+
let srcTensorData = this.tensorMap.get(dataId);
598+
const {values, dtype, shape} = srcTensorData;
599+
let resource = srcTensorData.resource;
599600

600601
if (dtype === 'complex64') {
601602
throw new Error('Does not support reading buffer for complex64 dtype.');
602603
}
603604

604605
if (resource == null) {
605606
if (values != null) {
606-
throw new Error('Data is not on GPU but on CPU.');
607+
this.uploadToGPU(dataId);
608+
srcTensorData = this.tensorMap.get(dataId);
609+
resource = srcTensorData.resource;
607610
} else {
608611
throw new Error('There is no data on GPU or CPU.');
609612
}

tfjs-backend-webgpu/src/backend_webgpu_test.ts

+12
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,18 @@ describeWebGPU('backend webgpu', () => {
200200
await c3.data();
201201
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', savedFlag);
202202
});
203+
204+
it('dataToGPU uploads to GPU if the tensor is on CPU', async () => {
205+
const webGPUBackend = (tf.backend() as WebGPUBackend);
206+
const data = [1,2,3,4,5];
207+
const tensor = tf.tensor1d(data);
208+
const res = tensor.dataToGPU();
209+
expect(res.buffer).toBeDefined();
210+
const resData = await webGPUBackend.getBufferData(res.buffer);
211+
const values = tf.util.convertBackendValuesAndArrayBuffer(
212+
resData, res.tensorRef.dtype);
213+
expectArraysEqual(values, data);
214+
});
203215
});
204216

205217
describeWebGPU('backendWebGPU', () => {

0 commit comments

Comments
 (0)