Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions tensorflow/lite/core/api/flatbuffer_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1882,9 +1882,7 @@ TfLiteStatus ParseMul(const Operator* op, ErrorReporter* error_reporter,
params->activation =
ConvertActivation(schema_params->fused_activation_function());
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better understand the ramifications of changing the legacy behavior.
// Default activation is none.
}

*builtin_data = params.release();
Expand Down Expand Up @@ -2430,6 +2428,18 @@ TfLiteStatus ParseStablehloComposite(const Operator* op,
const StableHLOCompositeOptions* schema_params =
op->builtin_options_2_as_StableHLOCompositeOptions();
if (schema_params) {
if (schema_params->name() == nullptr) {
TF_LITE_REPORT_ERROR(
error_reporter,
"'stablehlo.composite' missing required option 'name'.");
return kTfLiteError;
}
if (schema_params->composite_attributes() == nullptr) {
TF_LITE_REPORT_ERROR(error_reporter,
"'stablehlo.composite' missing required option "
"'composite_attributes'.");
return kTfLiteError;
}
params->name = schema_params->name()->c_str();
params->version = schema_params->version();
params->subgraph_index = schema_params->decomposition_subgraph_index();
Expand Down
30 changes: 16 additions & 14 deletions tensorflow/lite/kernels/internal/reference/batch_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,13 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
}
}

template <typename T, typename AccumT>
template <typename lhsT, typename AccumT, typename rhsT = lhsT,
typename outputT = lhsT>
inline void BatchMatMul(const FullyConnectedParams& params,
const RuntimeShape& lhs_shape, const T* lhs_data,
const RuntimeShape& rhs_shape, const T* rhs_data,
const RuntimeShape& output_shape, T* output_data) {
const RuntimeShape& lhs_shape, const lhsT* lhs_data,
const RuntimeShape& rhs_shape, const rhsT* rhs_data,
const RuntimeShape& output_shape,
outputT* output_data) {
const RuntimeShape extended_lhs_shape =
RuntimeShape::ExtendedShape(5, lhs_shape);
const RuntimeShape extended_rhs_shape =
Expand Down Expand Up @@ -241,17 +243,17 @@ inline void BatchMatMul(const FullyConnectedParams& params,
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);

for (int b0 = 0; b0 < batch_dim0; ++b0) {
const T* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
const T* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
const lhsT* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
const rhsT* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
for (int b1 = 0; b1 < batch_dim1; ++b1) {
const T* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
const T* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
const lhsT* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
const rhsT* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
for (int b2 = 0; b2 < batch_dim2; ++b2) {
const T* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
const T* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
T* out_ptr = output_data +
((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) *
lhs_rows * rhs_cols;
const lhsT* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
const rhsT* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
outputT* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
b1 * batch_dim2 + b2) *
lhs_rows * rhs_cols;

for (int j = 0; j < rhs_cols; ++j) {
for (int i = 0; i < lhs_rows; ++i) {
Expand All @@ -267,7 +269,7 @@ inline void BatchMatMul(const FullyConnectedParams& params,
total_scaled = std::max(total_scaled, output_activation_min);
total_scaled = std::min(total_scaled, output_activation_max);
const int idx = lhs_rows * j + i;
out_ptr[idx] = static_cast<T>(total_scaled);
out_ptr[idx] = static_cast<outputT>(total_scaled);
}
}
}
Expand Down
29 changes: 20 additions & 9 deletions tensorflow/lite/kernels/internal/reference/leaky_relu.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,31 @@ inline void QuantizeLeakyRelu(const LeakyReluParams& params,
const int flat_size = MatchingFlatSize(input_shape, output_shape);
static const int32_t quantized_min = std::numeric_limits<T>::min();
static const int32_t quantized_max = std::numeric_limits<T>::max();

// Extract the sign and create a safely positive multiplier outside the loop.
// This supports negative alpha values (matching float execution behavior)
// while preventing assertion failures, as MultiplyByQuantizedMultiplier
// strictly requires a non-negative multiplier.
const bool is_alpha_negative = params.output_multiplier_alpha < 0;
const int32_t safe_alpha_multiplier = is_alpha_negative
? -params.output_multiplier_alpha
: params.output_multiplier_alpha;

for (int i = 0; i < flat_size; ++i) {
const int32_t input_value = input_data[i] - params.input_offset;
int32_t unclamped_output;

int32_t unclamped_output = params.output_offset;
if (input_value >= 0) {
unclamped_output = params.output_offset +
MultiplyByQuantizedMultiplier(
input_value, params.output_multiplier_identity,
params.output_shift_identity);
unclamped_output += MultiplyByQuantizedMultiplier(
input_value, params.output_multiplier_identity,
params.output_shift_identity);
} else {
unclamped_output = params.output_offset +
MultiplyByQuantizedMultiplier(
input_value, params.output_multiplier_alpha,
params.output_shift_alpha);
int32_t scaled_alpha_value = MultiplyByQuantizedMultiplier(
input_value, safe_alpha_multiplier, params.output_shift_alpha);
unclamped_output +=
is_alpha_negative ? -scaled_alpha_value : scaled_alpha_value;
}

const T clamped_output =
std::min(quantized_max, std::max(quantized_min, unclamped_output));
output_data[i] = static_cast<T>(clamped_output);
Expand Down
59 changes: 29 additions & 30 deletions tensorflow/lite/kernels/kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,51 +528,50 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,

// Size of string is not constant, return 0 in such case.
int TfLiteTypeGetSize(TfLiteType type) {
int size_bits = TfLiteTypeGetSizeBits(type);
if (size_bits % 8 == 0) {
return size_bits / 8;
} else {
// For non-byte sized types, return 0.
return 0;
}
}

int TfLiteTypeGetSizeBits(TfLiteType type) {
switch (type) {
case kTfLiteInt2:
return 2;
case kTfLiteInt4:
case kTfLiteUInt4:
return 4;
case kTfLiteUInt8:
static_assert(sizeof(uint8_t) == 1, "");
return 1;
case kTfLiteInt8:
static_assert(sizeof(int8_t) == 1, "");
return 1;
case kTfLiteBool:
return sizeof(bool);
return 8;
case kTfLiteUInt16:
static_assert(sizeof(uint16_t) == 2, "");
return 2;
case kTfLiteInt16:
static_assert(sizeof(int16_t) == 2, "");
return 2;
case kTfLiteFloat16:
static_assert(sizeof(int16_t) == 2, "");
return 2;
case kTfLiteBFloat16:
return 16;
case kTfLiteFloat32:
static_assert(sizeof(float) == 4, "");
return 4;
case kTfLiteInt32:
static_assert(sizeof(int32_t) == 4, "");
return 4;
case kTfLiteUInt32:
static_assert(sizeof(uint32_t) == 4, "");
return 4;
return 32;
case kTfLiteInt64:
static_assert(sizeof(int64_t) == 8, "");
return 8;
case kTfLiteUInt64:
static_assert(sizeof(uint64_t) == 8, "");
return 8;
case kTfLiteFloat64:
static_assert(sizeof(double) == 8, "");
return 8;
case kTfLiteComplex64:
static_assert(sizeof(std::complex<float>) == 8, "");
return 8;
return 64;
case kTfLiteComplex128:
static_assert(sizeof(std::complex<double>) == 16, "");
return 16;
default:
return 0;
return 128;
case kTfLiteBool:
return sizeof(bool) * 8;
case kTfLiteString:
case kTfLiteNoType:
case kTfLiteResource:
case kTfLiteVariant:
break;
}
return 0;
}

bool IsMobilePlatform() {
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/lite/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
// Return the size of given type in bytes. Return 0 in case of string.
int TfLiteTypeGetSize(TfLiteType type);

// Return the size of given type in bits. Returns 0 in case of string.
int TfLiteTypeGetSizeBits(TfLiteType type);

// Whether the current platform is mobile (Android or iOS).
bool IsMobilePlatform();

Expand Down
Loading