diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc index 9f8b5de622..721e3a7bc5 100644 --- a/compiler/cpp/src/thrift/generate/t_go_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc @@ -1769,6 +1769,13 @@ void t_go_generator::generate_go_struct_reader(ostream& out, out << indent() << "func (p *" << tstruct_name << ") " << read_method_name_ << "(ctx context.Context, iprot thrift.TProtocol) error {" << '\n'; indent_up(); + out << indent() << "ctx, err := thrift.CheckRecursionDepth(ctx)" << '\n'; + out << indent() << "if err != nil {" << '\n'; + indent_up(); + out << indent() << "return err" << '\n'; + indent_down(); + out << indent() << "}" << '\n'; + out << indent() << "defer thrift.DecrementRecursionDepth(ctx)" << '\n'; out << indent() << "if _, err := iprot.ReadStructBegin(ctx); err != nil {" << '\n'; indent_up(); out << indent() << "return thrift.PrependError(fmt.Sprintf(\"%T read error: \", p), err)" @@ -1946,6 +1953,13 @@ void t_go_generator::generate_go_struct_writer(ostream& out, vector::const_iterator f_iter; indent(out) << "func (p *" << tstruct_name << ") " << write_method_name_ << "(ctx context.Context, oprot thrift.TProtocol) error {" << '\n'; indent_up(); + out << indent() << "ctx, err := thrift.CheckRecursionDepth(ctx)" << '\n'; + out << indent() << "if err != nil {" << '\n'; + indent_up(); + out << indent() << "return err" << '\n'; + indent_down(); + out << indent() << "}" << '\n'; + out << indent() << "defer thrift.DecrementRecursionDepth(ctx)" << '\n'; if (tstruct->is_union() && uses_countsetfields) { std::string tstruct_name(publicize(tstruct->get_name())); out << indent() << "if c := p.CountSetFields" << tstruct_name << "(); c != 1 {" << '\n'; diff --git a/lib/go/thrift/recursion_tracker.go b/lib/go/thrift/recursion_tracker.go new file mode 100644 index 0000000000..2b7b9ee572 --- /dev/null +++ b/lib/go/thrift/recursion_tracker.go @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "errors" +) + +type recursionDepthKey struct{} + +type recursionDepthTracker struct { + depth int + limit int +} + +// CheckRecursionDepth increments the per-context struct nesting depth and +// returns an error when it exceeds DEFAULT_RECURSION_DEPTH. The returned +// context must be passed to DecrementRecursionDepth when the struct is done +// being read or written. +func CheckRecursionDepth(ctx context.Context) (context.Context, error) { + tracker, _ := ctx.Value(recursionDepthKey{}).(*recursionDepthTracker) + if tracker == nil { + tracker = &recursionDepthTracker{limit: DEFAULT_RECURSION_DEPTH} + ctx = context.WithValue(ctx, recursionDepthKey{}, tracker) + } + tracker.depth++ + if tracker.depth > tracker.limit { + tracker.depth-- + return ctx, NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("maximum recursion depth exceeded")) + } + return ctx, nil +} + +// DecrementRecursionDepth decrements the per-context struct nesting depth. +// It must be called after CheckRecursionDepth returns nil, typically via defer. +func DecrementRecursionDepth(ctx context.Context) { + if tracker, ok := ctx.Value(recursionDepthKey{}).(*recursionDepthTracker); ok && tracker != nil { + tracker.depth-- + } +} diff --git a/lib/go/thrift/recursion_tracker_test.go b/lib/go/thrift/recursion_tracker_test.go new file mode 100644 index 0000000000..1c1c5f39d0 --- /dev/null +++ b/lib/go/thrift/recursion_tracker_test.go @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "testing" +) + +func TestCheckRecursionDepthShallow(t *testing.T) { + ctx := context.Background() + for i := 0; i < DEFAULT_RECURSION_DEPTH; i++ { + var err error + ctx, err = CheckRecursionDepth(ctx) + if err != nil { + t.Fatalf("unexpected error at depth %d: %v", i+1, err) + } + } +} + +func TestCheckRecursionDepthExceeded(t *testing.T) { + ctx := context.Background() + var err error + for i := 0; i < DEFAULT_RECURSION_DEPTH; i++ { + ctx, err = CheckRecursionDepth(ctx) + if err != nil { + t.Fatalf("unexpected error at depth %d: %v", i+1, err) + } + } + // One more should fail + _, err = CheckRecursionDepth(ctx) + if err == nil { + t.Fatal("expected DEPTH_LIMIT error but got nil") + } + protoErr, ok := err.(TProtocolException) + if !ok || protoErr.TypeId() != DEPTH_LIMIT { + t.Fatalf("expected DEPTH_LIMIT TProtocolException, got: %T %v", err, err) + } +} + +func TestDecrementRecursionDepth(t *testing.T) { + ctx := context.Background() + var err error + for i := 0; i < DEFAULT_RECURSION_DEPTH; i++ { + ctx, err = CheckRecursionDepth(ctx) + if err != nil { + t.Fatalf("unexpected error at depth %d: %v", i+1, err) + } + } + // Decrement back to zero + for i := 0; i < DEFAULT_RECURSION_DEPTH; i++ { + DecrementRecursionDepth(ctx) + } + // Should be able to go deep again + for i := 0; i < DEFAULT_RECURSION_DEPTH; i++ { + ctx, err = CheckRecursionDepth(ctx) + if err != nil { + t.Fatalf("unexpected error after decrement at depth %d: %v", i+1, err) + } + } +}