From 2c6d256e9c8b17985bbae30644dbbd57491bba44 Mon Sep 17 00:00:00 2001 From: Benjamin Cane Date: Mon, 11 May 2026 09:13:32 -0700 Subject: [PATCH 1/3] Harden tests, file output paths, and CI lint/tooling --- .github/workflows/go.yaml | 6 ++-- .github/workflows/lint.yml | 7 ++-- .golangci.yml | 5 +++ README.md | 35 +++++++++--------- gencerts_test.go | 7 ++++ kpconfig.go | 2 +- testcerts.go | 53 ++++++++++++++------------- testcerts_test.go | 73 +++++++++++++++++++++++++------------- 8 files changed, 115 insertions(+), 73 deletions(-) create mode 100644 .golangci.yml diff --git a/.github/workflows/go.yaml b/.github/workflows/go.yaml index b813b5f..4906e58 100644 --- a/.github/workflows/go.yaml +++ b/.github/workflows/go.yaml @@ -15,13 +15,13 @@ jobs: os: [ubuntu-latest, macos-latest] steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: ^1.20 id: go - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Test run: | @@ -29,6 +29,6 @@ jobs: go test -v -race -covermode=atomic -coverprofile=coverage.out ./... - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b641583..ecab68e 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,12 +12,11 @@ jobs: name: golangci runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: golangci-lint - uses: golangci/golangci-lint-action@v2 + uses: golangci/golangci-lint-action@v8 with: - # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version - version: latest + version: v2.1.6 # Optional: working directory, useful for monorepos # working-directory: somedir diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..0762930 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,5 @@ +version: "2" +linters: + enable: + - misspell + - revive diff --git a/README.md b/README.md index bcb20ed..21ee79a 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Stop saving test certificates in your code repos. Start generating them in your ```go func TestFunc(t *testing.T) { // Create and write self-signed Certificate and Key to temporary files - cert, key, err := testcerts.GenerateToTempFile("/tmp/") + cert, key, err := testcerts.GenerateCertsToTempFile("/tmp/") if err != nil { // do something } @@ -33,20 +33,20 @@ func TestFunc(t *testing.T) { // Generate Certificate Authority ca := testcerts.NewCA() - go func() { - // Create a signed Certificate and Key for "localhost" - certs, err := ca.NewKeyPair("localhost") - if err != nil { - // do something - } + // Create a signed Certificate and Key for "localhost" + certs, err := ca.NewKeyPair("localhost") + if err != nil { + // do something + } - // Write certificates to a file - err = certs.ToFile("/tmp/cert", "/tmp/key") - if err { - // do something - } + // Write certificates to a file + err = certs.ToFile("/tmp/cert", "/tmp/key") + if err != nil { + // do something + } - // Start HTTP Listener + // Start HTTP Listener + go func() { err = http.ListenAndServeTLS("localhost:443", "/tmp/cert", "/tmp/key", someHandler) if err != nil { // do something @@ -54,9 +54,14 @@ func TestFunc(t *testing.T) { }() // Create a client with the self-signed CA + tlsConfig, err := certs.ConfigureTLSConfig(ca.GenerateTLSConfig()) + if err != nil { + // do something + } + client := &http.Client{ Transport: &http.Transport{ - TLSClientConfig: certs.ConfigureTLSConfig(ca.GenerateTLSConfig()), + TLSClientConfig: tlsConfig, }, } @@ -75,5 +80,3 @@ If you find a bug or have an idea for a feature, please open an issue or a pull testcerts is released under the MIT License. See [LICENSE](./LICENSE) for details. - - diff --git a/gencerts_test.go b/gencerts_test.go index 6c01c31..1875dfa 100644 --- a/gencerts_test.go +++ b/gencerts_test.go @@ -3,6 +3,7 @@ package testcerts import ( "os" "path/filepath" + "runtime" "testing" ) @@ -109,6 +110,9 @@ func TestGeneratingCertsToFile(t *testing.T) { }) t.Run("Testing the unhappy path for insufficient permissions", func(t *testing.T) { + if runtime.GOOS != "windows" && os.Geteuid() == 0 { + t.Skip("running as root bypasses directory permission checks") + } dir, err := os.MkdirTemp("", "permission-test") if err != nil { t.Errorf("Error creating temp directory - %s", err) @@ -167,6 +171,9 @@ func TestGenerateCertsToTempFile(t *testing.T) { }) t.Run("Testing the unhappy path for insufficient permissions when creating temp file", func(t *testing.T) { + if runtime.GOOS != "windows" && os.Geteuid() == 0 { + t.Skip("running as root bypasses directory permission checks") + } dir, err := os.MkdirTemp("", "permission-test") if err != nil { t.Errorf("Error creating temp directory - %s", err) diff --git a/kpconfig.go b/kpconfig.go index da0237c..aab66ff 100644 --- a/kpconfig.go +++ b/kpconfig.go @@ -52,7 +52,7 @@ func (c *KeyPairConfig) Validate() error { return nil } -// IPAddresses returns a list of IP addresses in Net.IP format. +// IPNetAddresses returns a list of IP addresses in Net.IP format. func (c *KeyPairConfig) IPNetAddresses() ([]net.IP, error) { var ips []net.IP for _, ip := range c.IPAddresses { diff --git a/testcerts.go b/testcerts.go index d9356ba..801560a 100644 --- a/testcerts.go +++ b/testcerts.go @@ -76,6 +76,8 @@ import ( "time" ) +const fileMode = 0640 + // CertificateAuthority represents a self-signed x509 certificate authority. type CertificateAuthority struct { cert *x509.Certificate @@ -210,32 +212,36 @@ func (ca *CertificateAuthority) CertPool() *x509.CertPool { // PrivateKey returns the private key of the CertificateAuthority. func (ca *CertificateAuthority) PrivateKey() []byte { + if ca == nil || ca.privateKey == nil { + return nil + } return pem.EncodeToMemory(ca.privateKey) } // PublicKey returns the public key of the CertificateAuthority. func (ca *CertificateAuthority) PublicKey() []byte { + if ca == nil || ca.publicKey == nil { + return nil + } return pem.EncodeToMemory(ca.publicKey) } -// ToFile saves the CertificateAuthority certificate and private key to the specified files. -// Returns an error if any file operation fails. -func (ca *CertificateAuthority) ToFile(certFile, keyFile string) error { - // Write Certificate - err := os.WriteFile(certFile, ca.PublicKey(), 0640) - if err != nil { +func writePairToFiles(certData []byte, certFile string, keyData []byte, keyFile string) error { + if err := os.WriteFile(certFile, certData, fileMode); err != nil { return fmt.Errorf("unable to create certificate file - %w", err) } - - // Write Key - err = os.WriteFile(keyFile, ca.PrivateKey(), 0640) - if err != nil { - return fmt.Errorf("unable to create certificate file - %w", err) + if err := os.WriteFile(keyFile, keyData, fileMode); err != nil { + return fmt.Errorf("unable to create key file - %w", err) } - return nil } +// ToFile saves the CertificateAuthority certificate and private key to the specified files. +// Returns an error if any file operation fails. +func (ca *CertificateAuthority) ToFile(certFile, keyFile string) error { + return writePairToFiles(ca.PublicKey(), certFile, ca.PrivateKey(), keyFile) +} + // ToTempFile saves the CertificateAuthority certificate and private key to temporary files. // The temporary files are created in the specified directory and have random names. func (ca *CertificateAuthority) ToTempFile(dir string) (cfh *os.File, kfh *os.File, err error) { @@ -287,30 +293,24 @@ func (kp *KeyPair) Cert() *x509.Certificate { // PrivateKey returns the private key of the KeyPair. func (kp *KeyPair) PrivateKey() []byte { + if kp == nil || kp.privateKey == nil { + return nil + } return pem.EncodeToMemory(kp.privateKey) } // PublicKey returns the public key of the KeyPair. func (kp *KeyPair) PublicKey() []byte { + if kp == nil || kp.publicKey == nil { + return nil + } return pem.EncodeToMemory(kp.publicKey) } // ToFile saves the KeyPair certificate and private key to the specified files. // Returns an error if any file operation fails. func (kp *KeyPair) ToFile(certFile, keyFile string) error { - // Write Certificate - err := os.WriteFile(certFile, kp.PublicKey(), 0640) - if err != nil { - return fmt.Errorf("unable to create certificate file - %w", err) - } - - // Write Key - err = os.WriteFile(keyFile, kp.PrivateKey(), 0640) - if err != nil { - return fmt.Errorf("unable to create key file - %w", err) - } - - return nil + return writePairToFiles(kp.PublicKey(), certFile, kp.PrivateKey(), keyFile) } // ToTempFile saves the KeyPair certificate and private key to temporary files. @@ -352,6 +352,9 @@ func (kp *KeyPair) ToTempFile(dir string) (cfh *os.File, kfh *os.File, err error // ConfigureTLSConfig will configure the tls.Config with the KeyPair certificate and private key. // The returned tls.Config can be used for a server or client. func (kp *KeyPair) ConfigureTLSConfig(tlsConfig *tls.Config) (*tls.Config, error) { + if tlsConfig == nil { + tlsConfig = &tls.Config{} + } cert, err := tls.X509KeyPair(kp.PublicKey(), kp.PrivateKey()) if err != nil { return nil, fmt.Errorf("could not create x509 key pair - %w", err) diff --git a/testcerts_test.go b/testcerts_test.go index 3b37c34..39ca5a5 100644 --- a/testcerts_test.go +++ b/testcerts_test.go @@ -1,17 +1,18 @@ package testcerts import ( + "context" "crypto/tls" "crypto/x509" "errors" "fmt" "math/big" + "net" "net/http" "os" "path/filepath" "strings" "testing" - "time" ) func TestCertsUsage(t *testing.T) { @@ -484,12 +485,18 @@ func TestFullFlow(t *testing.T) { } // Setup HTTP Server + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Error creating listener - %s", err) + } + t.Cleanup(func() { + _ = listener.Close() + }) + server := &http.Server{ - Addr: c.listenAddr + ":8443", Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, err := w.Write([]byte("Hello, World!")) - if err != nil { - t.Errorf("Error writing response - %s", err) + if _, writeErr := w.Write([]byte("Hello, World!")); writeErr != nil { + t.Errorf("Error writing response - %s", writeErr) } }), TLSConfig: serverTLSConfig, @@ -504,22 +511,22 @@ func TestFullFlow(t *testing.T) { t.Fatalf("Error writing certs to temp files - %s", err) } + serverErrCh := make(chan error, 1) go func() { // Start HTTP Listener - err = server.ListenAndServeTLS(certFile.Name(), keyFile.Name()) - if err != nil && err != http.ErrServerClosed { - t.Errorf("Listener returned error - %s", err) - } + serverErrCh <- server.ServeTLS(listener, certFile.Name(), keyFile.Name()) }() - // Wait for Listener to start - <-time.After(3 * time.Second) - // Setup HTTP Client + baseTransport := &http.Transport{ + TLSClientConfig: clientTLSConfig, + } + baseTransport.DialContext = func(ctx context.Context, network, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, listener.Addr().String()) + } client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: clientTLSConfig, - }, + Transport: baseTransport, } // Make an HTTPS request @@ -530,7 +537,15 @@ func TestFullFlow(t *testing.T) { for _, a := range addr { t.Run("Client Request to "+a, func(t *testing.T) { - rsp, err := client.Get("https://" + a + ":8443") + host := a + if strings.Contains(a, ":") { + host = "[" + a + "]" + } + req, reqErr := http.NewRequest(http.MethodGet, "https://"+host, nil) + if reqErr != nil { + t.Fatalf("could not create request: %v", reqErr) + } + rsp, err := client.Do(req) if err != nil && c.clientErr == nil { t.Fatalf("client returned unexpected error: %v", err) @@ -558,6 +573,13 @@ func TestFullFlow(t *testing.T) { } }) } + + if closeErr := server.Close(); closeErr != nil { + t.Errorf("error closing server: %v", closeErr) + } + if serveErr := <-serverErrCh; serveErr != nil && serveErr != http.ErrServerClosed { + t.Errorf("Listener returned error - %s", serveErr) + } }) } } @@ -570,18 +592,21 @@ func ExampleNewCA() { certs, err := ca.NewKeyPair("localhost") if err != nil { fmt.Printf("Error generating keypair - %s", err) + return } // Write the certificates to a file cert, key, err := certs.ToTempFile("") if err != nil { fmt.Printf("Error writing certs to temp files - %s", err) + return } // Setup Server TLS Config serverTLSConfig, err := certs.ConfigureTLSConfig(ca.GenerateTLSConfig()) if err != nil { fmt.Printf("Error configuring server TLS - %s", err) + return } // Require Valid Client Cert @@ -589,11 +614,10 @@ func ExampleNewCA() { // Create an HTTP Server server := &http.Server{ - Addr: "0.0.0.0:8443", + Addr: "127.0.0.1:8443", Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, err := w.Write([]byte("Hello, World!")) - if err != nil { - fmt.Printf("Error writing response - %s", err) + if _, writeErr := w.Write([]byte("Hello, World!")); writeErr != nil { + fmt.Printf("Error writing response - %s", writeErr) } }), TLSConfig: serverTLSConfig, @@ -609,14 +633,11 @@ func ExampleNewCA() { fmt.Printf("Listener returned error - %s", err) } }() - - // Wait for Listener to start - <-time.After(3 * time.Second) - // Client TLS Config clientTLSConfig, err := certs.ConfigureTLSConfig(ca.GenerateTLSConfig()) if err != nil { fmt.Printf("Error configuring client TLS - %s", err) + return } // Setup HTTP Client with Cert Pool @@ -630,7 +651,11 @@ func ExampleNewCA() { rsp, err := client.Get("https://localhost:8443") if err != nil { fmt.Printf("Client returned error - %s", err) + return } + defer func() { + _ = rsp.Body.Close() + }() // Print the response fmt.Println(rsp.Status) From ee7d1d3d1a8350cb5d8e077c73ddb00aad2038cd Mon Sep 17 00:00:00 2001 From: Benjamin Cane Date: Sat, 16 May 2026 13:19:42 -0700 Subject: [PATCH 2/3] Address review feedback in TLS tests and CA tempfile errors --- testcerts.go | 4 ++-- testcerts_test.go | 32 +++++++++++++------------------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/testcerts.go b/testcerts.go index 801560a..eefbbb8 100644 --- a/testcerts.go +++ b/testcerts.go @@ -263,7 +263,7 @@ func (ca *CertificateAuthority) ToTempFile(dir string) (cfh *os.File, kfh *os.Fi // Write Key kfh, err = os.CreateTemp(dir, "*.key") if err != nil { - return cfh, &os.File{}, fmt.Errorf("unable to create certificate file - %w", err) + return cfh, &os.File{}, fmt.Errorf("unable to create key file - %w", err) } defer func() { if closeErr := kfh.Close(); closeErr != nil { @@ -272,7 +272,7 @@ func (ca *CertificateAuthority) ToTempFile(dir string) (cfh *os.File, kfh *os.Fi }() _, err = kfh.Write(ca.PrivateKey()) if err != nil { - return cfh, kfh, fmt.Errorf("unable to create certificate file - %w", err) + return cfh, kfh, fmt.Errorf("unable to create key file - %w", err) } return cfh, kfh, nil diff --git a/testcerts_test.go b/testcerts_test.go index 39ca5a5..8880126 100644 --- a/testcerts_test.go +++ b/testcerts_test.go @@ -371,35 +371,31 @@ func TestKeyPairConfig(t *testing.T) { } type FullFlowTestCase struct { - name string - listenAddr string - domains []string - kpCfg KeyPairConfig - kpErr error - clientErr error + name string + domains []string + kpCfg KeyPairConfig + kpErr error + clientErr error } func TestFullFlow(t *testing.T) { tc := []FullFlowTestCase{ { - name: "Localhost Domain", - listenAddr: "0.0.0.0", - domains: []string{"localhost"}, - kpCfg: KeyPairConfig{}, - kpErr: nil, + name: "Localhost Domain", + domains: []string{"localhost"}, + kpCfg: KeyPairConfig{}, + kpErr: nil, }, { - name: "Localhost IP", - listenAddr: "0.0.0.0", + name: "Localhost IP", kpCfg: KeyPairConfig{ IPAddresses: []string{"127.0.0.1"}, }, kpErr: nil, }, { - name: "Localhost IP and Domain", - listenAddr: "0.0.0.0", + name: "Localhost IP and Domain", kpCfg: KeyPairConfig{ IPAddresses: []string{"127.0.0.1", "::1"}, Domains: []string{"localhost"}, @@ -407,8 +403,7 @@ func TestFullFlow(t *testing.T) { kpErr: nil, }, { - name: "Localhost IP, Domain, Serial Number, and Common Name", - listenAddr: "0.0.0.0", + name: "Localhost IP, Domain, Serial Number, and Common Name", kpCfg: KeyPairConfig{ IPAddresses: []string{"127.0.0.1", "::1"}, Domains: []string{"localhost"}, @@ -418,8 +413,7 @@ func TestFullFlow(t *testing.T) { kpErr: nil, }, { - name: "Expired certificate", - listenAddr: "0.0.0.0", + name: "Expired certificate", kpCfg: KeyPairConfig{ IPAddresses: []string{"127.0.0.1"}, Expired: true, From 9b2761ae3ef0bbaed27d890c6eae01e83aa296eb Mon Sep 17 00:00:00 2001 From: Benjamin Cane Date: Sun, 24 May 2026 15:47:33 -0700 Subject: [PATCH 3/3] =?UTF-8?q?fix:=20address=20PR=20review=20hardening=20?= =?UTF-8?q?notes=20=F0=9F=94=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Validate cert/key PEM before writing files, clean up orphaned certs on key write failure, and make ExampleNewCA avoid listener races. --- gencerts_test.go | 2 + kpconfig.go | 2 +- testcerts.go | 41 +++++++++++++- testcerts_test.go | 137 +++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 171 insertions(+), 11 deletions(-) diff --git a/gencerts_test.go b/gencerts_test.go index 1875dfa..e5afd78 100644 --- a/gencerts_test.go +++ b/gencerts_test.go @@ -111,6 +111,7 @@ func TestGeneratingCertsToFile(t *testing.T) { t.Run("Testing the unhappy path for insufficient permissions", func(t *testing.T) { if runtime.GOOS != "windows" && os.Geteuid() == 0 { + // CAP_DAC_OVERRIDE can also bypass this in some rootless Linux containers. t.Skip("running as root bypasses directory permission checks") } dir, err := os.MkdirTemp("", "permission-test") @@ -172,6 +173,7 @@ func TestGenerateCertsToTempFile(t *testing.T) { t.Run("Testing the unhappy path for insufficient permissions when creating temp file", func(t *testing.T) { if runtime.GOOS != "windows" && os.Geteuid() == 0 { + // CAP_DAC_OVERRIDE can also bypass this in some rootless Linux containers. t.Skip("running as root bypasses directory permission checks") } dir, err := os.MkdirTemp("", "permission-test") diff --git a/kpconfig.go b/kpconfig.go index aab66ff..42973c5 100644 --- a/kpconfig.go +++ b/kpconfig.go @@ -52,7 +52,7 @@ func (c *KeyPairConfig) Validate() error { return nil } -// IPNetAddresses returns a list of IP addresses in Net.IP format. +// IPNetAddresses returns a list of IP addresses in net.IP format. func (c *KeyPairConfig) IPNetAddresses() ([]net.IP, error) { var ips []net.IP for _, ip := range c.IPAddresses { diff --git a/testcerts.go b/testcerts.go index eefbbb8..f3efcdb 100644 --- a/testcerts.go +++ b/testcerts.go @@ -78,6 +78,20 @@ import ( const fileMode = 0640 +var ( + // ErrEmptyCertificateData is returned when certificate PEM data is empty. + ErrEmptyCertificateData = errors.New("empty certificate data") + + // ErrEmptyKeyData is returned when private key PEM data is empty. + ErrEmptyKeyData = errors.New("empty key data") + + // ErrInvalidCertificateData is returned when certificate data is not valid PEM. + ErrInvalidCertificateData = errors.New("invalid certificate data") + + // ErrInvalidKeyData is returned when private key data is not valid PEM. + ErrInvalidKeyData = errors.New("invalid key data") +) + // CertificateAuthority represents a self-signed x509 certificate authority. type CertificateAuthority struct { cert *x509.Certificate @@ -227,15 +241,39 @@ func (ca *CertificateAuthority) PublicKey() []byte { } func writePairToFiles(certData []byte, certFile string, keyData []byte, keyFile string) error { + if err := validatePEMData(certData, ErrEmptyCertificateData, ErrInvalidCertificateData); err != nil { + return err + } + if err := validatePEMData(keyData, ErrEmptyKeyData, ErrInvalidKeyData); err != nil { + return err + } + if err := os.WriteFile(certFile, certData, fileMode); err != nil { return fmt.Errorf("unable to create certificate file - %w", err) } if err := os.WriteFile(keyFile, keyData, fileMode); err != nil { + if removeErr := os.Remove(certFile); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) { + return errors.Join( + fmt.Errorf("unable to create key file - %w", err), + fmt.Errorf("unable to remove certificate file after key write failure - %w", removeErr), + ) + } return fmt.Errorf("unable to create key file - %w", err) } return nil } +func validatePEMData(data []byte, emptyErr, invalidErr error) error { + if len(data) == 0 { + return emptyErr + } + block, _ := pem.Decode(data) + if block == nil || len(block.Bytes) == 0 { + return invalidErr + } + return nil +} + // ToFile saves the CertificateAuthority certificate and private key to the specified files. // Returns an error if any file operation fails. func (ca *CertificateAuthority) ToFile(certFile, keyFile string) error { @@ -349,7 +387,8 @@ func (kp *KeyPair) ToTempFile(dir string) (cfh *os.File, kfh *os.File, err error return cfh, kfh, nil } -// ConfigureTLSConfig will configure the tls.Config with the KeyPair certificate and private key. +// ConfigureTLSConfig configures tlsConfig with the KeyPair certificate and private key. +// If tlsConfig is nil, it creates one. Otherwise, it mutates and returns the provided config. // The returned tls.Config can be used for a server or client. func (kp *KeyPair) ConfigureTLSConfig(tlsConfig *tls.Config) (*tls.Config, error) { if tlsConfig == nil { diff --git a/testcerts_test.go b/testcerts_test.go index 8880126..98d1647 100644 --- a/testcerts_test.go +++ b/testcerts_test.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "errors" "fmt" + "io" "math/big" "net" "net/http" @@ -119,6 +120,81 @@ func TestCertsUsage(t *testing.T) { }) }) + t.Run("Write Missing Data to File", func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatalf("Error creating temporary directory: %s", err) + } + t.Cleanup(func() { + _ = os.RemoveAll(tempDir) + }) + + certPath := filepath.Join(tempDir, "cert") + keyPath := filepath.Join(tempDir, "key") + + var emptyCA *CertificateAuthority + err = emptyCA.ToFile(certPath, keyPath) + if !errors.Is(err, ErrEmptyCertificateData) { + t.Fatalf("expected ErrEmptyCertificateData, got %v", err) + } + if _, statErr := os.Stat(certPath); !os.IsNotExist(statErr) { + t.Fatalf("expected no certificate file, got %v", statErr) + } + if _, statErr := os.Stat(keyPath); !os.IsNotExist(statErr) { + t.Fatalf("expected no key file, got %v", statErr) + } + }) + + t.Run("Reject Invalid File Data", func(t *testing.T) { + validCert := ca.PublicKey() + validKey := ca.PrivateKey() + for _, tc := range []struct { + name string + certData []byte + keyData []byte + wantErr error + }{ + { + name: "invalid cert", + certData: []byte("not pem"), + keyData: validKey, + wantErr: ErrInvalidCertificateData, + }, + { + name: "empty key", + certData: validCert, + keyData: nil, + wantErr: ErrEmptyKeyData, + }, + { + name: "invalid key", + certData: validCert, + keyData: []byte("not pem"), + wantErr: ErrInvalidKeyData, + }, + } { + t.Run(tc.name, func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatalf("Error creating temporary directory: %s", err) + } + t.Cleanup(func() { + _ = os.RemoveAll(tempDir) + }) + + err = writePairToFiles( + tc.certData, + filepath.Join(tempDir, "cert"), + tc.keyData, + filepath.Join(tempDir, "key"), + ) + if !errors.Is(err, tc.wantErr) { + t.Fatalf("expected %v, got %v", tc.wantErr, err) + } + }) + } + }) + t.Run("Write to Invalid TempFile", func(t *testing.T) { _, _, err := ca.ToTempFile("/notValidPath/") if err == nil { @@ -220,6 +296,27 @@ func TestCertsUsage(t *testing.T) { }) }) + t.Run("Remove Cert When Key Write Fails", func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatalf("Error creating temporary directory: %s", err) + } + t.Cleanup(func() { + _ = os.RemoveAll(tempDir) + }) + + certPath := filepath.Join(tempDir, "cert") + keyPath := filepath.Join(tempDir, "doesntexist", "key") + + err = kp.ToFile(certPath, keyPath) + if err == nil { + t.Fatalf("expected key write error, got nil") + } + if _, statErr := os.Stat(certPath); !os.IsNotExist(statErr) { + t.Fatalf("expected certificate file cleanup, got %v", statErr) + } + }) + t.Run("Write to Invalid TempFile", func(t *testing.T) { _, _, err := kp.ToTempFile("/notValidPath/") if err == nil { @@ -607,8 +704,16 @@ func ExampleNewCA() { serverTLSConfig.ClientAuth = tls.RequireAndVerifyClientCert // Create an HTTP Server + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + fmt.Printf("Error creating listener - %s", err) + return + } + defer func() { + _ = listener.Close() + }() + server := &http.Server{ - Addr: "127.0.0.1:8443", Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { if _, writeErr := w.Write([]byte("Hello, World!")); writeErr != nil { fmt.Printf("Error writing response - %s", writeErr) @@ -620,13 +725,12 @@ func ExampleNewCA() { _ = server.Close() }() + serverErrCh := make(chan error, 1) go func() { // Start HTTP Listener - err = server.ListenAndServeTLS(cert.Name(), key.Name()) - if err != nil && err != http.ErrServerClosed { - fmt.Printf("Listener returned error - %s", err) - } + serverErrCh <- server.ServeTLS(listener, cert.Name(), key.Name()) }() + // Client TLS Config clientTLSConfig, err := certs.ConfigureTLSConfig(ca.GenerateTLSConfig()) if err != nil { @@ -635,14 +739,19 @@ func ExampleNewCA() { } // Setup HTTP Client with Cert Pool + transport := &http.Transport{ + TLSClientConfig: clientTLSConfig, + } + transport.DialContext = func(ctx context.Context, network, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, listener.Addr().String()) + } client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: clientTLSConfig, - }, + Transport: transport, } // Make an HTTPS request - rsp, err := client.Get("https://localhost:8443") + rsp, err := client.Get("https://localhost") if err != nil { fmt.Printf("Client returned error - %s", err) return @@ -653,6 +762,16 @@ func ExampleNewCA() { // Print the response fmt.Println(rsp.Status) + _, _ = io.Copy(io.Discard, rsp.Body) + + if closeErr := server.Close(); closeErr != nil { + fmt.Printf("Error closing server - %s", closeErr) + return + } + if serveErr := <-serverErrCh; serveErr != nil && serveErr != http.ErrServerClosed { + fmt.Printf("Listener returned error - %s", serveErr) + return + } // Output: // 200 OK