From 582ff0a5e4293869f31620461f63a76fc875a645 Mon Sep 17 00:00:00 2001 From: ParsaKSH Date: Mon, 30 Mar 2026 03:09:01 +0330 Subject: [PATCH 1/4] Revert checker and keepalive system - Add connection pool now, in split packet strategy, we have 3 reserve connection --- cmd/centralserver/main.go | 4 - cmd/slipstreamplus/main.go | 3 - internal/health/checker.go | 27 +--- internal/tunnel/pool.go | 261 ++++++++++++++++++++----------------- 4 files changed, 144 insertions(+), 151 deletions(-) diff --git a/cmd/centralserver/main.go b/cmd/centralserver/main.go index 4a8f7ab..99a82cc 100644 --- a/cmd/centralserver/main.go +++ b/cmd/centralserver/main.go @@ -311,10 +311,6 @@ func (cs *centralServer) parseHeader(hdr [tunnel.HeaderSize]byte, conn net.Conn, } func (cs *centralServer) dispatchFrame(frame *tunnel.Frame, source *sourceConn) { - // ConnID 0 is a keepalive from TunnelPool — ignore silently - if frame.ConnID == 0 { - return - } if frame.IsSYN() { cs.handleSYN(frame, source) return diff --git a/cmd/slipstreamplus/main.go b/cmd/slipstreamplus/main.go index 1490131..780fca9 100644 --- a/cmd/slipstreamplus/main.go +++ b/cmd/slipstreamplus/main.go @@ -99,9 +99,6 @@ func main() { if isPacketSplit { tunnelPool = tunnel.NewTunnelPool(mgr) tunnelPool.Start() - // Tell health checker to use TunnelPool for probes instead of - // creating separate TCP connections that interfere with the tunnel. - checker.SetTunnelPool(tunnelPool) log.Printf("Packet-split mode: central_server=%s, chunk_size=%d", cfg.CentralServer.Address, cfg.CentralServer.ChunkSize) } diff --git a/internal/health/checker.go b/internal/health/checker.go index ef8f834..4afbbb2 100644 --- a/internal/health/checker.go +++ b/internal/health/checker.go @@ -25,18 +25,12 @@ import ( // Latency is only set from successful tunnel probes (real RTT). const maxConsecutiveFailures = 3 -// tunnelHealthChecker is implemented by TunnelPool to report active tunnels. -type tunnelHealthChecker interface { - HasActiveTunnel(instID int) bool -} - type Checker struct { manager *engine.Manager interval time.Duration timeout time.Duration target string // health_check.target (e.g. "google.com") packetSplit bool // true when strategy=packet_split - tunnelPool tunnelHealthChecker // set in packet_split mode ctx context.Context cancel context.CancelFunc @@ -81,14 +75,6 @@ func (c *Checker) SetPacketSplit(enabled bool) { c.packetSplit = enabled } -// SetTunnelPool provides the TunnelPool so the health checker can skip -// creating separate TCP connections when the TunnelPool already has an -// active connection to an instance. This prevents interference with the -// persistent tunnel connections used in packet_split mode. -func (c *Checker) SetTunnelPool(pool tunnelHealthChecker) { - c.tunnelPool = pool -} - func (c *Checker) Stop() { c.cancel() } @@ -190,22 +176,13 @@ func (c *Checker) checkOne(inst *engine.Instance) { } // Step 3: End-to-end probe. - // In packet_split mode: if TunnelPool has an active connection, skip the - // separate framing probe entirely — the TunnelPool's keepalive + stale - // detection provides continuous health monitoring. Opening a separate TCP - // connection can interfere with the DNS tunnel's persistent connection. - // Only fall back to probeFramingProtocol for initial validation (before - // TunnelPool has connected). + // In packet_split mode: test if instance's upstream speaks our framing protocol. // In normal mode: full SOCKS5 CONNECT + HTTP through the tunnel. if c.target != "" && inst.Config.Mode != "ssh" { var e2eRtt time.Duration var e2eErr error - if c.packetSplit && c.tunnelPool != nil && c.tunnelPool.HasActiveTunnel(inst.ID()) { - // TunnelPool is connected — tunnel is working, skip separate probe - e2eRtt = rtt - } else if c.packetSplit { - // TunnelPool not yet connected — need separate probe for initial health check + if c.packetSplit { e2eRtt, e2eErr = c.probeFramingProtocol(inst) } else { e2eRtt, e2eErr = c.probeEndToEnd(inst) diff --git a/internal/tunnel/pool.go b/internal/tunnel/pool.go index ce44d11..aa86678 100644 --- a/internal/tunnel/pool.go +++ b/internal/tunnel/pool.go @@ -18,48 +18,46 @@ import ( const writeTimeout = 10 * time.Second // staleThreshold: if we've sent data but haven't received anything -// in this long, the connection is considered half-dead and will be -// force-closed by refreshConnections (which triggers reconnect). +// in this long, the connection is considered half-dead. const staleThreshold = 15 * time.Second -// maxTunnelAge: force-reconnect tunnels older than this, even if they -// appear healthy. Prevents long-lived connection degradation in DNS tunnels. -const maxTunnelAge = 3 * time.Minute - -// keepaliveInterval: how often to send keepalive frames to detect dead tunnels. -const keepaliveInterval = 10 * time.Second +// ConnsPerInstance is the number of persistent connections to maintain +// per healthy instance. When one dies, it's replaced on the next refresh. +// Multiple connections provide redundancy — if one degrades, others serve traffic. +const ConnsPerInstance = 3 // TunnelConn wraps a persistent TCP connection to a single instance. type TunnelConn struct { - inst *engine.Instance - mu sync.Mutex - conn net.Conn - writeMu sync.Mutex - closed bool - createdAt time.Time + inst *engine.Instance + mu sync.Mutex + conn net.Conn + writeMu sync.Mutex + closed bool lastRead atomic.Int64 // unix millis of last successful read lastWrite atomic.Int64 // unix millis of last successful write } -// TunnelPool manages ONE persistent connection per healthy instance. +// TunnelPool manages multiple persistent connections per healthy instance. +// All connections serve the same handler map, providing redundancy. type TunnelPool struct { - mgr *engine.Manager - mu sync.RWMutex - tunnels map[int]*TunnelConn - handlers sync.Map // ConnID (uint32) → chan *Frame - stopCh chan struct{} - wg sync.WaitGroup - ready chan struct{} // closed when at least one tunnel is connected - readyOnce sync.Once + mgr *engine.Manager + mu sync.RWMutex + conns map[int][]*TunnelConn // instance ID → pool of connections + handlers sync.Map // ConnID (uint32) → chan *Frame + stopCh chan struct{} + wg sync.WaitGroup + ready chan struct{} + readyOnce sync.Once + roundRobin atomic.Uint64 } func NewTunnelPool(mgr *engine.Manager) *TunnelPool { return &TunnelPool{ - mgr: mgr, - tunnels: make(map[int]*TunnelConn), - stopCh: make(chan struct{}), - ready: make(chan struct{}), + mgr: mgr, + conns: make(map[int][]*TunnelConn), + stopCh: make(chan struct{}), + ready: make(chan struct{}), } } @@ -76,17 +74,28 @@ func (p *TunnelPool) WaitReady(ctx context.Context) bool { // HasTunnels returns true if at least one tunnel is currently connected. func (p *TunnelPool) HasTunnels() bool { p.mu.RLock() - n := len(p.tunnels) - p.mu.RUnlock() - return n > 0 + defer p.mu.RUnlock() + for _, conns := range p.conns { + for _, tc := range conns { + if !tc.closed { + return true + } + } + } + return false } -// HasActiveTunnel returns true if the given instance has a non-closed tunnel. +// HasActiveTunnel returns true if the given instance has at least one non-closed tunnel. func (p *TunnelPool) HasActiveTunnel(instID int) bool { p.mu.RLock() - tc, ok := p.tunnels[instID] + conns := p.conns[instID] p.mu.RUnlock() - return ok && !tc.closed + for _, tc := range conns { + if !tc.closed { + return true + } + } + return false } func (p *TunnelPool) Start() { @@ -95,33 +104,31 @@ func (p *TunnelPool) Start() { p.wg.Add(1) go func() { defer p.wg.Done() - refreshTicker := time.NewTicker(5 * time.Second) - keepaliveTicker := time.NewTicker(keepaliveInterval) - defer refreshTicker.Stop() - defer keepaliveTicker.Stop() + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() for { select { case <-p.stopCh: return - case <-refreshTicker.C: + case <-ticker.C: p.refreshConnections() - case <-keepaliveTicker.C: - p.sendKeepalives() } } }() - log.Printf("[tunnel-pool] started (stale_threshold=%s, max_age=%s, keepalive=%s)", - staleThreshold, maxTunnelAge, keepaliveInterval) + log.Printf("[tunnel-pool] started (conns_per_instance=%d, stale_threshold=%s)", + ConnsPerInstance, staleThreshold) } func (p *TunnelPool) Stop() { close(p.stopCh) p.mu.Lock() - for _, tc := range p.tunnels { - tc.close() + for _, conns := range p.conns { + for _, tc := range conns { + tc.close() + } } - p.tunnels = make(map[int]*TunnelConn) + p.conns = make(map[int][]*TunnelConn) p.mu.Unlock() p.wg.Wait() log.Printf("[tunnel-pool] stopped") @@ -140,49 +147,39 @@ func (p *TunnelPool) UnregisterConn(connID uint32) { } } +// SendFrame sends a frame through one of the instance's pooled connections. +// Uses round-robin with automatic failover to the next connection on error. func (p *TunnelPool) SendFrame(instID int, f *Frame) error { p.mu.RLock() - tc, ok := p.tunnels[instID] + conns := p.conns[instID] + n := len(conns) p.mu.RUnlock() - if !ok { + if n == 0 { return fmt.Errorf("no tunnel for instance %d", instID) } - return tc.writeFrame(f) -} - -// sendKeepalives writes a small keepalive frame through each tunnel. -// This detects dead tunnels faster than waiting for stale detection, -// and keeps DNS tunnel sessions alive. -func (p *TunnelPool) sendKeepalives() { - p.mu.RLock() - tunnels := make([]*TunnelConn, 0, len(p.tunnels)) - for _, tc := range p.tunnels { - tunnels = append(tunnels, tc) - } - p.mu.RUnlock() - - keepalive := &Frame{ - ConnID: 0, // reserved — CentralServer ignores ConnID 0 - SeqNum: 0, - Flags: FlagData, - Payload: nil, - } - - for _, tc := range tunnels { - if err := tc.writeFrame(keepalive); err != nil { - log.Printf("[tunnel-pool] instance %d: keepalive failed: %v", tc.inst.ID(), err) + // Round-robin with fallback: try each connection once + start := int(p.roundRobin.Add(1)) % n + for i := 0; i < n; i++ { + tc := conns[(start+i)%n] + if tc.closed { + continue + } + if err := tc.writeFrame(f); err != nil { + // Mark dead — will be cleaned up and replaced by refreshConnections tc.close() + continue } + return nil } + return fmt.Errorf("all %d tunnels for instance %d failed", n, instID) } -// refreshConnections reconnects dead/stale/old tunnels and adds new ones. +// refreshConnections removes dead/stale connections and replenishes pools. func (p *TunnelPool) refreshConnections() { healthy := p.mgr.HealthyInstances() - now := time.Now() - nowMs := now.UnixMilli() + nowMs := time.Now().UnixMilli() p.mu.Lock() defer p.mu.Unlock() @@ -192,58 +189,82 @@ func (p *TunnelPool) refreshConnections() { activeIDs[inst.ID()] = true } - for id, tc := range p.tunnels { - shouldRemove := false - reason := "" - + // Phase 1: Remove dead connections, clean up unhealthy instances + for id, conns := range p.conns { if !activeIDs[id] { - shouldRemove = true - reason = "instance unhealthy" - } else if tc.closed { - shouldRemove = true - reason = "connection closed" - } else if now.Sub(tc.createdAt) > maxTunnelAge { - // Force-reconnect old connections to prevent DNS tunnel degradation - shouldRemove = true - reason = fmt.Sprintf("max age exceeded (%s)", now.Sub(tc.createdAt).Round(time.Second)) - } else { - // Detect half-dead connections - lastW := tc.lastWrite.Load() - lastR := tc.lastRead.Load() - if lastW > 0 && (nowMs-lastR) > staleThreshold.Milliseconds() { + for _, tc := range conns { + tc.close() + } + delete(p.conns, id) + continue + } + + // Filter out closed/stale connections (new slice to avoid sharing backing array) + var alive []*TunnelConn + for _, tc := range conns { + shouldRemove := false + if tc.closed { shouldRemove = true - reason = fmt.Sprintf("stale (last_read=%dms ago, last_write=%dms ago)", nowMs-lastR, nowMs-lastW) + } else { + lastW := tc.lastWrite.Load() + lastR := tc.lastRead.Load() + if lastW > 0 && (nowMs-lastR) > staleThreshold.Milliseconds() { + log.Printf("[tunnel-pool] instance %d: stale conn (last_read=%dms ago), replacing", + id, nowMs-lastR) + shouldRemove = true + } + } + + if shouldRemove { + tc.close() + } else { + alive = append(alive, tc) } } - if shouldRemove { - log.Printf("[tunnel-pool] instance %d: removing (%s)", id, reason) - tc.close() - delete(p.tunnels, id) + if len(alive) == 0 { + delete(p.conns, id) + } else { + p.conns[id] = alive } } + // Phase 2: Replenish — ensure each healthy instance has ConnsPerInstance connections for _, inst := range healthy { if inst.Config.Mode == "ssh" { continue } - if _, exists := p.tunnels[inst.ID()]; exists { - continue - } - tc, err := p.connectInstance(inst) - if err != nil { - continue + id := inst.ID() + current := len(p.conns[id]) + need := ConnsPerInstance - current + + if need > 0 { + added := 0 + for i := 0; i < need; i++ { + tc, err := p.connectInstance(inst) + if err != nil { + break // dial failed, stop trying this instance + } + p.conns[id] = append(p.conns[id], tc) + added++ + } + if added > 0 { + log.Printf("[tunnel-pool] instance %d: +%d connections (now %d/%d)", + id, added, len(p.conns[id]), ConnsPerInstance) + } } - p.tunnels[inst.ID()] = tc - log.Printf("[tunnel-pool] connected to instance %d (%s:%d)", - inst.ID(), inst.Config.Domain, inst.Config.Port) } - // Signal readiness once we have at least one tunnel - if len(p.tunnels) > 0 { + // Signal readiness + total := 0 + for _, conns := range p.conns { + total += len(conns) + } + if total > 0 { p.readyOnce.Do(func() { close(p.ready) - log.Printf("[tunnel-pool] ready (%d tunnels connected)", len(p.tunnels)) + log.Printf("[tunnel-pool] ready (%d total connections across %d instances)", + total, len(p.conns)) }) } } @@ -261,21 +282,20 @@ func (p *TunnelPool) connectInstance(inst *engine.Instance) (*TunnelConn, error) } now := time.Now() - tunnel := &TunnelConn{ - inst: inst, - conn: conn, - createdAt: now, + tc := &TunnelConn{ + inst: inst, + conn: conn, } - tunnel.lastRead.Store(now.UnixMilli()) - tunnel.lastWrite.Store(0) + tc.lastRead.Store(now.UnixMilli()) + tc.lastWrite.Store(0) p.wg.Add(1) go func() { defer p.wg.Done() - p.readLoop(tunnel) + p.readLoop(tc) }() - return tunnel, nil + return tc, nil } // readLoop reads frames WITHOUT any read deadline. @@ -283,6 +303,9 @@ func (p *TunnelPool) connectInstance(inst *engine.Instance) (*TunnelConn, error) // header reads corrupt the entire frame stream (all subsequent reads get // misaligned). The only way to stop readLoop is to close the connection // (via tc.close()), which makes ReadFrame return an error. +// +// Each connection in the pool has its own readLoop. All readLoops dispatch +// to the same handlers map, so frames from any connection reach the right handler. func (p *TunnelPool) readLoop(tc *TunnelConn) { for { select { @@ -294,7 +317,7 @@ func (p *TunnelPool) readLoop(tc *TunnelConn) { frame, err := ReadFrame(tc.conn) if err != nil { if err != io.EOF && !isClosedErr(err) { - log.Printf("[tunnel-pool] instance %d read error: %v", tc.inst.ID(), err) + log.Printf("[tunnel-pool] instance %d: read error: %v", tc.inst.ID(), err) } tc.close() return @@ -302,7 +325,7 @@ func (p *TunnelPool) readLoop(tc *TunnelConn) { tc.lastRead.Store(time.Now().UnixMilli()) - // Skip keepalive responses (ConnID 0) + // ConnID 0 is reserved (keepalive) — skip if frame.ConnID == 0 { continue } From 9083b9c7a8d4477f654e1d01198b1bf842ff9d4d Mon Sep 17 00:00:00 2001 From: ParsaKSH Date: Tue, 31 Mar 2026 04:11:13 +0330 Subject: [PATCH 2/4] =?UTF-8?q?Per-instance=20real-time=20blocking:**=20-?= =?UTF-8?q?=20`blockedInstances()`=20scans=20the=20pending=20map=20and=20r?= =?UTF-8?q?eturns=20all=20instance=20=20=20IDs=20that=20have=20at=20least?= =?UTF-8?q?=20one=20frame=20older=20than=20`ackTimeout`=20(3s)=20without?= =?UTF-8?q?=20=20=20ACK=20-=20`pickInstance()`=20snapshots=20the=20blocked?= =?UTF-8?q?=20set=20first=20(under=20`pendingMu`),=20=20=20then=20selects?= =?UTF-8?q?=20under=20`mu`=20=E2=80=94=20no=20deadlock=20risk=20-=20Blocke?= =?UTF-8?q?d=20instances=20are=20skipped=20in=20weighted=20round-robin=20a?= =?UTF-8?q?nd=20in=20fallback=20=20=20selection=20-=20Last-resort=20fallba?= =?UTF-8?q?ck:=20if=20ALL=20instances=20are=20blocked,=20still=20picks=20a?= =?UTF-8?q?=20=20=20healthy=20one=20to=20avoid=20total=20stall?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **No concurrency/performance impact:** - `blockedInstances()` is O(n) where n = pending frames — small for DNS tunnel bandwidth (~36-375 frames at typical speeds) - Lazy `blocked` map allocation (nil until first overdue frame found) - Lock held briefly — no I/O under lock - `pickInstanceExcluding()` also has fallback to avoid retry deadlock when all alternatives are blocked **Flow:** 1. Frame sent to Instance A → tracked in `pending[SeqNum] = {instID: A, sentAt: now}` 2. ACK arrives → `delete(pending, SeqNum)` → Instance A stays unblocked 3. No ACK after 3s → `blockedInstances()` returns `{A: true}` → new packets skip A 4. Retry goroutine resends frame via Instance B, updates `pf.instID = B` 5. If ACK arrives via B → A may become unblocked (if no other overdue frames) --- cmd/centralserver/main.go | 8 +++ internal/health/checker.go | 33 +++++++-- internal/tunnel/pool.go | 2 +- internal/tunnel/splitter.go | 139 +++++++++++++++++++++++++++++++++++- 4 files changed, 173 insertions(+), 9 deletions(-) diff --git a/cmd/centralserver/main.go b/cmd/centralserver/main.go index 99a82cc..19702ae 100644 --- a/cmd/centralserver/main.go +++ b/cmd/centralserver/main.go @@ -637,6 +637,14 @@ func (cs *centralServer) handleData(frame *tunnel.Frame, source *sourceConn) { state.lastActive = time.Now() state.reorderer.Insert(frame.SeqNum, frame.Payload) + // ACK: confirm receipt to the client so it can stop tracking this frame. + // Sent async through the delivering source to avoid blocking frame dispatch. + go source.WriteFrame(&tunnel.Frame{ + ConnID: frame.ConnID, + SeqNum: frame.SeqNum, + Flags: tunnel.FlagACK, + }) + // If upstream not connected yet, data stays in reorderer for later flush writeCh := state.writeCh if writeCh == nil { diff --git a/internal/health/checker.go b/internal/health/checker.go index 4afbbb2..bafa3ae 100644 --- a/internal/health/checker.go +++ b/internal/health/checker.go @@ -20,11 +20,15 @@ import ( // An instance is HEALTHY only after a successful tunnel probe (SOCKS5/SSH). // An instance is UNHEALTHY if: // - TCP connect to local port fails (process dead) -// - Tunnel probe fails 3 consecutive times (tunnel broken) +// - Tunnel probe fails N consecutive times (tunnel broken) // // Latency is only set from successful tunnel probes (real RTT). const maxConsecutiveFailures = 3 +// In packet_split mode, a single unhealthy instance drops packets for ALL +// connections, so we use a stricter threshold to remove bad instances faster. +const maxConsecutiveFailuresPacketSplit = 1 + type Checker struct { manager *engine.Manager interval time.Duration @@ -59,6 +63,13 @@ func NewChecker(mgr *engine.Manager, cfg *config.HealthCheckConfig) *Checker { } } +func (c *Checker) maxFailures() int { + if c.packetSplit { + return maxConsecutiveFailuresPacketSplit + } + return maxConsecutiveFailures +} + func (c *Checker) Start() { go c.run() mode := "connection" @@ -66,7 +77,7 @@ func (c *Checker) Start() { mode = "packet-split" } log.Printf("[health] checker started (interval=%s, tunnel_timeout=%s, target=%s, mode=%s, unhealthy_after=%d failures)", - c.interval, c.timeout, c.target, mode, maxConsecutiveFailures) + c.interval, c.timeout, c.target, mode, c.maxFailures()) } // SetPacketSplit enables framing protocol health checks. @@ -156,7 +167,8 @@ func (c *Checker) checkOne(inst *engine.Instance) { if err != nil { failCount := c.recordFailure(inst.ID()) - if failCount >= maxConsecutiveFailures { + maxFail := c.maxFailures() + if failCount >= maxFail { if inst.State() != engine.StateUnhealthy { log.Printf("[health] instance %d (%s:%d) UNHEALTHY after %d tunnel failures: %v", inst.ID(), inst.Config.Domain, inst.Config.Port, failCount, err) @@ -170,7 +182,7 @@ func (c *Checker) checkOne(inst *engine.Instance) { } else { log.Printf("[health] instance %d (%s:%d) tunnel probe failed (%d/%d): %v", inst.ID(), inst.Config.Domain, inst.Config.Port, - failCount, maxConsecutiveFailures, err) + failCount, maxFail, err) } return } @@ -190,17 +202,26 @@ func (c *Checker) checkOne(inst *engine.Instance) { if e2eErr != nil { failCount := c.recordFailure(inst.ID()) - if failCount >= maxConsecutiveFailures { + maxFail := c.maxFailures() + if failCount >= maxFail { if inst.State() != engine.StateUnhealthy { log.Printf("[health] instance %d (%s:%d) UNHEALTHY after %d e2e failures: %v", inst.ID(), inst.Config.Domain, inst.Config.Port, failCount, e2eErr) inst.SetState(engine.StateUnhealthy) inst.SetLastPingMs(-1) + // In packet_split mode, auto-restart on e2e failure too — + // a broken upstream path drops packets for ALL connections. + if c.packetSplit { + go func() { + log.Printf("[health] auto-restarting instance %d (e2e failure in packet_split)", inst.ID()) + c.manager.RestartInstance(inst.ID()) + }() + } } } else { log.Printf("[health] instance %d (%s:%d) e2e probe failed (%d/%d): %v", inst.ID(), inst.Config.Domain, inst.Config.Port, - failCount, maxConsecutiveFailures, e2eErr) + failCount, maxFail, e2eErr) } return } diff --git a/internal/tunnel/pool.go b/internal/tunnel/pool.go index aa86678..80d0cfc 100644 --- a/internal/tunnel/pool.go +++ b/internal/tunnel/pool.go @@ -24,7 +24,7 @@ const staleThreshold = 15 * time.Second // ConnsPerInstance is the number of persistent connections to maintain // per healthy instance. When one dies, it's replaced on the next refresh. // Multiple connections provide redundancy — if one degrades, others serve traffic. -const ConnsPerInstance = 3 +const ConnsPerInstance = 8 // TunnelConn wraps a persistent TCP connection to a single instance. type TunnelConn struct { diff --git a/internal/tunnel/splitter.go b/internal/tunnel/splitter.go index eea6924..a6f959d 100644 --- a/internal/tunnel/splitter.go +++ b/internal/tunnel/splitter.go @@ -3,6 +3,7 @@ package tunnel import ( "context" "io" + "log" "sync" "sync/atomic" "time" @@ -10,6 +11,20 @@ import ( "github.com/ParsaKSH/SlipStream-Plus/internal/engine" ) +// ackTimeout is how long to wait for a frame ACK before retrying via another instance. +const ackTimeout = 3 * time.Second + +// maxFrameRetries is the maximum number of retry attempts per frame. +const maxFrameRetries = 2 + +// pendingFrame tracks a sent frame awaiting ACK from the CentralServer. +type pendingFrame struct { + frame *Frame + sentAt time.Time + instID int + retries int +} + // PacketSplitter distributes data from a client connection across multiple // instances at the packet/chunk level, and reassembles reverse-direction // frames back to the client. @@ -22,6 +37,11 @@ type PacketSplitter struct { txSeq atomic.Uint32 + // ACK tracking: frames awaiting confirmation from CentralServer + pendingMu sync.Mutex + pending map[uint32]*pendingFrame // SeqNum → pending + stopRetry chan struct{} + // Weighted round-robin state mu sync.Mutex weights []int @@ -36,12 +56,17 @@ func NewPacketSplitter(connID uint32, pool *TunnelPool, instances []*engine.Inst instances: instances, chunkSize: chunkSize, incoming: pool.RegisterConn(connID), + pending: make(map[uint32]*pendingFrame), + stopRetry: make(chan struct{}), } ps.recalcWeights() + go ps.retryLoop() return ps } func (ps *PacketSplitter) Close() { + close(ps.stopRetry) + fin := &Frame{ ConnID: ps.connID, SeqNum: ps.txSeq.Add(1) - 1, @@ -53,6 +78,63 @@ func (ps *PacketSplitter) Close() { ps.pool.UnregisterConn(ps.connID) } +// retryLoop periodically checks for unACKed frames and resends them +// through a different instance. +func (ps *PacketSplitter) retryLoop() { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + for { + select { + case <-ps.stopRetry: + return + case <-ticker.C: + ps.retryPending() + } + } +} + +func (ps *PacketSplitter) retryPending() { + now := time.Now() + + ps.pendingMu.Lock() + var toRetry []*pendingFrame + for seq, pf := range ps.pending { + if now.Sub(pf.sentAt) < ackTimeout { + continue + } + if pf.retries >= maxFrameRetries { + delete(ps.pending, seq) + continue + } + toRetry = append(toRetry, pf) + } + ps.pendingMu.Unlock() + + for _, pf := range toRetry { + // Pick a different instance than the one that failed + inst := ps.pickInstanceExcluding(pf.instID) + if inst == nil { + inst = ps.pickInstance() + } + if inst == nil { + continue + } + + if err := ps.pool.SendFrame(inst.ID(), pf.frame); err != nil { + continue + } + + ps.pendingMu.Lock() + pf.retries++ + pf.sentAt = now + pf.instID = inst.ID() + ps.pendingMu.Unlock() + + log.Printf("[splitter] conn=%d: retried seq=%d via instance %d (attempt %d/%d)", + ps.connID, pf.frame.SeqNum, inst.ID(), pf.retries, maxFrameRetries) + } +} + func (ps *PacketSplitter) SendSYN(atyp byte, addr []byte, port []byte) error { payload := EncodeSYNPayload(atyp, addr, port) frame := &Frame{ @@ -103,12 +185,23 @@ func (ps *PacketSplitter) RelayClientToUpstream(ctx context.Context, client io.R } copy(frame.Payload, buf[:n]) + sentInstID := inst.ID() if sendErr := ps.pool.SendFrame(inst.ID(), frame); sendErr != nil { inst2 := ps.pickInstanceExcluding(inst.ID()) if inst2 != nil { ps.pool.SendFrame(inst2.ID(), frame) + sentInstID = inst2.ID() } } + + // Track frame for ACK — retryLoop will resend if no ACK arrives + ps.pendingMu.Lock() + ps.pending[frame.SeqNum] = &pendingFrame{ + frame: frame, + sentAt: time.Now(), + instID: sentInstID, + } + ps.pendingMu.Unlock() } if err != nil { @@ -173,7 +266,13 @@ func (ps *PacketSplitter) RelayUpstreamToClient(ctx context.Context, client io.W return totalBytes } - if frame.IsACK() || frame.IsSYN() { + if frame.IsACK() { + ps.pendingMu.Lock() + delete(ps.pending, frame.SeqNum) + ps.pendingMu.Unlock() + continue + } + if frame.IsSYN() { continue } if len(frame.Payload) == 0 { @@ -223,7 +322,29 @@ func (ps *PacketSplitter) recalcWeights() { } } +// blockedInstances returns the set of instance IDs that have at least one +// unACKed frame past ackTimeout. These instances are considered unreliable +// at this moment and must not receive new packets. +// Lock order: pendingMu must be acquired BEFORE mu to avoid deadlock. +func (ps *PacketSplitter) blockedInstances() map[int]bool { + now := time.Now() + ps.pendingMu.Lock() + defer ps.pendingMu.Unlock() + var blocked map[int]bool + for _, pf := range ps.pending { + if now.Sub(pf.sentAt) > ackTimeout { + if blocked == nil { + blocked = make(map[int]bool) + } + blocked[pf.instID] = true + } + } + return blocked +} + func (ps *PacketSplitter) pickInstance() *engine.Instance { + blocked := ps.blockedInstances() + ps.mu.Lock() defer ps.mu.Unlock() @@ -239,13 +360,20 @@ func (ps *PacketSplitter) pickInstance() *engine.Instance { ps.counter++ inst := ps.instances[ps.current] - if inst.IsHealthy() { + if inst.IsHealthy() && !blocked[inst.ID()] { return inst } ps.counter = 0 ps.current = (ps.current + 1) % len(ps.instances) } + // Fallback: any healthy non-blocked instance + for _, inst := range ps.instances { + if inst.IsHealthy() && !blocked[inst.ID()] { + return inst + } + } + // Last resort: any healthy instance even if blocked (avoid total stall) for _, inst := range ps.instances { if inst.IsHealthy() { return inst @@ -255,6 +383,13 @@ func (ps *PacketSplitter) pickInstance() *engine.Instance { } func (ps *PacketSplitter) pickInstanceExcluding(excludeID int) *engine.Instance { + blocked := ps.blockedInstances() + for _, inst := range ps.instances { + if inst.ID() != excludeID && inst.IsHealthy() && !blocked[inst.ID()] { + return inst + } + } + // Fallback: any healthy instance excluding the specified one for _, inst := range ps.instances { if inst.ID() != excludeID && inst.IsHealthy() { return inst From 3af414ba73028d62e81a5a1023320343a8140701 Mon Sep 17 00:00:00 2001 From: ParsaKSH Date: Tue, 31 Mar 2026 04:56:53 +0330 Subject: [PATCH 3/4] fuck the ack system --- cmd/centralserver/main.go | 8 --- internal/health/checker.go | 33 ++------- internal/tunnel/splitter.go | 139 +----------------------------------- 3 files changed, 8 insertions(+), 172 deletions(-) diff --git a/cmd/centralserver/main.go b/cmd/centralserver/main.go index 19702ae..99a82cc 100644 --- a/cmd/centralserver/main.go +++ b/cmd/centralserver/main.go @@ -637,14 +637,6 @@ func (cs *centralServer) handleData(frame *tunnel.Frame, source *sourceConn) { state.lastActive = time.Now() state.reorderer.Insert(frame.SeqNum, frame.Payload) - // ACK: confirm receipt to the client so it can stop tracking this frame. - // Sent async through the delivering source to avoid blocking frame dispatch. - go source.WriteFrame(&tunnel.Frame{ - ConnID: frame.ConnID, - SeqNum: frame.SeqNum, - Flags: tunnel.FlagACK, - }) - // If upstream not connected yet, data stays in reorderer for later flush writeCh := state.writeCh if writeCh == nil { diff --git a/internal/health/checker.go b/internal/health/checker.go index bafa3ae..4afbbb2 100644 --- a/internal/health/checker.go +++ b/internal/health/checker.go @@ -20,15 +20,11 @@ import ( // An instance is HEALTHY only after a successful tunnel probe (SOCKS5/SSH). // An instance is UNHEALTHY if: // - TCP connect to local port fails (process dead) -// - Tunnel probe fails N consecutive times (tunnel broken) +// - Tunnel probe fails 3 consecutive times (tunnel broken) // // Latency is only set from successful tunnel probes (real RTT). const maxConsecutiveFailures = 3 -// In packet_split mode, a single unhealthy instance drops packets for ALL -// connections, so we use a stricter threshold to remove bad instances faster. -const maxConsecutiveFailuresPacketSplit = 1 - type Checker struct { manager *engine.Manager interval time.Duration @@ -63,13 +59,6 @@ func NewChecker(mgr *engine.Manager, cfg *config.HealthCheckConfig) *Checker { } } -func (c *Checker) maxFailures() int { - if c.packetSplit { - return maxConsecutiveFailuresPacketSplit - } - return maxConsecutiveFailures -} - func (c *Checker) Start() { go c.run() mode := "connection" @@ -77,7 +66,7 @@ func (c *Checker) Start() { mode = "packet-split" } log.Printf("[health] checker started (interval=%s, tunnel_timeout=%s, target=%s, mode=%s, unhealthy_after=%d failures)", - c.interval, c.timeout, c.target, mode, c.maxFailures()) + c.interval, c.timeout, c.target, mode, maxConsecutiveFailures) } // SetPacketSplit enables framing protocol health checks. @@ -167,8 +156,7 @@ func (c *Checker) checkOne(inst *engine.Instance) { if err != nil { failCount := c.recordFailure(inst.ID()) - maxFail := c.maxFailures() - if failCount >= maxFail { + if failCount >= maxConsecutiveFailures { if inst.State() != engine.StateUnhealthy { log.Printf("[health] instance %d (%s:%d) UNHEALTHY after %d tunnel failures: %v", inst.ID(), inst.Config.Domain, inst.Config.Port, failCount, err) @@ -182,7 +170,7 @@ func (c *Checker) checkOne(inst *engine.Instance) { } else { log.Printf("[health] instance %d (%s:%d) tunnel probe failed (%d/%d): %v", inst.ID(), inst.Config.Domain, inst.Config.Port, - failCount, maxFail, err) + failCount, maxConsecutiveFailures, err) } return } @@ -202,26 +190,17 @@ func (c *Checker) checkOne(inst *engine.Instance) { if e2eErr != nil { failCount := c.recordFailure(inst.ID()) - maxFail := c.maxFailures() - if failCount >= maxFail { + if failCount >= maxConsecutiveFailures { if inst.State() != engine.StateUnhealthy { log.Printf("[health] instance %d (%s:%d) UNHEALTHY after %d e2e failures: %v", inst.ID(), inst.Config.Domain, inst.Config.Port, failCount, e2eErr) inst.SetState(engine.StateUnhealthy) inst.SetLastPingMs(-1) - // In packet_split mode, auto-restart on e2e failure too — - // a broken upstream path drops packets for ALL connections. - if c.packetSplit { - go func() { - log.Printf("[health] auto-restarting instance %d (e2e failure in packet_split)", inst.ID()) - c.manager.RestartInstance(inst.ID()) - }() - } } } else { log.Printf("[health] instance %d (%s:%d) e2e probe failed (%d/%d): %v", inst.ID(), inst.Config.Domain, inst.Config.Port, - failCount, maxFail, e2eErr) + failCount, maxConsecutiveFailures, e2eErr) } return } diff --git a/internal/tunnel/splitter.go b/internal/tunnel/splitter.go index a6f959d..eea6924 100644 --- a/internal/tunnel/splitter.go +++ b/internal/tunnel/splitter.go @@ -3,7 +3,6 @@ package tunnel import ( "context" "io" - "log" "sync" "sync/atomic" "time" @@ -11,20 +10,6 @@ import ( "github.com/ParsaKSH/SlipStream-Plus/internal/engine" ) -// ackTimeout is how long to wait for a frame ACK before retrying via another instance. -const ackTimeout = 3 * time.Second - -// maxFrameRetries is the maximum number of retry attempts per frame. -const maxFrameRetries = 2 - -// pendingFrame tracks a sent frame awaiting ACK from the CentralServer. -type pendingFrame struct { - frame *Frame - sentAt time.Time - instID int - retries int -} - // PacketSplitter distributes data from a client connection across multiple // instances at the packet/chunk level, and reassembles reverse-direction // frames back to the client. @@ -37,11 +22,6 @@ type PacketSplitter struct { txSeq atomic.Uint32 - // ACK tracking: frames awaiting confirmation from CentralServer - pendingMu sync.Mutex - pending map[uint32]*pendingFrame // SeqNum → pending - stopRetry chan struct{} - // Weighted round-robin state mu sync.Mutex weights []int @@ -56,17 +36,12 @@ func NewPacketSplitter(connID uint32, pool *TunnelPool, instances []*engine.Inst instances: instances, chunkSize: chunkSize, incoming: pool.RegisterConn(connID), - pending: make(map[uint32]*pendingFrame), - stopRetry: make(chan struct{}), } ps.recalcWeights() - go ps.retryLoop() return ps } func (ps *PacketSplitter) Close() { - close(ps.stopRetry) - fin := &Frame{ ConnID: ps.connID, SeqNum: ps.txSeq.Add(1) - 1, @@ -78,63 +53,6 @@ func (ps *PacketSplitter) Close() { ps.pool.UnregisterConn(ps.connID) } -// retryLoop periodically checks for unACKed frames and resends them -// through a different instance. -func (ps *PacketSplitter) retryLoop() { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - for { - select { - case <-ps.stopRetry: - return - case <-ticker.C: - ps.retryPending() - } - } -} - -func (ps *PacketSplitter) retryPending() { - now := time.Now() - - ps.pendingMu.Lock() - var toRetry []*pendingFrame - for seq, pf := range ps.pending { - if now.Sub(pf.sentAt) < ackTimeout { - continue - } - if pf.retries >= maxFrameRetries { - delete(ps.pending, seq) - continue - } - toRetry = append(toRetry, pf) - } - ps.pendingMu.Unlock() - - for _, pf := range toRetry { - // Pick a different instance than the one that failed - inst := ps.pickInstanceExcluding(pf.instID) - if inst == nil { - inst = ps.pickInstance() - } - if inst == nil { - continue - } - - if err := ps.pool.SendFrame(inst.ID(), pf.frame); err != nil { - continue - } - - ps.pendingMu.Lock() - pf.retries++ - pf.sentAt = now - pf.instID = inst.ID() - ps.pendingMu.Unlock() - - log.Printf("[splitter] conn=%d: retried seq=%d via instance %d (attempt %d/%d)", - ps.connID, pf.frame.SeqNum, inst.ID(), pf.retries, maxFrameRetries) - } -} - func (ps *PacketSplitter) SendSYN(atyp byte, addr []byte, port []byte) error { payload := EncodeSYNPayload(atyp, addr, port) frame := &Frame{ @@ -185,23 +103,12 @@ func (ps *PacketSplitter) RelayClientToUpstream(ctx context.Context, client io.R } copy(frame.Payload, buf[:n]) - sentInstID := inst.ID() if sendErr := ps.pool.SendFrame(inst.ID(), frame); sendErr != nil { inst2 := ps.pickInstanceExcluding(inst.ID()) if inst2 != nil { ps.pool.SendFrame(inst2.ID(), frame) - sentInstID = inst2.ID() } } - - // Track frame for ACK — retryLoop will resend if no ACK arrives - ps.pendingMu.Lock() - ps.pending[frame.SeqNum] = &pendingFrame{ - frame: frame, - sentAt: time.Now(), - instID: sentInstID, - } - ps.pendingMu.Unlock() } if err != nil { @@ -266,13 +173,7 @@ func (ps *PacketSplitter) RelayUpstreamToClient(ctx context.Context, client io.W return totalBytes } - if frame.IsACK() { - ps.pendingMu.Lock() - delete(ps.pending, frame.SeqNum) - ps.pendingMu.Unlock() - continue - } - if frame.IsSYN() { + if frame.IsACK() || frame.IsSYN() { continue } if len(frame.Payload) == 0 { @@ -322,29 +223,7 @@ func (ps *PacketSplitter) recalcWeights() { } } -// blockedInstances returns the set of instance IDs that have at least one -// unACKed frame past ackTimeout. These instances are considered unreliable -// at this moment and must not receive new packets. -// Lock order: pendingMu must be acquired BEFORE mu to avoid deadlock. -func (ps *PacketSplitter) blockedInstances() map[int]bool { - now := time.Now() - ps.pendingMu.Lock() - defer ps.pendingMu.Unlock() - var blocked map[int]bool - for _, pf := range ps.pending { - if now.Sub(pf.sentAt) > ackTimeout { - if blocked == nil { - blocked = make(map[int]bool) - } - blocked[pf.instID] = true - } - } - return blocked -} - func (ps *PacketSplitter) pickInstance() *engine.Instance { - blocked := ps.blockedInstances() - ps.mu.Lock() defer ps.mu.Unlock() @@ -360,20 +239,13 @@ func (ps *PacketSplitter) pickInstance() *engine.Instance { ps.counter++ inst := ps.instances[ps.current] - if inst.IsHealthy() && !blocked[inst.ID()] { + if inst.IsHealthy() { return inst } ps.counter = 0 ps.current = (ps.current + 1) % len(ps.instances) } - // Fallback: any healthy non-blocked instance - for _, inst := range ps.instances { - if inst.IsHealthy() && !blocked[inst.ID()] { - return inst - } - } - // Last resort: any healthy instance even if blocked (avoid total stall) for _, inst := range ps.instances { if inst.IsHealthy() { return inst @@ -383,13 +255,6 @@ func (ps *PacketSplitter) pickInstance() *engine.Instance { } func (ps *PacketSplitter) pickInstanceExcluding(excludeID int) *engine.Instance { - blocked := ps.blockedInstances() - for _, inst := range ps.instances { - if inst.ID() != excludeID && inst.IsHealthy() && !blocked[inst.ID()] { - return inst - } - } - // Fallback: any healthy instance excluding the specified one for _, inst := range ps.instances { if inst.ID() != excludeID && inst.IsHealthy() { return inst From 4ed957e5785194244b5016b369bf0d7d54164932 Mon Sep 17 00:00:00 2001 From: ParsaKSH Date: Tue, 31 Mar 2026 05:21:03 +0330 Subject: [PATCH 4/4] refactor: split large files into smaller focused modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit No logic or function signatures changed — pure reorganization: - cmd/centralserver: main.go → + state.go, server.go, upstream.go - internal/config: config.go → + expand.go - internal/gui: server.go → + handlers.go - internal/health: checker.go → + probes.go - internal/proxy: socks5.go → + relay.go - internal/tunnel: splitter.go → + reorder.go, pool.go → + tunnelconn.go - internal/users: manager.go → + user.go, io.go - fix .gitignore: scope centralserver ignore to binary only Co-Authored-By: Claude Sonnet 4.6 --- .gitignore | 4 +- cmd/centralserver/main.go | 720 ---------------------------------- cmd/centralserver/server.go | 289 ++++++++++++++ cmd/centralserver/state.go | 114 ++++++ cmd/centralserver/upstream.go | 346 ++++++++++++++++ internal/config/config.go | 125 ------ internal/config/expand.go | 131 +++++++ internal/gui/handlers.go | 319 +++++++++++++++ internal/gui/server.go | 308 --------------- internal/health/checker.go | 210 ---------- internal/health/probes.go | 218 ++++++++++ internal/proxy/relay.go | 274 +++++++++++++ internal/proxy/socks5.go | 261 ------------ internal/tunnel/pool.go | 55 --- internal/tunnel/reorder.go | 100 +++++ internal/tunnel/splitter.go | 96 ----- internal/tunnel/tunnelconn.go | 66 ++++ internal/users/io.go | 90 +++++ internal/users/manager.go | 265 ------------- internal/users/user.go | 191 +++++++++ 20 files changed, 2141 insertions(+), 2041 deletions(-) create mode 100644 cmd/centralserver/server.go create mode 100644 cmd/centralserver/state.go create mode 100644 cmd/centralserver/upstream.go create mode 100644 internal/config/expand.go create mode 100644 internal/gui/handlers.go create mode 100644 internal/health/probes.go create mode 100644 internal/proxy/relay.go create mode 100644 internal/tunnel/reorder.go create mode 100644 internal/tunnel/tunnelconn.go create mode 100644 internal/users/io.go create mode 100644 internal/users/user.go diff --git a/.gitignore b/.gitignore index be38a91..05f81d5 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,8 @@ config.json slipstrampluss slipstreampluss SlipStreamPlus.tar.gz -centralserver +centralserver/centralserver +/centralserver central slipstreamplusb # Rust core binaries (built by CI from source) @@ -14,3 +15,4 @@ slipstreamorg/slipstream-client-* slipstreamorg/slipstream-client.exe slipstreamplus-v1.0.4-linux-amd64 +slipstreamplusbb diff --git a/cmd/centralserver/main.go b/cmd/centralserver/main.go index 99a82cc..8abfa20 100644 --- a/cmd/centralserver/main.go +++ b/cmd/centralserver/main.go @@ -1,75 +1,15 @@ package main import ( - "context" - "encoding/binary" "flag" - "fmt" - "io" "log" "net" "os" "os/signal" - "strings" - "sync" "syscall" "time" - - "github.com/ParsaKSH/SlipStream-Plus/internal/tunnel" ) -// sourceConn wraps a tunnel connection with a write mutex so that -// concurrent WriteFrame calls from different ConnIDs don't interleave bytes. -type sourceConn struct { - conn net.Conn - writeMu sync.Mutex -} - -// writeTimeout prevents tunnel writes from blocking forever on congested TCP. -const sourceWriteTimeout = 10 * time.Second - -func (sc *sourceConn) WriteFrame(f *tunnel.Frame) error { - sc.writeMu.Lock() - defer sc.writeMu.Unlock() - sc.conn.SetWriteDeadline(time.Now().Add(sourceWriteTimeout)) - err := tunnel.WriteFrame(sc.conn, f) - sc.conn.SetWriteDeadline(time.Time{}) - return err -} - -// connState tracks a single reassembled connection. -type connState struct { - mu sync.Mutex - target net.Conn // connection to the SOCKS upstream - reorderer *tunnel.Reorderer - txSeq uint32 // next sequence number for reverse data - cancel context.CancelFunc - created time.Time - lastActive time.Time // last time data was sent or received - - // Sources: all tunnel connections that can carry reverse data. - // We round-robin responses across them (not broadcast). - sources []*sourceConn - sourceIdx int - - // Async write queue: handleData sends chunks here instead of writing - // to target synchronously (which would block the frame dispatch loop). - // The upstreamWriter goroutine drains this channel and writes to target. - writeCh chan []byte -} - -// centralServer manages all active connections. -type centralServer struct { - socksUpstream string - - mu sync.RWMutex - conns map[uint32]*connState // ConnID → state - - // sourceMu protects the sources map (net.Conn → *sourceConn). - sourceMu sync.Mutex - sourceMap map[net.Conn]*sourceConn -} - func main() { listenAddr := flag.String("listen", "0.0.0.0:9500", "listen address for tunnel connections") socksUpstream := flag.String("socks-upstream", "127.0.0.1:1080", "upstream SOCKS5 proxy address") @@ -118,663 +58,3 @@ func main() { go cs.handleIncoming(conn) } } - -// handleIncoming detects the protocol from the first byte: -// -// 0x05 → SOCKS5 health probe → passthrough to socks-upstream -// else → framing protocol → read frames -func (cs *centralServer) handleIncoming(conn net.Conn) { - defer conn.Close() - remoteAddr := conn.RemoteAddr().String() - - firstByte := make([]byte, 1) - conn.SetReadDeadline(time.Now().Add(10 * time.Second)) - if _, err := io.ReadFull(conn, firstByte); err != nil { - return - } - conn.SetReadDeadline(time.Time{}) - - if firstByte[0] == 0x05 { - cs.handleSOCKS5Passthrough(conn, firstByte[0], remoteAddr) - } else { - cs.handleFrameConn(conn, firstByte[0], remoteAddr) - } -} - -// handleSOCKS5Passthrough transparently proxies a SOCKS5 connection. -func (cs *centralServer) handleSOCKS5Passthrough(clientConn net.Conn, firstByte byte, remoteAddr string) { - upstream, err := net.DialTimeout("tcp", cs.socksUpstream, 10*time.Second) - if err != nil { - return - } - defer upstream.Close() - if tc, ok := upstream.(*net.TCPConn); ok { - tc.SetNoDelay(true) - } - upstream.Write([]byte{firstByte}) - - done := make(chan struct{}, 2) - go func() { - io.Copy(upstream, clientConn) - if tc, ok := upstream.(*net.TCPConn); ok { - tc.CloseWrite() - } - done <- struct{}{} - }() - go func() { - io.Copy(clientConn, upstream) - if tc, ok := clientConn.(*net.TCPConn); ok { - tc.CloseWrite() - } - done <- struct{}{} - }() - <-done -} - -// handleFrameConn reads framed packets from a tunnel connection. -// When this function returns (source TCP died), it cleans up all -// connStates that had this as their only source. -func (cs *centralServer) handleFrameConn(conn net.Conn, firstByte byte, remoteAddr string) { - log.Printf("[central] frame connection from %s", remoteAddr) - - sc := cs.getSourceConn(conn) - - // Track which ConnIDs this source served - servedIDs := make(map[uint32]bool) - - defer func() { - // Source TCP died — clean up sourceConn and connStates - cs.removeSourceConn(conn) - cs.cleanupSource(sc, servedIDs, remoteAddr) - }() - - // Read remaining header bytes (we already read 1) - var hdrRest [tunnel.HeaderSize - 1]byte - if _, err := io.ReadFull(conn, hdrRest[:]); err != nil { - return - } - - var fullHdr [tunnel.HeaderSize]byte - fullHdr[0] = firstByte - copy(fullHdr[1:], hdrRest[:]) - - firstFrame := cs.parseHeader(fullHdr, conn, remoteAddr) - if firstFrame != nil { - servedIDs[firstFrame.ConnID] = true - cs.dispatchFrame(firstFrame, sc) - } - - for { - frame, err := tunnel.ReadFrame(conn) - if err != nil { - if err != io.EOF && !isClosedConnErr(err) { - log.Printf("[central] %s: read error: %v", remoteAddr, err) - } - return - } - servedIDs[frame.ConnID] = true - cs.dispatchFrame(frame, sc) - } -} - -// cleanupSource removes a dead source connection from all connStates. -// If a connState has no remaining sources, it is fully cleaned up. -func (cs *centralServer) cleanupSource(deadSource *sourceConn, servedIDs map[uint32]bool, remoteAddr string) { - cs.mu.Lock() - defer cs.mu.Unlock() - - cleaned := 0 - for connID := range servedIDs { - state, ok := cs.conns[connID] - if !ok { - continue - } - - state.mu.Lock() - // Remove dead source from sources list - for i, src := range state.sources { - if src == deadSource { - state.sources = append(state.sources[:i], state.sources[i+1:]...) - break - } - } - - // If no sources left, fully clean up this connState - if len(state.sources) == 0 { - state.mu.Unlock() - if state.cancel != nil { - state.cancel() - } - if state.target != nil { - state.target.Close() - } - delete(cs.conns, connID) - cleaned++ - } else { - state.mu.Unlock() - } - } - - if cleaned > 0 { - log.Printf("[central] %s: source disconnected, cleaned %d orphaned connections", remoteAddr, cleaned) - } -} - -// getSourceConn returns the sourceConn wrapper for a raw net.Conn, -// creating one if it doesn't exist yet. -func (cs *centralServer) getSourceConn(conn net.Conn) *sourceConn { - cs.sourceMu.Lock() - defer cs.sourceMu.Unlock() - sc, ok := cs.sourceMap[conn] - if !ok { - sc = &sourceConn{conn: conn} - cs.sourceMap[conn] = sc - } - return sc -} - -// removeSourceConn removes the sourceConn wrapper when the raw conn dies. -func (cs *centralServer) removeSourceConn(conn net.Conn) { - cs.sourceMu.Lock() - delete(cs.sourceMap, conn) - cs.sourceMu.Unlock() -} - -func isClosedConnErr(err error) bool { - if err == nil { - return false - } - s := err.Error() - return strings.Contains(s, "use of closed network connection") || - strings.Contains(s, "connection reset by peer") -} - -func (cs *centralServer) parseHeader(hdr [tunnel.HeaderSize]byte, conn net.Conn, remoteAddr string) *tunnel.Frame { - length := binary.BigEndian.Uint16(hdr[9:11]) - if length > tunnel.MaxPayloadSize { - log.Printf("[central] %s: frame payload too large: %d", remoteAddr, length) - return nil - } - var payload []byte - if length > 0 { - payload = make([]byte, length) - if _, err := io.ReadFull(conn, payload); err != nil { - return nil - } - } - return &tunnel.Frame{ - ConnID: binary.BigEndian.Uint32(hdr[0:4]), - SeqNum: binary.BigEndian.Uint32(hdr[4:8]), - Flags: hdr[8], - Payload: payload, - } -} - -func (cs *centralServer) dispatchFrame(frame *tunnel.Frame, source *sourceConn) { - if frame.IsSYN() { - cs.handleSYN(frame, source) - return - } - if frame.IsFIN() { - cs.handleFIN(frame) - return - } - if frame.IsRST() { - cs.handleRST(frame) - return - } - cs.handleData(frame, source) -} - -func (cs *centralServer) handleSYN(frame *tunnel.Frame, source *sourceConn) { - connID := frame.ConnID - - cs.mu.Lock() - if existing, ok := cs.conns[connID]; ok { - // Another instance's SYN → just register additional source - existing.mu.Lock() - existing.sources = append(existing.sources, source) - existing.mu.Unlock() - cs.mu.Unlock() - return - } - - atyp, addr, port, err := tunnel.DecodeSYNPayload(frame.Payload) - if err != nil { - cs.mu.Unlock() - log.Printf("[central] conn=%d: bad SYN payload: %v", connID, err) - return - } - - var targetAddr string - switch atyp { - case 0x01: - targetAddr = fmt.Sprintf("%s:%d", net.IP(addr).String(), binary.BigEndian.Uint16(port)) - case 0x03: - domLen := int(addr[0]) - targetAddr = fmt.Sprintf("%s:%d", string(addr[1:1+domLen]), binary.BigEndian.Uint16(port)) - case 0x04: - targetAddr = fmt.Sprintf("[%s]:%d", net.IP(addr).String(), binary.BigEndian.Uint16(port)) - } - - now := time.Now() - ctx, cancel := context.WithCancel(context.Background()) - state := &connState{ - reorderer: tunnel.NewReordererAt(frame.SeqNum + 1), // skip SYN's SeqNum - sources: []*sourceConn{source}, - cancel: cancel, - created: now, - lastActive: now, - } - cs.conns[connID] = state - cs.mu.Unlock() - - log.Printf("[central] conn=%d: SYN → target=%s", connID, targetAddr) - go cs.connectUpstream(ctx, connID, state, atyp, addr, port, targetAddr) -} - -func (cs *centralServer) connectUpstream(ctx context.Context, connID uint32, state *connState, - atyp byte, addr, port []byte, targetAddr string) { - - upConn, err := net.DialTimeout("tcp", cs.socksUpstream, 10*time.Second) - if err != nil { - log.Printf("[central] conn=%d: upstream dial failed: %v", connID, err) - cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) - cs.removeConn(connID) - return - } - if tc, ok := upConn.(*net.TCPConn); ok { - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(30 * time.Second) - tc.SetNoDelay(true) - } - - // SOCKS5 handshake - pipelined := make([]byte, 0, 3+4+len(addr)+2) - pipelined = append(pipelined, 0x05, 0x01, 0x00) - pipelined = append(pipelined, 0x05, 0x01, 0x00, atyp) - pipelined = append(pipelined, addr...) - pipelined = append(pipelined, port...) - - if _, err := upConn.Write(pipelined); err != nil { - log.Printf("[central] conn=%d: upstream write failed: %v", connID, err) - upConn.Close() - cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) - cs.removeConn(connID) - return - } - - // Read greeting response (2 bytes) + CONNECT response header (4 bytes) - resp := make([]byte, 6) - if _, err := io.ReadFull(upConn, resp); err != nil { - log.Printf("[central] conn=%d: upstream response read failed: %v", connID, err) - upConn.Close() - cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) - cs.removeConn(connID) - return - } - - // Check CONNECT result BEFORE draining bind address. - // resp[3] = REP field (0x00 = success). If non-zero, upstream may close - // without sending bind address, so don't try to drain it. - if resp[3] != 0x00 { - log.Printf("[central] conn=%d: upstream CONNECT rejected: 0x%02x", connID, resp[3]) - upConn.Close() - cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) - cs.removeConn(connID) - return - } - - // Drain bind address (only on success) - switch resp[5] { - case 0x01: - io.ReadFull(upConn, make([]byte, 6)) - case 0x03: - lb := make([]byte, 1) - io.ReadFull(upConn, lb) - io.ReadFull(upConn, make([]byte, int(lb[0])+2)) - case 0x04: - io.ReadFull(upConn, make([]byte, 18)) - default: - io.ReadFull(upConn, make([]byte, 6)) - } - - // Create async write channel and set target under lock - writeCh := make(chan []byte, 256) - - state.mu.Lock() - state.target = upConn - state.writeCh = writeCh - - // Drain any data that arrived before upstream was ready - var flushChunks [][]byte - for { - data := state.reorderer.Next() - if data == nil { - break - } - flushChunks = append(flushChunks, data) - } - state.mu.Unlock() - - // Start async writer goroutine — all writes to upstream go through writeCh - go cs.upstreamWriter(ctx, connID, upConn, writeCh) - - // Send flush data through the channel (channel is empty, won't block) - for _, data := range flushChunks { - select { - case writeCh <- data: - default: - } - } - - log.Printf("[central] conn=%d: upstream connected to %s", connID, targetAddr) - - // Read upstream data and send back through tunnel (NO broadcast — round-robin) - cs.relayUpstreamToTunnel(ctx, connID, state, upConn) -} - -func (cs *centralServer) relayUpstreamToTunnel(ctx context.Context, connID uint32, - state *connState, upstream net.Conn) { - - defer func() { - upstream.Close() - cs.sendFrame(connID, &tunnel.Frame{ - ConnID: connID, - Flags: tunnel.FlagFIN | tunnel.FlagReverse, - }) - cs.removeConn(connID) - }() - - buf := make([]byte, tunnel.MaxPayloadSize) - for { - select { - case <-ctx.Done(): - return - default: - } - - n, err := upstream.Read(buf) - if n > 0 { - state.mu.Lock() - seq := state.txSeq - state.txSeq++ - state.lastActive = time.Now() - state.mu.Unlock() - - frame := &tunnel.Frame{ - ConnID: connID, - SeqNum: seq, - Flags: tunnel.FlagReverse, - Payload: make([]byte, n), - } - copy(frame.Payload, buf[:n]) - - // Send through ONE source (round-robin), NOT all - cs.sendFrame(connID, frame) - } - if err != nil { - if err != io.EOF && !isClosedConnErr(err) { - log.Printf("[central] conn=%d: upstream read error: %v", connID, err) - } - return - } - } -} - -// upstreamWriter is a dedicated goroutine that drains writeCh and writes -// to the upstream (Xray) connection. This decouples upstream write speed -// from frame dispatch speed — handleData never blocks the frame loop. -func (cs *centralServer) upstreamWriter(ctx context.Context, connID uint32, upstream net.Conn, writeCh chan []byte) { - for { - select { - case <-ctx.Done(): - // Context cancelled (removeConn or cleanup) — drain any remaining data best-effort - for { - select { - case data := <-writeCh: - upstream.SetWriteDeadline(time.Now().Add(2 * time.Second)) - upstream.Write(data) - default: - return - } - } - case data, ok := <-writeCh: - if !ok { - return - } - upstream.SetWriteDeadline(time.Now().Add(upstreamWriteTimeout)) - if _, err := upstream.Write(data); err != nil { - log.Printf("[central] conn=%d: upstream write failed: %v", connID, err) - upstream.SetWriteDeadline(time.Time{}) - // Drain channel to unblock any senders, then exit - for { - select { - case <-writeCh: - default: - return - } - } - } - upstream.SetWriteDeadline(time.Time{}) - } - } -} - -// sendFrame picks ONE source via round-robin and writes the frame. -// CRITICAL: state.mu is only held briefly to pick the source and advance -// the index — the actual TCP write happens OUTSIDE the lock to prevent -// cascading lock contention that freezes frame dispatch for other ConnIDs. -func (cs *centralServer) sendFrame(connID uint32, frame *tunnel.Frame) { - cs.mu.RLock() - state, ok := cs.conns[connID] - cs.mu.RUnlock() - if !ok { - return - } - - // Snapshot sources under lock, then write outside lock - state.mu.Lock() - n := len(state.sources) - if n == 0 { - state.mu.Unlock() - return - } - // Build ordered list starting from current sourceIdx - sources := make([]*sourceConn, n) - startIdx := state.sourceIdx % n - state.sourceIdx++ - for i := 0; i < n; i++ { - sources[i] = state.sources[(startIdx+i)%n] - } - state.mu.Unlock() - - // Try each source — write happens outside state.mu - for _, sc := range sources { - if err := sc.WriteFrame(frame); err != nil { - // Remove dead source under lock - state.mu.Lock() - for i, s := range state.sources { - if s == sc { - state.sources = append(state.sources[:i], state.sources[i+1:]...) - break - } - } - state.mu.Unlock() - continue - } - return // success - } - log.Printf("[central] conn=%d: all sources failed", connID) -} - -// upstreamWriteTimeout prevents writes to Xray upstream from blocking forever. -const upstreamWriteTimeout = 10 * time.Second - -func (cs *centralServer) handleData(frame *tunnel.Frame, source *sourceConn) { - cs.mu.RLock() - state, ok := cs.conns[frame.ConnID] - cs.mu.RUnlock() - if !ok { - return - } - - // Insert frame and collect ready data under lock (fast, no I/O) - state.mu.Lock() - - // If this source isn't known yet (e.g., after tunnel recycling), add it. - found := false - for _, s := range state.sources { - if s == source { - found = true - break - } - } - if !found { - state.sources = append(state.sources, source) - } - - state.lastActive = time.Now() - state.reorderer.Insert(frame.SeqNum, frame.Payload) - - // If upstream not connected yet, data stays in reorderer for later flush - writeCh := state.writeCh - if writeCh == nil { - state.mu.Unlock() - return - } - - // Drain all ready data from reorderer - var chunks [][]byte - for { - data := state.reorderer.Next() - if data == nil { - break - } - chunks = append(chunks, data) - } - state.mu.Unlock() - - // Send to async writer (non-blocking — MUST NOT stall the frame dispatch loop) - for _, data := range chunks { - select { - case writeCh <- data: - default: - log.Printf("[central] conn=%d: write queue full, dropping %d bytes", frame.ConnID, len(data)) - } - } -} - -func (cs *centralServer) handleFIN(frame *tunnel.Frame) { - cs.mu.RLock() - state, ok := cs.conns[frame.ConnID] - cs.mu.RUnlock() - if !ok { - return - } - - // Drain remaining data and send to async writer (non-blocking) - state.mu.Lock() - var chunks [][]byte - writeCh := state.writeCh - if writeCh != nil { - for { - data := state.reorderer.Next() - if data == nil { - break - } - chunks = append(chunks, data) - } - } - state.writeCh = nil // prevent further sends from handleData - state.mu.Unlock() - - // Send remaining data to writer (non-blocking, best-effort) - if writeCh != nil { - for _, data := range chunks { - select { - case writeCh <- data: - default: - } - } - } - - // removeConn cancels ctx → upstreamWriter exits → upstream closed by relayUpstreamToTunnel - cs.removeConn(frame.ConnID) - log.Printf("[central] conn=%d: FIN received, cleaned up", frame.ConnID) -} - -func (cs *centralServer) handleRST(frame *tunnel.Frame) { - cs.removeConn(frame.ConnID) - log.Printf("[central] conn=%d: RST received", frame.ConnID) -} - -func (cs *centralServer) removeConn(connID uint32) { - cs.mu.Lock() - state, ok := cs.conns[connID] - if ok { - delete(cs.conns, connID) - } - cs.mu.Unlock() - if ok && state.cancel != nil { - state.cancel() - } -} - -func (cs *centralServer) closeAll() { - cs.mu.Lock() - defer cs.mu.Unlock() - for id, state := range cs.conns { - if state.cancel != nil { - state.cancel() - } - if state.target != nil { - state.target.Close() - } - delete(cs.conns, id) - } -} - -func (cs *centralServer) cleanupLoop() { - ticker := time.NewTicker(15 * time.Second) - defer ticker.Stop() - for range ticker.C { - cs.mu.Lock() - now := time.Now() - cleaned := 0 - for id, state := range cs.conns { - state.mu.Lock() - shouldClean := false - - // No upstream established after 60 seconds = stuck - if state.target == nil && now.Sub(state.created) > 60*time.Second { - shouldClean = true - } - // No sources left = all tunnel connections died - if len(state.sources) == 0 && now.Sub(state.created) > 30*time.Second { - shouldClean = true - } - // Idle for too long — no data sent or received in 60 seconds. - // This catches stuck connections where both sides stopped talking. - if now.Sub(state.lastActive) > 60*time.Second { - shouldClean = true - } - - state.mu.Unlock() - if shouldClean { - if state.cancel != nil { - state.cancel() - } - if state.target != nil { - state.target.Close() - } - delete(cs.conns, id) - cleaned++ - } - } - if cleaned > 0 { - log.Printf("[central] cleanup: removed %d stale connections (%d active)", cleaned, len(cs.conns)) - } - cs.mu.Unlock() - } -} diff --git a/cmd/centralserver/server.go b/cmd/centralserver/server.go new file mode 100644 index 0000000..4cb18a6 --- /dev/null +++ b/cmd/centralserver/server.go @@ -0,0 +1,289 @@ +package main + +import ( + "encoding/binary" + "io" + "log" + "net" + "strings" + "time" + + "github.com/ParsaKSH/SlipStream-Plus/internal/tunnel" +) + +// handleIncoming detects the protocol from the first byte: +// +// 0x05 → SOCKS5 health probe → passthrough to socks-upstream +// else → framing protocol → read frames +func (cs *centralServer) handleIncoming(conn net.Conn) { + defer conn.Close() + remoteAddr := conn.RemoteAddr().String() + + firstByte := make([]byte, 1) + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + if _, err := io.ReadFull(conn, firstByte); err != nil { + return + } + conn.SetReadDeadline(time.Time{}) + + if firstByte[0] == 0x05 { + cs.handleSOCKS5Passthrough(conn, firstByte[0], remoteAddr) + } else { + cs.handleFrameConn(conn, firstByte[0], remoteAddr) + } +} + +// handleSOCKS5Passthrough transparently proxies a SOCKS5 connection. +func (cs *centralServer) handleSOCKS5Passthrough(clientConn net.Conn, firstByte byte, remoteAddr string) { + upstream, err := net.DialTimeout("tcp", cs.socksUpstream, 10*time.Second) + if err != nil { + return + } + defer upstream.Close() + if tc, ok := upstream.(*net.TCPConn); ok { + tc.SetNoDelay(true) + } + upstream.Write([]byte{firstByte}) + + done := make(chan struct{}, 2) + go func() { + io.Copy(upstream, clientConn) + if tc, ok := upstream.(*net.TCPConn); ok { + tc.CloseWrite() + } + done <- struct{}{} + }() + go func() { + io.Copy(clientConn, upstream) + if tc, ok := clientConn.(*net.TCPConn); ok { + tc.CloseWrite() + } + done <- struct{}{} + }() + <-done +} + +// handleFrameConn reads framed packets from a tunnel connection. +// When this function returns (source TCP died), it cleans up all +// connStates that had this as their only source. +func (cs *centralServer) handleFrameConn(conn net.Conn, firstByte byte, remoteAddr string) { + log.Printf("[central] frame connection from %s", remoteAddr) + + sc := cs.getSourceConn(conn) + + // Track which ConnIDs this source served + servedIDs := make(map[uint32]bool) + + defer func() { + // Source TCP died — clean up sourceConn and connStates + cs.removeSourceConn(conn) + cs.cleanupSource(sc, servedIDs, remoteAddr) + }() + + // Read remaining header bytes (we already read 1) + var hdrRest [tunnel.HeaderSize - 1]byte + if _, err := io.ReadFull(conn, hdrRest[:]); err != nil { + return + } + + var fullHdr [tunnel.HeaderSize]byte + fullHdr[0] = firstByte + copy(fullHdr[1:], hdrRest[:]) + + firstFrame := cs.parseHeader(fullHdr, conn, remoteAddr) + if firstFrame != nil { + servedIDs[firstFrame.ConnID] = true + cs.dispatchFrame(firstFrame, sc) + } + + for { + frame, err := tunnel.ReadFrame(conn) + if err != nil { + if err != io.EOF && !isClosedConnErr(err) { + log.Printf("[central] %s: read error: %v", remoteAddr, err) + } + return + } + servedIDs[frame.ConnID] = true + cs.dispatchFrame(frame, sc) + } +} + +// cleanupSource removes a dead source connection from all connStates. +// If a connState has no remaining sources, it is fully cleaned up. +func (cs *centralServer) cleanupSource(deadSource *sourceConn, servedIDs map[uint32]bool, remoteAddr string) { + cs.mu.Lock() + defer cs.mu.Unlock() + + cleaned := 0 + for connID := range servedIDs { + state, ok := cs.conns[connID] + if !ok { + continue + } + + state.mu.Lock() + // Remove dead source from sources list + for i, src := range state.sources { + if src == deadSource { + state.sources = append(state.sources[:i], state.sources[i+1:]...) + break + } + } + + // If no sources left, fully clean up this connState + if len(state.sources) == 0 { + state.mu.Unlock() + if state.cancel != nil { + state.cancel() + } + if state.target != nil { + state.target.Close() + } + delete(cs.conns, connID) + cleaned++ + } else { + state.mu.Unlock() + } + } + + if cleaned > 0 { + log.Printf("[central] %s: source disconnected, cleaned %d orphaned connections", remoteAddr, cleaned) + } +} + +// getSourceConn returns the sourceConn wrapper for a raw net.Conn, +// creating one if it doesn't exist yet. +func (cs *centralServer) getSourceConn(conn net.Conn) *sourceConn { + cs.sourceMu.Lock() + defer cs.sourceMu.Unlock() + sc, ok := cs.sourceMap[conn] + if !ok { + sc = &sourceConn{conn: conn} + cs.sourceMap[conn] = sc + } + return sc +} + +// removeSourceConn removes the sourceConn wrapper when the raw conn dies. +func (cs *centralServer) removeSourceConn(conn net.Conn) { + cs.sourceMu.Lock() + delete(cs.sourceMap, conn) + cs.sourceMu.Unlock() +} + +func isClosedConnErr(err error) bool { + if err == nil { + return false + } + s := err.Error() + return strings.Contains(s, "use of closed network connection") || + strings.Contains(s, "connection reset by peer") +} + +func (cs *centralServer) parseHeader(hdr [tunnel.HeaderSize]byte, conn net.Conn, remoteAddr string) *tunnel.Frame { + length := binary.BigEndian.Uint16(hdr[9:11]) + if length > tunnel.MaxPayloadSize { + log.Printf("[central] %s: frame payload too large: %d", remoteAddr, length) + return nil + } + var payload []byte + if length > 0 { + payload = make([]byte, length) + if _, err := io.ReadFull(conn, payload); err != nil { + return nil + } + } + return &tunnel.Frame{ + ConnID: binary.BigEndian.Uint32(hdr[0:4]), + SeqNum: binary.BigEndian.Uint32(hdr[4:8]), + Flags: hdr[8], + Payload: payload, + } +} + +func (cs *centralServer) dispatchFrame(frame *tunnel.Frame, source *sourceConn) { + if frame.IsSYN() { + cs.handleSYN(frame, source) + return + } + if frame.IsFIN() { + cs.handleFIN(frame) + return + } + if frame.IsRST() { + cs.handleRST(frame) + return + } + cs.handleData(frame, source) +} + +func (cs *centralServer) cleanupLoop() { + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + for range ticker.C { + cs.mu.Lock() + now := time.Now() + cleaned := 0 + for id, state := range cs.conns { + state.mu.Lock() + shouldClean := false + + // No upstream established after 60 seconds = stuck + if state.target == nil && now.Sub(state.created) > 60*time.Second { + shouldClean = true + } + // No sources left = all tunnel connections died + if len(state.sources) == 0 && now.Sub(state.created) > 30*time.Second { + shouldClean = true + } + // Idle for too long — no data sent or received in 60 seconds. + // This catches stuck connections where both sides stopped talking. + if now.Sub(state.lastActive) > 60*time.Second { + shouldClean = true + } + + state.mu.Unlock() + if shouldClean { + if state.cancel != nil { + state.cancel() + } + if state.target != nil { + state.target.Close() + } + delete(cs.conns, id) + cleaned++ + } + } + if cleaned > 0 { + log.Printf("[central] cleanup: removed %d stale connections (%d active)", cleaned, len(cs.conns)) + } + cs.mu.Unlock() + } +} + +func (cs *centralServer) removeConn(connID uint32) { + cs.mu.Lock() + state, ok := cs.conns[connID] + if ok { + delete(cs.conns, connID) + } + cs.mu.Unlock() + if ok && state.cancel != nil { + state.cancel() + } +} + +func (cs *centralServer) closeAll() { + cs.mu.Lock() + defer cs.mu.Unlock() + for id, state := range cs.conns { + if state.cancel != nil { + state.cancel() + } + if state.target != nil { + state.target.Close() + } + delete(cs.conns, id) + } +} diff --git a/cmd/centralserver/state.go b/cmd/centralserver/state.go new file mode 100644 index 0000000..94370bb --- /dev/null +++ b/cmd/centralserver/state.go @@ -0,0 +1,114 @@ +package main + +import ( + "context" + "log" + "net" + "sync" + "time" + + "github.com/ParsaKSH/SlipStream-Plus/internal/tunnel" +) + +// t +// sourceConn wraps a tunnel connection with a write mutex so that +// concurrent WriteFrame calls from different ConnIDs don't interleave bytes. +type sourceConn struct { + conn net.Conn + writeMu sync.Mutex +} + +// writeTimeout prevents tunnel writes from blocking forever on congested TCP. +const sourceWriteTimeout = 10 * time.Second + +func (sc *sourceConn) WriteFrame(f *tunnel.Frame) error { + sc.writeMu.Lock() + defer sc.writeMu.Unlock() + sc.conn.SetWriteDeadline(time.Now().Add(sourceWriteTimeout)) + err := tunnel.WriteFrame(sc.conn, f) + sc.conn.SetWriteDeadline(time.Time{}) + return err +} + +// connState tracks a single reassembled connection. +type connState struct { + mu sync.Mutex + target net.Conn // connection to the SOCKS upstream + reorderer *tunnel.Reorderer + txSeq uint32 // next sequence number for reverse data + cancel context.CancelFunc + created time.Time + lastActive time.Time // last time data was sent or received + + // Sources: all tunnel connections that can carry reverse data. + // We round-robin responses across them (not broadcast). + sources []*sourceConn + sourceIdx int + + // Async write queue: handleData sends chunks here instead of writing + // to target synchronously (which would block the frame dispatch loop). + // The upstreamWriter goroutine drains this channel and writes to target. + writeCh chan []byte +} + +// centralServer manages all active connections. +type centralServer struct { + socksUpstream string + + mu sync.RWMutex + conns map[uint32]*connState // ConnID → state + + // sourceMu protects the sources map (net.Conn → *sourceConn). + sourceMu sync.Mutex + sourceMap map[net.Conn]*sourceConn +} + +// upstreamWriteTimeout prevents writes to Xray upstream from blocking forever. +const upstreamWriteTimeout = 10 * time.Second + +// sendFrame picks ONE source via round-robin and writes the frame. +// CRITICAL: state.mu is only held briefly to pick the source and advance +// the index — the actual TCP write happens OUTSIDE the lock to prevent +// cascading lock contention that freezes frame dispatch for other ConnIDs. +func (cs *centralServer) sendFrame(connID uint32, frame *tunnel.Frame) { + cs.mu.RLock() + state, ok := cs.conns[connID] + cs.mu.RUnlock() + if !ok { + return + } + + // Snapshot sources under lock, then write outside lock + state.mu.Lock() + n := len(state.sources) + if n == 0 { + state.mu.Unlock() + return + } + // Build ordered list starting from current sourceIdx + sources := make([]*sourceConn, n) + startIdx := state.sourceIdx % n + state.sourceIdx++ + for i := 0; i < n; i++ { + sources[i] = state.sources[(startIdx+i)%n] + } + state.mu.Unlock() + + // Try each source — write happens outside state.mu + for _, sc := range sources { + if err := sc.WriteFrame(frame); err != nil { + // Remove dead source under lock + state.mu.Lock() + for i, s := range state.sources { + if s == sc { + state.sources = append(state.sources[:i], state.sources[i+1:]...) + break + } + } + state.mu.Unlock() + continue + } + return // success + } + log.Printf("[central] conn=%d: all sources failed", connID) +} diff --git a/cmd/centralserver/upstream.go b/cmd/centralserver/upstream.go new file mode 100644 index 0000000..92458fe --- /dev/null +++ b/cmd/centralserver/upstream.go @@ -0,0 +1,346 @@ +package main + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "log" + "net" + "time" + + "github.com/ParsaKSH/SlipStream-Plus/internal/tunnel" +) + +func (cs *centralServer) handleSYN(frame *tunnel.Frame, source *sourceConn) { + connID := frame.ConnID + + cs.mu.Lock() + if existing, ok := cs.conns[connID]; ok { + // Another instance's SYN → just register additional source + existing.mu.Lock() + existing.sources = append(existing.sources, source) + existing.mu.Unlock() + cs.mu.Unlock() + return + } + + atyp, addr, port, err := tunnel.DecodeSYNPayload(frame.Payload) + if err != nil { + cs.mu.Unlock() + log.Printf("[central] conn=%d: bad SYN payload: %v", connID, err) + return + } + + var targetAddr string + switch atyp { + case 0x01: + targetAddr = fmt.Sprintf("%s:%d", net.IP(addr).String(), binary.BigEndian.Uint16(port)) + case 0x03: + domLen := int(addr[0]) + targetAddr = fmt.Sprintf("%s:%d", string(addr[1:1+domLen]), binary.BigEndian.Uint16(port)) + case 0x04: + targetAddr = fmt.Sprintf("[%s]:%d", net.IP(addr).String(), binary.BigEndian.Uint16(port)) + } + + now := time.Now() + ctx, cancel := context.WithCancel(context.Background()) + state := &connState{ + reorderer: tunnel.NewReordererAt(frame.SeqNum + 1), // skip SYN's SeqNum + sources: []*sourceConn{source}, + cancel: cancel, + created: now, + lastActive: now, + } + cs.conns[connID] = state + cs.mu.Unlock() + + log.Printf("[central] conn=%d: SYN → target=%s", connID, targetAddr) + go cs.connectUpstream(ctx, connID, state, atyp, addr, port, targetAddr) +} + +func (cs *centralServer) connectUpstream(ctx context.Context, connID uint32, state *connState, + atyp byte, addr, port []byte, targetAddr string) { + + upConn, err := net.DialTimeout("tcp", cs.socksUpstream, 10*time.Second) + if err != nil { + log.Printf("[central] conn=%d: upstream dial failed: %v", connID, err) + cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) + cs.removeConn(connID) + return + } + if tc, ok := upConn.(*net.TCPConn); ok { + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(30 * time.Second) + tc.SetNoDelay(true) + } + + // SOCKS5 handshake + pipelined := make([]byte, 0, 3+4+len(addr)+2) + pipelined = append(pipelined, 0x05, 0x01, 0x00) + pipelined = append(pipelined, 0x05, 0x01, 0x00, atyp) + pipelined = append(pipelined, addr...) + pipelined = append(pipelined, port...) + + if _, err := upConn.Write(pipelined); err != nil { + log.Printf("[central] conn=%d: upstream write failed: %v", connID, err) + upConn.Close() + cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) + cs.removeConn(connID) + return + } + + // Read greeting response (2 bytes) + CONNECT response header (4 bytes) + resp := make([]byte, 6) + if _, err := io.ReadFull(upConn, resp); err != nil { + log.Printf("[central] conn=%d: upstream response read failed: %v", connID, err) + upConn.Close() + cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) + cs.removeConn(connID) + return + } + + // Check CONNECT result BEFORE draining bind address. + // resp[3] = REP field (0x00 = success). If non-zero, upstream may close + // without sending bind address, so don't try to drain it. + if resp[3] != 0x00 { + log.Printf("[central] conn=%d: upstream CONNECT rejected: 0x%02x", connID, resp[3]) + upConn.Close() + cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) + cs.removeConn(connID) + return + } + + // Drain bind address (only on success) + switch resp[5] { + case 0x01: + io.ReadFull(upConn, make([]byte, 6)) + case 0x03: + lb := make([]byte, 1) + io.ReadFull(upConn, lb) + io.ReadFull(upConn, make([]byte, int(lb[0])+2)) + case 0x04: + io.ReadFull(upConn, make([]byte, 18)) + default: + io.ReadFull(upConn, make([]byte, 6)) + } + + // Create async write channel and set target under lock + writeCh := make(chan []byte, 256) + + state.mu.Lock() + state.target = upConn + state.writeCh = writeCh + + // Drain any data that arrived before upstream was ready + var flushChunks [][]byte + for { + data := state.reorderer.Next() + if data == nil { + break + } + flushChunks = append(flushChunks, data) + } + state.mu.Unlock() + + // Start async writer goroutine — all writes to upstream go through writeCh + go cs.upstreamWriter(ctx, connID, upConn, writeCh) + + // Send flush data through the channel (channel is empty, won't block) + for _, data := range flushChunks { + select { + case writeCh <- data: + default: + } + } + + log.Printf("[central] conn=%d: upstream connected to %s", connID, targetAddr) + + // Read upstream data and send back through tunnel (NO broadcast — round-robin) + cs.relayUpstreamToTunnel(ctx, connID, state, upConn) +} + +func (cs *centralServer) relayUpstreamToTunnel(ctx context.Context, connID uint32, + state *connState, upstream net.Conn) { + + defer func() { + upstream.Close() + cs.sendFrame(connID, &tunnel.Frame{ + ConnID: connID, + Flags: tunnel.FlagFIN | tunnel.FlagReverse, + }) + cs.removeConn(connID) + }() + + buf := make([]byte, tunnel.MaxPayloadSize) + for { + select { + case <-ctx.Done(): + return + default: + } + + n, err := upstream.Read(buf) + if n > 0 { + state.mu.Lock() + seq := state.txSeq + state.txSeq++ + state.lastActive = time.Now() + state.mu.Unlock() + + frame := &tunnel.Frame{ + ConnID: connID, + SeqNum: seq, + Flags: tunnel.FlagReverse, + Payload: make([]byte, n), + } + copy(frame.Payload, buf[:n]) + + // Send through ONE source (round-robin), NOT all + cs.sendFrame(connID, frame) + } + if err != nil { + if err != io.EOF && !isClosedConnErr(err) { + log.Printf("[central] conn=%d: upstream read error: %v", connID, err) + } + return + } + } +} + +// upstreamWriter is a dedicated goroutine that drains writeCh and writes +// to the upstream (Xray) connection. This decouples upstream write speed +// from frame dispatch speed — handleData never blocks the frame loop. +func (cs *centralServer) upstreamWriter(ctx context.Context, connID uint32, upstream net.Conn, writeCh chan []byte) { + for { + select { + case <-ctx.Done(): + // Context cancelled (removeConn or cleanup) — drain any remaining data best-effort + for { + select { + case data := <-writeCh: + upstream.SetWriteDeadline(time.Now().Add(2 * time.Second)) + upstream.Write(data) + default: + return + } + } + case data, ok := <-writeCh: + if !ok { + return + } + upstream.SetWriteDeadline(time.Now().Add(upstreamWriteTimeout)) + if _, err := upstream.Write(data); err != nil { + log.Printf("[central] conn=%d: upstream write failed: %v", connID, err) + upstream.SetWriteDeadline(time.Time{}) + // Drain channel to unblock any senders, then exit + for { + select { + case <-writeCh: + default: + return + } + } + } + upstream.SetWriteDeadline(time.Time{}) + } + } +} + +func (cs *centralServer) handleData(frame *tunnel.Frame, source *sourceConn) { + cs.mu.RLock() + state, ok := cs.conns[frame.ConnID] + cs.mu.RUnlock() + if !ok { + return + } + + // Insert frame and collect ready data under lock (fast, no I/O) + state.mu.Lock() + + // If this source isn't known yet (e.g., after tunnel recycling), add it. + found := false + for _, s := range state.sources { + if s == source { + found = true + break + } + } + if !found { + state.sources = append(state.sources, source) + } + + state.lastActive = time.Now() + state.reorderer.Insert(frame.SeqNum, frame.Payload) + + // If upstream not connected yet, data stays in reorderer for later flush + writeCh := state.writeCh + if writeCh == nil { + state.mu.Unlock() + return + } + + // Drain all ready data from reorderer + var chunks [][]byte + for { + data := state.reorderer.Next() + if data == nil { + break + } + chunks = append(chunks, data) + } + state.mu.Unlock() + + // Send to async writer (non-blocking — MUST NOT stall the frame dispatch loop) + for _, data := range chunks { + select { + case writeCh <- data: + default: + log.Printf("[central] conn=%d: write queue full, dropping %d bytes", frame.ConnID, len(data)) + } + } +} + +func (cs *centralServer) handleFIN(frame *tunnel.Frame) { + cs.mu.RLock() + state, ok := cs.conns[frame.ConnID] + cs.mu.RUnlock() + if !ok { + return + } + + // Drain remaining data and send to async writer (non-blocking) + state.mu.Lock() + var chunks [][]byte + writeCh := state.writeCh + if writeCh != nil { + for { + data := state.reorderer.Next() + if data == nil { + break + } + chunks = append(chunks, data) + } + } + state.writeCh = nil // prevent further sends from handleData + state.mu.Unlock() + + // Send remaining data to writer (non-blocking, best-effort) + if writeCh != nil { + for _, data := range chunks { + select { + case writeCh <- data: + default: + } + } + } + + // removeConn cancels ctx → upstreamWriter exits → upstream closed by relayUpstreamToTunnel + cs.removeConn(frame.ConnID) + log.Printf("[central] conn=%d: FIN received, cleaned up", frame.ConnID) +} + +func (cs *centralServer) handleRST(frame *tunnel.Frame) { + cs.removeConn(frame.ConnID) + log.Printf("[central] conn=%d: RST received", frame.ConnID) +} diff --git a/internal/config/config.go b/internal/config/config.go index 81426ee..fd60eef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,8 +4,6 @@ import ( "encoding/json" "fmt" "os" - "strconv" - "strings" "time" ) @@ -96,129 +94,6 @@ type InstanceConfig struct { SSHKey string `json:"ssh_key,omitempty"` } -func (ic *InstanceConfig) ParsePorts() ([]int, error) { - raw := strings.TrimSpace(string(ic.Port)) - if port, err := strconv.Atoi(raw); err == nil { - return []int{port}, nil - } - var portStr string - if err := json.Unmarshal(ic.Port, &portStr); err == nil { - return parsePortRange(portStr) - } - var portNum int - if err := json.Unmarshal(ic.Port, &portNum); err == nil { - return []int{portNum}, nil - } - return nil, fmt.Errorf("invalid port value: %s", raw) -} - -func parsePortRange(s string) ([]int, error) { - s = strings.TrimSpace(s) - if port, err := strconv.Atoi(s); err == nil { - return []int{port}, nil - } - parts := strings.SplitN(s, "-", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid port range: %s", s) - } - start, err := strconv.Atoi(strings.TrimSpace(parts[0])) - if err != nil { - return nil, fmt.Errorf("invalid port range start: %s", parts[0]) - } - end, err := strconv.Atoi(strings.TrimSpace(parts[1])) - if err != nil { - return nil, fmt.Errorf("invalid port range end: %s", parts[1]) - } - if start > end { - return nil, fmt.Errorf("port range start (%d) must be <= end (%d)", start, end) - } - if start <= 0 || end > 65535 { - return nil, fmt.Errorf("ports must be between 1 and 65535") - } - ports := make([]int, 0, end-start+1) - for p := start; p <= end; p++ { - ports = append(ports, p) - } - return ports, nil -} - -type ExpandedInstance struct { - Domain string - Resolver string - Port int - Mode string // "socks" or "ssh" (per-instance) - Authoritative bool - Cert string - SSHPort int - SSHUser string - SSHPassword string - SSHKey string - OriginalIndex int - ReplicaIndex int -} - -func (c *Config) ExpandInstances() ([]ExpandedInstance, error) { - var result []ExpandedInstance - - for i, inst := range c.Instances { - replicas := inst.Replicas - if replicas <= 0 { - replicas = 1 - } - - ports, err := inst.ParsePorts() - if err != nil { - return nil, fmt.Errorf("instances[%d]: %w", i, err) - } - - if len(ports) == 1 && replicas > 1 { - basePort := ports[0] - ports = make([]int, replicas) - for r := 0; r < replicas; r++ { - ports[r] = basePort + r - } - } - - if len(ports) < replicas { - return nil, fmt.Errorf("instances[%d]: port range provides %d ports but replicas=%d", - i, len(ports), replicas) - } - - // Resolve mode: per-instance overrides default "socks" - mode := inst.Mode - if mode == "" { - mode = "socks" - } - - for r := 0; r < replicas; r++ { - result = append(result, ExpandedInstance{ - Domain: inst.Domain, - Resolver: inst.Resolver, - Port: ports[r], - Mode: mode, - Authoritative: inst.Authoritative, - Cert: inst.Cert, - SSHPort: inst.SSHPort, - SSHUser: inst.SSHUser, - SSHPassword: inst.SSHPassword, - SSHKey: inst.SSHKey, - OriginalIndex: i, - ReplicaIndex: r, - }) - } - } - - portSet := make(map[int]bool) - for _, ei := range result { - if portSet[ei.Port] { - return nil, fmt.Errorf("duplicate port %d after expansion", ei.Port) - } - portSet[ei.Port] = true - } - - return result, nil -} - func (c *HealthCheckConfig) IntervalDuration() time.Duration { d, err := time.ParseDuration(c.Interval) if err != nil { diff --git a/internal/config/expand.go b/internal/config/expand.go new file mode 100644 index 0000000..1c50d37 --- /dev/null +++ b/internal/config/expand.go @@ -0,0 +1,131 @@ +package config + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" +) + +type ExpandedInstance struct { + Domain string + Resolver string + Port int + Mode string // "socks" or "ssh" (per-instance) + Authoritative bool + Cert string + SSHPort int + SSHUser string + SSHPassword string + SSHKey string + OriginalIndex int + ReplicaIndex int +} + +func (ic *InstanceConfig) ParsePorts() ([]int, error) { + raw := strings.TrimSpace(string(ic.Port)) + if port, err := strconv.Atoi(raw); err == nil { + return []int{port}, nil + } + var portStr string + if err := json.Unmarshal(ic.Port, &portStr); err == nil { + return parsePortRange(portStr) + } + var portNum int + if err := json.Unmarshal(ic.Port, &portNum); err == nil { + return []int{portNum}, nil + } + return nil, fmt.Errorf("invalid port value: %s", raw) +} + +func parsePortRange(s string) ([]int, error) { + s = strings.TrimSpace(s) + if port, err := strconv.Atoi(s); err == nil { + return []int{port}, nil + } + parts := strings.SplitN(s, "-", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid port range: %s", s) + } + start, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + return nil, fmt.Errorf("invalid port range start: %s", parts[0]) + } + end, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + return nil, fmt.Errorf("invalid port range end: %s", parts[1]) + } + if start > end { + return nil, fmt.Errorf("port range start (%d) must be <= end (%d)", start, end) + } + if start <= 0 || end > 65535 { + return nil, fmt.Errorf("ports must be between 1 and 65535") + } + ports := make([]int, 0, end-start+1) + for p := start; p <= end; p++ { + ports = append(ports, p) + } + return ports, nil +} + +func (c *Config) ExpandInstances() ([]ExpandedInstance, error) { + var result []ExpandedInstance + + for i, inst := range c.Instances { + replicas := inst.Replicas + if replicas <= 0 { + replicas = 1 + } + + ports, err := inst.ParsePorts() + if err != nil { + return nil, fmt.Errorf("instances[%d]: %w", i, err) + } + + if len(ports) == 1 && replicas > 1 { + basePort := ports[0] + ports = make([]int, replicas) + for r := 0; r < replicas; r++ { + ports[r] = basePort + r + } + } + + if len(ports) < replicas { + return nil, fmt.Errorf("instances[%d]: port range provides %d ports but replicas=%d", + i, len(ports), replicas) + } + + // Resolve mode: per-instance overrides default "socks" + mode := inst.Mode + if mode == "" { + mode = "socks" + } + + for r := 0; r < replicas; r++ { + result = append(result, ExpandedInstance{ + Domain: inst.Domain, + Resolver: inst.Resolver, + Port: ports[r], + Mode: mode, + Authoritative: inst.Authoritative, + Cert: inst.Cert, + SSHPort: inst.SSHPort, + SSHUser: inst.SSHUser, + SSHPassword: inst.SSHPassword, + SSHKey: inst.SSHKey, + OriginalIndex: i, + ReplicaIndex: r, + }) + } + } + + portSet := make(map[int]bool) + for _, ei := range result { + if portSet[ei.Port] { + return nil, fmt.Errorf("duplicate port %d after expansion", ei.Port) + } + portSet[ei.Port] = true + } + + return result, nil +} diff --git a/internal/gui/handlers.go b/internal/gui/handlers.go new file mode 100644 index 0000000..c58ec6b --- /dev/null +++ b/internal/gui/handlers.go @@ -0,0 +1,319 @@ +package gui + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "strconv" + "strings" + "syscall" + "time" + + "github.com/ParsaKSH/SlipStream-Plus/internal/config" + "github.com/ParsaKSH/SlipStream-Plus/internal/health" + "github.com/ParsaKSH/SlipStream-Plus/internal/users" +) + +func (s *APIServer) handleStatus(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + status := s.manager.StatusAll() + json.NewEncoder(w).Encode(map[string]any{"instances": status, "strategy": s.cfg.Strategy, "socks": s.cfg.Socks.Listen}) +} + +func (s *APIServer) handleBandwidth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + s.bwMu.RLock() + data := make([]bwPoint, len(s.bwHistory)) + copy(data, s.bwHistory) + s.bwMu.RUnlock() + json.NewEncoder(w).Encode(data) +} + +func (s *APIServer) handleConfig(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + if r.Method == "OPTIONS" { + return + } + if r.Method == "GET" { + json.NewEncoder(w).Encode(s.cfg) + return + } + if r.Method == "POST" { + var newCfg config.Config + if err := json.NewDecoder(r.Body).Decode(&newCfg); err != nil { + http.Error(w, fmt.Sprintf("invalid JSON: %v", err), http.StatusBadRequest) + return + } + if err := newCfg.Validate(); err != nil { + http.Error(w, fmt.Sprintf("validation: %v", err), http.StatusBadRequest) + return + } + if err := newCfg.Save(s.configPath); err != nil { + http.Error(w, fmt.Sprintf("save: %v", err), http.StatusInternalServerError) + return + } + *s.cfg = newCfg + json.NewEncoder(w).Encode(map[string]string{"status": "saved"}) + return + } + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) +} + +func (s *APIServer) handleReload(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + if r.Method == "OPTIONS" { + return + } + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + newCfg, err := config.Load(s.configPath) + if err != nil { + http.Error(w, fmt.Sprintf("load config: %v", err), http.StatusBadRequest) + return + } + if newCfg.SlipstreamBinary == "" { + newCfg.SlipstreamBinary = s.cfg.SlipstreamBinary + } + + // Stop health checker before reload to prevent race conditions + if s.checker != nil { + s.checker.Stop() + } + + if err := s.manager.Reload(newCfg); err != nil { + http.Error(w, fmt.Sprintf("reload: %v", err), http.StatusInternalServerError) + return + } + *s.cfg = *newCfg + + // Restart health checker with fresh state + if s.checker != nil { + s.checker = health.NewChecker(s.manager, &s.cfg.HealthCheck) + s.checker.Start() + } + + // Reload user manager if users changed + if s.userMgr != nil && len(s.cfg.Socks.Users) > 0 { + s.userMgr = users.NewManager(s.cfg.Socks.Users) + } + + log.Printf("[gui] config reloaded and instances restarted") + json.NewEncoder(w).Encode(map[string]string{"status": "reloaded"}) +} + +func (s *APIServer) handleRestart(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + if r.Method == "OPTIONS" { + return + } + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + json.NewEncoder(w).Encode(map[string]string{"status": "restarting"}) + + // Schedule restart after response is sent + go func() { + time.Sleep(500 * time.Millisecond) + log.Printf("[gui] full restart requested, shutting down before re-exec...") + + // Stop health checker first + if s.checker != nil { + s.checker.Stop() + } + + // Shutdown manager — this kills all child processes and frees ports + s.manager.Shutdown() + + time.Sleep(200 * time.Millisecond) // brief pause for port release + + exe, err := os.Executable() + if err != nil { + log.Printf("[gui] restart failed: %v", err) + return + } + syscall.Exec(exe, os.Args, os.Environ()) + }() +} + +func (s *APIServer) handleInstance(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + if r.Method == "OPTIONS" { + return + } + parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") + if len(parts) < 4 || parts[3] != "restart" { + http.Error(w, "not found", http.StatusNotFound) + return + } + id, err := strconv.Atoi(parts[2]) + if err != nil { + http.Error(w, "invalid id", http.StatusBadRequest) + return + } + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := s.manager.RestartInstance(id); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + json.NewEncoder(w).Encode(map[string]string{"status": "restarting"}) +} + +func (s *APIServer) handleUsers(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + if r.Method == "OPTIONS" { + return + } + if r.Method == "GET" { + if s.userMgr == nil { + json.NewEncoder(w).Encode([]any{}) + return + } + allUsers := s.userMgr.AllUsers() + result := make([]users.UserStatus, len(allUsers)) + for i, u := range allUsers { + result[i] = u.Status() + } + json.NewEncoder(w).Encode(result) + return + } + if r.Method == "POST" { + // Add new user + var uc config.UserConfig + if err := json.NewDecoder(r.Body).Decode(&uc); err != nil { + http.Error(w, "invalid JSON", http.StatusBadRequest) + return + } + if uc.Username == "" || uc.Password == "" { + http.Error(w, "username and password required", http.StatusBadRequest) + return + } + // Check duplicate + for _, u := range s.cfg.Socks.Users { + if u.Username == uc.Username { + http.Error(w, "user already exists", http.StatusConflict) + return + } + } + s.cfg.Socks.Users = append(s.cfg.Socks.Users, uc) + if err := s.cfg.Save(s.configPath); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + s.userMgr = users.NewManager(s.cfg.Socks.Users) + json.NewEncoder(w).Encode(map[string]string{"status": "added"}) + return + } + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) +} + +func (s *APIServer) handleUserAction(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + if r.Method == "OPTIONS" { + return + } + + parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") + if len(parts) < 4 { + http.Error(w, "not found", http.StatusNotFound) + return + } + username := parts[2] + action := parts[3] + + switch action { + case "reset": + if s.userMgr == nil { + http.Error(w, "no users", http.StatusNotFound) + return + } + user := s.userMgr.GetUser(username) + if user == nil { + http.Error(w, "not found", http.StatusNotFound) + return + } + user.ResetUsedBytes() + json.NewEncoder(w).Encode(map[string]string{"status": "reset"}) + + case "edit": + var uc config.UserConfig + if err := json.NewDecoder(r.Body).Decode(&uc); err != nil { + http.Error(w, "invalid JSON", http.StatusBadRequest) + return + } + found := false + for i, u := range s.cfg.Socks.Users { + if u.Username == username { + uc.Username = username // can't change username + s.cfg.Socks.Users[i] = uc + found = true + break + } + } + if !found { + http.Error(w, "not found", http.StatusNotFound) + return + } + if err := s.cfg.Save(s.configPath); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + s.userMgr = users.NewManager(s.cfg.Socks.Users) + json.NewEncoder(w).Encode(map[string]string{"status": "updated"}) + + case "delete": + newUsers := make([]config.UserConfig, 0) + for _, u := range s.cfg.Socks.Users { + if u.Username != username { + newUsers = append(newUsers, u) + } + } + s.cfg.Socks.Users = newUsers + if err := s.cfg.Save(s.configPath); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if len(s.cfg.Socks.Users) > 0 { + s.userMgr = users.NewManager(s.cfg.Socks.Users) + } else { + s.userMgr = nil + } + json.NewEncoder(w).Encode(map[string]string{"status": "deleted"}) + + default: + http.Error(w, "not found", http.StatusNotFound) + } +} + +func (s *APIServer) handleDashboard(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Write([]byte(dashboardHTML)) +} diff --git a/internal/gui/server.go b/internal/gui/server.go index 862e961..4d6fe51 100644 --- a/internal/gui/server.go +++ b/internal/gui/server.go @@ -1,15 +1,9 @@ package gui import ( - "encoding/json" - "fmt" "log" "net/http" - "os" - "strconv" - "strings" "sync" - "syscall" "time" "github.com/ParsaKSH/SlipStream-Plus/internal/config" @@ -112,305 +106,3 @@ func (s *APIServer) Start() error { }() return nil } - -func (s *APIServer) handleStatus(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - status := s.manager.StatusAll() - json.NewEncoder(w).Encode(map[string]any{"instances": status, "strategy": s.cfg.Strategy, "socks": s.cfg.Socks.Listen}) -} - -func (s *APIServer) handleBandwidth(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - s.bwMu.RLock() - data := make([]bwPoint, len(s.bwHistory)) - copy(data, s.bwHistory) - s.bwMu.RUnlock() - json.NewEncoder(w).Encode(data) -} - -func (s *APIServer) handleConfig(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") - if r.Method == "OPTIONS" { - return - } - if r.Method == "GET" { - json.NewEncoder(w).Encode(s.cfg) - return - } - if r.Method == "POST" { - var newCfg config.Config - if err := json.NewDecoder(r.Body).Decode(&newCfg); err != nil { - http.Error(w, fmt.Sprintf("invalid JSON: %v", err), http.StatusBadRequest) - return - } - if err := newCfg.Validate(); err != nil { - http.Error(w, fmt.Sprintf("validation: %v", err), http.StatusBadRequest) - return - } - if err := newCfg.Save(s.configPath); err != nil { - http.Error(w, fmt.Sprintf("save: %v", err), http.StatusInternalServerError) - return - } - *s.cfg = newCfg - json.NewEncoder(w).Encode(map[string]string{"status": "saved"}) - return - } - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) -} - -func (s *APIServer) handleReload(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") - if r.Method == "OPTIONS" { - return - } - if r.Method != "POST" { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - newCfg, err := config.Load(s.configPath) - if err != nil { - http.Error(w, fmt.Sprintf("load config: %v", err), http.StatusBadRequest) - return - } - if newCfg.SlipstreamBinary == "" { - newCfg.SlipstreamBinary = s.cfg.SlipstreamBinary - } - - // Stop health checker before reload to prevent race conditions - if s.checker != nil { - s.checker.Stop() - } - - if err := s.manager.Reload(newCfg); err != nil { - http.Error(w, fmt.Sprintf("reload: %v", err), http.StatusInternalServerError) - return - } - *s.cfg = *newCfg - - // Restart health checker with fresh state - if s.checker != nil { - s.checker = health.NewChecker(s.manager, &s.cfg.HealthCheck) - s.checker.Start() - } - - // Reload user manager if users changed - if s.userMgr != nil && len(s.cfg.Socks.Users) > 0 { - s.userMgr = users.NewManager(s.cfg.Socks.Users) - } - - log.Printf("[gui] config reloaded and instances restarted") - json.NewEncoder(w).Encode(map[string]string{"status": "reloaded"}) -} - -func (s *APIServer) handleRestart(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") - if r.Method == "OPTIONS" { - return - } - if r.Method != "POST" { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - - json.NewEncoder(w).Encode(map[string]string{"status": "restarting"}) - - // Schedule restart after response is sent - go func() { - time.Sleep(500 * time.Millisecond) - log.Printf("[gui] full restart requested, shutting down before re-exec...") - - // Stop health checker first - if s.checker != nil { - s.checker.Stop() - } - - // Shutdown manager — this kills all child processes and frees ports - s.manager.Shutdown() - - time.Sleep(200 * time.Millisecond) // brief pause for port release - - exe, err := os.Executable() - if err != nil { - log.Printf("[gui] restart failed: %v", err) - return - } - syscall.Exec(exe, os.Args, os.Environ()) - }() -} - -func (s *APIServer) handleInstance(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") - if r.Method == "OPTIONS" { - return - } - parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") - if len(parts) < 4 || parts[3] != "restart" { - http.Error(w, "not found", http.StatusNotFound) - return - } - id, err := strconv.Atoi(parts[2]) - if err != nil { - http.Error(w, "invalid id", http.StatusBadRequest) - return - } - if r.Method != "POST" { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - if err := s.manager.RestartInstance(id); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - json.NewEncoder(w).Encode(map[string]string{"status": "restarting"}) -} - -func (s *APIServer) handleUsers(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") - if r.Method == "OPTIONS" { - return - } - if r.Method == "GET" { - if s.userMgr == nil { - json.NewEncoder(w).Encode([]any{}) - return - } - allUsers := s.userMgr.AllUsers() - result := make([]users.UserStatus, len(allUsers)) - for i, u := range allUsers { - result[i] = u.Status() - } - json.NewEncoder(w).Encode(result) - return - } - if r.Method == "POST" { - // Add new user - var uc config.UserConfig - if err := json.NewDecoder(r.Body).Decode(&uc); err != nil { - http.Error(w, "invalid JSON", http.StatusBadRequest) - return - } - if uc.Username == "" || uc.Password == "" { - http.Error(w, "username and password required", http.StatusBadRequest) - return - } - // Check duplicate - for _, u := range s.cfg.Socks.Users { - if u.Username == uc.Username { - http.Error(w, "user already exists", http.StatusConflict) - return - } - } - s.cfg.Socks.Users = append(s.cfg.Socks.Users, uc) - if err := s.cfg.Save(s.configPath); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - s.userMgr = users.NewManager(s.cfg.Socks.Users) - json.NewEncoder(w).Encode(map[string]string{"status": "added"}) - return - } - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) -} - -func (s *APIServer) handleUserAction(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") - if r.Method == "OPTIONS" { - return - } - - parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") - if len(parts) < 4 { - http.Error(w, "not found", http.StatusNotFound) - return - } - username := parts[2] - action := parts[3] - - switch action { - case "reset": - if s.userMgr == nil { - http.Error(w, "no users", http.StatusNotFound) - return - } - user := s.userMgr.GetUser(username) - if user == nil { - http.Error(w, "not found", http.StatusNotFound) - return - } - user.ResetUsedBytes() - json.NewEncoder(w).Encode(map[string]string{"status": "reset"}) - - case "edit": - var uc config.UserConfig - if err := json.NewDecoder(r.Body).Decode(&uc); err != nil { - http.Error(w, "invalid JSON", http.StatusBadRequest) - return - } - found := false - for i, u := range s.cfg.Socks.Users { - if u.Username == username { - uc.Username = username // can't change username - s.cfg.Socks.Users[i] = uc - found = true - break - } - } - if !found { - http.Error(w, "not found", http.StatusNotFound) - return - } - if err := s.cfg.Save(s.configPath); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - s.userMgr = users.NewManager(s.cfg.Socks.Users) - json.NewEncoder(w).Encode(map[string]string{"status": "updated"}) - - case "delete": - newUsers := make([]config.UserConfig, 0) - for _, u := range s.cfg.Socks.Users { - if u.Username != username { - newUsers = append(newUsers, u) - } - } - s.cfg.Socks.Users = newUsers - if err := s.cfg.Save(s.configPath); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if len(s.cfg.Socks.Users) > 0 { - s.userMgr = users.NewManager(s.cfg.Socks.Users) - } else { - s.userMgr = nil - } - json.NewEncoder(w).Encode(map[string]string{"status": "deleted"}) - - default: - http.Error(w, "not found", http.StatusNotFound) - } -} - -func (s *APIServer) handleDashboard(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Write([]byte(dashboardHTML)) -} diff --git a/internal/health/checker.go b/internal/health/checker.go index 4afbbb2..5760320 100644 --- a/internal/health/checker.go +++ b/internal/health/checker.go @@ -2,10 +2,6 @@ package health import ( "context" - "crypto/rand" - "encoding/binary" - "fmt" - "io" "log" "net" "sync" @@ -14,7 +10,6 @@ import ( "github.com/ParsaKSH/SlipStream-Plus/internal/config" "github.com/ParsaKSH/SlipStream-Plus/internal/engine" - "github.com/ParsaKSH/SlipStream-Plus/internal/tunnel" ) // An instance is HEALTHY only after a successful tunnel probe (SOCKS5/SSH). @@ -224,208 +219,3 @@ func (c *Checker) checkOne(inst *engine.Instance) { inst.SetState(engine.StateHealthy) } } - -func (c *Checker) probeSOCKS(inst *engine.Instance) (time.Duration, error) { - start := time.Now() - conn, err := net.DialTimeout("tcp", inst.Addr(), c.timeout) - if err != nil { - return 0, fmt.Errorf("tcp connect: %w", err) - } - defer conn.Close() - conn.SetDeadline(time.Now().Add(c.timeout)) - - _, err = conn.Write([]byte{0x05, 0x01, 0x00}) - if err != nil { - return 0, fmt.Errorf("socks5 write: %w", err) - } - resp := make([]byte, 2) - _, err = io.ReadFull(conn, resp) - if err != nil { - return 0, fmt.Errorf("socks5 read: %w", err) - } - if resp[0] != 0x05 { - return 0, fmt.Errorf("socks5 bad version: %d", resp[0]) - } - return time.Since(start), nil -} - -func (c *Checker) probeSSH(inst *engine.Instance) (time.Duration, error) { - start := time.Now() - conn, err := net.DialTimeout("tcp", inst.Addr(), c.timeout) - if err != nil { - return 0, fmt.Errorf("tcp connect: %w", err) - } - defer conn.Close() - conn.SetDeadline(time.Now().Add(c.timeout)) - - banner := make([]byte, 64) - n, err := conn.Read(banner) - if err != nil { - return 0, fmt.Errorf("ssh banner read: %w", err) - } - if n < 4 || string(banner[:4]) != "SSH-" { - return 0, fmt.Errorf("ssh bad banner: %q", string(banner[:n])) - } - return time.Since(start), nil -} - -// probeEndToEnd does a full SOCKS5 CONNECT through the tunnel to the health -// check target (port 80), sends an HTTP HEAD request, and verifies a response. -// This tests the entire path: instance → DNS tunnel → centralserver → SOCKS upstream → internet. -func (c *Checker) probeEndToEnd(inst *engine.Instance) (time.Duration, error) { - start := time.Now() - - conn, err := net.DialTimeout("tcp", inst.Addr(), c.timeout) - if err != nil { - return 0, fmt.Errorf("e2e connect: %w", err) - } - defer conn.Close() - conn.SetDeadline(time.Now().Add(c.timeout)) - - // SOCKS5 greeting (no auth) - if _, err := conn.Write([]byte{0x05, 0x01, 0x00}); err != nil { - return 0, fmt.Errorf("e2e socks greeting: %w", err) - } - greeting := make([]byte, 2) - if _, err := io.ReadFull(conn, greeting); err != nil { - return 0, fmt.Errorf("e2e socks greeting resp: %w", err) - } - if greeting[0] != 0x05 { - return 0, fmt.Errorf("e2e bad socks version: %d", greeting[0]) - } - - // SOCKS5 CONNECT to target:80 - domain := c.target - connectReq := make([]byte, 0, 4+1+len(domain)+2) - connectReq = append(connectReq, 0x05, 0x01, 0x00, 0x03) // VER CMD RSV ATYP(domain) - connectReq = append(connectReq, byte(len(domain))) // domain length - connectReq = append(connectReq, []byte(domain)...) // domain - portBuf := make([]byte, 2) - binary.BigEndian.PutUint16(portBuf, 80) - connectReq = append(connectReq, portBuf...) - - if _, err := conn.Write(connectReq); err != nil { - return 0, fmt.Errorf("e2e socks connect: %w", err) - } - - // Read CONNECT response (VER REP RSV ATYP) - connectResp := make([]byte, 4) - if _, err := io.ReadFull(conn, connectResp); err != nil { - return 0, fmt.Errorf("e2e socks connect resp: %w", err) - } - if connectResp[1] != 0x00 { - return 0, fmt.Errorf("e2e socks connect rejected: 0x%02x", connectResp[1]) - } - - // Drain bind address - switch connectResp[3] { - case 0x01: - io.ReadFull(conn, make([]byte, 4+2)) - case 0x03: - lb := make([]byte, 1) - io.ReadFull(conn, lb) - io.ReadFull(conn, make([]byte, int(lb[0])+2)) - case 0x04: - io.ReadFull(conn, make([]byte, 16+2)) - default: - io.ReadFull(conn, make([]byte, 4+2)) - } - - // Send HTTP HEAD request - httpReq := fmt.Sprintf("HEAD / HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", domain) - if _, err := conn.Write([]byte(httpReq)); err != nil { - return 0, fmt.Errorf("e2e http write: %w", err) - } - - // Read HTTP response (at least status line) - respBuf := make([]byte, 128) - n, err := conn.Read(respBuf) - if err != nil && n == 0 { - return 0, fmt.Errorf("e2e http read: %w", err) - } - if n < 12 || string(respBuf[:4]) != "HTTP" { - return 0, fmt.Errorf("e2e bad http response: %q", string(respBuf[:n])) - } - - return time.Since(start), nil -} - -// probeFramingProtocol tests if the instance's upstream speaks our framing protocol -// (i.e., is connected to centralserver). It sends a SYN frame and expects a valid -// frame response. Instances whose upstream is a plain SOCKS5 proxy will fail. -func (c *Checker) probeFramingProtocol(inst *engine.Instance) (time.Duration, error) { - start := time.Now() - - conn, err := net.DialTimeout("tcp", inst.Addr(), c.timeout) - if err != nil { - return 0, fmt.Errorf("frame probe connect: %w", err) - } - defer conn.Close() - - if tc, ok := conn.(*net.TCPConn); ok { - tc.SetNoDelay(true) - } - conn.SetDeadline(time.Now().Add(c.timeout)) - - // Build SYN targeting health_check.target:80 - domain := c.target - synPayload := make([]byte, 0, 1+1+len(domain)+2) - synPayload = append(synPayload, 0x03) // ATYP = domain - synPayload = append(synPayload, byte(len(domain))) // domain length - synPayload = append(synPayload, []byte(domain)...) // domain - synPayload = append(synPayload, 0x00, 0x50) // port 80 - - // Use a unique probe ConnID combining high-range prefix, instance ID, - // monotonic counter, and random bits to avoid collisions with real connections - // and previous probes that haven't been cleaned up yet. - seq := c.probeSeq.Add(1) - var rndBuf [2]byte - rand.Read(rndBuf[:]) - rnd := uint32(binary.BigEndian.Uint16(rndBuf[:])) - probeConnID := uint32(0xFE000000) | (uint32(inst.ID())&0xFF)<<16 | (seq&0xFF)<<8 | (rnd & 0xFF) - - synFrame := &tunnel.Frame{ - ConnID: probeConnID, - SeqNum: 0, - Flags: tunnel.FlagSYN, - Payload: synPayload, - } - - if err := tunnel.WriteFrame(conn, synFrame); err != nil { - return 0, fmt.Errorf("frame probe write SYN: %w", err) - } - - // Send DATA with HTTP HEAD so the target actually responds - httpReq := fmt.Sprintf("HEAD / HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", domain) - dataFrame := &tunnel.Frame{ - ConnID: probeConnID, - SeqNum: 1, - Flags: tunnel.FlagData, - Payload: []byte(httpReq), - } - if err := tunnel.WriteFrame(conn, dataFrame); err != nil { - return 0, fmt.Errorf("frame probe write DATA: %w", err) - } - - // Read response frame from centralserver. - // centralserver → connects target, forwards HTTP, target responds → reverse frame. - // plain SOCKS5 → can't parse frame → timeout/error. - respFrame, err := tunnel.ReadFrame(conn) - if err != nil { - return 0, fmt.Errorf("frame probe read: %w", err) - } - - if respFrame.ConnID != probeConnID { - return 0, fmt.Errorf("frame probe wrong connID: got %d, want %d", - respFrame.ConnID, probeConnID) - } - - // Valid frame = centralserver is there. Send FIN to clean up. - tunnel.WriteFrame(conn, &tunnel.Frame{ - ConnID: probeConnID, - SeqNum: 2, - Flags: tunnel.FlagFIN, - }) - - return time.Since(start), nil -} diff --git a/internal/health/probes.go b/internal/health/probes.go new file mode 100644 index 0000000..e614a2e --- /dev/null +++ b/internal/health/probes.go @@ -0,0 +1,218 @@ +package health + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "net" + "time" + + "github.com/ParsaKSH/SlipStream-Plus/internal/engine" + "github.com/ParsaKSH/SlipStream-Plus/internal/tunnel" +) + +func (c *Checker) probeSOCKS(inst *engine.Instance) (time.Duration, error) { + start := time.Now() + conn, err := net.DialTimeout("tcp", inst.Addr(), c.timeout) + if err != nil { + return 0, fmt.Errorf("tcp connect: %w", err) + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(c.timeout)) + + _, err = conn.Write([]byte{0x05, 0x01, 0x00}) + if err != nil { + return 0, fmt.Errorf("socks5 write: %w", err) + } + resp := make([]byte, 2) + _, err = io.ReadFull(conn, resp) + if err != nil { + return 0, fmt.Errorf("socks5 read: %w", err) + } + if resp[0] != 0x05 { + return 0, fmt.Errorf("socks5 bad version: %d", resp[0]) + } + return time.Since(start), nil +} + +func (c *Checker) probeSSH(inst *engine.Instance) (time.Duration, error) { + start := time.Now() + conn, err := net.DialTimeout("tcp", inst.Addr(), c.timeout) + if err != nil { + return 0, fmt.Errorf("tcp connect: %w", err) + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(c.timeout)) + + banner := make([]byte, 64) + n, err := conn.Read(banner) + if err != nil { + return 0, fmt.Errorf("ssh banner read: %w", err) + } + if n < 4 || string(banner[:4]) != "SSH-" { + return 0, fmt.Errorf("ssh bad banner: %q", string(banner[:n])) + } + return time.Since(start), nil +} + +// probeEndToEnd does a full SOCKS5 CONNECT through the tunnel to the health +// check target (port 80), sends an HTTP HEAD request, and verifies a response. +// This tests the entire path: instance → DNS tunnel → centralserver → SOCKS upstream → internet. +func (c *Checker) probeEndToEnd(inst *engine.Instance) (time.Duration, error) { + start := time.Now() + + conn, err := net.DialTimeout("tcp", inst.Addr(), c.timeout) + if err != nil { + return 0, fmt.Errorf("e2e connect: %w", err) + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(c.timeout)) + + // SOCKS5 greeting (no auth) + if _, err := conn.Write([]byte{0x05, 0x01, 0x00}); err != nil { + return 0, fmt.Errorf("e2e socks greeting: %w", err) + } + greeting := make([]byte, 2) + if _, err := io.ReadFull(conn, greeting); err != nil { + return 0, fmt.Errorf("e2e socks greeting resp: %w", err) + } + if greeting[0] != 0x05 { + return 0, fmt.Errorf("e2e bad socks version: %d", greeting[0]) + } + + // SOCKS5 CONNECT to target:80 + domain := c.target + connectReq := make([]byte, 0, 4+1+len(domain)+2) + connectReq = append(connectReq, 0x05, 0x01, 0x00, 0x03) // VER CMD RSV ATYP(domain) + connectReq = append(connectReq, byte(len(domain))) // domain length + connectReq = append(connectReq, []byte(domain)...) // domain + portBuf := make([]byte, 2) + binary.BigEndian.PutUint16(portBuf, 80) + connectReq = append(connectReq, portBuf...) + + if _, err := conn.Write(connectReq); err != nil { + return 0, fmt.Errorf("e2e socks connect: %w", err) + } + + // Read CONNECT response (VER REP RSV ATYP) + connectResp := make([]byte, 4) + if _, err := io.ReadFull(conn, connectResp); err != nil { + return 0, fmt.Errorf("e2e socks connect resp: %w", err) + } + if connectResp[1] != 0x00 { + return 0, fmt.Errorf("e2e socks connect rejected: 0x%02x", connectResp[1]) + } + + // Drain bind address + switch connectResp[3] { + case 0x01: + io.ReadFull(conn, make([]byte, 4+2)) + case 0x03: + lb := make([]byte, 1) + io.ReadFull(conn, lb) + io.ReadFull(conn, make([]byte, int(lb[0])+2)) + case 0x04: + io.ReadFull(conn, make([]byte, 16+2)) + default: + io.ReadFull(conn, make([]byte, 4+2)) + } + + // Send HTTP HEAD request + httpReq := fmt.Sprintf("HEAD / HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", domain) + if _, err := conn.Write([]byte(httpReq)); err != nil { + return 0, fmt.Errorf("e2e http write: %w", err) + } + + // Read HTTP response (at least status line) + respBuf := make([]byte, 128) + n, err := conn.Read(respBuf) + if err != nil && n == 0 { + return 0, fmt.Errorf("e2e http read: %w", err) + } + if n < 12 || string(respBuf[:4]) != "HTTP" { + return 0, fmt.Errorf("e2e bad http response: %q", string(respBuf[:n])) + } + + return time.Since(start), nil +} + +// probeFramingProtocol tests if the instance's upstream speaks our framing protocol +// (i.e., is connected to centralserver). It sends a SYN frame and expects a valid +// frame response. Instances whose upstream is a plain SOCKS5 proxy will fail. +func (c *Checker) probeFramingProtocol(inst *engine.Instance) (time.Duration, error) { + start := time.Now() + + conn, err := net.DialTimeout("tcp", inst.Addr(), c.timeout) + if err != nil { + return 0, fmt.Errorf("frame probe connect: %w", err) + } + defer conn.Close() + + if tc, ok := conn.(*net.TCPConn); ok { + tc.SetNoDelay(true) + } + conn.SetDeadline(time.Now().Add(c.timeout)) + + // Build SYN targeting health_check.target:80 + domain := c.target + synPayload := make([]byte, 0, 1+1+len(domain)+2) + synPayload = append(synPayload, 0x03) // ATYP = domain + synPayload = append(synPayload, byte(len(domain))) // domain length + synPayload = append(synPayload, []byte(domain)...) // domain + synPayload = append(synPayload, 0x00, 0x50) // port 80 + + // Use a unique probe ConnID combining high-range prefix, instance ID, + // monotonic counter, and random bits to avoid collisions with real connections + // and previous probes that haven't been cleaned up yet. + seq := c.probeSeq.Add(1) + var rndBuf [2]byte + rand.Read(rndBuf[:]) + rnd := uint32(binary.BigEndian.Uint16(rndBuf[:])) + probeConnID := uint32(0xFE000000) | (uint32(inst.ID())&0xFF)<<16 | (seq&0xFF)<<8 | (rnd & 0xFF) + + synFrame := &tunnel.Frame{ + ConnID: probeConnID, + SeqNum: 0, + Flags: tunnel.FlagSYN, + Payload: synPayload, + } + + if err := tunnel.WriteFrame(conn, synFrame); err != nil { + return 0, fmt.Errorf("frame probe write SYN: %w", err) + } + + // Send DATA with HTTP HEAD so the target actually responds + httpReq := fmt.Sprintf("HEAD / HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", domain) + dataFrame := &tunnel.Frame{ + ConnID: probeConnID, + SeqNum: 1, + Flags: tunnel.FlagData, + Payload: []byte(httpReq), + } + if err := tunnel.WriteFrame(conn, dataFrame); err != nil { + return 0, fmt.Errorf("frame probe write DATA: %w", err) + } + + // Read response frame from centralserver. + // centralserver → connects target, forwards HTTP, target responds → reverse frame. + // plain SOCKS5 → can't parse frame → timeout/error. + respFrame, err := tunnel.ReadFrame(conn) + if err != nil { + return 0, fmt.Errorf("frame probe read: %w", err) + } + + if respFrame.ConnID != probeConnID { + return 0, fmt.Errorf("frame probe wrong connID: got %d, want %d", + respFrame.ConnID, probeConnID) + } + + // Valid frame = centralserver is there. Send FIN to clean up. + tunnel.WriteFrame(conn, &tunnel.Frame{ + ConnID: probeConnID, + SeqNum: 2, + Flags: tunnel.FlagFIN, + }) + + return time.Since(start), nil +} diff --git a/internal/proxy/relay.go b/internal/proxy/relay.go new file mode 100644 index 0000000..f5d1d04 --- /dev/null +++ b/internal/proxy/relay.go @@ -0,0 +1,274 @@ +package proxy + +import ( + "context" + "encoding/binary" + "io" + "log" + "net" + "sync" + "time" + + "github.com/ParsaKSH/SlipStream-Plus/internal/engine" + "github.com/ParsaKSH/SlipStream-Plus/internal/tunnel" + "github.com/ParsaKSH/SlipStream-Plus/internal/users" +) + +// handlePacketSplit handles a connection using packet-level load balancing. +func (s *Server) handlePacketSplit(clientConn net.Conn, connID uint64, atyp byte, addrBytes, portBytes []byte, user *users.User) { + // Fast-fail if no tunnels are connected yet (e.g., right after restart) + if !s.tunnelPool.HasTunnels() { + clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + healthy := s.manager.HealthyInstances() + socksHealthy := make([]*engine.Instance, 0, len(healthy)) + for _, inst := range healthy { + if inst.Config.Mode != "ssh" { + socksHealthy = append(socksHealthy, inst) + } + } + if len(socksHealthy) == 0 { + clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + // Track connection count on ONE instance only (first healthy one). + // In packet_split, the single user connection is multiplexed across all + // instances, so incrementing all of them inflates the count by N×. + trackInst := socksHealthy[0] + trackInst.IncrConns() + defer trackInst.DecrConns() + + // Create a packet splitter for this connection + tunnelConnID := s.connIDGen.Next() + splitter := tunnel.NewPacketSplitter(tunnelConnID, s.tunnelPool, socksHealthy, s.chunkSize) + defer splitter.Close() + + // Send SYN with target address info + if err := splitter.SendSYN(atyp, addrBytes, portBytes); err != nil { + log.Printf("[proxy] conn#%d: SYN failed: %v", connID, err) + clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + // Tell client: connection successful + clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + log.Printf("[proxy] conn#%d: packet-split SYN sent, connID=%d", connID, tunnelConnID) + + port := binary.BigEndian.Uint16(portBytes) + log.Printf("[proxy] conn#%d: packet-split mode, %d instances, port %d", + connID, len(socksHealthy), port) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // When context is cancelled (either direction finished), close the client + // connection so that blocking Read/Write calls unblock immediately. + go func() { + <-ctx.Done() + clientConn.Close() + }() + + var txN, rxN int64 + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + if user != nil && user.NeedsRateLimit() { + wrapped := user.WrapReader(clientConn) + txN = splitter.RelayClientToUpstream(ctx, wrapped) + } else { + txN = splitter.RelayClientToUpstream(ctx, clientConn) + } + cancel() + }() + + wg.Add(1) + go func() { + defer wg.Done() + if user != nil && user.NeedsRateLimit() { + wrapped := user.WrapWriter(clientConn) + rxN = splitter.RelayUpstreamToClient(ctx, wrapped) + } else { + rxN = splitter.RelayUpstreamToClient(ctx, clientConn) + } + cancel() + }() + + wg.Wait() + log.Printf("[proxy] conn#%d: packet-split done, tx=%d rx=%d", connID, txN, rxN) + + // Track TX/RX on instances (distributed proportionally) + nInstances := int64(len(socksHealthy)) + if nInstances > 0 { + txPer := txN / nInstances + rxPer := rxN / nInstances + for _, inst := range socksHealthy { + inst.AddTx(txPer) + inst.AddRx(rxPer) + } + } + + // Track bytes for user data quota + if user != nil { + user.AddUsedBytes(txN + rxN) + } +} + +// handleConnectionLevel handles a connection using traditional per-connection load balancing. +func (s *Server) handleConnectionLevel(clientConn net.Conn, connID uint64, atyp byte, addrBytes, portBytes []byte, user *users.User) { + healthy := s.manager.HealthyInstances() + socksHealthy := make([]*engine.Instance, 0, len(healthy)) + for _, inst := range healthy { + if inst.Config.Mode != "ssh" { + socksHealthy = append(socksHealthy, inst) + } + } + if len(socksHealthy) == 0 { + clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + inst := s.balancer.Pick(socksHealthy) + if inst == nil { + clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + upstreamConn, err := inst.Dial() + if err != nil { + log.Printf("[proxy] conn#%d: dial instance %d failed: %v", connID, inst.ID(), err) + clientConn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + defer upstreamConn.Close() + + if tc, ok := upstreamConn.(*net.TCPConn); ok { + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(30 * time.Second) + tc.SetNoDelay(true) + tc.SetReadBuffer(s.bufferSize) + tc.SetWriteBuffer(s.bufferSize) + } + + // ──── Pipelined SOCKS5 negotiation with upstream ──── + pipelined := make([]byte, 0, 3+4+len(addrBytes)+2) + pipelined = append(pipelined, 0x05, 0x01, 0x00) // greeting + pipelined = append(pipelined, 0x05, 0x01, 0x00, atyp) // CONNECT header + pipelined = append(pipelined, addrBytes...) // target addr + pipelined = append(pipelined, portBytes...) // target port + if _, err := upstreamConn.Write(pipelined); err != nil { + clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + // Read greeting response (2 bytes) + CONNECT response header (4 bytes) = 6 bytes + resp := make([]byte, 6) + if _, err := io.ReadFull(upstreamConn, resp); err != nil { + clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + // Drain the CONNECT reply's bind address + port + repAtyp := resp[5] + switch repAtyp { + case 0x01: + io.ReadFull(upstreamConn, make([]byte, 4+2)) + case 0x03: + lenBuf := make([]byte, 1) + io.ReadFull(upstreamConn, lenBuf) + io.ReadFull(upstreamConn, make([]byte, int(lenBuf[0])+2)) + case 0x04: + io.ReadFull(upstreamConn, make([]byte, 16+2)) + default: + io.ReadFull(upstreamConn, make([]byte, 4+2)) + } + + if resp[3] != 0x00 { // CONNECT reply status + clientConn.Write([]byte{0x05, resp[3], 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + // ──── Success! Tell client and start relay ──── + clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + + inst.IncrConns() + defer inst.DecrConns() + + port := binary.BigEndian.Uint16(portBytes) + log.Printf("[proxy] conn#%d: connected via instance %d, port %d", connID, inst.ID(), port) + + s.relay(inst.ConnCtx, clientConn, upstreamConn, inst, user, connID) +} + +func (s *Server) relay(ctx context.Context, clientConn, upstreamConn net.Conn, inst *engine.Instance, user *users.User, connID uint64) { + // Wrap both connections with idle timeout so stuck connections are cleaned up. + idleClient := newIdleConn(clientConn, idleTimeout) + idleUpstream := newIdleConn(upstreamConn, idleTimeout) + + // If the instance is stopped (context cancelled), close both connections + // so the relay goroutines unblock and exit. + go func() { + select { + case <-ctx.Done(): + clientConn.Close() + upstreamConn.Close() + } + }() + + // Determine if we need rate-limited (wrapped) relay or can use zero-copy. + needsWrap := user != nil && user.NeedsRateLimit() + + var clientToUpstream, upstreamToClient int64 + done := make(chan struct{}, 2) + + go func() { + if needsWrap { + src := user.WrapReader(idleClient) + bufPtr := inst.BufPool.Get().(*[]byte) + n, _ := io.CopyBuffer(idleUpstream, src, *bufPtr) + inst.BufPool.Put(bufPtr) + clientToUpstream = n + } else { + // Zero-copy path with idle timeout + n, _ := io.Copy(idleUpstream, idleClient) + clientToUpstream = n + } + if tc, ok := upstreamConn.(*net.TCPConn); ok { + tc.CloseWrite() + } + done <- struct{}{} + }() + + go func() { + if needsWrap { + dst := user.WrapWriter(idleClient) + bufPtr := inst.BufPool.Get().(*[]byte) + n, _ := io.CopyBuffer(dst, idleUpstream, *bufPtr) + inst.BufPool.Put(bufPtr) + upstreamToClient = n + } else { + // Zero-copy path with idle timeout + n, _ := io.Copy(idleClient, idleUpstream) + upstreamToClient = n + } + if tc, ok := clientConn.(*net.TCPConn); ok { + tc.CloseWrite() + } + done <- struct{}{} + }() + + <-done + <-done + + inst.AddTx(clientToUpstream) + inst.AddRx(upstreamToClient) + + // Track bytes for user data quota (non-rate-limited path) + if user != nil && !needsWrap { + user.AddUsedBytes(clientToUpstream + upstreamToClient) + } +} diff --git a/internal/proxy/socks5.go b/internal/proxy/socks5.go index 6c01164..4b2b313 100644 --- a/internal/proxy/socks5.go +++ b/internal/proxy/socks5.go @@ -2,12 +2,10 @@ package proxy import ( "context" - "encoding/binary" "fmt" "io" "log" "net" - "sync" "sync/atomic" "time" @@ -235,265 +233,6 @@ func (s *Server) handleConnection(clientConn net.Conn, connID uint64) { } } -// handlePacketSplit handles a connection using packet-level load balancing. -func (s *Server) handlePacketSplit(clientConn net.Conn, connID uint64, atyp byte, addrBytes, portBytes []byte, user *users.User) { - // Fast-fail if no tunnels are connected yet (e.g., right after restart) - if !s.tunnelPool.HasTunnels() { - clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - return - } - - healthy := s.manager.HealthyInstances() - socksHealthy := make([]*engine.Instance, 0, len(healthy)) - for _, inst := range healthy { - if inst.Config.Mode != "ssh" { - socksHealthy = append(socksHealthy, inst) - } - } - if len(socksHealthy) == 0 { - clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - return - } - - // Track connection count on ONE instance only (first healthy one). - // In packet_split, the single user connection is multiplexed across all - // instances, so incrementing all of them inflates the count by N×. - trackInst := socksHealthy[0] - trackInst.IncrConns() - defer trackInst.DecrConns() - - // Create a packet splitter for this connection - tunnelConnID := s.connIDGen.Next() - splitter := tunnel.NewPacketSplitter(tunnelConnID, s.tunnelPool, socksHealthy, s.chunkSize) - defer splitter.Close() - - // Send SYN with target address info - if err := splitter.SendSYN(atyp, addrBytes, portBytes); err != nil { - log.Printf("[proxy] conn#%d: SYN failed: %v", connID, err) - clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - return - } - - // Tell client: connection successful - clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - log.Printf("[proxy] conn#%d: packet-split SYN sent, connID=%d", connID, tunnelConnID) - - port := binary.BigEndian.Uint16(portBytes) - log.Printf("[proxy] conn#%d: packet-split mode, %d instances, port %d", - connID, len(socksHealthy), port) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // When context is cancelled (either direction finished), close the client - // connection so that blocking Read/Write calls unblock immediately. - go func() { - <-ctx.Done() - clientConn.Close() - }() - - var txN, rxN int64 - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - if user != nil && user.NeedsRateLimit() { - wrapped := user.WrapReader(clientConn) - txN = splitter.RelayClientToUpstream(ctx, wrapped) - } else { - txN = splitter.RelayClientToUpstream(ctx, clientConn) - } - cancel() - }() - - wg.Add(1) - go func() { - defer wg.Done() - if user != nil && user.NeedsRateLimit() { - wrapped := user.WrapWriter(clientConn) - rxN = splitter.RelayUpstreamToClient(ctx, wrapped) - } else { - rxN = splitter.RelayUpstreamToClient(ctx, clientConn) - } - cancel() - }() - - wg.Wait() - log.Printf("[proxy] conn#%d: packet-split done, tx=%d rx=%d", connID, txN, rxN) - - // Track TX/RX on instances (distributed proportionally) - nInstances := int64(len(socksHealthy)) - if nInstances > 0 { - txPer := txN / nInstances - rxPer := rxN / nInstances - for _, inst := range socksHealthy { - inst.AddTx(txPer) - inst.AddRx(rxPer) - } - } - - // Track bytes for user data quota - if user != nil { - user.AddUsedBytes(txN + rxN) - } -} - -// handleConnectionLevel handles a connection using traditional per-connection load balancing. -func (s *Server) handleConnectionLevel(clientConn net.Conn, connID uint64, atyp byte, addrBytes, portBytes []byte, user *users.User) { - healthy := s.manager.HealthyInstances() - socksHealthy := make([]*engine.Instance, 0, len(healthy)) - for _, inst := range healthy { - if inst.Config.Mode != "ssh" { - socksHealthy = append(socksHealthy, inst) - } - } - if len(socksHealthy) == 0 { - clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - return - } - - inst := s.balancer.Pick(socksHealthy) - if inst == nil { - clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - return - } - - upstreamConn, err := inst.Dial() - if err != nil { - log.Printf("[proxy] conn#%d: dial instance %d failed: %v", connID, inst.ID(), err) - clientConn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - return - } - defer upstreamConn.Close() - - if tc, ok := upstreamConn.(*net.TCPConn); ok { - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(30 * time.Second) - tc.SetNoDelay(true) - tc.SetReadBuffer(s.bufferSize) - tc.SetWriteBuffer(s.bufferSize) - } - - // ──── Pipelined SOCKS5 negotiation with upstream ──── - pipelined := make([]byte, 0, 3+4+len(addrBytes)+2) - pipelined = append(pipelined, 0x05, 0x01, 0x00) // greeting - pipelined = append(pipelined, 0x05, 0x01, 0x00, atyp) // CONNECT header - pipelined = append(pipelined, addrBytes...) // target addr - pipelined = append(pipelined, portBytes...) // target port - if _, err := upstreamConn.Write(pipelined); err != nil { - clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - return - } - - // Read greeting response (2 bytes) + CONNECT response header (4 bytes) = 6 bytes - resp := make([]byte, 6) - if _, err := io.ReadFull(upstreamConn, resp); err != nil { - clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - return - } - - // Drain the CONNECT reply's bind address + port - repAtyp := resp[5] - switch repAtyp { - case 0x01: - io.ReadFull(upstreamConn, make([]byte, 4+2)) - case 0x03: - lenBuf := make([]byte, 1) - io.ReadFull(upstreamConn, lenBuf) - io.ReadFull(upstreamConn, make([]byte, int(lenBuf[0])+2)) - case 0x04: - io.ReadFull(upstreamConn, make([]byte, 16+2)) - default: - io.ReadFull(upstreamConn, make([]byte, 4+2)) - } - - if resp[3] != 0x00 { // CONNECT reply status - clientConn.Write([]byte{0x05, resp[3], 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - return - } - - // ──── Success! Tell client and start relay ──── - clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - - inst.IncrConns() - defer inst.DecrConns() - - port := binary.BigEndian.Uint16(portBytes) - log.Printf("[proxy] conn#%d: connected via instance %d, port %d", connID, inst.ID(), port) - - s.relay(inst.ConnCtx, clientConn, upstreamConn, inst, user, connID) -} - -func (s *Server) relay(ctx context.Context, clientConn, upstreamConn net.Conn, inst *engine.Instance, user *users.User, connID uint64) { - // Wrap both connections with idle timeout so stuck connections are cleaned up. - idleClient := newIdleConn(clientConn, idleTimeout) - idleUpstream := newIdleConn(upstreamConn, idleTimeout) - - // If the instance is stopped (context cancelled), close both connections - // so the relay goroutines unblock and exit. - go func() { - select { - case <-ctx.Done(): - clientConn.Close() - upstreamConn.Close() - } - }() - - // Determine if we need rate-limited (wrapped) relay or can use zero-copy. - needsWrap := user != nil && user.NeedsRateLimit() - - var clientToUpstream, upstreamToClient int64 - done := make(chan struct{}, 2) - - go func() { - if needsWrap { - src := user.WrapReader(idleClient) - bufPtr := inst.BufPool.Get().(*[]byte) - n, _ := io.CopyBuffer(idleUpstream, src, *bufPtr) - inst.BufPool.Put(bufPtr) - clientToUpstream = n - } else { - // Zero-copy path with idle timeout - n, _ := io.Copy(idleUpstream, idleClient) - clientToUpstream = n - } - if tc, ok := upstreamConn.(*net.TCPConn); ok { - tc.CloseWrite() - } - done <- struct{}{} - }() - - go func() { - if needsWrap { - dst := user.WrapWriter(idleClient) - bufPtr := inst.BufPool.Get().(*[]byte) - n, _ := io.CopyBuffer(dst, idleUpstream, *bufPtr) - inst.BufPool.Put(bufPtr) - upstreamToClient = n - } else { - // Zero-copy path with idle timeout - n, _ := io.Copy(idleClient, idleUpstream) - upstreamToClient = n - } - if tc, ok := clientConn.(*net.TCPConn); ok { - tc.CloseWrite() - } - done <- struct{}{} - }() - - <-done - <-done - - inst.AddTx(clientToUpstream) - inst.AddRx(upstreamToClient) - - // Track bytes for user data quota (non-rate-limited path) - if user != nil && !needsWrap { - user.AddUsedBytes(clientToUpstream + upstreamToClient) - } -} - func (s *Server) ActiveConnections() int64 { return s.activeConns.Load() } diff --git a/internal/tunnel/pool.go b/internal/tunnel/pool.go index 80d0cfc..6a5af85 100644 --- a/internal/tunnel/pool.go +++ b/internal/tunnel/pool.go @@ -6,7 +6,6 @@ import ( "io" "log" "net" - "strings" "sync" "sync/atomic" "time" @@ -14,30 +13,11 @@ import ( "github.com/ParsaKSH/SlipStream-Plus/internal/engine" ) -// writeTimeout prevents writes from blocking forever on stalled connections. -const writeTimeout = 10 * time.Second - -// staleThreshold: if we've sent data but haven't received anything -// in this long, the connection is considered half-dead. -const staleThreshold = 15 * time.Second - // ConnsPerInstance is the number of persistent connections to maintain // per healthy instance. When one dies, it's replaced on the next refresh. // Multiple connections provide redundancy — if one degrades, others serve traffic. const ConnsPerInstance = 8 -// TunnelConn wraps a persistent TCP connection to a single instance. -type TunnelConn struct { - inst *engine.Instance - mu sync.Mutex - conn net.Conn - writeMu sync.Mutex - closed bool - - lastRead atomic.Int64 // unix millis of last successful read - lastWrite atomic.Int64 // unix millis of last successful write -} - // TunnelPool manages multiple persistent connections per healthy instance. // All connections serve the same handler map, providing redundancy. type TunnelPool struct { @@ -341,38 +321,3 @@ func (p *TunnelPool) readLoop(tc *TunnelConn) { } } } - -func (tc *TunnelConn) writeFrame(f *Frame) error { - tc.writeMu.Lock() - defer tc.writeMu.Unlock() - - if tc.closed { - return fmt.Errorf("tunnel closed") - } - - tc.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - err := WriteFrame(tc.conn, f) - tc.conn.SetWriteDeadline(time.Time{}) - - if err == nil { - tc.lastWrite.Store(time.Now().UnixMilli()) - } - return err -} - -func (tc *TunnelConn) close() { - tc.mu.Lock() - defer tc.mu.Unlock() - if !tc.closed { - tc.closed = true - tc.conn.Close() - } -} - -func isClosedErr(err error) bool { - if err == nil { - return false - } - return strings.Contains(err.Error(), "use of closed network connection") || - strings.Contains(err.Error(), "connection reset by peer") -} diff --git a/internal/tunnel/reorder.go b/internal/tunnel/reorder.go new file mode 100644 index 0000000..b2e324b --- /dev/null +++ b/internal/tunnel/reorder.go @@ -0,0 +1,100 @@ +package tunnel + +import "time" + +// Reorderer buffers out-of-order frames and delivers them in sequence order. +// If a frame is missing for longer than gapTimeout, it is skipped to prevent +// permanent stalls from lost frames. +type Reorderer struct { + nextSeq uint32 + buffer map[uint32][]byte + gapTimeout time.Duration + waitingSince time.Time // when we first started waiting for nextSeq +} + +func NewReorderer() *Reorderer { + return &Reorderer{ + nextSeq: 0, + buffer: make(map[uint32][]byte), + gapTimeout: 2 * time.Second, + } +} + +func NewReordererAt(startSeq uint32) *Reorderer { + return &Reorderer{ + nextSeq: startSeq, + buffer: make(map[uint32][]byte), + gapTimeout: 2 * time.Second, + } +} + +func (r *Reorderer) Insert(seq uint32, data []byte) { + if seq < r.nextSeq { + return + } + r.buffer[seq] = data +} + +func (r *Reorderer) Next() []byte { + // Fast path: next seq is available + data, ok := r.buffer[r.nextSeq] + if ok { + delete(r.buffer, r.nextSeq) + r.nextSeq++ + r.waitingSince = time.Time{} // reset wait timer + return data + } + + // Nothing buffered at all — nothing to skip to + if len(r.buffer) == 0 { + r.waitingSince = time.Time{} + return nil + } + + // There are buffered frames but nextSeq is missing. + // Start or check the gap timer. + now := time.Now() + if r.waitingSince.IsZero() { + r.waitingSince = now + return nil + } + + if now.Sub(r.waitingSince) < r.gapTimeout { + return nil // still within grace period + } + + // Gap timeout expired — skip to the lowest available seq + r.skipToLowest() + r.waitingSince = time.Time{} + + data, ok = r.buffer[r.nextSeq] + if ok { + delete(r.buffer, r.nextSeq) + r.nextSeq++ + return data + } + return nil +} + +// skipToLowest advances nextSeq to the lowest seq number in the buffer. +func (r *Reorderer) skipToLowest() { + minSeq := r.nextSeq + found := false + for seq := range r.buffer { + if !found || seq < minSeq { + minSeq = seq + found = true + } + } + if found && minSeq > r.nextSeq { + r.nextSeq = minSeq + } +} + +func (r *Reorderer) Pending() int { + return len(r.buffer) +} + +func (r *Reorderer) SkipGap() { + r.nextSeq++ +} diff --git a/internal/tunnel/splitter.go b/internal/tunnel/splitter.go index eea6924..8db91f7 100644 --- a/internal/tunnel/splitter.go +++ b/internal/tunnel/splitter.go @@ -263,99 +263,3 @@ func (ps *PacketSplitter) pickInstanceExcluding(excludeID int) *engine.Instance return nil } -// Reorderer buffers out-of-order frames and delivers them in sequence order. -// If a frame is missing for longer than gapTimeout, it is skipped to prevent -// permanent stalls from lost frames. -type Reorderer struct { - nextSeq uint32 - buffer map[uint32][]byte - gapTimeout time.Duration - waitingSince time.Time // when we first started waiting for nextSeq -} - -func NewReorderer() *Reorderer { - return &Reorderer{ - nextSeq: 0, - buffer: make(map[uint32][]byte), - gapTimeout: 2 * time.Second, - } -} - -func NewReordererAt(startSeq uint32) *Reorderer { - return &Reorderer{ - nextSeq: startSeq, - buffer: make(map[uint32][]byte), - gapTimeout: 2 * time.Second, - } -} - -func (r *Reorderer) Insert(seq uint32, data []byte) { - if seq < r.nextSeq { - return - } - r.buffer[seq] = data -} - -func (r *Reorderer) Next() []byte { - // Fast path: next seq is available - data, ok := r.buffer[r.nextSeq] - if ok { - delete(r.buffer, r.nextSeq) - r.nextSeq++ - r.waitingSince = time.Time{} // reset wait timer - return data - } - - // Nothing buffered at all — nothing to skip to - if len(r.buffer) == 0 { - r.waitingSince = time.Time{} - return nil - } - - // There are buffered frames but nextSeq is missing. - // Start or check the gap timer. - now := time.Now() - if r.waitingSince.IsZero() { - r.waitingSince = now - return nil - } - - if now.Sub(r.waitingSince) < r.gapTimeout { - return nil // still within grace period - } - - // Gap timeout expired — skip to the lowest available seq - r.skipToLowest() - r.waitingSince = time.Time{} - - data, ok = r.buffer[r.nextSeq] - if ok { - delete(r.buffer, r.nextSeq) - r.nextSeq++ - return data - } - return nil -} - -// skipToLowest advances nextSeq to the lowest seq number in the buffer. -func (r *Reorderer) skipToLowest() { - minSeq := r.nextSeq - found := false - for seq := range r.buffer { - if !found || seq < minSeq { - minSeq = seq - found = true - } - } - if found && minSeq > r.nextSeq { - r.nextSeq = minSeq - } -} - -func (r *Reorderer) Pending() int { - return len(r.buffer) -} - -func (r *Reorderer) SkipGap() { - r.nextSeq++ -} diff --git a/internal/tunnel/tunnelconn.go b/internal/tunnel/tunnelconn.go new file mode 100644 index 0000000..0f5939a --- /dev/null +++ b/internal/tunnel/tunnelconn.go @@ -0,0 +1,66 @@ +package tunnel + +import ( + "fmt" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/ParsaKSH/SlipStream-Plus/internal/engine" +) + +// writeTimeout prevents writes from blocking forever on stalled connections. +const writeTimeout = 10 * time.Second + +// staleThreshold: if we've sent data but haven't received anything +// in this long, the connection is considered half-dead. +const staleThreshold = 15 * time.Second + +// TunnelConn wraps a persistent TCP connection to a single instance. +type TunnelConn struct { + inst *engine.Instance + mu sync.Mutex + conn net.Conn + writeMu sync.Mutex + closed bool + + lastRead atomic.Int64 // unix millis of last successful read + lastWrite atomic.Int64 // unix millis of last successful write +} + +func (tc *TunnelConn) writeFrame(f *Frame) error { + tc.writeMu.Lock() + defer tc.writeMu.Unlock() + + if tc.closed { + return fmt.Errorf("tunnel closed") + } + + tc.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + err := WriteFrame(tc.conn, f) + tc.conn.SetWriteDeadline(time.Time{}) + + if err == nil { + tc.lastWrite.Store(time.Now().UnixMilli()) + } + return err +} + +func (tc *TunnelConn) close() { + tc.mu.Lock() + defer tc.mu.Unlock() + if !tc.closed { + tc.closed = true + tc.conn.Close() + } +} + +func isClosedErr(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "use of closed network connection") || + strings.Contains(err.Error(), "connection reset by peer") +} diff --git a/internal/users/io.go b/internal/users/io.go new file mode 100644 index 0000000..18faba4 --- /dev/null +++ b/internal/users/io.go @@ -0,0 +1,90 @@ +package users + +import ( + "context" + "io" + "time" + + "golang.org/x/time/rate" +) + +type trackingReader struct { + r io.Reader + user *User +} + +func (r *trackingReader) Read(p []byte) (int, error) { + n, err := r.r.Read(p) + if n > 0 { + r.user.AddUsedBytes(int64(n)) + } + return n, err +} + +type trackingWriter struct { + w io.Writer + user *User +} + +func (w *trackingWriter) Write(p []byte) (int, error) { + n, err := w.w.Write(p) + if n > 0 { + w.user.AddUsedBytes(int64(n)) + } + return n, err +} + +type rateLimitedReader struct { + r io.Reader + limiter *rate.Limiter + user *User +} + +func (r *rateLimitedReader) Read(p []byte) (int, error) { + // Limit read size to burst for proper WaitN + burst := r.limiter.Burst() + if burst > 0 && len(p) > burst { + p = p[:burst] + } + n, err := r.r.Read(p) + if n > 0 { + r.user.AddUsedBytes(int64(n)) + // Block until tokens available — this creates backpressure + if waitErr := r.limiter.WaitN(context.Background(), n); waitErr != nil { + // If WaitN fails (shouldn't with Background()), sleep as fallback + time.Sleep(time.Duration(n) * time.Second / time.Duration(r.limiter.Limit())) + } + } + return n, err +} + +type rateLimitedWriter struct { + w io.Writer + limiter *rate.Limiter + user *User +} + +func (w *rateLimitedWriter) Write(p []byte) (int, error) { + total := 0 + burst := w.limiter.Burst() + for len(p) > 0 { + chunk := len(p) + if burst > 0 && chunk > burst { + chunk = burst + } + // Wait BEFORE writing — this is the correct rate-limiting approach + if waitErr := w.limiter.WaitN(context.Background(), chunk); waitErr != nil { + time.Sleep(time.Duration(chunk) * time.Second / time.Duration(w.limiter.Limit())) + } + n, err := w.w.Write(p[:chunk]) + total += n + if n > 0 { + w.user.AddUsedBytes(int64(n)) + } + if err != nil { + return total, err + } + p = p[n:] + } + return total, nil +} diff --git a/internal/users/manager.go b/internal/users/manager.go index 1e2bb69..04d33df 100644 --- a/internal/users/manager.go +++ b/internal/users/manager.go @@ -1,13 +1,8 @@ package users import ( - "context" - "fmt" - "io" "log" - "net" "sync" - "sync/atomic" "time" "golang.org/x/time/rate" @@ -15,20 +10,6 @@ import ( "github.com/ParsaKSH/SlipStream-Plus/internal/config" ) -// User represents a runtime user with usage tracking. -type User struct { - Config config.UserConfig - limiter *rate.Limiter // bandwidth rate limiter (nil = unlimited) - - usedBytes atomic.Int64 // total bytes consumed - dataLimit int64 // max bytes (0 = unlimited) - - mu sync.Mutex - activeIPs map[string]int // IP → active connection count - cooldownIPs map[string]time.Time // IP → disconnect time (for cooldown) - ipLimit int // max concurrent IPs (0 = unlimited) -} - // Manager handles user auth, rate limiting, quotas, and connection limits. type Manager struct { mu sync.RWMutex @@ -93,249 +74,3 @@ func (m *Manager) Authenticate(username, password string) (*User, bool) { } return u, true } - -// CheckConnect verifies a user can open a new connection from the given IP. -func (u *User) CheckConnect(clientIP string) string { - // Check data quota - if u.dataLimit > 0 && u.usedBytes.Load() >= u.dataLimit { - return fmt.Sprintf("data quota exceeded (%d bytes used of %d)", - u.usedBytes.Load(), u.dataLimit) - } - - // Check IP limit - if u.ipLimit <= 0 { - return "" - } - - u.mu.Lock() - defer u.mu.Unlock() - - now := time.Now() - cooldown := 10 * time.Second - - // Clean up expired cooldowns - for ip, disconnectTime := range u.cooldownIPs { - if now.Sub(disconnectTime) > cooldown { - delete(u.cooldownIPs, ip) - } - } - - // If this IP already has active connections, allow - if u.activeIPs[clientIP] > 0 { - return "" - } - - // If this IP is in cooldown, deny - if disconnectTime, inCooldown := u.cooldownIPs[clientIP]; inCooldown { - remaining := cooldown - now.Sub(disconnectTime) - return fmt.Sprintf("ip cooldown (%s remaining)", remaining.Round(time.Second)) - } - - // Count distinct active IPs - activeCount := len(u.activeIPs) - if activeCount >= u.ipLimit { - return fmt.Sprintf("ip limit reached (%d/%d active IPs)", activeCount, u.ipLimit) - } - - return "" -} - -// MarkConnect records that a connection from this IP is active. -func (u *User) MarkConnect(clientIP string) { - u.mu.Lock() - u.activeIPs[clientIP]++ - // Remove from cooldown if reconnecting - delete(u.cooldownIPs, clientIP) - u.mu.Unlock() -} - -// MarkDisconnect decrements active count; when zero, start cooldown. -func (u *User) MarkDisconnect(clientIP string) { - u.mu.Lock() - u.activeIPs[clientIP]-- - if u.activeIPs[clientIP] <= 0 { - delete(u.activeIPs, clientIP) - if u.ipLimit > 0 { - u.cooldownIPs[clientIP] = time.Now() - } - } - u.mu.Unlock() -} - -// AddUsedBytes adds to the total bytes consumed. -func (u *User) AddUsedBytes(n int64) { - u.usedBytes.Add(n) -} - -// UsedBytes returns total bytes consumed. -func (u *User) UsedBytes() int64 { - return u.usedBytes.Load() -} - -// ResetUsedBytes resets the data counter. -func (u *User) ResetUsedBytes() { - u.usedBytes.Store(0) -} - -// NeedsRateLimit returns true if this user has an active bandwidth limiter. -// When false, the proxy can use zero-copy (splice) relay for much better performance. -func (u *User) NeedsRateLimit() bool { - return u.limiter != nil -} - -// WrapReader wraps a reader with rate limiting and byte tracking for this user. -func (u *User) WrapReader(r io.Reader) io.Reader { - if u.limiter == nil { - return &trackingReader{r: r, user: u} - } - return &rateLimitedReader{r: r, limiter: u.limiter, user: u} -} - -// WrapWriter wraps a writer with rate limiting and byte tracking for this user. -func (u *User) WrapWriter(w io.Writer) io.Writer { - if u.limiter == nil { - return &trackingWriter{w: w, user: u} - } - return &rateLimitedWriter{w: w, limiter: u.limiter, user: u} -} - -// AllUsers returns all users in config insertion order. -func (m *Manager) AllUsers() []*User { - m.mu.RLock() - defer m.mu.RUnlock() - result := make([]*User, 0, len(m.ordering)) - for _, name := range m.ordering { - if u, ok := m.users[name]; ok { - result = append(result, u) - } - } - return result -} - -// GetUser returns a user by username. -func (m *Manager) GetUser(username string) *User { - m.mu.RLock() - defer m.mu.RUnlock() - return m.users[username] -} - -// UserStatus returns JSON-friendly status for a user. -type UserStatus struct { - Username string `json:"username"` - BandwidthLimit int `json:"bandwidth_limit"` - BandwidthUnit string `json:"bandwidth_unit"` - DataLimit int `json:"data_limit"` - DataUnit string `json:"data_unit"` - DataUsedBytes int64 `json:"data_used_bytes"` - IPLimit int `json:"ip_limit"` - ActiveIPs int `json:"active_ips"` -} - -func (u *User) Status() UserStatus { - u.mu.Lock() - activeCount := len(u.activeIPs) - u.mu.Unlock() - - return UserStatus{ - Username: u.Config.Username, - BandwidthLimit: u.Config.BandwidthLimit, - BandwidthUnit: u.Config.BandwidthUnit, - DataLimit: u.Config.DataLimit, - DataUnit: u.Config.DataUnit, - DataUsedBytes: u.UsedBytes(), - IPLimit: u.ipLimit, - ActiveIPs: activeCount, - } -} - -// --- Rate-limited I/O --- - -type rateLimitedReader struct { - r io.Reader - limiter *rate.Limiter - user *User -} - -func (r *rateLimitedReader) Read(p []byte) (int, error) { - // Limit read size to burst for proper WaitN - burst := r.limiter.Burst() - if burst > 0 && len(p) > burst { - p = p[:burst] - } - n, err := r.r.Read(p) - if n > 0 { - r.user.AddUsedBytes(int64(n)) - // Block until tokens available — this creates backpressure - if waitErr := r.limiter.WaitN(context.Background(), n); waitErr != nil { - // If WaitN fails (shouldn't with Background()), sleep as fallback - time.Sleep(time.Duration(n) * time.Second / time.Duration(r.limiter.Limit())) - } - } - return n, err -} - -type rateLimitedWriter struct { - w io.Writer - limiter *rate.Limiter - user *User -} - -func (w *rateLimitedWriter) Write(p []byte) (int, error) { - total := 0 - burst := w.limiter.Burst() - for len(p) > 0 { - chunk := len(p) - if burst > 0 && chunk > burst { - chunk = burst - } - // Wait BEFORE writing — this is the correct rate-limiting approach - if waitErr := w.limiter.WaitN(context.Background(), chunk); waitErr != nil { - time.Sleep(time.Duration(chunk) * time.Second / time.Duration(w.limiter.Limit())) - } - n, err := w.w.Write(p[:chunk]) - total += n - if n > 0 { - w.user.AddUsedBytes(int64(n)) - } - if err != nil { - return total, err - } - p = p[n:] - } - return total, nil -} - -type trackingReader struct { - r io.Reader - user *User -} - -func (r *trackingReader) Read(p []byte) (int, error) { - n, err := r.r.Read(p) - if n > 0 { - r.user.AddUsedBytes(int64(n)) - } - return n, err -} - -type trackingWriter struct { - w io.Writer - user *User -} - -func (w *trackingWriter) Write(p []byte) (int, error) { - n, err := w.w.Write(p) - if n > 0 { - w.user.AddUsedBytes(int64(n)) - } - return n, err -} - -// ExtractIP gets the IP portion from a net.Addr. -func ExtractIP(addr net.Addr) string { - host, _, err := net.SplitHostPort(addr.String()) - if err != nil { - return addr.String() - } - return host -} diff --git a/internal/users/user.go b/internal/users/user.go new file mode 100644 index 0000000..4382134 --- /dev/null +++ b/internal/users/user.go @@ -0,0 +1,191 @@ +package users + +import ( + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "golang.org/x/time/rate" + + "github.com/ParsaKSH/SlipStream-Plus/internal/config" +) + +// User represents a runtime user with usage tracking. +type User struct { + Config config.UserConfig + limiter *rate.Limiter // bandwidth rate limiter (nil = unlimited) + + usedBytes atomic.Int64 // total bytes consumed + dataLimit int64 // max bytes (0 = unlimited) + + mu sync.Mutex + activeIPs map[string]int // IP → active connection count + cooldownIPs map[string]time.Time // IP → disconnect time (for cooldown) + ipLimit int // max concurrent IPs (0 = unlimited) +} + +// CheckConnect verifies a user can open a new connection from the given IP. +func (u *User) CheckConnect(clientIP string) string { + // Check data quota + if u.dataLimit > 0 && u.usedBytes.Load() >= u.dataLimit { + return fmt.Sprintf("data quota exceeded (%d bytes used of %d)", + u.usedBytes.Load(), u.dataLimit) + } + + // Check IP limit + if u.ipLimit <= 0 { + return "" + } + + u.mu.Lock() + defer u.mu.Unlock() + + now := time.Now() + cooldown := 10 * time.Second + + // Clean up expired cooldowns + for ip, disconnectTime := range u.cooldownIPs { + if now.Sub(disconnectTime) > cooldown { + delete(u.cooldownIPs, ip) + } + } + + // If this IP already has active connections, allow + if u.activeIPs[clientIP] > 0 { + return "" + } + + // If this IP is in cooldown, deny + if disconnectTime, inCooldown := u.cooldownIPs[clientIP]; inCooldown { + remaining := cooldown - now.Sub(disconnectTime) + return fmt.Sprintf("ip cooldown (%s remaining)", remaining.Round(time.Second)) + } + + // Count distinct active IPs + activeCount := len(u.activeIPs) + if activeCount >= u.ipLimit { + return fmt.Sprintf("ip limit reached (%d/%d active IPs)", activeCount, u.ipLimit) + } + + return "" +} + +// MarkConnect records that a connection from this IP is active. +func (u *User) MarkConnect(clientIP string) { + u.mu.Lock() + u.activeIPs[clientIP]++ + // Remove from cooldown if reconnecting + delete(u.cooldownIPs, clientIP) + u.mu.Unlock() +} + +// MarkDisconnect decrements active count; when zero, start cooldown. +func (u *User) MarkDisconnect(clientIP string) { + u.mu.Lock() + u.activeIPs[clientIP]-- + if u.activeIPs[clientIP] <= 0 { + delete(u.activeIPs, clientIP) + if u.ipLimit > 0 { + u.cooldownIPs[clientIP] = time.Now() + } + } + u.mu.Unlock() +} + +// AddUsedBytes adds to the total bytes consumed. +func (u *User) AddUsedBytes(n int64) { + u.usedBytes.Add(n) +} + +// UsedBytes returns total bytes consumed. +func (u *User) UsedBytes() int64 { + return u.usedBytes.Load() +} + +// ResetUsedBytes resets the data counter. +func (u *User) ResetUsedBytes() { + u.usedBytes.Store(0) +} + +// NeedsRateLimit returns true if this user has an active bandwidth limiter. +// When false, the proxy can use zero-copy (splice) relay for much better performance. +func (u *User) NeedsRateLimit() bool { + return u.limiter != nil +} + +// WrapReader wraps a reader with rate limiting and byte tracking for this user. +func (u *User) WrapReader(r io.Reader) io.Reader { + if u.limiter == nil { + return &trackingReader{r: r, user: u} + } + return &rateLimitedReader{r: r, limiter: u.limiter, user: u} +} + +// WrapWriter wraps a writer with rate limiting and byte tracking for this user. +func (u *User) WrapWriter(w io.Writer) io.Writer { + if u.limiter == nil { + return &trackingWriter{w: w, user: u} + } + return &rateLimitedWriter{w: w, limiter: u.limiter, user: u} +} + +// UserStatus returns JSON-friendly status for a user. +type UserStatus struct { + Username string `json:"username"` + BandwidthLimit int `json:"bandwidth_limit"` + BandwidthUnit string `json:"bandwidth_unit"` + DataLimit int `json:"data_limit"` + DataUnit string `json:"data_unit"` + DataUsedBytes int64 `json:"data_used_bytes"` + IPLimit int `json:"ip_limit"` + ActiveIPs int `json:"active_ips"` +} + +func (u *User) Status() UserStatus { + u.mu.Lock() + activeCount := len(u.activeIPs) + u.mu.Unlock() + + return UserStatus{ + Username: u.Config.Username, + BandwidthLimit: u.Config.BandwidthLimit, + BandwidthUnit: u.Config.BandwidthUnit, + DataLimit: u.Config.DataLimit, + DataUnit: u.Config.DataUnit, + DataUsedBytes: u.UsedBytes(), + IPLimit: u.ipLimit, + ActiveIPs: activeCount, + } +} + +// AllUsers returns all users in config insertion order. +func (m *Manager) AllUsers() []*User { + m.mu.RLock() + defer m.mu.RUnlock() + result := make([]*User, 0, len(m.ordering)) + for _, name := range m.ordering { + if u, ok := m.users[name]; ok { + result = append(result, u) + } + } + return result +} + +// GetUser returns a user by username. +func (m *Manager) GetUser(username string) *User { + m.mu.RLock() + defer m.mu.RUnlock() + return m.users[username] +} + +// ExtractIP gets the IP portion from a net.Addr. +func ExtractIP(addr net.Addr) string { + host, _, err := net.SplitHostPort(addr.String()) + if err != nil { + return addr.String() + } + return host +}