Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions tensorflow/lite/micro/memory_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,37 +106,61 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {

TfLiteStatus BytesRequiredForTensor(const tflite::Tensor& flatbuffer_tensor,
size_t* bytes, size_t* type_size) {
int element_count = 1;
size_t element_count = 1;
// If flatbuffer_tensor.shape == nullptr, then flatbuffer_tensor is a scalar
// so has 1 element.
if (flatbuffer_tensor.shape() != nullptr) {
for (size_t n = 0; n < flatbuffer_tensor.shape()->size(); ++n) {
element_count *= flatbuffer_tensor.shape()->Get(n);
int dim = flatbuffer_tensor.shape()->Get(n);
if (dim <= 0) {
return kTfLiteError;
}
size_t prev = element_count;
element_count *= static_cast<size_t>(dim);
if (element_count / static_cast<size_t>(dim) != prev) {
return kTfLiteError;
}
}
}

TfLiteType tf_lite_type;
TF_LITE_ENSURE_STATUS(
ConvertTensorType(flatbuffer_tensor.type(), &tf_lite_type));
TF_LITE_ENSURE_STATUS(TfLiteTypeSizeOf(tf_lite_type, type_size));
size_t prev = element_count;
*bytes = element_count * (*type_size);
if (*type_size != 0 && *bytes / (*type_size) != prev) {
return kTfLiteError;
}
return kTfLiteOk;
}

TfLiteStatus TfLiteEvalTensorByteLength(const TfLiteEvalTensor* eval_tensor,
size_t* out_bytes) {
TFLITE_DCHECK(out_bytes != nullptr);

int element_count = 1;
size_t element_count = 1;
// If eval_tensor->dims == nullptr, then tensor is a scalar so has 1 element.
if (eval_tensor->dims != nullptr) {
for (int n = 0; n < eval_tensor->dims->size; ++n) {
element_count *= eval_tensor->dims->data[n];
int dim = eval_tensor->dims->data[n];
if (dim <= 0) {
return kTfLiteError;
}
size_t prev = element_count;
element_count *= static_cast<size_t>(dim);
if (element_count / static_cast<size_t>(dim) != prev) {
return kTfLiteError;
}
}
}
size_t type_size;
TF_LITE_ENSURE_STATUS(TfLiteTypeSizeOf(eval_tensor->type, &type_size));
size_t prev = element_count;
*out_bytes = element_count * type_size;
if (type_size != 0 && *out_bytes / type_size != prev) {
return kTfLiteError;
}
return kTfLiteOk;
}

Expand Down
Loading