diff --git a/runtime/core/evalue.cpp b/runtime/core/evalue.cpp index 121a9a29fa2..6fd118dadd0 100644 --- a/runtime/core/evalue.cpp +++ b/runtime/core/evalue.cpp @@ -10,6 +10,10 @@ namespace executorch { namespace runtime { + +// Specialize for list of optional tensors, as nullptr is a valid std::nullopt. +// For non-optional types, nullptr is invalid. + template <> executorch::aten::ArrayRef> BoxedEvalueList>::get() const { @@ -27,5 +31,26 @@ BoxedEvalueList>::get() const { return executorch::aten::ArrayRef>{ unwrapped_vals_, wrapped_vals_.size()}; } + +template <> +Result>> +BoxedEvalueList>::tryGet() const { + for (typename executorch::aten::ArrayRef< + std::optional>::size_type i = 0; + i < wrapped_vals_.size(); + i++) { + if (wrapped_vals_[i] == nullptr) { + unwrapped_vals_[i] = std::nullopt; + continue; + } + auto r = wrapped_vals_[i]->tryToOptional(); + if (!r.ok()) { + return r.error(); + } + unwrapped_vals_[i] = std::move(r.get()); + } + return executorch::aten::ArrayRef>{ + unwrapped_vals_, wrapped_vals_.size()}; +} } // namespace runtime } // namespace executorch diff --git a/runtime/core/evalue.h b/runtime/core/evalue.h index 8d75b1ace97..eed52bb74f7 100644 --- a/runtime/core/evalue.h +++ b/runtime/core/evalue.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -71,6 +72,16 @@ class BoxedEvalueList { */ executorch::aten::ArrayRef get() const; + /** + * Result-returning counterpart of get(). Validates each wrapped EValue's + * tag before materializing; returns Error::InvalidType if any element's + * tag does not match T and Error::InvalidState if any element pointer is + * null. Use this when materializing lists from untrusted .pte data so that + * a malformed program cannot force a process abort inside to() / + * ET_CHECK. + */ + Result> tryGet() const; + /** * Destroys the unwrapped elements without re-dereferencing wrapped_vals_. * This is safe to call during EValue destruction because it does not @@ -107,6 +118,10 @@ template <> executorch::aten::ArrayRef> BoxedEvalueList>::get() const; +template <> +Result>> +BoxedEvalueList>::tryGet() const; + // Aggregate typing system similar to IValue only slimmed down with less // functionality, no dependencies on atomic, and fewer supported types to better // suit embedded systems (ie no intrusive ptr) @@ -193,6 +208,13 @@ struct EValue { return payload.copyable_union.as_int; } + Result tryToInt() const { + if (!isInt()) { + return Error::InvalidType; + } + return payload.copyable_union.as_int; + } + /****** Double Type ******/ /*implicit*/ EValue(double d) : tag(Tag::Double) { payload.copyable_union.as_double = d; @@ -207,6 +229,13 @@ struct EValue { return payload.copyable_union.as_double; } + Result tryToDouble() const { + if (!isDouble()) { + return Error::InvalidType; + } + return payload.copyable_union.as_double; + } + /****** Bool Type ******/ /*implicit*/ EValue(bool b) : tag(Tag::Bool) { payload.copyable_union.as_bool = b; @@ -221,6 +250,13 @@ struct EValue { return payload.copyable_union.as_bool; } + Result tryToBool() const { + if (!isBool()) { + return Error::InvalidType; + } + return payload.copyable_union.as_bool; + } + /****** Scalar Type ******/ /// Construct an EValue using the implicit value of a Scalar. /*implicit*/ EValue(executorch::aten::Scalar s) { @@ -256,6 +292,19 @@ struct EValue { } } + Result tryToScalar() const { + if (isDouble()) { + return executorch::aten::Scalar(payload.copyable_union.as_double); + } + if (isInt()) { + return executorch::aten::Scalar(payload.copyable_union.as_int); + } + if (isBool()) { + return executorch::aten::Scalar(payload.copyable_union.as_bool); + } + return Error::InvalidType; + } + /****** Tensor Type ******/ /*implicit*/ EValue(executorch::aten::Tensor t) : tag(Tag::Tensor) { // When built in aten mode, at::Tensor has a non trivial constructor @@ -305,6 +354,16 @@ struct EValue { return payload.as_tensor; } + // Returns a copy of the Tensor handle (one intrusive_ptr refcount bump in + // ATen mode; free in lean mode). Unlike toTensor()'s const& / & overloads, + // tryToTensor() cannot return a reference — Result wraps by value. + Result tryToTensor() const { + if (!isTensor()) { + return Error::InvalidType; + } + return payload.as_tensor; + } + /****** String Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* s) : tag(Tag::String) { ET_CHECK_MSG(s != nullptr, "ArrayRef pointer cannot be null"); @@ -325,6 +384,18 @@ struct EValue { payload.copyable_union.as_string_ptr->size()); } + Result tryToString() const { + if (!isString()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_string_ptr == nullptr) { + return Error::InvalidState; + } + return std::string_view( + payload.copyable_union.as_string_ptr->data(), + payload.copyable_union.as_string_ptr->size()); + } + /****** Int List Type ******/ /*implicit*/ EValue(BoxedEvalueList* i) : tag(Tag::ListInt) { ET_CHECK_MSG( @@ -344,6 +415,16 @@ struct EValue { return (payload.copyable_union.as_int_list_ptr)->get(); } + Result> tryToIntList() const { + if (!isIntList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_int_list_ptr == nullptr) { + return Error::InvalidState; + } + return (payload.copyable_union.as_int_list_ptr)->tryGet(); + } + /****** Bool List Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* b) : tag(Tag::ListBool) { @@ -363,6 +444,16 @@ struct EValue { return *(payload.copyable_union.as_bool_list_ptr); } + Result> tryToBoolList() const { + if (!isBoolList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_bool_list_ptr == nullptr) { + return Error::InvalidState; + } + return *(payload.copyable_union.as_bool_list_ptr); + } + /****** Double List Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* d) : tag(Tag::ListDouble) { @@ -382,6 +473,16 @@ struct EValue { return *(payload.copyable_union.as_double_list_ptr); } + Result> tryToDoubleList() const { + if (!isDoubleList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_double_list_ptr == nullptr) { + return Error::InvalidState; + } + return *(payload.copyable_union.as_double_list_ptr); + } + /****** Tensor List Type ******/ /*implicit*/ EValue(BoxedEvalueList* t) : tag(Tag::ListTensor) { @@ -402,6 +503,17 @@ struct EValue { return payload.copyable_union.as_tensor_list_ptr->get(); } + Result> tryToTensorList() + const { + if (!isTensorList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_tensor_list_ptr == nullptr) { + return Error::InvalidState; + } + return payload.copyable_union.as_tensor_list_ptr->tryGet(); + } + /****** List Optional Tensor Type ******/ /*implicit*/ EValue( BoxedEvalueList>* t) @@ -426,6 +538,17 @@ struct EValue { return payload.copyable_union.as_list_optional_tensor_ptr->get(); } + Result>> + tryToListOptionalTensor() const { + if (!isListOptionalTensor()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_list_optional_tensor_ptr == nullptr) { + return Error::InvalidState; + } + return payload.copyable_union.as_list_optional_tensor_ptr->tryGet(); + } + /****** ScalarType Type ******/ executorch::aten::ScalarType toScalarType() const { ET_CHECK_MSG(isInt(), "EValue is not a ScalarType."); @@ -433,6 +556,14 @@ struct EValue { payload.copyable_union.as_int); } + Result tryToScalarType() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast( + payload.copyable_union.as_int); + } + /****** MemoryFormat Type ******/ executorch::aten::MemoryFormat toMemoryFormat() const { ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat."); @@ -440,12 +571,27 @@ struct EValue { payload.copyable_union.as_int); } + Result tryToMemoryFormat() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast( + payload.copyable_union.as_int); + } + /****** Layout Type ******/ executorch::aten::Layout toLayout() const { ET_CHECK_MSG(isInt(), "EValue is not a Layout."); return static_cast(payload.copyable_union.as_int); } + Result tryToLayout() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast(payload.copyable_union.as_int); + } + /****** Device Type ******/ executorch::aten::Device toDevice() const { ET_CHECK_MSG(isInt(), "EValue is not a Device."); @@ -455,6 +601,16 @@ struct EValue { -1); } + Result tryToDevice() const { + if (!isInt()) { + return Error::InvalidType; + } + return executorch::aten::Device( + static_cast( + payload.copyable_union.as_int), + -1); + } + template T to() &&; template @@ -462,6 +618,15 @@ struct EValue { template typename internal::evalue_to_ref_overload_return::type to() &; + /** + * Result-returning equivalent of `to()`. Tag mismatch returns + * `Error::InvalidType`; a null list/string payload returns + * `Error::InvalidState`. Specializations are defined below via + * `EVALUE_DEFINE_TRY_TO`. + */ + template + Result tryTo() const; + /** * Converts the EValue to an optional object that can represent both T and * an uninitialized state. @@ -474,6 +639,23 @@ struct EValue { return this->to(); } + /** + * Result-returning equivalent of `toOptional()`. None maps to an empty + * optional; any other tag that doesn't match T propagates `tryTo()`'s + * error (`Error::InvalidType`). + */ + template + inline Result> tryToOptional() const { + if (this->isNone()) { + return std::optional(std::nullopt); + } + auto r = this->tryTo(); + if (!r.ok()) { + return r.error(); + } + return std::optional(std::move(r.get())); + } + private: // Pre cond: the payload value has had its destructor called void clearToNone() noexcept { @@ -591,6 +773,59 @@ EVALUE_DEFINE_TO( toListOptionalTensor) #undef EVALUE_DEFINE_TO +#define EVALUE_DEFINE_TRY_TO(T, method_name) \ + template <> \ + inline Result EValue::tryTo() const { \ + return this->method_name(); \ + } + +EVALUE_DEFINE_TRY_TO(executorch::aten::Scalar, tryToScalar) +EVALUE_DEFINE_TRY_TO(int64_t, tryToInt) +EVALUE_DEFINE_TRY_TO(bool, tryToBool) +EVALUE_DEFINE_TRY_TO(double, tryToDouble) +EVALUE_DEFINE_TRY_TO(std::string_view, tryToString) +EVALUE_DEFINE_TRY_TO(executorch::aten::ScalarType, tryToScalarType) +EVALUE_DEFINE_TRY_TO(executorch::aten::MemoryFormat, tryToMemoryFormat) +EVALUE_DEFINE_TRY_TO(executorch::aten::Layout, tryToLayout) +EVALUE_DEFINE_TRY_TO(executorch::aten::Device, tryToDevice) +// Tensor and Optional Tensor +EVALUE_DEFINE_TRY_TO(executorch::aten::Tensor, tryToTensor) +EVALUE_DEFINE_TRY_TO( + std::optional, + tryToOptional) + +// IntList and Optional IntList +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToIntList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// DoubleList and Optional DoubleList +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToDoubleList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// BoolList and Optional BoolList +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToBoolList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// TensorList and Optional TensorList +EVALUE_DEFINE_TRY_TO( + executorch::aten::ArrayRef, + tryToTensorList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// List of Optional Tensor +EVALUE_DEFINE_TRY_TO( + executorch::aten::ArrayRef>, + tryToListOptionalTensor) +#undef EVALUE_DEFINE_TRY_TO + template executorch::aten::ArrayRef BoxedEvalueList::get() const { for (typename executorch::aten::ArrayRef::size_type i = 0; @@ -602,6 +837,23 @@ executorch::aten::ArrayRef BoxedEvalueList::get() const { return executorch::aten::ArrayRef{unwrapped_vals_, wrapped_vals_.size()}; } +template +Result> BoxedEvalueList::tryGet() const { + for (typename executorch::aten::ArrayRef::size_type i = 0; + i < wrapped_vals_.size(); + i++) { + if (wrapped_vals_[i] == nullptr) { + return Error::InvalidState; + } + auto r = wrapped_vals_[i]->template tryTo(); + if (!r.ok()) { + return r.error(); + } + unwrapped_vals_[i] = std::move(r.get()); + } + return executorch::aten::ArrayRef{unwrapped_vals_, wrapped_vals_.size()}; +} + } // namespace runtime } // namespace executorch diff --git a/runtime/core/test/evalue_test.cpp b/runtime/core/test/evalue_test.cpp index edf6a1b12c1..1b0b86c1392 100644 --- a/runtime/core/test/evalue_test.cpp +++ b/runtime/core/test/evalue_test.cpp @@ -16,8 +16,12 @@ using namespace ::testing; +using executorch::aten::DeviceType; +using executorch::aten::Layout; +using executorch::aten::MemoryFormat; using executorch::aten::ScalarType; using executorch::runtime::BoxedEvalueList; +using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::Tag; using executorch::runtime::testing::TensorFactory; @@ -214,6 +218,56 @@ TEST_F(EValueTest, BoxedEvalueList) { EXPECT_EQ(unwrapped[2], 3); } +TEST_F(EValueTest, BoxedEvalueListTryGetSuccess) { + EValue values[3] = { + EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)}; + EValue* values_p[3] = {&values[0], &values[1], &values[2]}; + int64_t storage[3] = {0, 0, 0}; + BoxedEvalueList x{values_p, storage, 3}; + auto result = x.tryGet(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result->size(), 3); + EXPECT_EQ((*result)[0], 1); + EXPECT_EQ((*result)[1], 2); + EXPECT_EQ((*result)[2], 3); +} + +TEST_F(EValueTest, BoxedEvalueListTryGetWrongElementTag) { + // Second element is a Double, not an Int; tryGet should reject it rather + // than abort inside to(). + EValue values[3] = {EValue((int64_t)1), EValue(3.14), EValue((int64_t)3)}; + EValue* values_p[3] = {&values[0], &values[1], &values[2]}; + int64_t storage[3] = {0, 0, 0}; + BoxedEvalueList x{values_p, storage, 3}; + auto result = x.tryGet(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, BoxedEvalueListTryGetNullElement) { + // A null value is a malformed program for non-optional lists. + EValue a((int64_t)1); + EValue c((int64_t)3); + EValue* values_p[3] = {&a, nullptr, &c}; + int64_t storage[3] = {0, 0, 0}; + BoxedEvalueList x{values_p, storage, 3}; + auto result = x.tryGet(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidState); +} + +TEST_F(EValueTest, BoxedEvalueListTryGetOptionalTensorNullIsNone) { + // For optional, null value is valid. + EValue a; + EValue* values_p[2] = {&a, nullptr}; + std::optional storage[2]; + BoxedEvalueList> x{ + values_p, storage, 2}; + auto result = x.tryGet(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result->size(), 2); + EXPECT_FALSE((*result)[0].has_value()); + EXPECT_FALSE((*result)[1].has_value()); +} + TEST_F(EValueTest, toOptionalTensorList) { // create list, empty evalue ctor gets tag::None EValue values[2] = {EValue(), EValue()}; @@ -417,3 +471,116 @@ TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) { EXPECT_TRUE(e.isListOptionalTensor()); ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, "pointer is null"); } + +// Per-type tryTo* coverage. +// For each type: +// - success and failure for named method tryTo[Int/Double/Bool/Tensor/..] +// - success and failure for templated tryTo() specialization + +TEST_F(EValueTest, TryToInt) { + EValue e_int(static_cast(42)); + EValue e_mismatch(3.14); + EXPECT_EQ(e_int.tryToInt().get(), 42); + EXPECT_EQ(e_mismatch.tryToInt().error(), Error::InvalidType); + EXPECT_EQ(e_int.tryTo().get(), 42); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToDouble) { + EValue e_double(3.14); + EValue e_mismatch(static_cast(42)); + EXPECT_DOUBLE_EQ(e_double.tryToDouble().get(), 3.14); + EXPECT_EQ(e_mismatch.tryToDouble().error(), Error::InvalidType); + EXPECT_DOUBLE_EQ(e_double.tryTo().get(), 3.14); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToBool) { + EValue e_bool(true); + EValue e_mismatch(static_cast(42)); + EXPECT_EQ(e_bool.tryToBool().get(), true); + EXPECT_EQ(e_mismatch.tryToBool().error(), Error::InvalidType); + EXPECT_EQ(e_bool.tryTo().get(), true); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToTensor) { + TensorFactory tf; + EValue e_tensor(tf.ones({3, 2})); + EValue e_mismatch(static_cast(42)); + EXPECT_EQ(e_tensor.tryToTensor()->numel(), 6); + EXPECT_EQ(e_mismatch.tryToTensor().error(), Error::InvalidType); + EXPECT_EQ(e_tensor.tryTo()->numel(), 6); + EXPECT_EQ( + e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToOptionalTensor) { + TensorFactory tf; + EValue e_tensor(tf.ones({3, 2})); + EValue e_none; + EValue e_mismatch(static_cast(42)); + // Named tryToOptional: value, None, mismatch. + auto r_val = e_tensor.tryToOptional(); + EXPECT_TRUE(r_val->has_value()); + EXPECT_EQ(r_val->value().numel(), 6); + EXPECT_FALSE(e_none.tryToOptional()->has_value()); + EXPECT_EQ( + e_mismatch.tryToOptional().error(), + Error::InvalidType); + // Templated tryTo>: None path. + EXPECT_FALSE( + e_none.tryTo>()->has_value()); +} + +TEST_F(EValueTest, TryToScalar) { + EValue e_int(static_cast(7)); + EValue e_double(2.5); + EValue e_bool(true); + EValue e_none; + EXPECT_EQ(e_int.tryToScalar()->to(), 7); + EXPECT_DOUBLE_EQ(e_double.tryToScalar()->to(), 2.5); + EXPECT_EQ(e_bool.tryToScalar()->to(), true); + // None is neither Int/Double/Bool. + EXPECT_EQ(e_none.tryToScalar().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToScalarType) { + EValue e(static_cast(ScalarType::Float)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToScalarType().get(), ScalarType::Float); + EXPECT_EQ(e_mismatch.tryToScalarType().error(), Error::InvalidType); + EXPECT_EQ(e.tryTo().get(), ScalarType::Float); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToMemoryFormat) { + EValue e(static_cast(MemoryFormat::Contiguous)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToMemoryFormat().get(), MemoryFormat::Contiguous); + EXPECT_EQ(e_mismatch.tryToMemoryFormat().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToLayout) { + EValue e(static_cast(Layout::Strided)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToLayout().get(), Layout::Strided); + EXPECT_EQ(e_mismatch.tryToLayout().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToDevice) { + EValue e(static_cast(DeviceType::CPU)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToDevice().get().type(), DeviceType::CPU); + EXPECT_EQ(e_mismatch.tryToDevice().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToTensorList) { + EValue e(static_cast(42)); + EXPECT_EQ(e.tryToTensorList().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToListOptionalTensor) { + EValue e(static_cast(42)); + EXPECT_EQ(e.tryToListOptionalTensor().error(), Error::InvalidType); +}