Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import "github.com/abczzz13/clientip"
## Compatibility

- Core module (`github.com/abczzz13/clientip`) supports Go `1.21+`.
- Optional Prometheus adapter (`github.com/abczzz13/clientip/prometheus`) supports Go `1.21+` and validates in consumer mode on Go `1.21.x` and `1.26.x`.
- Optional Prometheus adapter (`github.com/abczzz13/clientip/prometheus`) has a minimum Go version of `1.21`; CI currently validates consumer mode on Go `1.21.x` and `1.26.x`.
- Prometheus client dependency in the adapter is pinned to `github.com/prometheus/client_golang v1.21.1`.

## Quick start
Expand Down
20 changes: 20 additions & 0 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ func BenchmarkExtract_Forwarded_Simple(b *testing.B) {
}
}

func BenchmarkExtract_Forwarded_WithParams(b *testing.B) {
extractor, _ := New(
TrustLoopbackProxy(),
Priority(SourceForwarded, SourceRemoteAddr),
)
req := &http.Request{
RemoteAddr: "127.0.0.1:12345",
Header: make(http.Header),
}
req.Header.Set("Forwarded", `for="[2606:4700:4700::1]:8080";proto=https;by=10.0.0.1, for=1.1.1.1`)

b.ResetTimer()
for i := 0; i < b.N; i++ {
result, err := extractor.Extract(req)
if err != nil || !result.IP.IsValid() {
b.Fatal("extraction failed")
}
}
}

func BenchmarkExtract_XForwardedFor_LongChain(b *testing.B) {
cidrs, _ := ParseCIDRs("10.0.0.0/8")
extractor, _ := New(
Expand Down
163 changes: 86 additions & 77 deletions forwarded.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clientip

import (
"errors"
"fmt"
"strings"
)
Expand All @@ -21,29 +22,25 @@ func (e *Extractor) parseForwardedValues(values []string) ([]string, error) {
parts := make([]string, 0, typicalChainCapacity)

for _, value := range values {
elements, err := splitForwardedHeaderValue(value, ',')
if err != nil {
return nil, invalidForwardedHeaderError(err)
}

for _, element := range elements {
element = strings.TrimSpace(element)
if element == "" {
continue
}

forwardedFor, hasFor, err := parseForwardedElement(element)
if err != nil {
return nil, invalidForwardedHeaderError(err)
err := scanForwardedSegments(value, ',', func(element string) error {
forwardedFor, hasFor, parseErr := parseForwardedElement(element)
if parseErr != nil {
return parseErr
}
if !hasFor {
continue
return nil
}

parts, err = e.appendChainPart(parts, forwardedFor, SourceForwarded)
if err != nil {
var appendErr error
parts, appendErr = e.appendChainPart(parts, forwardedFor, SourceForwarded)
return appendErr
})
if err != nil {
if errors.Is(err, ErrChainTooLong) {
return nil, err
}

return nil, invalidForwardedHeaderError(err)
}
}

Expand All @@ -66,51 +63,96 @@ func invalidForwardedHeaderError(err error) error {
// case-insensitively, and rejects duplicate for parameters in the same
// element.
func parseForwardedElement(element string) (forwardedFor string, hasFor bool, err error) {
params, err := splitForwardedHeaderValue(element, ';')
if err != nil {
return "", false, err
}

for _, param := range params {
param = strings.TrimSpace(param)
if param == "" {
continue
}

err = scanForwardedSegments(element, ';', func(param string) error {
eq := strings.IndexByte(param, '=')
if eq <= 0 {
return "", false, fmt.Errorf("invalid forwarded parameter %q", param)
return fmt.Errorf("invalid forwarded parameter %q", param)
}

key := strings.TrimSpace(param[:eq])
value := strings.TrimSpace(param[eq+1:])
if key == "" {
return "", false, fmt.Errorf("empty parameter key in %q", param)
return fmt.Errorf("empty parameter key in %q", param)
}
if value == "" {
return "", false, fmt.Errorf("empty parameter value for %q", key)
return fmt.Errorf("empty parameter value for %q", key)
}

if !strings.EqualFold(key, "for") {
continue
return nil
}

if hasFor {
return "", false, fmt.Errorf("duplicate for parameter in element %q", element)
return fmt.Errorf("duplicate for parameter in element %q", element)
}

parsedValue, parseErr := parseForwardedForValue(value)
if parseErr != nil {
return "", false, parseErr
return parseErr
}

forwardedFor = parsedValue
hasFor = true
return nil
})
if err != nil {
return "", false, err
}

return forwardedFor, hasFor, nil
}

// scanForwardedSegments splits value by delimiter while respecting quoted
// segments and escape sequences inside quoted strings.
func scanForwardedSegments(value string, delimiter byte, onSegment func(string) error) error {
start := 0
inQuotes := false
escaped := false

for i := 0; i <= len(value); i++ {
if i == len(value) {
if inQuotes {
return fmt.Errorf("unterminated quoted string in %q", value)
}
if escaped {
return fmt.Errorf("unterminated escape in %q", value)
}
} else {
ch := value[i]

if escaped {
escaped = false
continue
}

if ch == '\\' && inQuotes {
escaped = true
continue
}

if ch == '"' {
inQuotes = !inQuotes
continue
}

if ch != delimiter || inQuotes {
continue
}
}

segment := strings.TrimSpace(value[start:i])
if segment != "" {
if err := onSegment(segment); err != nil {
return err
}
}

start = i + 1
}

return nil
}

// parseForwardedForValue parses a Forwarded for parameter value.
//
// The value may be an unquoted token or a quoted string. For quoted strings,
Expand Down Expand Up @@ -143,7 +185,17 @@ func unquoteForwardedValue(value string) (string, error) {
return "", fmt.Errorf("invalid quoted string %q", value)
}

inner := value[1 : len(value)-1]
if strings.IndexByte(inner, '\\') == -1 {
if strings.IndexByte(inner, '"') != -1 {
return "", fmt.Errorf("unexpected quote in %q", value)
}

return inner, nil
}

var b strings.Builder
b.Grow(len(inner))
escaped := false

for i := 1; i < len(value)-1; i++ {
Expand Down Expand Up @@ -173,46 +225,3 @@ func unquoteForwardedValue(value string) (string, error) {

return b.String(), nil
}

// splitForwardedHeaderValue splits value by delimiter while respecting quoted
// segments and escape sequences inside quoted strings.
func splitForwardedHeaderValue(value string, delimiter byte) ([]string, error) {
segments := make([]string, 0, 4)
start := 0
inQuotes := false
escaped := false

for i := 0; i < len(value); i++ {
ch := value[i]

if escaped {
escaped = false
continue
}

if ch == '\\' && inQuotes {
escaped = true
continue
}

if ch == '"' {
inQuotes = !inQuotes
continue
}

if ch == delimiter && !inQuotes {
segments = append(segments, value[start:i])
start = i + 1
}
}

if inQuotes {
return nil, fmt.Errorf("unterminated quoted string in %q", value)
}
if escaped {
return nil, fmt.Errorf("unterminated escape in %q", value)
}

segments = append(segments, value[start:])
return segments, nil
}
20 changes: 20 additions & 0 deletions forwarded_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ func TestParseForwardedValues(t *testing.T) {
values: []string{"for=\"[2606:4700:4700::1]:8080\""},
want: []string{"[2606:4700:4700::1]:8080"},
},
{
name: "quoted comma is not treated as element delimiter",
values: []string{"for=\"1.1.1.1,8.8.8.8\";proto=https"},
want: []string{"1.1.1.1,8.8.8.8"},
},
{
name: "quoted semicolon is not treated as param delimiter",
values: []string{"for=\"1.1.1.1;edge\";proto=https"},
want: []string{"1.1.1.1;edge"},
},
{
name: "escaped quote remains inside quoted value",
values: []string{`for="1.1.1.1\";edge";proto=https`},
want: []string{`1.1.1.1";edge`},
},
{
name: "ignores element without for parameter",
values: []string{"proto=https;by=10.0.0.1, for=8.8.8.8"},
Expand All @@ -63,6 +78,11 @@ func TestParseForwardedValues(t *testing.T) {
values: []string{"for=1.1.1.1;for=8.8.8.8"},
wantErr: ErrInvalidForwardedHeader,
},
{
name: "trailing escape in quoted value",
values: []string{`for="1.1.1.1\`},
wantErr: ErrInvalidForwardedHeader,
},
}

for _, tt := range tests {
Expand Down
2 changes: 2 additions & 0 deletions source_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"net/textproto"
"strings"
)

Expand Down Expand Up @@ -44,6 +45,7 @@ func newSingleHeaderSource(extractor *Extractor, headerName string) sourceExtrac
return &singleHeaderSource{
extractor: extractor,
headerName: headerName,
headerKey: textproto.CanonicalMIMEHeaderKey(headerName),
sourceName: sourceName,
unavailableErr: &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName},
}
Expand Down
14 changes: 14 additions & 0 deletions source_chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"net/http"
"net/netip"
"net/textproto"
"testing"
)

Expand Down Expand Up @@ -61,6 +62,7 @@ func TestChainedSource_Extract(t *testing.T) {
&singleHeaderSource{
extractor: extractor,
headerName: "X-Real-IP",
headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"),
sourceName: SourceXRealIP,
},
&forwardedForSource{extractor: extractor},
Expand Down Expand Up @@ -120,6 +122,7 @@ func TestChainedSource_Name(t *testing.T) {
&singleHeaderSource{
extractor: extractor,
headerName: "X-Real-IP",
headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"),
sourceName: SourceXRealIP,
},
&remoteAddrSource{extractor: extractor},
Expand Down Expand Up @@ -156,6 +159,16 @@ func TestSourceFactories(t *testing.T) {
if source.Name() != SourceXRealIP {
t.Errorf("newSingleHeaderSource(X-Real-IP) source name = %q, want %q", source.Name(), SourceXRealIP)
}

single, ok := source.(*singleHeaderSource)
if !ok {
t.Fatalf("newSingleHeaderSource() type = %T, want *singleHeaderSource", source)
}

wantHeaderKey := textproto.CanonicalMIMEHeaderKey("X-Real-IP")
if single.headerKey != wantHeaderKey {
t.Errorf("newSingleHeaderSource(X-Real-IP) headerKey = %q, want %q", single.headerKey, wantHeaderKey)
}
})

t.Run("RemoteAddr source", func(t *testing.T) {
Expand Down Expand Up @@ -244,6 +257,7 @@ func TestSourceUnavailableErrors(t *testing.T) {
source := &singleHeaderSource{
extractor: extractor,
headerName: "X-Real-IP",
headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"),
sourceName: SourceXRealIP,
}
req := &http.Request{Header: make(http.Header)}
Expand Down
4 changes: 2 additions & 2 deletions source_forwarded.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (s *forwardedSource) sourceUnavailableError() error {
}

func (s *forwardedSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) {
forwardedValues := r.Header.Values("Forwarded")
forwardedValues := r.Header["Forwarded"]

if len(forwardedValues) == 0 {
return extractionResult{}, s.sourceUnavailableError()
Expand Down Expand Up @@ -71,7 +71,7 @@ func (s *forwardedSource) Extract(ctx context.Context, r *http.Request) (extract
}

func (s *forwardedForSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) {
xffValues := r.Header.Values("X-Forwarded-For")
xffValues := r.Header["X-Forwarded-For"]

if len(xffValues) == 0 {
return extractionResult{}, s.sourceUnavailableError()
Expand Down
3 changes: 2 additions & 1 deletion source_single_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
type singleHeaderSource struct {
extractor *Extractor
headerName string
headerKey string
sourceName string
unavailableErr error
}
Expand All @@ -25,7 +26,7 @@ func (s *singleHeaderSource) sourceUnavailableError() error {
}

func (s *singleHeaderSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) {
headerValues := r.Header.Values(s.headerName)
headerValues := r.Header[s.headerKey]
if len(headerValues) == 0 {
return extractionResult{}, s.sourceUnavailableError()
}
Expand Down
Loading