From ffb1cdcd693673bdde8d6d153a484d61cd8bc14d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20R=C3=A4dle?= Date: Tue, 30 Aug 2022 13:40:07 -0700 Subject: [PATCH] Change isNumber check for tensor indexing API (#117) Summary: Pull Request resolved: https://github.com/facebookresearch/playtorch/pull/117 The tensor indexing API as implemented before was relying on exceptions when a prop name was not a number. This fails with RN 0.69.x on iOS with the following error: ``` Exception in HostObject::get for prop 'reshape': stoi: no conversion ``` This change implements an `isNumber` function using the `::isdigit` function Differential Revision: D39141720 fbshipit-source-id: f5e930acfd39285a91fea362116f5297ab72411a --- .../src/torchlive/torch/TensorHostObject.cpp | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp b/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp index c4a288fc0..3497a3c52 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp @@ -38,6 +38,17 @@ using namespace facebook; namespace { +/** + * Check if the input string is zero or a positive integer. + * + * @param s String to check for zero or positive integer. + * @return `true` if input string is zero or a positive integer, or `false` + * otherwise. + */ +bool isZeroOrPositiveInteger(const std::string& s) { + return !s.empty() && std::all_of(s.begin(), s.end(), ::isdigit); +} + jsi::Value absImpl( jsi::Runtime& runtime, const jsi::Value& thisValue, @@ -702,22 +713,20 @@ jsi::Value TensorHostObject::get( return jsi::Value(runtime, toString_); } - int idx = -1; - try { - idx = std::stoi(name.c_str()); - } catch (const std::exception& e) { - // Cannot parse name value to int. This can happen when the name in bracket - // or dot notion is not an int (e.g., tensor['foo']). - // Let's ignore this exception here since this function will return - // undefined if it reaches the function end. - } - // Check if index is within bounds of dimension 0 - if (idx >= 0 && idx < this->tensor.size(0)) { - auto outputTensor = this->tensor.index({idx}); - auto tensorHostObject = - std::make_shared( - runtime, std::move(outputTensor)); - return jsi::Object::createFromHostObject(runtime, tensorHostObject); + // Check if prop name is zero or a positive integer, and if so it will access + // the tensor via the tensor indexing API: + // + // https://pytorch.org/cppdocs/notes/tensor_indexing.html + if (isZeroOrPositiveInteger(name)) { + int idx = std::stoi(name.c_str()); + // Check if index is within bounds of dimension 0 + if (idx >= 0 && idx < this->tensor.size(0)) { + auto outputTensor = this->tensor.index({idx}); + auto tensorHostObject = + std::make_shared( + runtime, std::move(outputTensor)); + return jsi::Object::createFromHostObject(runtime, tensorHostObject); + } } return BaseHostObject::get(runtime, propNameId);