diff --git a/dart/packages/fory/lib/fory.dart b/dart/packages/fory/lib/fory.dart index 7011281e1b..ffffef41fc 100644 --- a/dart/packages/fory/lib/fory.dart +++ b/dart/packages/fory/lib/fory.dart @@ -28,6 +28,7 @@ export 'src/serializer/enum_serializer.dart'; export 'src/serializer/serializer.dart'; export 'src/serializer/union_serializer.dart'; export 'src/types/fixed_ints.dart'; +export 'src/types/bfloat16.dart'; export 'src/types/float16.dart'; export 'src/types/float32.dart'; export 'src/types/local_date.dart'; diff --git a/dart/packages/fory/lib/src/buffer.dart b/dart/packages/fory/lib/src/buffer.dart index ded8b5f5ce..f176d04c30 100644 --- a/dart/packages/fory/lib/src/buffer.dart +++ b/dart/packages/fory/lib/src/buffer.dart @@ -3,6 +3,7 @@ import 'dart:typed_data'; import 'package:meta/meta.dart'; +import 'package:fory/src/types/bfloat16.dart'; import 'package:fory/src/types/float16.dart'; final BigInt _mask64Big = (BigInt.one << 64) - BigInt.one; @@ -236,6 +237,12 @@ final class Buffer { /// Reads a half-precision floating-point value. Float16 readFloat16() => Float16.fromBits(readUint16()); + /// Writes a brain floating-point (bfloat16) value. + void writeBfloat16(BFloat16 value) => writeUint16(value.toBits()); + + /// Reads a brain floating-point (bfloat16) value. + BFloat16 readBfloat16() => BFloat16.fromBits(readUint16()); + /// Writes [value] verbatim. void writeBytes(List value) { ensureWritable(value.length); diff --git a/dart/packages/fory/lib/src/codegen/fory_generator.dart b/dart/packages/fory/lib/src/codegen/fory_generator.dart index 424301df19..e837cb8b44 100644 --- a/dart/packages/fory/lib/src/codegen/fory_generator.dart +++ b/dart/packages/fory/lib/src/codegen/fory_generator.dart @@ -68,9 +68,7 @@ final class ForyGenerator extends Generator { final fields = element.fields .where( - (field) => - !field.isStatic && - !field.isSynthetic, + (field) => !field.isStatic && !field.isSynthetic, ) .where((field) => !_isSkipped(field)) .map(_analyzeField) @@ -669,7 +667,8 @@ final class ForyGenerator extends Generator { " throw ArgumentError('Exactly one registration mode is required: id, or namespace + typeName.');", ) ..writeln(' }') - ..writeln(' if (hasNamed && (namespace == null || typeName == null)) {') + ..writeln( + ' if (hasNamed && (namespace == null || typeName == null)) {') ..writeln( " throw ArgumentError('Both namespace and typeName are required for named registration.');", ) @@ -1002,6 +1001,7 @@ GeneratedFieldType( case TypeIds.int16: case TypeIds.uint16: case TypeIds.float16: + case TypeIds.bfloat16: return 2; case TypeIds.int32: case TypeIds.uint32: @@ -1065,6 +1065,8 @@ GeneratedFieldType( return 'buffer.writeTaggedUint64(${_directGeneratedScalarExpression(field, valueExpression)})'; case TypeIds.float16: return 'buffer.writeFloat16($valueExpression)'; + case TypeIds.bfloat16: + return 'buffer.writeBfloat16($valueExpression)'; case TypeIds.float32: return 'buffer.writeFloat32(${_directGeneratedScalarExpression(field, valueExpression)})'; case TypeIds.float64: @@ -1088,6 +1090,7 @@ GeneratedFieldType( case TypeIds.uint32Array: case TypeIds.uint64Array: case TypeIds.float16Array: + case TypeIds.bfloat16Array: case TypeIds.float32Array: case TypeIds.float64Array: return 'writeGeneratedFixedArrayValue(context, $valueExpression)'; @@ -1138,6 +1141,8 @@ GeneratedFieldType( return '$cursorExpression.writeTaggedUint64(${_directGeneratedScalarExpression(field, valueExpression)})'; case TypeIds.float16: return '$cursorExpression.writeFloat16($valueExpression)'; + case TypeIds.bfloat16: + return '$cursorExpression.writeBfloat16($valueExpression)'; case TypeIds.float32: return '$cursorExpression.writeFloat32(${_directGeneratedScalarExpression(field, valueExpression)})'; case TypeIds.float64: @@ -1201,6 +1206,8 @@ GeneratedFieldType( return 'buffer.readTaggedUint64()'; case TypeIds.float16: return 'buffer.readFloat16()'; + case TypeIds.bfloat16: + return 'buffer.readBfloat16()'; case TypeIds.float32: return field.type.isDartCoreDouble ? 'buffer.readFloat32()' @@ -1229,6 +1236,7 @@ GeneratedFieldType( return 'readGeneratedBinaryValue(context)'; case TypeIds.uint16Array: case TypeIds.float16Array: + case TypeIds.bfloat16Array: return 'readGeneratedTypedArrayValue(context, 2, (bytes) => bytes.buffer.asUint16List(bytes.offsetInBytes, bytes.lengthInBytes ~/ 2))'; case TypeIds.uint32Array: return 'readGeneratedTypedArrayValue(context, 4, (bytes) => bytes.buffer.asUint32List(bytes.offsetInBytes, bytes.lengthInBytes ~/ 4))'; @@ -1296,6 +1304,8 @@ GeneratedFieldType( return '$cursorExpression.readTaggedUint64()'; case TypeIds.float16: return '$cursorExpression.readFloat16()'; + case TypeIds.bfloat16: + return '$cursorExpression.readBfloat16()'; case TypeIds.float32: return field.type.isDartCoreDouble ? '$cursorExpression.readFloat32()' @@ -1349,6 +1359,7 @@ GeneratedFieldType( } switch (field.fieldType.typeId) { case TypeIds.float16: + case TypeIds.bfloat16: return valueExpression; default: return '$valueExpression.value'; @@ -1577,6 +1588,7 @@ GeneratedFieldType( case TypeIds.int16: case TypeIds.uint16: case TypeIds.float16: + case TypeIds.bfloat16: return 2; case TypeIds.int32: case TypeIds.varInt32: @@ -1629,6 +1641,7 @@ GeneratedFieldType( case TypeIds.varUint64: case TypeIds.taggedUint64: case TypeIds.float16: + case TypeIds.bfloat16: case TypeIds.float32: case TypeIds.float64: return true; @@ -1653,6 +1666,7 @@ GeneratedFieldType( case TypeIds.uint32Array: case TypeIds.uint64Array: case TypeIds.float16Array: + case TypeIds.bfloat16Array: case TypeIds.float32Array: case TypeIds.float64Array: return true; @@ -1792,6 +1806,8 @@ GeneratedFieldType( return TypeIds.uint32; case 'Float16': return TypeIds.float16; + case 'BFloat16': + return TypeIds.bfloat16; case 'Float32': return TypeIds.float32; case 'Timestamp': diff --git a/dart/packages/fory/lib/src/codegen/generated_support.dart b/dart/packages/fory/lib/src/codegen/generated_support.dart index ddaac63357..56c243a0cd 100644 --- a/dart/packages/fory/lib/src/codegen/generated_support.dart +++ b/dart/packages/fory/lib/src/codegen/generated_support.dart @@ -101,6 +101,10 @@ final class GeneratedWriteCursor { writeUint16(value.toBits()); } + void writeBfloat16(BFloat16 value) { + writeUint16(value.toBits()); + } + void writeFloat32(double value) { _view.setFloat32(_offset, value, Endian.little); _offset += 4; @@ -271,6 +275,8 @@ final class GeneratedReadCursor { Float16 readFloat16() => Float16.fromBits(readUint16()); + BFloat16 readBfloat16() => BFloat16.fromBits(readUint16()); + double readFloat32() { final value = _view.getFloat32(_offset, Endian.little); _offset += 4; diff --git a/dart/packages/fory/lib/src/context/meta_string_reader.dart b/dart/packages/fory/lib/src/context/meta_string_reader.dart index d5833876b2..a91bafc70d 100644 --- a/dart/packages/fory/lib/src/context/meta_string_reader.dart +++ b/dart/packages/fory/lib/src/context/meta_string_reader.dart @@ -4,8 +4,13 @@ import 'package:fory/src/buffer.dart'; import 'package:fory/src/meta/meta_string.dart'; import 'package:fory/src/resolver/type_resolver.dart'; -typedef _MetaStringWords = - ({int length, int word0, int word1, int word2, int word3}); +typedef _MetaStringWords = ({ + int length, + int word0, + int word1, + int word2, + int word3 +}); /// Read-side state for meta-string references in one deserialization stream. final class MetaStringReader { diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index ec78c2f7d2..b47f444a21 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -13,6 +13,7 @@ import 'package:fory/src/serializer/primitive_serializers.dart'; import 'package:fory/src/serializer/scalar_serializers.dart'; import 'package:fory/src/serializer/serializer.dart'; import 'package:fory/src/serializer/typed_array_serializers.dart'; +import 'package:fory/src/types/bfloat16.dart'; import 'package:fory/src/types/float16.dart'; /// Read-side serializer context. @@ -157,6 +158,9 @@ final class ReadContext { /// Reads a half-precision floating-point value. Float16 readFloat16() => _buffer.readFloat16(); + /// Reads a brain floating-point (bfloat16) value. + BFloat16 readBfloat16() => _buffer.readBfloat16(); + /// Reads a single-precision floating-point value. double readFloat32() => _buffer.readFloat32(); diff --git a/dart/packages/fory/lib/src/context/write_context.dart b/dart/packages/fory/lib/src/context/write_context.dart index f039e2f422..796bd50c17 100644 --- a/dart/packages/fory/lib/src/context/write_context.dart +++ b/dart/packages/fory/lib/src/context/write_context.dart @@ -16,6 +16,7 @@ import 'package:fory/src/serializer/map_serializers.dart'; import 'package:fory/src/serializer/primitive_serializers.dart'; import 'package:fory/src/serializer/scalar_serializers.dart'; import 'package:fory/src/serializer/typed_array_serializers.dart'; +import 'package:fory/src/types/bfloat16.dart'; import 'package:fory/src/types/float16.dart'; import 'package:fory/src/types/local_date.dart'; import 'package:fory/src/types/timestamp.dart'; @@ -140,6 +141,9 @@ final class WriteContext { /// Writes a half-precision floating-point value. void writeFloat16(Float16 value) => _buffer.writeFloat16(value); + /// Writes a brain floating-point (bfloat16) value. + void writeBfloat16(BFloat16 value) => _buffer.writeBfloat16(value); + /// Writes a single-precision floating-point value. void writeFloat32(double value) => _buffer.writeFloat32(value); @@ -272,6 +276,7 @@ final class WriteContext { case TypeIds.varUint64: case TypeIds.taggedUint64: case TypeIds.float16: + case TypeIds.bfloat16: case TypeIds.float32: case TypeIds.float64: PrimitiveSerializer.writePayload(this, resolved.typeId, value); diff --git a/dart/packages/fory/lib/src/meta/type_meta.dart b/dart/packages/fory/lib/src/meta/type_meta.dart index e9c629f93e..0de461ca18 100644 --- a/dart/packages/fory/lib/src/meta/type_meta.dart +++ b/dart/packages/fory/lib/src/meta/type_meta.dart @@ -239,6 +239,7 @@ final class WireTypeMetaDecoder { wireTypeId == TypeIds.varUint64 || wireTypeId == TypeIds.taggedUint64 || wireTypeId == TypeIds.float16 || + wireTypeId == TypeIds.bfloat16 || wireTypeId == TypeIds.float32 || wireTypeId == TypeIds.float64 || wireTypeId == TypeIds.string || @@ -258,6 +259,7 @@ final class WireTypeMetaDecoder { wireTypeId == TypeIds.uint32Array || wireTypeId == TypeIds.uint64Array || wireTypeId == TypeIds.float16Array || + wireTypeId == TypeIds.bfloat16Array || wireTypeId == TypeIds.float32Array || wireTypeId == TypeIds.float64Array; } diff --git a/dart/packages/fory/lib/src/resolver/type_resolver.dart b/dart/packages/fory/lib/src/resolver/type_resolver.dart index 4aeceb9767..dc26c7c817 100644 --- a/dart/packages/fory/lib/src/resolver/type_resolver.dart +++ b/dart/packages/fory/lib/src/resolver/type_resolver.dart @@ -22,6 +22,7 @@ import 'package:fory/src/serializer/struct_serializer.dart'; import 'package:fory/src/serializer/typed_array_serializers.dart'; import 'package:fory/src/serializer/union_serializer.dart'; import 'package:fory/src/types/fixed_ints.dart'; +import 'package:fory/src/types/bfloat16.dart'; import 'package:fory/src/types/float16.dart'; import 'package:fory/src/types/float32.dart'; import 'package:fory/src/types/local_date.dart'; @@ -46,6 +47,7 @@ abstract final class TypeIds { static const int varUint64 = 14; static const int taggedUint64 = 15; static const int float16 = 17; + static const int bfloat16 = 18; static const int float32 = 19; static const int float64 = 20; static const int string = 21; @@ -77,6 +79,7 @@ abstract final class TypeIds { static const int uint32Array = 50; static const int uint64Array = 51; static const int float16Array = 53; + static const int bfloat16Array = 54; static const int float32Array = 55; static const int float64Array = 56; @@ -97,6 +100,7 @@ abstract final class TypeIds { typeId == varUint64 || typeId == taggedUint64 || typeId == float16 || + typeId == bfloat16 || typeId == float32 || typeId == float64; @@ -132,6 +136,7 @@ abstract final class TypeIds { typeId == uint32Array || typeId == uint64Array || typeId == float16Array || + typeId == bfloat16Array || typeId == float32Array || typeId == float64Array; @@ -153,6 +158,7 @@ abstract final class TypeIds { case uint32Array: case uint64Array: case float16Array: + case bfloat16Array: case float32Array: case float64Array: return false; @@ -443,6 +449,9 @@ final class TypeResolver { if (value is Float16) { return _builtin(Float16, TypeIds.float16); } + if (value is BFloat16) { + return _builtin(BFloat16, TypeIds.bfloat16); + } if (value is Float32) { return _builtin(Float32, TypeIds.float32); } @@ -525,6 +534,7 @@ final class TypeResolver { case TypeIds.varUint64: case TypeIds.taggedUint64: case TypeIds.float16: + case TypeIds.bfloat16: case TypeIds.float32: case TypeIds.float64: case TypeIds.string: @@ -544,6 +554,7 @@ final class TypeResolver { case TypeIds.uint32Array: case TypeIds.uint64Array: case TypeIds.float16Array: + case TypeIds.bfloat16Array: case TypeIds.float32Array: case TypeIds.float64Array: return _builtin(fieldType.type, fieldType.typeId); @@ -1117,6 +1128,8 @@ final class TypeResolver { return _builtin(int, TypeIds.taggedUint64); case TypeIds.float16: return _builtin(Float16, TypeIds.float16); + case TypeIds.bfloat16: + return _builtin(BFloat16, TypeIds.bfloat16); case TypeIds.float32: return _builtin(Float32, TypeIds.float32); case TypeIds.float64: @@ -1155,6 +1168,8 @@ final class TypeResolver { return _builtin(Uint64List, TypeIds.uint64Array); case TypeIds.float16Array: return _builtin(Uint16List, TypeIds.float16Array); + case TypeIds.bfloat16Array: + return _builtin(Uint16List, TypeIds.bfloat16Array); case TypeIds.float32Array: return _builtin(Float32List, TypeIds.float32Array); case TypeIds.float64Array: @@ -1222,6 +1237,8 @@ final class TypeResolver { return taggedUint64Serializer as Serializer; case TypeIds.float16: return float16Serializer as Serializer; + case TypeIds.bfloat16: + return bfloat16Serializer as Serializer; case TypeIds.float32: return float32Serializer as Serializer; case TypeIds.float64: @@ -1307,6 +1324,9 @@ final class TypeResolver { if (type == Float16) { return TypeIds.float16; } + if (type == BFloat16) { + return TypeIds.bfloat16; + } if (type == Float32) { return TypeIds.float32; } diff --git a/dart/packages/fory/lib/src/serializer/map_serializers.dart b/dart/packages/fory/lib/src/serializer/map_serializers.dart index 8f81ae64ec..04a39a70e0 100644 --- a/dart/packages/fory/lib/src/serializer/map_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/map_serializers.dart @@ -277,13 +277,13 @@ Map readTypedMapPayload( final valueTypeInfo = valueDeclared ? null : context.readTypeMetaValue(); final tracksDepth = ((keyDeclared ? declaredKeyTypeInfo : keyTypeInfo) != null && - tracksNestedPayloadDepth( - keyDeclared ? declaredKeyTypeInfo! : keyTypeInfo!, - )) || - ((valueDeclared ? declaredValueTypeInfo : valueTypeInfo) != null && - tracksNestedPayloadDepth( - valueDeclared ? declaredValueTypeInfo! : valueTypeInfo!, - )); + tracksNestedPayloadDepth( + keyDeclared ? declaredKeyTypeInfo! : keyTypeInfo!, + )) || + ((valueDeclared ? declaredValueTypeInfo : valueTypeInfo) != null && + tracksNestedPayloadDepth( + valueDeclared ? declaredValueTypeInfo! : valueTypeInfo!, + )); if (tracksDepth) { context.increaseDepth(); } diff --git a/dart/packages/fory/lib/src/serializer/primitive_serializers.dart b/dart/packages/fory/lib/src/serializer/primitive_serializers.dart index d90c8ba127..b808688da5 100644 --- a/dart/packages/fory/lib/src/serializer/primitive_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/primitive_serializers.dart @@ -3,6 +3,7 @@ import 'package:fory/src/context/write_context.dart'; import 'package:fory/src/resolver/type_resolver.dart'; import 'package:fory/src/serializer/serializer.dart'; import 'package:fory/src/types/fixed_ints.dart'; +import 'package:fory/src/types/bfloat16.dart'; import 'package:fory/src/types/float16.dart'; import 'package:fory/src/types/float32.dart'; @@ -83,6 +84,9 @@ final class PrimitiveSerializer extends Serializer { case TypeIds.float16: buffer.writeFloat16(value as Float16); return; + case TypeIds.bfloat16: + buffer.writeBfloat16(value as BFloat16); + return; case TypeIds.float32: buffer.writeFloat32((value as Float32).value); return; @@ -132,6 +136,8 @@ final class PrimitiveSerializer extends Serializer { return buffer.readTaggedUint64(); case TypeIds.float16: return buffer.readFloat16(); + case TypeIds.bfloat16: + return buffer.readBfloat16(); case TypeIds.float32: return Float32(buffer.readFloat32()); case TypeIds.float64: @@ -171,8 +177,7 @@ const PrimitiveSerializer varInt64Serializer = PrimitiveSerializer( TypeIds.varInt64, supportsRef: false, ); -const PrimitiveSerializer taggedInt64Serializer = - PrimitiveSerializer( +const PrimitiveSerializer taggedInt64Serializer = PrimitiveSerializer( TypeIds.taggedInt64, supportsRef: false, ); @@ -213,6 +218,11 @@ const PrimitiveSerializer float16Serializer = TypeIds.float16, supportsRef: false, ); +const PrimitiveSerializer bfloat16Serializer = + PrimitiveSerializer( + TypeIds.bfloat16, + supportsRef: false, +); const PrimitiveSerializer float32Serializer = PrimitiveSerializer( TypeIds.float32, diff --git a/dart/packages/fory/lib/src/types/bfloat16.dart b/dart/packages/fory/lib/src/types/bfloat16.dart new file mode 100644 index 0000000000..e0b35e2b7c --- /dev/null +++ b/dart/packages/fory/lib/src/types/bfloat16.dart @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import 'dart:typed_data'; + +/// Brain floating-point (bfloat16) wrapper used by the xlang type system. +/// +/// BFloat16 uses 1 sign bit, 8 exponent bits, and 7 mantissa bits. +/// It has the same exponent range as float32 but with reduced precision. +final class BFloat16 implements Comparable { + final int _bits; + + /// Creates a value directly from raw bfloat16 bits. + const BFloat16.fromBits(int bits) : _bits = bits & 0xffff; + + /// Converts [value] to the closest representable bfloat16 value. + factory BFloat16(num value) => BFloat16.fromFloat32(value.toDouble()); + + /// Converts a float32 [value] to the closest representable bfloat16 value + /// using round-to-nearest, ties-to-even. + factory BFloat16.fromFloat32(double value) { + final f32 = Float32List(1); + final u32 = f32.buffer.asUint32List(); + f32[0] = value; + final bits = u32[0]; + final exponent = (bits >> 23) & 0xff; + + // NaN/Inf: preserve sign and truncate mantissa (keeps NaN payload bits). + if (exponent == 255) { + return BFloat16.fromBits((bits >> 16) & 0xffff); + } + + // Round-to-nearest, ties-to-even. + final remainder = bits & 0x1ffff; + var u = (bits + 0x8000) >> 16; + if (remainder == 0x8000 && (u & 1) != 0) { + u--; + } + return BFloat16.fromBits(u & 0xffff); + } + + /// Returns the raw bfloat16 bits for this value. + int toBits() => _bits; + + /// Expands this bfloat16 value to a Dart [double] (via float32). + double toDouble() { + final f32 = Float32List(1); + final u32 = f32.buffer.asUint32List(); + u32[0] = _bits << 16; + return f32[0]; + } + + @override + bool operator ==(Object other) => + identical(this, other) || other is BFloat16 && other._bits == _bits; + + @override + int get hashCode => _bits.hashCode; + + @override + int compareTo(BFloat16 other) => toDouble().compareTo(other.toDouble()); + + @override + String toString() => toDouble().toString(); +} diff --git a/dart/packages/fory/test/bfloat16_test.dart b/dart/packages/fory/test/bfloat16_test.dart new file mode 100644 index 0000000000..4f21cfc936 --- /dev/null +++ b/dart/packages/fory/test/bfloat16_test.dart @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import 'dart:typed_data'; + +import 'package:fory/fory.dart'; +import 'package:test/test.dart'; + +void main() { + group('BFloat16', () { + group('scalar conversions', () { + test('positive zero', () { + final bf = BFloat16.fromFloat32(0.0); + expect(bf.toBits(), equals(0x0000)); + expect(bf.toDouble(), equals(0.0)); + }); + + test('negative zero', () { + final bf = BFloat16.fromFloat32(-0.0); + expect(bf.toBits(), equals(0x8000)); + expect(bf.toDouble(), equals(-0.0)); + expect(bf.toDouble().isNegative, isTrue); + }); + + test('positive infinity', () { + final bf = BFloat16.fromFloat32(double.infinity); + expect(bf.toBits(), equals(0x7F80)); + expect(bf.toDouble(), equals(double.infinity)); + }); + + test('negative infinity', () { + final bf = BFloat16.fromFloat32(double.negativeInfinity); + expect(bf.toBits(), equals(0xFF80)); + expect(bf.toDouble(), equals(double.negativeInfinity)); + }); + + test('NaN', () { + final bf = BFloat16.fromFloat32(double.nan); + expect(bf.toDouble().isNaN, isTrue); + }); + + test('one', () { + final bf = BFloat16.fromFloat32(1.0); + expect(bf.toBits(), equals(0x3F80)); + expect(bf.toDouble(), equals(1.0)); + }); + + test('negative one', () { + final bf = BFloat16.fromFloat32(-1.0); + expect(bf.toBits(), equals(0xBF80)); + expect(bf.toDouble(), equals(-1.0)); + }); + + test('small value round-trip', () { + final bf = BFloat16.fromFloat32(0.5); + expect(bf.toDouble(), equals(0.5)); + }); + + test('large value round-trip', () { + final bf = BFloat16.fromFloat32(256.0); + expect(bf.toDouble(), equals(256.0)); + }); + + test('fromBits masks to 16 bits', () { + final bf = BFloat16.fromBits(0x13F80); + expect(bf.toBits(), equals(0x3F80)); + }); + + test('num constructor', () { + final bf = BFloat16(1.5); + expect(bf.toDouble(), equals(1.5)); + }); + + test('subnormal bfloat16', () { + // Smallest positive subnormal: 0x0001 + final bf = BFloat16.fromBits(0x0001); + expect(bf.toDouble(), isNot(equals(0.0))); + // Round-trip through fromFloat32 + final bf2 = BFloat16.fromFloat32(bf.toDouble()); + expect(bf2.toBits(), equals(bf.toBits())); + }); + + test('max normal bfloat16', () { + // Largest finite bfloat16: 0x7F7F + final bf = BFloat16.fromBits(0x7F7F); + expect(bf.toDouble().isFinite, isTrue); + expect(bf.toDouble().isNaN, isFalse); + }); + + test('min normal bfloat16', () { + // Smallest positive normal: 0x0080 + final bf = BFloat16.fromBits(0x0080); + expect(bf.toDouble(), greaterThan(0.0)); + expect(bf.toDouble().isFinite, isTrue); + }); + }); + + group('round-to-nearest ties-to-even', () { + test('ties-to-even rounding', () { + // 1.0 in bfloat16 is 0x3F80, next is 0x3F81 + // midpoint between them should round to even + final bf1 = BFloat16.fromBits(0x3F80); + final bf2 = BFloat16.fromBits(0x3F81); + final mid = (bf1.toDouble() + bf2.toDouble()) / 2; + final bfMid = BFloat16.fromFloat32(mid); + // Should round to even (the one with LSB = 0) + expect(bfMid.toBits() & 1, equals(0)); + }); + }); + + group('equality and comparison', () { + test('equal values', () { + final a = BFloat16.fromBits(0x3F80); + final b = BFloat16.fromBits(0x3F80); + expect(a, equals(b)); + expect(a.hashCode, equals(b.hashCode)); + }); + + test('different values', () { + final a = BFloat16.fromFloat32(1.0); + final b = BFloat16.fromFloat32(2.0); + expect(a, isNot(equals(b))); + }); + + test('compareTo', () { + final a = BFloat16.fromFloat32(1.0); + final b = BFloat16.fromFloat32(2.0); + expect(a.compareTo(b), lessThan(0)); + expect(b.compareTo(a), greaterThan(0)); + expect(a.compareTo(a), equals(0)); + }); + + test('toString', () { + final bf = BFloat16.fromFloat32(1.0); + expect(bf.toString(), equals('1.0')); + }); + }); + + group('buffer read/write', () { + test('round-trip through buffer', () { + final buffer = Buffer(64); + final original = BFloat16.fromFloat32(3.14); + buffer.writeBfloat16(original); + + final readBuffer = Buffer.wrap(buffer.toBytes()); + final result = readBuffer.readBfloat16(); + + expect(result.toBits(), equals(original.toBits())); + expect(result.toDouble(), equals(original.toDouble())); + }); + + test('multiple values through buffer', () { + final buffer = Buffer(64); + final values = [ + BFloat16.fromFloat32(0.0), + BFloat16.fromFloat32(1.0), + BFloat16.fromFloat32(-1.0), + BFloat16.fromFloat32(double.infinity), + BFloat16.fromFloat32(double.nan), + ]; + + for (final v in values) { + buffer.writeBfloat16(v); + } + + final readBuffer = Buffer.wrap(buffer.toBytes()); + for (int i = 0; i < values.length; i++) { + final result = readBuffer.readBfloat16(); + if (values[i].toDouble().isNaN) { + expect(result.toDouble().isNaN, isTrue); + } else { + expect(result.toBits(), equals(values[i].toBits())); + } + } + }); + }); + + group('bfloat16 array serialization', () { + test('packed Uint16List round-trip through buffer', () { + final buffer = Buffer(128); + + final values = [1.0, 2.0, 0.5, -1.0, 0.0]; + final packed = Uint16List(values.length); + for (int i = 0; i < values.length; i++) { + packed[i] = BFloat16.fromFloat32(values[i]).toBits(); + } + + // Write length + raw bytes + buffer.writeVarUint32(packed.length); + final bytes = packed.buffer.asUint8List( + packed.offsetInBytes, + packed.lengthInBytes, + ); + buffer.writeBytes(bytes); + + // Read back + final readBuffer = Buffer.wrap(buffer.toBytes()); + final length = readBuffer.readVarUint32(); + expect(length, equals(values.length)); + final rawBytes = readBuffer.copyBytes(length * 2); + final result = rawBytes.buffer.asUint16List( + rawBytes.offsetInBytes, + rawBytes.lengthInBytes ~/ 2, + ); + + for (int i = 0; i < values.length; i++) { + final bf = BFloat16.fromBits(result[i]); + if (values[i] == 0.0) { + expect(bf.toDouble(), equals(0.0)); + } else { + expect(bf.toDouble(), + equals(BFloat16.fromFloat32(values[i]).toDouble())); + } + } + }); + }); + }); +}