diff --git a/cmd/centralserver/main.go b/cmd/centralserver/main.go index cc37e50..bb198e4 100644 --- a/cmd/centralserver/main.go +++ b/cmd/centralserver/main.go @@ -17,7 +17,20 @@ import ( "github.com/ParsaKSH/SlipStream-Plus/internal/tunnel" ) -//cc + +// sourceConn wraps a tunnel connection with a write mutex so that +// concurrent WriteFrame calls from different ConnIDs don't interleave bytes. +type sourceConn struct { + conn net.Conn + writeMu sync.Mutex +} + +func (sc *sourceConn) WriteFrame(f *tunnel.Frame) error { + sc.writeMu.Lock() + defer sc.writeMu.Unlock() + return tunnel.WriteFrame(sc.conn, f) +} + // connState tracks a single reassembled connection. type connState struct { mu sync.Mutex @@ -29,7 +42,7 @@ type connState struct { // Sources: all tunnel connections that can carry reverse data. // We round-robin responses across them (not broadcast). - sources []io.Writer + sources []*sourceConn sourceIdx int } @@ -39,6 +52,10 @@ type centralServer struct { mu sync.RWMutex conns map[uint32]*connState // ConnID → state + + // sourceMu protects the sources map (net.Conn → *sourceConn). + sourceMu sync.Mutex + sourceMap map[net.Conn]*sourceConn } func main() { @@ -54,6 +71,7 @@ func main() { cs := ¢ralServer{ socksUpstream: *socksUpstream, conns: make(map[uint32]*connState), + sourceMap: make(map[net.Conn]*sourceConn), } sigCh := make(chan os.Signal, 1) @@ -147,12 +165,15 @@ func (cs *centralServer) handleSOCKS5Passthrough(clientConn net.Conn, firstByte func (cs *centralServer) handleFrameConn(conn net.Conn, firstByte byte, remoteAddr string) { log.Printf("[central] frame connection from %s", remoteAddr) + sc := cs.getSourceConn(conn) + // Track which ConnIDs this source served servedIDs := make(map[uint32]bool) defer func() { - // Source TCP died — clean up connStates that only had this source - cs.cleanupSource(conn, servedIDs, remoteAddr) + // Source TCP died — clean up sourceConn and connStates + cs.removeSourceConn(conn) + cs.cleanupSource(sc, servedIDs, remoteAddr) }() // Read remaining header bytes (we already read 1) @@ -168,7 +189,7 @@ func (cs *centralServer) handleFrameConn(conn net.Conn, firstByte byte, remoteAd firstFrame := cs.parseHeader(fullHdr, conn, remoteAddr) if firstFrame != nil { servedIDs[firstFrame.ConnID] = true - cs.dispatchFrame(firstFrame, conn) + cs.dispatchFrame(firstFrame, sc) } for { @@ -180,13 +201,13 @@ func (cs *centralServer) handleFrameConn(conn net.Conn, firstByte byte, remoteAd return } servedIDs[frame.ConnID] = true - cs.dispatchFrame(frame, conn) + cs.dispatchFrame(frame, sc) } } // cleanupSource removes a dead source connection from all connStates. // If a connState has no remaining sources, it is fully cleaned up. -func (cs *centralServer) cleanupSource(deadSource net.Conn, servedIDs map[uint32]bool, remoteAddr string) { +func (cs *centralServer) cleanupSource(deadSource *sourceConn, servedIDs map[uint32]bool, remoteAddr string) { cs.mu.Lock() defer cs.mu.Unlock() @@ -227,6 +248,26 @@ func (cs *centralServer) cleanupSource(deadSource net.Conn, servedIDs map[uint32 } } +// getSourceConn returns the sourceConn wrapper for a raw net.Conn, +// creating one if it doesn't exist yet. +func (cs *centralServer) getSourceConn(conn net.Conn) *sourceConn { + cs.sourceMu.Lock() + defer cs.sourceMu.Unlock() + sc, ok := cs.sourceMap[conn] + if !ok { + sc = &sourceConn{conn: conn} + cs.sourceMap[conn] = sc + } + return sc +} + +// removeSourceConn removes the sourceConn wrapper when the raw conn dies. +func (cs *centralServer) removeSourceConn(conn net.Conn) { + cs.sourceMu.Lock() + delete(cs.sourceMap, conn) + cs.sourceMu.Unlock() +} + func isClosedConnErr(err error) bool { if err == nil { return false @@ -257,7 +298,7 @@ func (cs *centralServer) parseHeader(hdr [tunnel.HeaderSize]byte, conn net.Conn, } } -func (cs *centralServer) dispatchFrame(frame *tunnel.Frame, source net.Conn) { +func (cs *centralServer) dispatchFrame(frame *tunnel.Frame, source *sourceConn) { if frame.IsSYN() { cs.handleSYN(frame, source) return @@ -273,7 +314,7 @@ func (cs *centralServer) dispatchFrame(frame *tunnel.Frame, source net.Conn) { cs.handleData(frame, source) } -func (cs *centralServer) handleSYN(frame *tunnel.Frame, source net.Conn) { +func (cs *centralServer) handleSYN(frame *tunnel.Frame, source *sourceConn) { connID := frame.ConnID cs.mu.Lock() @@ -307,7 +348,7 @@ func (cs *centralServer) handleSYN(frame *tunnel.Frame, source net.Conn) { ctx, cancel := context.WithCancel(context.Background()) state := &connState{ reorderer: tunnel.NewReordererAt(frame.SeqNum + 1), // skip SYN's SeqNum - sources: []io.Writer{source}, + sources: []*sourceConn{source}, cancel: cancel, created: time.Now(), } @@ -315,11 +356,11 @@ func (cs *centralServer) handleSYN(frame *tunnel.Frame, source net.Conn) { cs.mu.Unlock() log.Printf("[central] conn=%d: SYN → target=%s", connID, targetAddr) - go cs.connectUpstream(ctx, connID, state, atyp, addr, port, targetAddr, source) + go cs.connectUpstream(ctx, connID, state, atyp, addr, port, targetAddr) } func (cs *centralServer) connectUpstream(ctx context.Context, connID uint32, state *connState, - atyp byte, addr, port []byte, targetAddr string, source net.Conn) { + atyp byte, addr, port []byte, targetAddr string) { upConn, err := net.DialTimeout("tcp", cs.socksUpstream, 10*time.Second) if err != nil { @@ -349,6 +390,7 @@ func (cs *centralServer) connectUpstream(ctx context.Context, connID uint32, sta return } + // Read greeting response (2 bytes) + CONNECT response header (4 bytes) resp := make([]byte, 6) if _, err := io.ReadFull(upConn, resp); err != nil { log.Printf("[central] conn=%d: upstream response read failed: %v", connID, err) @@ -358,7 +400,18 @@ func (cs *centralServer) connectUpstream(ctx context.Context, connID uint32, sta return } - // Drain bind address + // Check CONNECT result BEFORE draining bind address. + // resp[3] = REP field (0x00 = success). If non-zero, upstream may close + // without sending bind address, so don't try to drain it. + if resp[3] != 0x00 { + log.Printf("[central] conn=%d: upstream CONNECT rejected: 0x%02x", connID, resp[3]) + upConn.Close() + cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) + cs.removeConn(connID) + return + } + + // Drain bind address (only on success) switch resp[5] { case 0x01: io.ReadFull(upConn, make([]byte, 6)) @@ -372,14 +425,6 @@ func (cs *centralServer) connectUpstream(ctx context.Context, connID uint32, sta io.ReadFull(upConn, make([]byte, 6)) } - if resp[3] != 0x00 { - log.Printf("[central] conn=%d: upstream CONNECT rejected: 0x%02x", connID, resp[3]) - upConn.Close() - cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) - cs.removeConn(connID) - return - } - state.mu.Lock() state.target = upConn @@ -453,7 +498,8 @@ func (cs *centralServer) relayUpstreamToTunnel(ctx context.Context, connID uint3 } // sendFrame picks ONE source via round-robin and writes the frame. -// If that source fails, tries the next one. Much better than broadcasting. +// If that source fails, tries the next one. Uses sourceConn.WriteFrame +// which is mutex-protected per TCP connection, preventing interleaved writes. func (cs *centralServer) sendFrame(connID uint32, frame *tunnel.Frame) { cs.mu.RLock() state, ok := cs.conns[connID] @@ -473,9 +519,9 @@ func (cs *centralServer) sendFrame(connID uint32, frame *tunnel.Frame) { for tries := 0; tries < len(state.sources); tries++ { idx := state.sourceIdx % len(state.sources) state.sourceIdx++ - w := state.sources[idx] + sc := state.sources[idx] - if err := tunnel.WriteFrame(w, frame); err != nil { + if err := sc.WriteFrame(frame); err != nil { // Remove dead source state.sources = append(state.sources[:idx], state.sources[idx+1:]...) if state.sourceIdx > 0 { @@ -488,7 +534,7 @@ func (cs *centralServer) sendFrame(connID uint32, frame *tunnel.Frame) { log.Printf("[central] conn=%d: all sources failed", connID) } -func (cs *centralServer) handleData(frame *tunnel.Frame, source net.Conn) { +func (cs *centralServer) handleData(frame *tunnel.Frame, source *sourceConn) { cs.mu.RLock() state, ok := cs.conns[frame.ConnID] cs.mu.RUnlock() @@ -548,7 +594,10 @@ func (cs *centralServer) handleFIN(frame *tunnel.Frame) { state.target.Close() } state.mu.Unlock() - log.Printf("[central] conn=%d: FIN received", frame.ConnID) + + // Remove from map and cancel context so relayUpstreamToTunnel exits cleanly + cs.removeConn(frame.ConnID) + log.Printf("[central] conn=%d: FIN received, cleaned up", frame.ConnID) } func (cs *centralServer) handleRST(frame *tunnel.Frame) { @@ -601,10 +650,8 @@ func (cs *centralServer) cleanupLoop() { if len(state.sources) == 0 && now.Sub(state.created) > 30*time.Second { shouldClean = true } - // Connection too old (5 min max lifetime) - if now.Sub(state.created) > 5*time.Minute { - shouldClean = true - } + // No max lifetime — long-lived connections (downloads, streams) + // are valid. Cleanup only based on actual broken state above. state.mu.Unlock() if shouldClean { diff --git a/internal/gui/dashboard.go b/internal/gui/dashboard.go index a200d64..cde306b 100644 --- a/internal/gui/dashboard.go +++ b/internal/gui/dashboard.go @@ -224,8 +224,8 @@ canvas{width:100%;height:200px;border-radius:var(--rs);background:var(--bg2);bor