From 0bdd789a7a0cfe6bf4ec67b685f9d4951152893b Mon Sep 17 00:00:00 2001 From: bubbajoe Date: Fri, 11 Jul 2025 08:19:57 +0900 Subject: [PATCH 1/2] add dgate cli to the readme add allow list for proxy fix other misc smells and bugs --- README.md | 17 +++------- internal/config/config.go | 1 + internal/config/loader.go | 10 ++++++ internal/proxy/change_log.go | 2 +- internal/proxy/dynamic_proxy.go | 34 ++++++++++++++++--- internal/proxy/proxy_handler.go | 12 +++---- internal/proxy/proxy_handler_test.go | 12 +++---- internal/proxy/proxy_state.go | 6 ++-- internal/proxy/request_context.go | 6 ++-- internal/proxy/reverse_proxy/reverse_proxy.go | 17 +++------- pkg/util/iplist/iplist.go | 18 ++++++++++ {internal => pkg/util}/pattern/pattern.go | 0 .../util}/pattern/pattern_test.go | 2 +- 13 files changed, 85 insertions(+), 52 deletions(-) rename {internal => pkg/util}/pattern/pattern.go (100%) rename {internal => pkg/util}/pattern/pattern_test.go (97%) diff --git a/README.md b/README.md index 595cb5c..f4bad54 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,12 @@ http://dgate.io/docs/getting-started ```bash # requires go 1.22+ + +# install dgate-server go install github.com/dgate-io/dgate-api/cmd/dgate-server@latest + +# install dgate-cli +go install github.com/dgate-io/dgate-api/cmd/dgate-cli@latest ``` ## Application Architecture @@ -33,15 +38,3 @@ DGate Server is proxy and admin server bundled into one. the admin server is res ### DGate CLI (dgate-cli) DGate CLI is a command-line interface that can be used to interact with the DGate Server. It can be used to deploy modules, manage the state of the cluster, and more. - -#### Proxy Modules - -- Fetch Upstream Module (`fetchUpstream`) - executed before the request is sent to the upstream server. This module is used to decided which upstream server to send the current request to. (Essentially a custom load balancer module) - -- Request Modifier Module (`requestModifier`) - executed before the request is sent to the upstream server. This module is used to modify the request before it is sent to the upstream server. - -- Response Modifier Module (`responseModifier`) - executed after the response is received from the upstream server. This module is used to modify the response before it is sent to the client. - -- Error Handler Module (`errorHandler`) - executed when an error occurs when sending a request to the upstream server. This module is used to modify the response before it is sent to the client. - -- Request Handler Module (`requestHandler`) - executed when a request is received from the client. This module is used to handle arbitrary requests, instead of using an upstream service. diff --git a/internal/config/config.go b/internal/config/config.go index a8491c4..ac311fa 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -46,6 +46,7 @@ type ( DisableXForwardedHeaders bool `koanf:"disable_x_forwarded_headers"` StrictMode bool `koanf:"strict_mode"` XForwardedForDepth int `koanf:"x_forwarded_for_depth"` + AllowList []string `koanf:"allow_list"` // WARN: debug use only InitResources *DGateResources `koanf:"init_resources"` diff --git a/internal/config/loader.go b/internal/config/loader.go index 6c02469..ecb93c9 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -12,6 +12,7 @@ import ( "strings" "github.com/dgate-io/dgate-api/pkg/util" + "github.com/dgate-io/dgate-api/pkg/util/iplist" "github.com/hashicorp/raft" kjson "github.com/knadh/koanf/parsers/json" ktoml "github.com/knadh/koanf/parsers/toml" @@ -158,6 +159,15 @@ func LoadConfig(dgateConfigPath string) (*DGateConfig, error) { kDefault(k, "proxy.enable_http2", false) kDefault(k, "proxy.console_log_level", k.Get("log_level")) + if k.Exists("proxy.allow_list") { + var ips []string = k.Get("proxy.allow_list").([]string) + ipList := iplist.NewIPList() + err = ipList.AddAll(ips) + if err != nil { + return nil, errors.New("proxy.allow_list error: " + err.Error()) + } + } + if k.Get("proxy.enable_h2c") == true && k.Get("proxy.enable_http2") == false { return nil, errors.New("proxy: enable_h2c is true but enable_http2 is false") diff --git a/internal/proxy/change_log.go b/internal/proxy/change_log.go index 4aca221..13a64e1 100644 --- a/internal/proxy/change_log.go +++ b/internal/proxy/change_log.go @@ -376,7 +376,7 @@ func (ps *ProxyState) restoreFromChangeLogs(directApply bool) error { } // DISABLED: compaction of change logs needs to have better testing - if len(logs) < 0 { + if false { removed, err := ps.compactChangeLogs(logs) if err != nil { ps.logger.Error("failed to compact state change logs", zap.Error(err)) diff --git a/internal/proxy/dynamic_proxy.go b/internal/proxy/dynamic_proxy.go index a3b7fb7..ebe7eeb 100644 --- a/internal/proxy/dynamic_proxy.go +++ b/internal/proxy/dynamic_proxy.go @@ -7,12 +7,14 @@ import ( "math" "net/http" "os" + "strings" "time" "github.com/dgate-io/dgate-api/internal/router" "github.com/dgate-io/dgate-api/pkg/modules/extractors" "github.com/dgate-io/dgate-api/pkg/spec" "github.com/dgate-io/dgate-api/pkg/typescript" + "github.com/dgate-io/dgate-api/pkg/util/iplist" "github.com/dgate-io/dgate-api/pkg/util/tree/avl" "github.com/dop251/goja" "go.uber.org/zap" @@ -391,11 +393,33 @@ func (ps *ProxyState) Stop() { } } -func (ps *ProxyState) HandleRoute(requestCtxProvider *RequestContextProvider, pattern string) http.HandlerFunc { +func (ps *ProxyState) HandleRoute(ctxProvider *RequestContextProvider, pattern string) http.HandlerFunc { + ipList := iplist.NewIPList() + if len(ps.config.ProxyConfig.AllowList) > 0 { + for _, address := range ps.config.ProxyConfig.AllowList { + if strings.Contains(address, "/") { + if err := ipList.AddCIDRString(address); err != nil { + panic(fmt.Errorf("invalid cidr address in proxy.allow_list: %s", address)) + } + } else { + if err := ipList.AddIPString(address); err != nil { + panic(fmt.Errorf("invalid ip address in proxy.allow_list: %s", address)) + } + } + } + } return func(w http.ResponseWriter, r *http.Request) { - // ctx, cancel := context.WithCancel(requestCtxPrdovider.ctx) - // defer cancel() - ps.ProxyHandler(ps, requestCtxProvider. - CreateRequestContext(requestCtxProvider.ctx, w, r, pattern)) + if ipList.Len() > 0 { + allowed, err := ipList.Contains(r.RemoteAddr) + if err != nil { + ps.logger.Error("Error checking ") + } + + if !allowed { + http.Error(w, "Forbidden", http.StatusForbidden) + } + } + reqContext := ctxProvider.CreateRequestContext(w, r, pattern) + ps.ProxyHandler(ps, reqContext) } } diff --git a/internal/proxy/proxy_handler.go b/internal/proxy/proxy_handler.go index d8be7db..4612eda 100644 --- a/internal/proxy/proxy_handler.go +++ b/internal/proxy/proxy_handler.go @@ -81,7 +81,7 @@ func proxyHandler(ps *ProxyState, reqCtx *RequestContext) { } func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExtractor) { - var host string + var upstreamUrlString string if fetchUpstreamUrl, ok := modExt.FetchUpstreamUrlFunc(); ok { fetchUpstreamStart := time.Now() hostUrl, err := fetchUpstreamUrl(modExt.ModuleContext()) @@ -98,9 +98,9 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt util.WriteStatusCodeError(reqCtx.rw, http.StatusInternalServerError) return } - host = hostUrl.String() + upstreamUrlString = hostUrl.String() } else { - if reqCtx.route.Service.URLs == nil || len(reqCtx.route.Service.URLs) == 0 { + if len(reqCtx.route.Service.URLs) == 0 { ps.logger.Error("Error getting service urls", zap.String("service", reqCtx.route.Service.Name), zap.String("namespace", reqCtx.route.Namespace.Name), @@ -108,7 +108,7 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt util.WriteStatusCodeError(reqCtx.rw, http.StatusInternalServerError) return } - host = reqCtx.route.Service.URLs[0].String() + upstreamUrlString = reqCtx.route.Service.URLs[0].String() } if reqCtx.route.Service.HideDGateHeaders { @@ -122,10 +122,10 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt // downstream headers if ps.debugMode { - reqCtx.rw.Header().Set("X-Upstream-URL", host) + reqCtx.rw.Header().Set("X-Upstream-URL", upstreamUrlString) } } - upstreamUrl, err := url.Parse(host) + upstreamUrl, err := url.Parse(upstreamUrlString) if err != nil { ps.logger.Error("Error parsing upstream url", zap.String("error", err.Error()), diff --git a/internal/proxy/proxy_handler_test.go b/internal/proxy/proxy_handler_test.go index eeb6f9a..99054ab 100644 --- a/internal/proxy/proxy_handler_test.go +++ b/internal/proxy/proxy_handler_test.go @@ -1,7 +1,6 @@ package proxy_test import ( - "context" "errors" "io" "net/http" @@ -57,8 +56,7 @@ func TestProxyHandler_ReverseProxy(t *testing.T) { wr.SetWriteFallThrough() wr.On("Header").Return(http.Header{}) wr.On("Write", mock.Anything).Return(0, nil).Maybe() - reqCtx := reqCtxProvider.CreateRequestContext( - context.Background(), wr, req, "/") + reqCtx := reqCtxProvider.CreateRequestContext(wr, req, "/") modExt := NewMockModuleExtractor() modExt.ConfigureDefaultMock(req, wr, ps, rt) @@ -78,7 +76,7 @@ func TestProxyHandler_ReverseProxy(t *testing.T) { modPool.AssertExpectations(t) modExt.AssertExpectations(t) rpBuilder.AssertExpectations(t) - // rpe.AssertExpectations(t) + rpe.AssertExpectations(t) } } @@ -129,8 +127,7 @@ func TestProxyHandler_ProxyHandler(t *testing.T) { modPool.On("Return", modExt).Return().Once() reqCtxProvider.UpdateModulePool(modPool) - reqCtx := reqCtxProvider.CreateRequestContext( - context.Background(), wr, req, "/") + reqCtx := reqCtxProvider.CreateRequestContext(wr, req, "/") ps.ProxyHandler(ps, reqCtx) wr.AssertExpectations(t) @@ -181,8 +178,7 @@ func TestProxyHandler_ProxyHandlerError(t *testing.T) { modPool.On("Return", modExt).Return().Once() reqCtxProvider := proxy.NewRequestContextProvider(rt, ps) reqCtxProvider.UpdateModulePool(modPool) - reqCtx := reqCtxProvider.CreateRequestContext( - context.Background(), wr, req, "/") + reqCtx := reqCtxProvider.CreateRequestContext(wr, req, "/") ps.ProxyHandler(ps, reqCtx) wr.AssertExpectations(t) diff --git a/internal/proxy/proxy_state.go b/internal/proxy/proxy_state.go index 95e1991..c903fca 100644 --- a/internal/proxy/proxy_state.go +++ b/internal/proxy/proxy_state.go @@ -15,7 +15,6 @@ import ( "time" "github.com/dgate-io/dgate-api/internal/config" - "github.com/dgate-io/dgate-api/internal/pattern" "github.com/dgate-io/dgate-api/internal/proxy/proxy_transport" "github.com/dgate-io/dgate-api/internal/proxy/proxystore" "github.com/dgate-io/dgate-api/internal/proxy/reverse_proxy" @@ -28,6 +27,7 @@ import ( "github.com/dgate-io/dgate-api/pkg/spec" "github.com/dgate-io/dgate-api/pkg/storage" "github.com/dgate-io/dgate-api/pkg/util" + "github.com/dgate-io/dgate-api/pkg/util/pattern" "github.com/dgate-io/dgate-api/pkg/util/tree/avl" "github.com/dop251/goja" "github.com/dop251/goja_nodejs/console" @@ -299,7 +299,7 @@ func (ps *ProxyState) ApplyChangeLog(log *spec.ChangeLog) error { if err != nil { return err } - raftLog := raft.Log{ Data: encodedCL } + raftLog := raft.Log{Data: encodedCL} now := time.Now() future := r.ApplyLog(raftLog, time.Second*15) err = future.Error() @@ -448,7 +448,7 @@ func (ps *ProxyState) getDomainCertificate( var err error var cached bool defer ps.metrics.MeasureCertResolutionDuration( - ctx, start, domain,cached, err, + ctx, start, domain, cached, err, ) certBucket := ps.sharedCache.Bucket("certs") key := fmt.Sprintf("cert:%s:%s:%d", d.Namespace.Name, diff --git a/internal/proxy/request_context.go b/internal/proxy/request_context.go index d542c6b..d1a90ec 100644 --- a/internal/proxy/request_context.go +++ b/internal/proxy/request_context.go @@ -97,7 +97,7 @@ func (reqCtxProvider *RequestContextProvider) ModulePool() ModulePool { } func (reqCtxProvider *RequestContextProvider) CreateRequestContext( - ctx context.Context, rw http.ResponseWriter, + rw http.ResponseWriter, req *http.Request, pattern string, ) *RequestContext { pathParams := make(map[string]string) @@ -107,12 +107,12 @@ func (reqCtxProvider *RequestContextProvider) CreateRequestContext( } } return &RequestContext{ - ctx: ctx, + ctx: reqCtxProvider.ctx, pattern: pattern, params: pathParams, provider: reqCtxProvider, route: reqCtxProvider.route, - req: req.WithContext(ctx), + req: req.WithContext(reqCtxProvider.ctx), rw: spec.NewResponseWriterTracker(rw), } } diff --git a/internal/proxy/reverse_proxy/reverse_proxy.go b/internal/proxy/reverse_proxy/reverse_proxy.go index b409781..3a5d695 100644 --- a/internal/proxy/reverse_proxy/reverse_proxy.go +++ b/internal/proxy/reverse_proxy/reverse_proxy.go @@ -152,22 +152,13 @@ func (b *reverseProxyBuilder) rewriteStripPath(strip bool) RewriteFunc { func (b *reverseProxyBuilder) rewritePreserveHost(preserve bool) RewriteFunc { return func(in, out *http.Request) { - scheme := "http" - out.URL.Host = b.upstreamUrl.Host if preserve { out.Host = in.Host - if out.Host == "" { - out.Host = out.URL.Host + inHost := in.Header.Get("Host") + if inHost == "" { + inHost = in.Host } - if in.TLS != nil { - scheme = "https" - } - } else { - out.Host = out.URL.Host - scheme = b.upstreamUrl.Scheme - } - if out.URL.Scheme == "" { - out.URL.Scheme = scheme + out.Header.Set("Host", inHost) } } } diff --git a/pkg/util/iplist/iplist.go b/pkg/util/iplist/iplist.go index 952f739..c5851b4 100644 --- a/pkg/util/iplist/iplist.go +++ b/pkg/util/iplist/iplist.go @@ -2,8 +2,10 @@ package iplist import ( "bytes" + "errors" "fmt" "net" + "strings" "github.com/dgate-io/dgate-api/pkg/util/linkedlist" ) @@ -20,6 +22,22 @@ func NewIPList() *IPList { } } + +func (l *IPList) AddAll(list []string) error { + for _, ip := range list { + if strings.Contains(ip, "/") { + if err := l.AddCIDRString(ip); err != nil { + return errors.New(ip + ": " + err.Error()) + } + } else { + if err := l.AddIPString(ip); err != nil { + return errors.New(ip + ": " + err.Error()) + } + } + } + return nil +} + func (l *IPList) AddCIDRString(cidr string) error { _, ipn, err := net.ParseCIDR(cidr) if err != nil { diff --git a/internal/pattern/pattern.go b/pkg/util/pattern/pattern.go similarity index 100% rename from internal/pattern/pattern.go rename to pkg/util/pattern/pattern.go diff --git a/internal/pattern/pattern_test.go b/pkg/util/pattern/pattern_test.go similarity index 97% rename from internal/pattern/pattern_test.go rename to pkg/util/pattern/pattern_test.go index fe39d27..1bed596 100644 --- a/internal/pattern/pattern_test.go +++ b/pkg/util/pattern/pattern_test.go @@ -4,7 +4,7 @@ import ( "strings" "testing" - "github.com/dgate-io/dgate-api/internal/pattern" + "github.com/dgate-io/dgate-api/pkg/util/pattern" "github.com/stretchr/testify/assert" ) From 21dfce0ef3b6dfef218a7e802cd96e1980bcacb5 Mon Sep 17 00:00:00 2001 From: bubbajoe Date: Sat, 12 Jul 2025 09:26:16 +0900 Subject: [PATCH 2/2] fix deadlock and fix reverse proxy scheme issue --- .gitignore | 3 +- config.dgate.yaml | 2 +- internal/proxy/change_log.go | 25 +++++---- internal/proxy/proxy_state.go | 53 ++++++++++++++++--- internal/proxy/reverse_proxy/reverse_proxy.go | 16 +++--- .../proxy/reverse_proxy/reverse_proxy_test.go | 51 ++++++++++-------- pkg/util/keylock/keylock.go | 4 +- 7 files changed, 104 insertions(+), 50 deletions(-) diff --git a/.gitignore b/.gitignore index f64bac2..df61396 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .dgate cov.out go.work.sum -.env \ No newline at end of file +.env +coverage.txt \ No newline at end of file diff --git a/config.dgate.yaml b/config.dgate.yaml index dc946f1..6526637 100644 --- a/config.dgate.yaml +++ b/config.dgate.yaml @@ -1,6 +1,6 @@ version: v1 debug: true -log_level: ${LOG_LEVEL:-info} +log_level: ${LOG_LEVEL:-debug} disable_default_namespace: true tags: [debug, local, test] storage: diff --git a/internal/proxy/change_log.go b/internal/proxy/change_log.go index 13a64e1..9f0b176 100644 --- a/internal/proxy/change_log.go +++ b/internal/proxy/change_log.go @@ -14,7 +14,7 @@ import ( ) // processChangeLog - processes a change log and applies the change to the proxy state -func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) (err error) { +func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) (restartNeeded bool, err error) { if reload { defer func(start time.Time) { if err != nil { @@ -68,12 +68,6 @@ func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) ( } else if !ps.changeHash.CompareAndSwap(oldHash, newHash) { goto hash_retry } - } else { - go ps.restartState(func(err error) { - if err != nil { - ps.Stop() - } - }) } }() if cl.Cmd.Resource() == spec.Documents { @@ -99,16 +93,17 @@ func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) ( // apply state changes to the proxy if reload { + ps.logger.Debug("Reloading change log", zap.String("id", cl.ID)) overrideReload := cl.Cmd.IsNoop() || ps.pendingChanges if overrideReload || cl.Cmd.Resource().IsRelatedTo(spec.Routes) { if err := ps.storeCachedDocuments(); err != nil { ps.logger.Error("error storing cached documents", zap.Error(err)) - return err + return false, err } ps.logger.Debug("Reloading change log", zap.String("id", cl.ID)) if err = ps.reconfigureState(cl); err != nil { ps.logger.Error("Error registering change log", zap.Error(err)) - return + return false, err } ps.pendingChanges = false } @@ -116,7 +111,7 @@ func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) ( ps.pendingChanges = true } - return nil + return restartNeeded, nil } func decode[T any](input any) (T, error) { @@ -353,7 +348,8 @@ func (ps *ProxyState) restoreFromChangeLogs(directApply bool) error { if cl.Cmd.Resource() == spec.Documents { continue } - if err = ps.processChangeLog(cl, false, false); err != nil { + _, err = ps.processChangeLog(cl, false, false) + if err != nil { ps.logger.Error("error processing change log", zap.Bool("skip", ps.debugMode), zap.Error(err), @@ -371,8 +367,11 @@ func (ps *ProxyState) restoreFromChangeLogs(directApply bool) error { if err = ps.reconfigureState(cl); err != nil { return err } - } else if err = ps.processChangeLog(cl, true, false); err != nil { - return err + } else { + _, err = ps.processChangeLog(cl, true, false) + if err != nil { + return err + } } // DISABLED: compaction of change logs needs to have better testing diff --git a/internal/proxy/proxy_state.go b/internal/proxy/proxy_state.go index c903fca..b94d1d2 100644 --- a/internal/proxy/proxy_state.go +++ b/internal/proxy/proxy_state.go @@ -292,9 +292,17 @@ func (ps *ProxyState) ApplyChangeLog(log *spec.ChangeLog) error { if r.State() != raft.Leader { return raft.ErrNotLeader } - if err := ps.processChangeLog(log, true, false); err != nil { + restartNeeded, err := ps.processChangeLog(log, true, false) + if err != nil { return err } + if restartNeeded { + go ps.restartState(func(err error) { + if err != nil { + ps.Stop() + } + }) + } encodedCL, err := json.Marshal(log) if err != nil { return err @@ -316,7 +324,15 @@ func (ps *ProxyState) ApplyChangeLog(log *spec.ChangeLog) error { } return err } else { - return ps.processChangeLog(log, true, true) + restartNeeded, err := ps.processChangeLog(log, true, true) + if restartNeeded { + go ps.restartState(func(err error) { + if err != nil { + ps.Stop() + } + }) + } + return err } } @@ -337,7 +353,6 @@ func (ps *ProxyState) SharedCache() cache.TCache { func (ps *ProxyState) restartState(fn func(error)) { ps.logger.Info("Attempting to restart state...") ps.proxyLock.Lock() - defer ps.proxyLock.Unlock() ps.changeHash.Store(0) ps.pendingChanges = false ps.rm.Empty() @@ -346,6 +361,8 @@ func (ps *ProxyState) restartState(fn func(error)) { ps.routers.Clear() ps.sharedCache.Clear() ps.skdr.Stop() + ps.proxyLock.Unlock() // unlock before resource init and restore + if err := ps.initConfigResources(ps.config.ProxyConfig.InitResources); err != nil { go fn(err) return @@ -371,16 +388,32 @@ func (ps *ProxyState) ReloadState(check bool, logs ...*spec.ChangeLog) error { } } if reload { - return ps.processChangeLog(nil, true, false) + restartNeeded, err := ps.processChangeLog(nil, true, false) + if restartNeeded { + go ps.restartState(func(err error) { + if err != nil { + ps.Stop() + } + }) + } + return err } return nil } func (ps *ProxyState) ProcessChangeLog(log *spec.ChangeLog, reload bool) error { - if err := ps.processChangeLog(log, reload, true); err != nil { + restartNeeded, err := ps.processChangeLog(log, reload, true) + if err != nil { ps.logger.Error("processing error", zap.Error(err)) return err } + if restartNeeded { + go ps.restartState(func(err error) { + if err != nil { + ps.Stop() + } + }) + } return nil } @@ -478,7 +511,15 @@ func (ps *ProxyState) getDomainCertificate( func (ps *ProxyState) initConfigResources(resources *config.DGateResources) error { processCL := func(cl *spec.ChangeLog) error { - return ps.processChangeLog(cl, false, false) + restartNeeded, err := ps.processChangeLog(cl, false, false) + if restartNeeded { + go ps.restartState(func(err error) { + if err != nil { + ps.Stop() + } + }) + } + return err } if resources != nil { numChanges, err := resources.Validate() diff --git a/internal/proxy/reverse_proxy/reverse_proxy.go b/internal/proxy/reverse_proxy/reverse_proxy.go index 3a5d695..485ab83 100644 --- a/internal/proxy/reverse_proxy/reverse_proxy.go +++ b/internal/proxy/reverse_proxy/reverse_proxy.go @@ -133,18 +133,13 @@ func (b *reverseProxyBuilder) rewriteStripPath(strip bool) RewriteFunc { in.URL = b.upstreamUrl if strip { if strings.HasSuffix(proxyPatternPath, "*") { - // this will remove the proxy path before the wildcard from the upstream url - // ex. (upstreamPath: /v1, proxyPattern: '/path/*', reqCall: '/path/test') -> '/v1/test' proxyPattern := strings.TrimSuffix(proxyPatternPath, "*") reqCallNoProxy := strings.TrimPrefix(reqCall, proxyPattern) out.URL.Path = path.Join(upstreamPath, reqCallNoProxy) } else { - // this will remove the proxy path from the upstream url - // ex. (upstreamPath: /v1, proxyPattern: '/path/{id}', reqCall: '/path/1') -> '/v1' out.URL.Path = upstreamPath } } else { - // ex. (upstreamPath: /v1, proxyPattern: '/path/*', reqCall: '/path/test') -> '/v1/path/test' out.URL.Path = path.Join(upstreamPath, reqCall) } } @@ -159,6 +154,9 @@ func (b *reverseProxyBuilder) rewritePreserveHost(preserve bool) RewriteFunc { inHost = in.Host } out.Header.Set("Host", inHost) + } else { + out.Host = b.upstreamUrl.Host + out.Header.Set("Host", b.upstreamUrl.Host) } } } @@ -240,6 +238,10 @@ func (b *reverseProxyBuilder) Build(upstreamUrl *url.URL, proxyPattern string) ( proxy.Transport = b.transport proxy.ErrorLog = b.errorLogger proxy.Rewrite = func(pr *httputil.ProxyRequest) { + // Ensure scheme and host are set correctly + pr.Out.URL.Scheme = b.upstreamUrl.Scheme + pr.Out.URL.Host = b.upstreamUrl.Host + b.rewriteStripPath(b.stripPath)(pr.In, pr.Out) b.rewritePreserveHost(b.preserveHost)(pr.In, pr.Out) b.rewriteDisableQueryParams(b.disableQueryParams)(pr.In, pr.Out) @@ -247,8 +249,8 @@ func (b *reverseProxyBuilder) Build(upstreamUrl *url.URL, proxyPattern string) ( if b.customRewrite != nil { b.customRewrite(pr.In, pr.Out) } - if pr.Out.URL.Path == "/" { - pr.Out.URL.Path = "" + if pr.Out.Host == "" { + pr.Out.Host = upstreamUrl.Host } } return proxy, nil diff --git a/internal/proxy/reverse_proxy/reverse_proxy_test.go b/internal/proxy/reverse_proxy/reverse_proxy_test.go index 6fd33bb..7f3a606 100644 --- a/internal/proxy/reverse_proxy/reverse_proxy_test.go +++ b/internal/proxy/reverse_proxy/reverse_proxy_test.go @@ -17,10 +17,12 @@ import ( type ProxyParams struct { host string - newHost string + expectedHost string upstreamUrl *url.URL - newUpsteamURL *url.URL + expectedUpsteamURL *url.URL + expectedScheme string + expectedPath string proxyPattern string proxyPath string @@ -75,14 +77,21 @@ func testDGateProxyRewrite( mockTp.On("RoundTrip", mock.Anything).Run(func(args mock.Arguments) { req := args.Get(0).(*http.Request) - if req.URL.String() != params.newUpsteamURL.String() { - t.Errorf("FAIL: Expected URL %s, got %s", params.newUpsteamURL, req.URL) + if req.URL.String() != params.expectedUpsteamURL.String() { + t.Errorf("FAIL: Expected URL %s, got %s", params.expectedUpsteamURL, req.URL) } else { t.Logf("PASS: upstreamUrl: %s, proxyPattern: %s, proxyPath: %s, newUpsteamURL: %s", - params.upstreamUrl, params.proxyPattern, params.proxyPath, params.newUpsteamURL) + params.upstreamUrl, params.proxyPattern, params.proxyPath, params.expectedUpsteamURL) } - if params.newHost != "" && req.Host != params.newHost { - t.Errorf("FAIL: Expected Host %s, got %s", params.newHost, req.Host) + if params.expectedHost != "" && + (req.Host != params.expectedHost || req.Header.Get("Host") != params.expectedHost) { + t.Errorf("FAIL: Expected Host %s, got (%s | %s)", params.expectedHost, req.Host, req.Header.Get("Host")) + } + if params.expectedScheme != "" && req.URL.Scheme != params.expectedScheme { + t.Errorf("FAIL: Expected Scheme %s, got %s", params.expectedScheme, req.URL.Scheme) + } + if params.expectedPath != "" && req.URL.Path != params.expectedPath { + t.Errorf("FAIL: Expected Path %s, got %s", params.expectedPath, req.URL.Path) } if rewriteParams.xForwardedHeaders { if req.Header.Get("X-Forwarded-For") == "" { @@ -182,10 +191,10 @@ func TestDGateProxyRewriteStripPath(t *testing.T) { // if proxy pattern is a prefix (ends with *) testDGateProxyRewrite(t, ProxyParams{ host: "test.net", - newHost: "example.com", + expectedHost: "example.com", upstreamUrl: mustParseURL(t, "http://example.com"), - newUpsteamURL: mustParseURL(t, "http://example.com/test/ing"), + expectedUpsteamURL: mustParseURL(t, "http://example.com/test/ing"), proxyPattern: "/test/*", proxyPath: "/test/test/ing", @@ -198,10 +207,10 @@ func TestDGateProxyRewriteStripPath(t *testing.T) { testDGateProxyRewrite(t, ProxyParams{ host: "test.net", - newHost: "example.com", + expectedHost: "example.com", upstreamUrl: mustParseURL(t, "http://example.com/pre"), - newUpsteamURL: mustParseURL(t, "http://example.com/pre"), + expectedUpsteamURL: mustParseURL(t, "http://example.com/pre"), proxyPattern: "/test/*", proxyPath: "/test/", @@ -216,10 +225,10 @@ func TestDGateProxyRewriteStripPath(t *testing.T) { func TestDGateProxyRewritePreserveHost(t *testing.T) { testDGateProxyRewrite(t, ProxyParams{ upstreamUrl: mustParseURL(t, "http://example.com"), - newUpsteamURL: mustParseURL(t, "http://example.com/test"), + expectedUpsteamURL: mustParseURL(t, "http://example.com/test"), host: "test.net", - newHost: "test.net", + expectedHost: "test.net", proxyPattern: "/test", proxyPath: "/test", @@ -234,10 +243,10 @@ func TestDGateProxyRewritePreserveHost(t *testing.T) { func TestDGateProxyRewriteDisableQueryParams(t *testing.T) { testDGateProxyRewrite(t, ProxyParams{ upstreamUrl: mustParseURL(t, "http://example.com"), - newUpsteamURL: mustParseURL(t, "http://example.com/test"), + expectedUpsteamURL: mustParseURL(t, "http://example.com/test"), host: "test.net", - newHost: "example.com", + expectedHost: "example.com", proxyPattern: "/test", proxyPath: "/test?testing=testing", @@ -250,10 +259,10 @@ func TestDGateProxyRewriteDisableQueryParams(t *testing.T) { testDGateProxyRewrite(t, ProxyParams{ upstreamUrl: mustParseURL(t, "http://example.com"), - newUpsteamURL: mustParseURL(t, "http://example.com/test?testing=testing"), + expectedUpsteamURL: mustParseURL(t, "http://example.com/test?testing=testing"), host: "test.net", - newHost: "example.com", + expectedHost: "example.com", proxyPattern: "/test", proxyPath: "/test?testing=testing", @@ -268,10 +277,10 @@ func TestDGateProxyRewriteDisableQueryParams(t *testing.T) { func TestDGateProxyRewriteXForwardedHeaders(t *testing.T) { testDGateProxyRewrite(t, ProxyParams{ upstreamUrl: mustParseURL(t, "http://example.com"), - newUpsteamURL: mustParseURL(t, "http://example.com/test"), + expectedUpsteamURL: mustParseURL(t, "http://example.com/test"), host: "test.net", - newHost: "example.com", + expectedHost: "example.com", proxyPattern: "/test", proxyPath: "/test", @@ -284,10 +293,10 @@ func TestDGateProxyRewriteXForwardedHeaders(t *testing.T) { testDGateProxyRewrite(t, ProxyParams{ upstreamUrl: mustParseURL(t, "http://example.com"), - newUpsteamURL: mustParseURL(t, "http://example.com/test"), + expectedUpsteamURL: mustParseURL(t, "http://example.com/test"), host: "test.net", - newHost: "example.com", + expectedHost: "example.com", proxyPattern: "/test", proxyPath: "/test", diff --git a/pkg/util/keylock/keylock.go b/pkg/util/keylock/keylock.go index 3714eb6..dabbadc 100644 --- a/pkg/util/keylock/keylock.go +++ b/pkg/util/keylock/keylock.go @@ -1,6 +1,8 @@ package keylock -import "sync" +import ( + "sync" +) type KeyLock struct { locks map[string]*sync.RWMutex