From 0c28b5581b0cc63d74f32de5d795463146a893a1 Mon Sep 17 00:00:00 2001 From: Riryan Date: Mon, 25 May 2026 17:19:22 -0500 Subject: [PATCH] Add collection count validation in NetDataReaderExtension Added validation for collection count in GetArrayExtension, GetArrayObject, and GetList methods. Introduced a constant to limit maximum collection size. --- .../Extensions/NetDataReaderExtension.cs | 124 ++++++++---------- 1 file changed, 55 insertions(+), 69 deletions(-) diff --git a/ThirdParty/LiteNetLibManager/Scripts/Extensions/NetDataReaderExtension.cs b/ThirdParty/LiteNetLibManager/Scripts/Extensions/NetDataReaderExtension.cs index 1f8dd0c4f..e18b897bf 100644 --- a/ThirdParty/LiteNetLibManager/Scripts/Extensions/NetDataReaderExtension.cs +++ b/ThirdParty/LiteNetLibManager/Scripts/Extensions/NetDataReaderExtension.cs @@ -7,6 +7,10 @@ namespace LiteNetLib.Utils { public static class NetDataReaderExtension { + // A defensive cap to prevent malformed packets from allocating huge collections. + // Raise this only if you intentionally send very large collections in one message. + public const int MaxCollectionElementCount = 65535; + public static TType GetValue(this NetDataReader reader) { return (TType)GetValue(reader, typeof(TType)); @@ -14,25 +18,49 @@ public static TType GetValue(this NetDataReader reader) public static object GetValue(this NetDataReader reader, Type type) { + if (reader == null) + throw new ArgumentNullException(nameof(reader)); + if (type == null) + throw new ArgumentNullException(nameof(type)); + + Type originalType = type; + if (type.IsEnum) type = type.GetEnumUnderlyingType(); if (ReaderRegistry.TryGetReader(type, out Func readerFunc)) { - return readerFunc(reader); + object value = readerFunc(reader); + return originalType.IsEnum ? Enum.ToObject(originalType, value) : value; } -#if UNITY_EDITOR || DEVELOPMENT_BUILD - Debug.LogWarning($"No reader registered for type: {type.FullName}"); -#endif - if (typeof(INetSerializable).IsAssignableFrom(type)) + if (typeof(INetSerializable).IsAssignableFrom(originalType)) { - object instance = Activator.CreateInstance(type); - (instance as INetSerializable).Deserialize(reader); + if (originalType.IsAbstract || originalType.IsInterface) + throw new ArgumentException($"NetDataReader cannot create an instance of abstract/interface type {originalType.FullName}"); + + object instance; + try + { + instance = Activator.CreateInstance(originalType); + } + catch (MissingMethodException exception) + { + throw new ArgumentException($"NetDataReader cannot create type {originalType.FullName}. INetSerializable types must have a public parameterless constructor.", exception); + } + + INetSerializable serializable = instance as INetSerializable; + if (serializable == null) + throw new ArgumentException($"Created value must implement {nameof(INetSerializable)} for type {originalType.FullName}"); + + serializable.Deserialize(reader); return instance; } - throw new ArgumentException("NetDataReader cannot read type " + type.Name); +#if UNITY_EDITOR || DEVELOPMENT_BUILD + Debug.LogWarning($"No reader registered for type: {originalType.FullName}"); +#endif + throw new ArgumentException("NetDataReader cannot read type " + originalType.Name); } public static Color GetColor(this NetDataReader reader) @@ -64,61 +92,6 @@ public static Vector3 GetVector3(this NetDataReader reader) return new Vector3(reader.GetFloat(), reader.GetFloat(), reader.GetFloat()); } - /// - ///Read a quantized Vector3 from the reader, the vector is quantized into integers based on the cell size and compression mode, which determines how many bits are used for each component. - public static Vector3 GetQuantizedVector3(this NetDataReader reader, ushort cellSize, out int compressionMode) - { - //Read Mode from the first byte, the mode is determined by the top 2 bits of the first byte, and the remaining 6 bits are used for quantized data. The mode determines how many bits are used for each component of the vector, which can be 3, 4, 5, or 6 bits for x, y, and z respectively. - byte first = reader.GetByte(); - compressionMode = ((first >> 6) & 0b11) + 3; - - int bx, by, bz; - - switch (compressionMode) - { - case 3: bx = 10; by = 4; bz = 10; break; - case 4: bx = 11; by = 10; bz = 11; break; - case 5: bx = 14; by = 12; bz = 14; break; - case 6: bx = 16; by = 16; bz = 16; break; - default: throw new Exception("Invalid mode"); - } - - int totalBits = bx + by + bz; - int byteCount = (totalBits + 7) / 8; - - ulong data = (ulong)(first & 0x3F); // first 6 bits - int shift = 6; - - for (int i = 1; i < byteCount; i++) - { - byte b = reader.GetByte(); - data |= ((ulong)b << shift); - shift += 8; - } - - - //Extract quantized values for x, y, and z from the combined data using bitwise operations. The values are extracted in the order of x, z, and y, based on the number of bits allocated for each component. - int s = 0; - - int qx = (int)((data >> s) & ((1UL << bx) - 1)); s += bx; - int qz = (int)((data >> s) & ((1UL << bz) - 1)); s += bz; - int qy = (int)((data >> s) & ((1UL << by) - 1)); - - float x = Dequantize(qx, cellSize, bx); - float y = Dequantize(qy, cellSize, by); - float z = Dequantize(qz, cellSize, bz); - - return new Vector3(x, y, z); - } - - /// - /// Dequantize an integer value to a float based on the cell size and the number of bits used for quantization. The value is first normalized to the range [0, 1] by dividing it by the maximum integer value that can be represented with the given number of bits, and then scaled by the cell size to get the final float value. - static float Dequantize(int value, ushort cellSize, int bits) - { - float maxInt = (1 << bits) - 1; - return ((float)value / maxInt) * cellSize; - } - public static Vector3Int GetVector3Int(this NetDataReader reader) { return new Vector3Int(reader.GetInt(), reader.GetInt(), reader.GetInt()); @@ -131,7 +104,7 @@ public static Vector4 GetVector4(this NetDataReader reader) public static TValue[] GetArrayExtension(this NetDataReader reader) { - int count = reader.GetInt(); + int count = reader.GetValidatedCollectionCount(nameof(GetArrayExtension)); TValue[] result = new TValue[count]; for (int i = 0; i < count; ++i) { @@ -142,7 +115,10 @@ public static TValue[] GetArrayExtension(this NetDataReader reader) public static object GetArrayObject(this NetDataReader reader, Type type) { - int count = reader.GetInt(); + if (type == null) + throw new ArgumentNullException(nameof(type)); + + int count = reader.GetValidatedCollectionCount(nameof(GetArrayObject)); Array array = Array.CreateInstance(type, count); for (int i = 0; i < count; ++i) { @@ -153,8 +129,8 @@ public static object GetArrayObject(this NetDataReader reader, Type type) public static List GetList(this NetDataReader reader) { - int count = reader.GetInt(); - List result = new List(); + int count = reader.GetValidatedCollectionCount(nameof(GetList)); + List result = new List(count); for (int i = 0; i < count; ++i) { result.Add(reader.GetValue()); @@ -164,8 +140,8 @@ public static List GetList(this NetDataReader reader) public static Dictionary GetDictionary(this NetDataReader reader) { - int count = reader.GetInt(); - Dictionary result = new Dictionary(); + int count = reader.GetValidatedCollectionCount(nameof(GetDictionary)); + Dictionary result = new Dictionary(count); for (int i = 0; i < count; ++i) { result.Add(reader.GetValue(), reader.GetValue()); @@ -173,6 +149,16 @@ public static Dictionary GetDictionary(this NetDataR return result; } + private static int GetValidatedCollectionCount(this NetDataReader reader, string source) + { + int count = reader.GetInt(); + if (count < 0) + throw new ArgumentOutOfRangeException(source, count, "Collection count cannot be negative."); + if (count > MaxCollectionElementCount) + throw new ArgumentOutOfRangeException(source, count, $"Collection count exceeds MaxCollectionElementCount ({MaxCollectionElementCount})."); + return count; + } + #region Packed Signed Int (Ref: https://developers.google.com/protocol-buffers/docs/encoding#signed-integers) public static short GetPackedShort(this NetDataReader reader) {