diff --git a/tensorflow/lite/micro/memory_helpers.cc b/tensorflow/lite/micro/memory_helpers.cc index d78e34d4d96..0cf6666c785 100644 --- a/tensorflow/lite/micro/memory_helpers.cc +++ b/tensorflow/lite/micro/memory_helpers.cc @@ -106,12 +106,20 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) { TfLiteStatus BytesRequiredForTensor(const tflite::Tensor& flatbuffer_tensor, size_t* bytes, size_t* type_size) { - int element_count = 1; + size_t element_count = 1; // If flatbuffer_tensor.shape == nullptr, then flatbuffer_tensor is a scalar // so has 1 element. if (flatbuffer_tensor.shape() != nullptr) { for (size_t n = 0; n < flatbuffer_tensor.shape()->size(); ++n) { - element_count *= flatbuffer_tensor.shape()->Get(n); + int dim = flatbuffer_tensor.shape()->Get(n); + if (dim <= 0) { + return kTfLiteError; + } + size_t prev = element_count; + element_count *= static_cast(dim); + if (element_count / static_cast(dim) != prev) { + return kTfLiteError; + } } } @@ -119,7 +127,11 @@ TfLiteStatus BytesRequiredForTensor(const tflite::Tensor& flatbuffer_tensor, TF_LITE_ENSURE_STATUS( ConvertTensorType(flatbuffer_tensor.type(), &tf_lite_type)); TF_LITE_ENSURE_STATUS(TfLiteTypeSizeOf(tf_lite_type, type_size)); + size_t prev = element_count; *bytes = element_count * (*type_size); + if (*type_size != 0 && *bytes / (*type_size) != prev) { + return kTfLiteError; + } return kTfLiteOk; } @@ -127,16 +139,28 @@ TfLiteStatus TfLiteEvalTensorByteLength(const TfLiteEvalTensor* eval_tensor, size_t* out_bytes) { TFLITE_DCHECK(out_bytes != nullptr); - int element_count = 1; + size_t element_count = 1; // If eval_tensor->dims == nullptr, then tensor is a scalar so has 1 element. if (eval_tensor->dims != nullptr) { for (int n = 0; n < eval_tensor->dims->size; ++n) { - element_count *= eval_tensor->dims->data[n]; + int dim = eval_tensor->dims->data[n]; + if (dim <= 0) { + return kTfLiteError; + } + size_t prev = element_count; + element_count *= static_cast(dim); + if (element_count / static_cast(dim) != prev) { + return kTfLiteError; + } } } size_t type_size; TF_LITE_ENSURE_STATUS(TfLiteTypeSizeOf(eval_tensor->type, &type_size)); + size_t prev = element_count; *out_bytes = element_count * type_size; + if (type_size != 0 && *out_bytes / type_size != prev) { + return kTfLiteError; + } return kTfLiteOk; }