From 974e2c608f74ae884f066c2ea77a51a31e6bf807 Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Mon, 3 Nov 2025 17:38:14 +0800 Subject: [PATCH] feat: support injecting customized codec Impl --- fastpb_impl.go | 7 +++- fastpb_impl_test.go | 99 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/fastpb_impl.go b/fastpb_impl.go index 01f8ce7..35cac38 100644 --- a/fastpb_impl.go +++ b/fastpb_impl.go @@ -26,7 +26,7 @@ import ( ) // Impl implements Protocol. -var Impl impl +var Impl Protocol = impl{} // When encoding length-prefixed fields, we speculatively set aside some number of bytes // for the length, encode the data, and then encode the length (shifting the data if necessary @@ -44,6 +44,11 @@ func SetSpanCache(enable bool) { spanCacheEnable = enable } +// SetImpl replaces the specific codec implementation to support function hijacking etc... +func SetImpl(impl Protocol) { + Impl = impl +} + type impl struct{} // WriteMessage implements TLV(tag, length, value) and V(value). diff --git a/fastpb_impl_test.go b/fastpb_impl_test.go index a92d330..8338832 100644 --- a/fastpb_impl_test.go +++ b/fastpb_impl_test.go @@ -19,6 +19,7 @@ package fastpb import ( "fmt" "testing" + "unicode/utf8" "google.golang.org/protobuf/encoding/protowire" ) @@ -493,3 +494,101 @@ func AssertConsumeTag(name string, got_n protowire.Number, got_t protowire.Type, panic(fmt.Errorf("%s ConsumeTag num[%d]type[%d] != except[%d][%d]", name, got_n, got_t, exp_n, exp_t)) } } + +type customizedImpl struct { + Protocol +} + +func (impl customizedImpl) WriteString(buf []byte, number int32, value string) (n int) { + if number != SkipTagNumber { + n += AppendTag(buf[n:], protowire.Number(number), protowire.BytesType) + } + if !utf8.ValidString(value) { + n += AppendString(buf[n:], "") + } else { + n += AppendString(buf[n:], value) + } + return n +} + +func (impl customizedImpl) SizeString(number int32, value string) (n int) { + if number != SkipTagNumber { + n += protowire.SizeVarint(protowire.EncodeTag(protowire.Number(number), protowire.BytesType)) + } + if !utf8.ValidString(value) { + // empty string + n += 1 + } else { + n += protowire.SizeBytes(len(value)) + } + return n +} + +func Test_InjectCustomizedImpl(t *testing.T) { + orig := Impl + defer func() { + SetImpl(orig) + }() + SetImpl(customizedImpl{ + Protocol: orig, + }) + + // utf-8 valid + // write + var num int32 = 255 + value := "hello world" + exceptSize := 14 + size := Impl.SizeString(num, value) + if size != exceptSize { + panic(fmt.Errorf("SizeString[%d] != except[%d]", size, exceptSize)) + } + buf := make([]byte, 64) + exceptWs := "fa0f0b68656c6c6f20776f726c64" + wn := Impl.WriteString(buf, num, value) + ws := fmt.Sprintf("%x", buf[:wn]) + if wn != size || ws != exceptWs { + panic(fmt.Errorf("WriteString[%d][%s] != except[%d][%s]", wn, ws, size, exceptWs)) + } + + // read + _type := protowire.BytesType + gotRn, gotRt, offset := protowire.ConsumeTag(buf) + AssertConsumeTag("ReadString", gotRn, gotRt, num, _type) + rv, rn, err := Impl.ReadString(buf[offset:], int8(_type)) + if err != nil { + panic(err) + } + rn += offset + if rn != wn || rv != value { + panic(fmt.Errorf("ReadString[%d][%s] != except[%d][%s]", rn, rv, wn, value)) + } + + // utf-8 invalid + // write + value = "'\xff'" + exceptSize = 3 + size = Impl.SizeString(num, value) + if size != exceptSize { + panic(fmt.Errorf("SizeString[%d] != except[%d]", size, exceptSize)) + } + buf = make([]byte, 64) + exceptWs = "fa0f00" + wn = Impl.WriteString(buf, num, value) + ws = fmt.Sprintf("%x", buf[:wn]) + if wn != exceptSize || ws != exceptWs { + panic(fmt.Errorf("WriteString[%d][%s] != except[%d][%s]", wn, ws, size, exceptWs)) + } + + // read + _type = protowire.BytesType + gotRn, gotRt, offset = protowire.ConsumeTag(buf) + AssertConsumeTag("ReadString", gotRn, gotRt, num, _type) + rv, rn, err = Impl.ReadString(buf[offset:], int8(_type)) + if err != nil { + panic(err) + } + rn += offset + if rn != wn || rv != "" { + panic(fmt.Errorf("ReadString[%d][%s] != except[%d][%s]", rn, rv, wn, value)) + } +}