diff --git a/tensorflow/lite/kernels/internal/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/portable_tensor_utils.cc index efc6ba5a9c0..05986f3bf49 100644 --- a/tensorflow/lite/kernels/internal/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/portable_tensor_utils.cc @@ -94,7 +94,8 @@ void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements, } void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements, - int bit_width, int8_t* dst_buffer) { + int bit_width, int8_t* dst_buffer, + bool unpack_unsigned) { assert(bit_width == 2 || bit_width == 4); if (bit_width == 4) { // num_elements means the number of elements regardless of packed or @@ -105,30 +106,44 @@ void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements, //. stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2) for (int i = 0; i < num_elements / 2; i++) { int8_t byte = src_buffer[i]; - // Shift left first so that sign is properly extended when shifted right - int8_t lower = static_cast(byte << 4) >> 4; - int8_t higher = byte >> 4; + int8_t lower, higher; + if (unpack_unsigned) { + lower = byte & 0x0F; + higher = (byte >> 4) & 0x0F; + } else { + // Shift left first so that sign is properly extended when shifted right + lower = static_cast(byte << 4) >> 4; + higher = byte >> 4; + } dst_buffer[2 * i] = lower; dst_buffer[2 * i + 1] = higher; } // If the buffer size is odd, extract the final lower nibble. if (num_elements % 2 != 0) { + int8_t byte = src_buffer[num_elements / 2]; dst_buffer[num_elements - 1] = - static_cast(src_buffer[num_elements / 2] << 4) >> 4; + unpack_unsigned ? (byte & 0x0F) : static_cast(byte << 4) >> 4; } } else if (bit_width == 2) { for (int i = 0; i < num_elements / 4; i++) { int8_t byte = src_buffer[i]; - // Shift left first so that sign is properly extended when shifted right - int8_t val1 = static_cast(byte << 6) >> 6; - int8_t val2 = static_cast((byte << 4) & 0xFF) >> 6; - int8_t val3 = static_cast((byte << 2) & 0xFF) >> 6; - int8_t val4 = byte >> 6; - dst_buffer[4 * i] = val1; - dst_buffer[4 * i + 1] = val2; - dst_buffer[4 * i + 2] = val3; - dst_buffer[4 * i + 3] = val4; + if (unpack_unsigned) { + dst_buffer[4 * i] = byte & 0x03; + dst_buffer[4 * i + 1] = (byte >> 2) & 0x03; + dst_buffer[4 * i + 2] = (byte >> 4) & 0x03; + dst_buffer[4 * i + 3] = (byte >> 6) & 0x03; + } else { + // Shift left first so that sign is properly extended when shifted right + int8_t val1 = static_cast(byte << 6) >> 6; + int8_t val2 = static_cast((byte << 4) & 0xFF) >> 6; + int8_t val3 = static_cast((byte << 2) & 0xFF) >> 6; + int8_t val4 = byte >> 6; + dst_buffer[4 * i] = val1; + dst_buffer[4 * i + 1] = val2; + dst_buffer[4 * i + 2] = val3; + dst_buffer[4 * i + 3] = val4; + } } // Handle the remaining elements. @@ -136,8 +151,13 @@ void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements, if (remaining_elements > 0) { int8_t byte = src_buffer[num_elements / 4]; for (int i = 0; i < remaining_elements; i++) { - dst_buffer[num_elements - remaining_elements + i] = - static_cast((byte << (6 - 2 * i)) & 0xFF) >> 6; + if (unpack_unsigned) { + dst_buffer[num_elements - remaining_elements + i] = + (byte >> (2 * i)) & 0x03; + } else { + dst_buffer[num_elements - remaining_elements + i] = + static_cast((byte << (6 - 2 * i)) & 0xFF) >> 6; + } } } } diff --git a/tensorflow/lite/kernels/internal/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/portable_tensor_utils.h index c70ac94db5f..60a17c6971d 100644 --- a/tensorflow/lite/kernels/internal/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/portable_tensor_utils.h @@ -633,7 +633,8 @@ void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements, // For 2-bit unpacking: e.g., `src_buffer = {0x12};` (num_elements = 4) // will return `dst_buffer = {0x02, 0x00, 0x01, 0x00}` (sign extended). void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements, - int bit_width, int8_t* dst_buffer); + int bit_width, int8_t* dst_buffer, + bool unpack_unsigned = false); // Pack `src_buffer` into a densely packed buffer of int2 or int4 values. // Parameters: