From 68821de84d87979207128a35b4df35e01e982cee Mon Sep 17 00:00:00 2001 From: Thomas de Jong Date: Sun, 15 Feb 2026 20:46:15 +0100 Subject: [PATCH] perf: optimize header extraction paths and harden Forwarded parsing --- README.md | 2 +- benchmark_test.go | 20 +++++ forwarded.go | 163 ++++++++++++++++++----------------- forwarded_test.go | 20 +++++ source_chain.go | 2 + source_chain_test.go | 14 +++ source_forwarded.go | 4 +- source_single_header.go | 3 +- source_single_header_test.go | 4 + 9 files changed, 151 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index 9caeeb5..712ed78 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ import "github.com/abczzz13/clientip" ## Compatibility - Core module (`github.com/abczzz13/clientip`) supports Go `1.21+`. -- Optional Prometheus adapter (`github.com/abczzz13/clientip/prometheus`) supports Go `1.21+` and validates in consumer mode on Go `1.21.x` and `1.26.x`. +- Optional Prometheus adapter (`github.com/abczzz13/clientip/prometheus`) has a minimum Go version of `1.21`; CI currently validates consumer mode on Go `1.21.x` and `1.26.x`. - Prometheus client dependency in the adapter is pinned to `github.com/prometheus/client_golang v1.21.1`. ## Quick start diff --git a/benchmark_test.go b/benchmark_test.go index 6e99d90..764031b 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -83,6 +83,26 @@ func BenchmarkExtract_Forwarded_Simple(b *testing.B) { } } +func BenchmarkExtract_Forwarded_WithParams(b *testing.B) { + extractor, _ := New( + TrustLoopbackProxy(), + Priority(SourceForwarded, SourceRemoteAddr), + ) + req := &http.Request{ + RemoteAddr: "127.0.0.1:12345", + Header: make(http.Header), + } + req.Header.Set("Forwarded", `for="[2606:4700:4700::1]:8080";proto=https;by=10.0.0.1, for=1.1.1.1`) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + result, err := extractor.Extract(req) + if err != nil || !result.IP.IsValid() { + b.Fatal("extraction failed") + } + } +} + func BenchmarkExtract_XForwardedFor_LongChain(b *testing.B) { cidrs, _ := ParseCIDRs("10.0.0.0/8") extractor, _ := New( diff --git a/forwarded.go b/forwarded.go index eff7113..fd57ab7 100644 --- a/forwarded.go +++ b/forwarded.go @@ -1,6 +1,7 @@ package clientip import ( + "errors" "fmt" "strings" ) @@ -21,29 +22,25 @@ func (e *Extractor) parseForwardedValues(values []string) ([]string, error) { parts := make([]string, 0, typicalChainCapacity) for _, value := range values { - elements, err := splitForwardedHeaderValue(value, ',') - if err != nil { - return nil, invalidForwardedHeaderError(err) - } - - for _, element := range elements { - element = strings.TrimSpace(element) - if element == "" { - continue - } - - forwardedFor, hasFor, err := parseForwardedElement(element) - if err != nil { - return nil, invalidForwardedHeaderError(err) + err := scanForwardedSegments(value, ',', func(element string) error { + forwardedFor, hasFor, parseErr := parseForwardedElement(element) + if parseErr != nil { + return parseErr } if !hasFor { - continue + return nil } - parts, err = e.appendChainPart(parts, forwardedFor, SourceForwarded) - if err != nil { + var appendErr error + parts, appendErr = e.appendChainPart(parts, forwardedFor, SourceForwarded) + return appendErr + }) + if err != nil { + if errors.Is(err, ErrChainTooLong) { return nil, err } + + return nil, invalidForwardedHeaderError(err) } } @@ -66,51 +63,96 @@ func invalidForwardedHeaderError(err error) error { // case-insensitively, and rejects duplicate for parameters in the same // element. func parseForwardedElement(element string) (forwardedFor string, hasFor bool, err error) { - params, err := splitForwardedHeaderValue(element, ';') - if err != nil { - return "", false, err - } - - for _, param := range params { - param = strings.TrimSpace(param) - if param == "" { - continue - } - + err = scanForwardedSegments(element, ';', func(param string) error { eq := strings.IndexByte(param, '=') if eq <= 0 { - return "", false, fmt.Errorf("invalid forwarded parameter %q", param) + return fmt.Errorf("invalid forwarded parameter %q", param) } key := strings.TrimSpace(param[:eq]) value := strings.TrimSpace(param[eq+1:]) if key == "" { - return "", false, fmt.Errorf("empty parameter key in %q", param) + return fmt.Errorf("empty parameter key in %q", param) } if value == "" { - return "", false, fmt.Errorf("empty parameter value for %q", key) + return fmt.Errorf("empty parameter value for %q", key) } if !strings.EqualFold(key, "for") { - continue + return nil } if hasFor { - return "", false, fmt.Errorf("duplicate for parameter in element %q", element) + return fmt.Errorf("duplicate for parameter in element %q", element) } parsedValue, parseErr := parseForwardedForValue(value) if parseErr != nil { - return "", false, parseErr + return parseErr } forwardedFor = parsedValue hasFor = true + return nil + }) + if err != nil { + return "", false, err } return forwardedFor, hasFor, nil } +// scanForwardedSegments splits value by delimiter while respecting quoted +// segments and escape sequences inside quoted strings. +func scanForwardedSegments(value string, delimiter byte, onSegment func(string) error) error { + start := 0 + inQuotes := false + escaped := false + + for i := 0; i <= len(value); i++ { + if i == len(value) { + if inQuotes { + return fmt.Errorf("unterminated quoted string in %q", value) + } + if escaped { + return fmt.Errorf("unterminated escape in %q", value) + } + } else { + ch := value[i] + + if escaped { + escaped = false + continue + } + + if ch == '\\' && inQuotes { + escaped = true + continue + } + + if ch == '"' { + inQuotes = !inQuotes + continue + } + + if ch != delimiter || inQuotes { + continue + } + } + + segment := strings.TrimSpace(value[start:i]) + if segment != "" { + if err := onSegment(segment); err != nil { + return err + } + } + + start = i + 1 + } + + return nil +} + // parseForwardedForValue parses a Forwarded for parameter value. // // The value may be an unquoted token or a quoted string. For quoted strings, @@ -143,7 +185,17 @@ func unquoteForwardedValue(value string) (string, error) { return "", fmt.Errorf("invalid quoted string %q", value) } + inner := value[1 : len(value)-1] + if strings.IndexByte(inner, '\\') == -1 { + if strings.IndexByte(inner, '"') != -1 { + return "", fmt.Errorf("unexpected quote in %q", value) + } + + return inner, nil + } + var b strings.Builder + b.Grow(len(inner)) escaped := false for i := 1; i < len(value)-1; i++ { @@ -173,46 +225,3 @@ func unquoteForwardedValue(value string) (string, error) { return b.String(), nil } - -// splitForwardedHeaderValue splits value by delimiter while respecting quoted -// segments and escape sequences inside quoted strings. -func splitForwardedHeaderValue(value string, delimiter byte) ([]string, error) { - segments := make([]string, 0, 4) - start := 0 - inQuotes := false - escaped := false - - for i := 0; i < len(value); i++ { - ch := value[i] - - if escaped { - escaped = false - continue - } - - if ch == '\\' && inQuotes { - escaped = true - continue - } - - if ch == '"' { - inQuotes = !inQuotes - continue - } - - if ch == delimiter && !inQuotes { - segments = append(segments, value[start:i]) - start = i + 1 - } - } - - if inQuotes { - return nil, fmt.Errorf("unterminated quoted string in %q", value) - } - if escaped { - return nil, fmt.Errorf("unterminated escape in %q", value) - } - - segments = append(segments, value[start:]) - return segments, nil -} diff --git a/forwarded_test.go b/forwarded_test.go index cc8c4c7..57fe136 100644 --- a/forwarded_test.go +++ b/forwarded_test.go @@ -43,6 +43,21 @@ func TestParseForwardedValues(t *testing.T) { values: []string{"for=\"[2606:4700:4700::1]:8080\""}, want: []string{"[2606:4700:4700::1]:8080"}, }, + { + name: "quoted comma is not treated as element delimiter", + values: []string{"for=\"1.1.1.1,8.8.8.8\";proto=https"}, + want: []string{"1.1.1.1,8.8.8.8"}, + }, + { + name: "quoted semicolon is not treated as param delimiter", + values: []string{"for=\"1.1.1.1;edge\";proto=https"}, + want: []string{"1.1.1.1;edge"}, + }, + { + name: "escaped quote remains inside quoted value", + values: []string{`for="1.1.1.1\";edge";proto=https`}, + want: []string{`1.1.1.1";edge`}, + }, { name: "ignores element without for parameter", values: []string{"proto=https;by=10.0.0.1, for=8.8.8.8"}, @@ -63,6 +78,11 @@ func TestParseForwardedValues(t *testing.T) { values: []string{"for=1.1.1.1;for=8.8.8.8"}, wantErr: ErrInvalidForwardedHeader, }, + { + name: "trailing escape in quoted value", + values: []string{`for="1.1.1.1\`}, + wantErr: ErrInvalidForwardedHeader, + }, } for _, tt := range tests { diff --git a/source_chain.go b/source_chain.go index 6405b91..8ce880c 100644 --- a/source_chain.go +++ b/source_chain.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "net/textproto" "strings" ) @@ -44,6 +45,7 @@ func newSingleHeaderSource(extractor *Extractor, headerName string) sourceExtrac return &singleHeaderSource{ extractor: extractor, headerName: headerName, + headerKey: textproto.CanonicalMIMEHeaderKey(headerName), sourceName: sourceName, unavailableErr: &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName}, } diff --git a/source_chain_test.go b/source_chain_test.go index 1298cfb..a8f408d 100644 --- a/source_chain_test.go +++ b/source_chain_test.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "net/netip" + "net/textproto" "testing" ) @@ -61,6 +62,7 @@ func TestChainedSource_Extract(t *testing.T) { &singleHeaderSource{ extractor: extractor, headerName: "X-Real-IP", + headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"), sourceName: SourceXRealIP, }, &forwardedForSource{extractor: extractor}, @@ -120,6 +122,7 @@ func TestChainedSource_Name(t *testing.T) { &singleHeaderSource{ extractor: extractor, headerName: "X-Real-IP", + headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"), sourceName: SourceXRealIP, }, &remoteAddrSource{extractor: extractor}, @@ -156,6 +159,16 @@ func TestSourceFactories(t *testing.T) { if source.Name() != SourceXRealIP { t.Errorf("newSingleHeaderSource(X-Real-IP) source name = %q, want %q", source.Name(), SourceXRealIP) } + + single, ok := source.(*singleHeaderSource) + if !ok { + t.Fatalf("newSingleHeaderSource() type = %T, want *singleHeaderSource", source) + } + + wantHeaderKey := textproto.CanonicalMIMEHeaderKey("X-Real-IP") + if single.headerKey != wantHeaderKey { + t.Errorf("newSingleHeaderSource(X-Real-IP) headerKey = %q, want %q", single.headerKey, wantHeaderKey) + } }) t.Run("RemoteAddr source", func(t *testing.T) { @@ -244,6 +257,7 @@ func TestSourceUnavailableErrors(t *testing.T) { source := &singleHeaderSource{ extractor: extractor, headerName: "X-Real-IP", + headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"), sourceName: SourceXRealIP, } req := &http.Request{Header: make(http.Header)} diff --git a/source_forwarded.go b/source_forwarded.go index 56bd698..e7fb75a 100644 --- a/source_forwarded.go +++ b/source_forwarded.go @@ -42,7 +42,7 @@ func (s *forwardedSource) sourceUnavailableError() error { } func (s *forwardedSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) { - forwardedValues := r.Header.Values("Forwarded") + forwardedValues := r.Header["Forwarded"] if len(forwardedValues) == 0 { return extractionResult{}, s.sourceUnavailableError() @@ -71,7 +71,7 @@ func (s *forwardedSource) Extract(ctx context.Context, r *http.Request) (extract } func (s *forwardedForSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) { - xffValues := r.Header.Values("X-Forwarded-For") + xffValues := r.Header["X-Forwarded-For"] if len(xffValues) == 0 { return extractionResult{}, s.sourceUnavailableError() diff --git a/source_single_header.go b/source_single_header.go index c8fc978..524aceb 100644 --- a/source_single_header.go +++ b/source_single_header.go @@ -8,6 +8,7 @@ import ( type singleHeaderSource struct { extractor *Extractor headerName string + headerKey string sourceName string unavailableErr error } @@ -25,7 +26,7 @@ func (s *singleHeaderSource) sourceUnavailableError() error { } func (s *singleHeaderSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) { - headerValues := r.Header.Values(s.headerName) + headerValues := r.Header[s.headerKey] if len(headerValues) == 0 { return extractionResult{}, s.sourceUnavailableError() } diff --git a/source_single_header_test.go b/source_single_header_test.go index 4aa4414..6f30df6 100644 --- a/source_single_header_test.go +++ b/source_single_header_test.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "net/netip" + "net/textproto" "testing" ) @@ -64,6 +65,7 @@ func TestSingleHeaderSource_Extract(t *testing.T) { source := &singleHeaderSource{ extractor: extractor, headerName: tt.headerName, + headerKey: textproto.CanonicalMIMEHeaderKey(tt.headerName), sourceName: NormalizeSourceName(tt.headerName), } @@ -102,6 +104,7 @@ func TestSingleHeaderSource_Extract_MultipleHeaderValues(t *testing.T) { source := &singleHeaderSource{ extractor: extractor, headerName: "X-Real-IP", + headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"), sourceName: NormalizeSourceName("X-Real-IP"), } @@ -161,6 +164,7 @@ func TestSingleHeaderSource_Name(t *testing.T) { source := &singleHeaderSource{ extractor: extractor, headerName: tt.headerName, + headerKey: textproto.CanonicalMIMEHeaderKey(tt.headerName), sourceName: NormalizeSourceName(tt.headerName), }