diff --git a/internal/pdcp/writer.go b/internal/pdcp/writer.go index 3adf91b6..c9c364d0 100644 --- a/internal/pdcp/writer.go +++ b/internal/pdcp/writer.go @@ -65,7 +65,7 @@ func NewUploadWriterCallback(ctx context.Context, creds *pdcpauth.PDCPCredential u := &UploadWriter{ creds: creds, done: make(chan struct{}, 1), - data: make(chan *clients.Response, 8), // default buffer size + data: make(chan *clients.Response, 1000), // increased buffer size TeamID: "", } var err error @@ -91,7 +91,11 @@ func NewUploadWriterCallback(ctx context.Context, creds *pdcpauth.PDCPCredential // GetWriterCallback returns the writer callback func (u *UploadWriter) GetWriterCallback() func(*clients.Response) { return func(resp *clients.Response) { - u.data <- resp + select { + case u.data <- resp: + default: + gologger.Warning().Msgf("PDCP upload buffer full, skipping result") + } } } diff --git a/pkg/tlsx/tls/tls.go b/pkg/tlsx/tls/tls.go index c07a5ed2..2f15e66d 100644 --- a/pkg/tlsx/tls/tls.go +++ b/pkg/tlsx/tls/tls.go @@ -236,10 +236,12 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con conn := tls.Client(baseConn, baseCfg) - if err := conn.Handshake(); err == nil { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.options.Timeout)*time.Second) + if err := conn.HandshakeContext(ctx); err == nil { ciphersuite := conn.ConnectionState().CipherSuite enumeratedCiphers = append(enumeratedCiphers, tls.CipherSuiteName(ciphersuite)) } + cancel() _ = conn.Close() // close baseConn internally } return enumeratedCiphers, nil diff --git a/pkg/tlsx/ztls/ztls.go b/pkg/tlsx/ztls/ztls.go index a03b7267..f89265cd 100644 --- a/pkg/tlsx/ztls/ztls.go +++ b/pkg/tlsx/ztls/ztls.go @@ -257,10 +257,12 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con conn := tls.Client(baseConn, baseCfg) baseCfg.CipherSuites = []uint16{ztlsCiphers[v]} - if err := c.tlsHandshakeWithTimeout(conn, context.TODO()); err == nil { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.options.Timeout)*time.Second) + if err := c.tlsHandshakeWithTimeout(conn, ctx); err == nil { h1 := conn.GetHandshakeLog() enumeratedCiphers = append(enumeratedCiphers, h1.ServerHello.CipherSuite.String()) } + cancel() _ = conn.Close() // also closes baseConn internally } return enumeratedCiphers, nil @@ -323,17 +325,18 @@ func (c *Client) getConfig(hostname, ip, port string, options clients.ConnectOpt // tlsHandshakeWithCtx attempts tls handshake with given timeout func (c *Client) tlsHandshakeWithTimeout(tlsConn *tls.Conn, ctx context.Context) error { errChan := make(chan error, 1) - defer close(errChan) + + go func() { + errChan <- tlsConn.Handshake() + }() select { case <-ctx.Done(): return errorutil.NewWithTag("ztls", "timeout while attempting handshake") //nolint - case errChan <- tlsConn.Handshake(): - } - - err := <-errChan - if err == tls.ErrCertsOnly { - err = nil + case err := <-errChan: + if err == tls.ErrCertsOnly { + err = nil + } + return err } - return err }