Skip to content
This repository was archived by the owner on Oct 23, 2023. It is now read-only.

Commit 68d976f

Browse files
committed
Tensor Indexing API for set tensor (#174)
Summary: Pull Request resolved: #174 Implement Tensor Indexing API for inplace put `index_put_`, which will allow the following notions: ```javascript const tensor = torch.zeros([2]); tensor[0] = 1; tensor[1] = torch.tensor([2]); ``` re #172 Reviewed By: justinhaaheim Differential Revision: D41746329 fbshipit-source-id: b49a207d4c70f551eb1288bea68e09671c389ae4
1 parent d8c37c3 commit 68d976f

File tree

3 files changed

+84
-5
lines changed

3 files changed

+84
-5
lines changed

react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,39 @@ jsi::Value TensorHostObject::get(
737737
return BaseHostObject::get(runtime, propNameId);
738738
}
739739

740+
void TensorHostObject::set(
741+
jsi::Runtime& runtime,
742+
const jsi::PropNameID& propNameId,
743+
const jsi::Value& value) {
744+
auto name = propNameId.utf8(runtime);
745+
746+
int idx = -1;
747+
try {
748+
idx = std::stoi(name.c_str());
749+
} catch (...) {
750+
// Cannot parse name value to int. This can happen when the name in bracket
751+
// or dot notion is not an int (e.g., tensor['foo']).
752+
// Let's ignore this exception here and have the PyTorch C++ API throw an
753+
// error for index out of bounds.
754+
// Note: The Tensor Indexing API allows for a much broader range of indices
755+
// but for now, the PlayTorch API only supports single value index values.
756+
}
757+
if (value.isObject()) {
758+
// Get TensorHostObject with wrapped tensor, otherwise it will be nullptr
759+
auto tensorHostObject =
760+
value.asObject(runtime).asHostObject<TensorHostObject>(runtime);
761+
if (tensorHostObject != nullptr) {
762+
this->tensor.index_put_({idx}, tensorHostObject->tensor);
763+
}
764+
} else if (value.isNumber()) {
765+
this->tensor.index_put_({idx}, value.asNumber());
766+
} else {
767+
throw jsi::JSError(
768+
runtime,
769+
"Invalid value! The value has to be of type tensor or a number");
770+
}
771+
}
772+
740773
jsi::Function TensorHostObject::createToString(jsi::Runtime& runtime) {
741774
auto toStringFunc = [this](
742775
jsi::Runtime& runtime,

react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class JSI_EXPORT TensorHostObject : public common::BaseHostObject {
3838
facebook::jsi::Value get(
3939
facebook::jsi::Runtime&,
4040
const facebook::jsi::PropNameID& name) override;
41+
void set(
42+
facebook::jsi::Runtime&,
43+
const facebook::jsi::PropNameID& name,
44+
const facebook::jsi::Value& value) override;
4145
std::vector<facebook::jsi::PropNameID> getPropertyNames(
4246
facebook::jsi::Runtime& rt) override;
4347

react-native-pytorch-core/cxx/test/TensorTests.cpp

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -360,14 +360,14 @@ TEST_F(TorchliveTensorRuntimeTest, TensorIndexing) {
360360
)";
361361
EXPECT_TRUE(eval(tensorAccessWithIndex).getBool());
362362

363-
std::string nestedTensorAcessWithIndex =
363+
std::string nestedTensorAccessWithIndex =
364364
R"(
365365
const tensor = torch.tensor([[128], [255]]);
366-
const tensor1 = tensor[0];
367-
const tensor2 = tensor[1];
368-
tensor1[0].item() == 128 && tensor2[0].item() == 255;
366+
const tensor0 = tensor[0];
367+
const tensor1 = tensor[1];
368+
tensor0[0].item() == 128 && tensor1[0].item() == 255;
369369
)";
370-
EXPECT_TRUE(eval(nestedTensorAcessWithIndex).getBool());
370+
EXPECT_TRUE(eval(nestedTensorAccessWithIndex).getBool());
371371

372372
EXPECT_TRUE(eval("torch.tensor([[128], [255]])['foo']").isUndefined());
373373

@@ -376,6 +376,48 @@ TEST_F(TorchliveTensorRuntimeTest, TensorIndexing) {
376376
EXPECT_TRUE(eval("torch.tensor([[128], [255]])[2]").isUndefined());
377377
}
378378

379+
TEST_F(TorchliveTensorRuntimeTest, TensorIndexingPut) {
380+
std::string tensorPutWithIndex =
381+
R"(
382+
const tensor = torch.zeros([3]);
383+
tensor[0] = torch.tensor([1]);
384+
tensor[1] = torch.tensor([2]);
385+
tensor[2] = torch.tensor([3]);
386+
tensor[0].item() === 1 && tensor[1].item() === 2 && tensor[2].item() === 3;
387+
)";
388+
EXPECT_TRUE(eval(tensorPutWithIndex).getBool());
389+
390+
std::string tensorPutWithIndexAndNumberValue =
391+
R"(
392+
const tensor = torch.zeros([3]);
393+
tensor[0] = 1;
394+
tensor[1] = 2;
395+
tensor[2] = 3;
396+
tensor[0].item() === 1 && tensor[1].item() === 2 && tensor[2].item() === 3;
397+
)";
398+
EXPECT_TRUE(eval(tensorPutWithIndexAndNumberValue).getBool());
399+
400+
std::string nestedTensorPutWithIndex =
401+
R"(
402+
const tensor = torch.tensor([[128], [0]]);
403+
tensor[1] = torch.tensor([[255]]);
404+
const tensor0 = tensor[0];
405+
const tensor1 = tensor[1];
406+
tensor0[0].item() == 128 && tensor1[0].item() == 255;
407+
)";
408+
EXPECT_TRUE(eval(nestedTensorPutWithIndex).getBool());
409+
410+
EXPECT_THROW(
411+
eval("torch.tensor([[128], [255]])['foo'] = 'bar'"),
412+
facebook::jsi::JSError);
413+
414+
EXPECT_THROW(
415+
eval("torch.tensor([[128], [255]])[-1] = 'bar'"), facebook::jsi::JSError);
416+
417+
EXPECT_THROW(
418+
eval("torch.tensor([[128], [255]])[2] = 'bar'"), facebook::jsi::JSError);
419+
}
420+
379421
TEST_F(TorchliveTensorRuntimeTest, TensorDivTest) {
380422
std::string tensorDivWithNumber =
381423
R"(

0 commit comments

Comments
 (0)