Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ build --workspace_status_command=./tools/workspace_status.sh
build --noenable_bzlmod

# Use the following C++ standard
build --cxxopt -std=c++17
build:windows --cxxopt=/std:c++17

build --cxxopt=-std=c++17
build:windows --cxxopt=/std:c++17

# Common options for --config=ci
Expand All @@ -46,6 +44,7 @@ build:ci --verbose_failures
build:ci --test_output=errors

# Windows CI options
build:windows_ci --config=windows
build:windows_ci --curses=no
build:windows_ci --color=no
build:windows_ci --noshow_progress
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/build_def.bzl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def tflite_copts():
"""Defines common compile time flags for TFLite libraries."""
return select({
"@bazel_tools//src/conditions:windows_msvc": [
"@bazel_tools//src/conditions:windows": [
"/DFARMHASH_NO_CXX_STRING",
"/EHs-", # -fno-exceptions
"/GR-",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/micro/build_def.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def tflm_copts():
be useful when additively overriding the defaults for a particular target.
"""
return select({
"@bazel_tools//src/conditions:windows_msvc": [
"@bazel_tools//src/conditions:windows": [
"/EHs-",
"/GR-",
"/DFLATBUFFERS_LOCALE_INDEPENDENT=0",
Expand Down
83 changes: 39 additions & 44 deletions tensorflow/lite/micro/tools/generate_cc_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,29 @@ def generate_file(out_fname, array_name, array_type, array_contents, size):
"""Write an array of values to a CC or header file."""
os.makedirs(os.path.dirname(out_fname), exist_ok=True)
if out_fname.endswith('.cc'):
out_cc_file = open(out_fname, 'w')
out_cc_file.write('#include <cstdint>\n\n')
out_cc_file.write('#include "{}"\n\n'.format(
out_fname.split('genfiles/')[-1].replace('.cc', '.h')))
out_cc_file.write('alignas(16) const {} {}[] = {{'.format(
array_type, array_name))
out_cc_file.write(array_contents)
out_cc_file.write('};\n')
out_cc_file.close()
with open(out_fname, 'w') as out_cc_file:
out_cc_file.write('#include <cstdint>\n\n')
# Header include path logic, maintaining compatibility with genfiles/ structure.
header_path = out_fname.split('genfiles/')[-1].replace('.cc', '.h')
out_cc_file.write('#include "{}"\n\n'.format(header_path))
out_cc_file.write('alignas(16) const {} {}[] = {{'.format(
array_type, array_name))
out_cc_file.write(array_contents)
out_cc_file.write('};\n')
elif out_fname.endswith('.h'):
out_hdr_file = open(out_fname, 'w')
out_hdr_file.write('#include <cstdint>\n\n')
out_hdr_file.write('constexpr unsigned int {}_size = {};\n'.format(
array_name, str(size)))
out_hdr_file.write('extern const {} {}[];\n'.format(
array_type, array_name))
out_hdr_file.close()
with open(out_fname, 'w') as out_hdr_file:
out_hdr_file.write('#include <cstdint>\n\n')
out_hdr_file.write('constexpr unsigned int {}_size = {};\n'.format(
array_name, str(size)))
out_hdr_file.write('extern const {} {}[];\n'.format(
array_type, array_name))
else:
raise ValueError('generated file must be end with .cc or .h')


def bytes_to_hexstring(buffer):
"""Convert a byte array to a hex string."""
hex_values = [hex(buffer[i]) for i in range(len(buffer))]
out_string = ','.join(hex_values)
return out_string
return ','.join([hex(b) for b in buffer])


def generate_array(input_fname):
Expand Down Expand Up @@ -92,31 +89,31 @@ def generate_array(input_fname):
data_1d = data.flatten()
out_string = ','.join([str(x) for x in data_1d])
return [len(data_1d), out_string]

else:
raise ValueError('input file must be .tflite, .bmp, .wav or .csv')


def get_array_name(input_fname):
# Normalize potential relative path to remove additional dot.
abs_fname = os.path.abspath(input_fname)
base_array_name = 'g_' + abs_fname.split('.')[-2].split('/')[-1]
def get_array_name_and_type(input_fname):
"""Return the array name and type for a given input file."""
# Use os.path.basename to correctly handle both Unix and Windows paths.
base_fname = os.path.basename(input_fname)
# Original logic extracted the filename part between the last two dots.
name_parts = base_fname.split('.')
name_no_ext = name_parts[-2] if len(name_parts) >= 2 else base_fname
base_array_name = 'g_' + name_no_ext

if input_fname.endswith('.tflite'):
return [base_array_name + '_model_data', 'unsigned char']
elif input_fname.endswith('.bmp'):
return [base_array_name + '_image_data', 'unsigned char']
elif input_fname.endswith('.wav'):
return [base_array_name + '_audio_data', 'int16_t']
elif input_fname.endswith('_int32.csv'):
return [base_array_name + '_test_data', 'int32_t']
elif input_fname.endswith('_int16.csv'):
return [base_array_name + '_test_data', 'int16_t']
elif input_fname.endswith('_int8.csv'):
return [base_array_name + '_test_data', 'int8_t']
elif input_fname.endswith('_float.csv'):
return [base_array_name + '_test_data', 'float']
elif input_fname.endswith('npy'):
return [base_array_name + '_test_data', 'float']
elif input_fname.endswith(('_int32.csv', '_int16.csv', '_int8.csv', '_float.csv', '.csv', '.npy')):
return [base_array_name + '_test_data', 'int32_t' if '_int32.csv' in input_fname else
'int16_t' if '_int16.csv' in input_fname else
'int8_t' if '_int8.csv' in input_fname else 'float']
else:
return [base_array_name + '_data', 'unsigned char']


def main():
Expand All @@ -135,7 +132,7 @@ def main():
if args.output.endswith('.cc') or args.output.endswith('.h'):
assert len(args.inputs) == 1
size, cc_array = generate_array(args.inputs[0])
generated_array_name, array_type = get_array_name(args.inputs[0])
generated_array_name, array_type = get_array_name_and_type(args.inputs[0])
generate_file(args.output, generated_array_name, array_type, cc_array,
size)
else:
Expand All @@ -144,15 +141,13 @@ def main():
output_base_fname = os.path.join(args.output,
os.path.splitext(input_file)[0])
if input_file.endswith('.tflite'):
output_base_fname = output_base_fname + '_model_data'
output_base_fname += '_model_data'
elif input_file.endswith('.bmp'):
output_base_fname = output_base_fname + '_image_data'
output_base_fname += '_image_data'
elif input_file.endswith('.wav'):
output_base_fname = output_base_fname + '_audio_data'
elif input_file.endswith('.csv'):
output_base_fname = output_base_fname + '_test_data'
elif input_file.endswith('.npy'):
output_base_fname = output_base_fname + '_test_data'
output_base_fname += '_audio_data'
elif input_file.endswith(('.csv', '.npy')):
output_base_fname += '_test_data'
else:
raise ValueError(
'input file must be .tflite, .bmp, .wav , .npy or .csv')
Expand All @@ -162,12 +157,12 @@ def main():
print(output_cc_fname)
output_hdr_fname = output_base_fname + '.h'
size, cc_array = generate_array(input_file)
generated_array_name, array_type = get_array_name(input_file)
generated_array_name, array_type = get_array_name_and_type(input_file)
generate_file(output_cc_fname, generated_array_name, array_type,
cc_array, size)
generate_file(output_hdr_fname, generated_array_name, array_type,
cc_array, size)


if __name__ == '__main__':
main()
main()
Loading