diff --git a/cmd/pki/create/create.go b/cmd/pki/create/create.go index 0af5455..a6b3556 100644 --- a/cmd/pki/create/create.go +++ b/cmd/pki/create/create.go @@ -32,8 +32,11 @@ package create import ( "context" + "crypto/x509" + "encoding/pem" "fmt" "io" + "os" aeCMD "github.com/aurae-runtime/ae/cmd" "github.com/aurae-runtime/ae/pkg/cli" @@ -48,6 +51,11 @@ type option struct { directory string domain string user string + caPath string + caKeyPath string + csrPath string + csr *pki.CertificateRequest + ca *pki.Certificate silent bool writer io.Writer } @@ -60,17 +68,97 @@ func (o *option) Complete(args []string) error { return fmt.Errorf("too many arguments for command 'create', expect %d, got %d", 1, len(args)) } + if o.caPath != "" { + b, err := os.ReadFile(o.caPath) + if err != nil { + return fmt.Errorf("failed to read ca certificate: %w", err) + } + + o.ca = &pki.Certificate{} + o.ca.Certificate = string(b) + + if o.caKeyPath != "" { + b, err := os.ReadFile(o.caKeyPath) + if err != nil { + return fmt.Errorf("failed to read ca private key: %w", err) + } + + o.ca.PrivateKey = string(b) + } else { + return fmt.Errorf("must provide --caKey and --csr when using --ca") + } + + if o.csrPath != "" { + b, err := os.ReadFile(o.csrPath) + if err != nil { + return fmt.Errorf("failed to read csr: %w", err) + } + + o.csr = &pki.CertificateRequest{} + o.csr.CSR = string(b) + } else { + return fmt.Errorf("must provide --caKey and --csr when using --ca") + } + } + + if o.csrPath != "" { + b, err := os.ReadFile(o.csrPath) + if err != nil { + return fmt.Errorf("failed to read csr: %w", err) + } + + o.csr = &pki.CertificateRequest{} + o.csr.CSR = string(b) + } + o.domain = args[0] return nil } func (o *option) Validate() error { + if o.caPath != "" { + caPem, _ := pem.Decode([]byte(o.ca.Certificate)) + _, err := x509.ParseCertificate(caPem.Bytes) + if err != nil { + return fmt.Errorf("could not parse ca file") + } + } + + if o.caKeyPath != "" { + caKeyPem, _ := pem.Decode([]byte(o.ca.PrivateKey)) + _, err := x509.ParsePKCS1PrivateKey(caKeyPem.Bytes) + if err != nil { + return fmt.Errorf("could not parse key file") + } + } + + if o.csrPath != "" { + csrPem, _ := pem.Decode([]byte(o.csr.CSR)) + _, err := x509.ParseCertificateRequest(csrPem.Bytes) + if err != nil { + return fmt.Errorf("could not parse csr file") + } + } + return nil } func (o *option) Execute(_ context.Context) error { if o.user != "" { + + if o.caPath != "" { + clientCrt, err := pki.CreateClientCertificate(o.directory, o.csr.CSR, o.ca, o.user) + if err != nil { + return fmt.Errorf("failed to create client certificate: %w", err) + } + if !o.silent { + o.outputFormat.ToPrinter().Print(o.writer, &clientCrt) + } + + return nil + } + clientCSR, err := pki.CreateClientCSR(o.directory, o.domain, o.user) if err != nil { return fmt.Errorf("failed to create client csr: %w", err) @@ -118,6 +206,9 @@ ae pki create --dir ./pki/ my.domain.com`, o.outputFormat.AddFlags(cmd) cmd.Flags().StringVarP(&o.directory, "dir", "d", o.directory, "Output directory to store CA files.") cmd.Flags().StringVarP(&o.user, "user", "u", o.user, "Creates client certificate for a given user.") + cmd.Flags().StringVar(&o.caPath, "ca", o.caPath, "Use the given CA certificate.") + cmd.Flags().StringVar(&o.caKeyPath, "caKey", o.caKeyPath, "The corresponding CA key.") + cmd.Flags().StringVar(&o.csrPath, "csr", o.csrPath, "CSR input file.") cmd.Flags().BoolVarP(&o.silent, "silent", "s", o.silent, "Silent mode, omits output") return cmd diff --git a/pkg/pki/pki.go b/pkg/pki/pki.go index a37ca0e..2daa55e 100644 --- a/pkg/pki/pki.go +++ b/pkg/pki/pki.go @@ -31,6 +31,7 @@ package pki import ( + "bytes" "crypto/rand" "crypto/rsa" "crypto/sha1" @@ -55,10 +56,10 @@ type CertificateRequest struct { User string `json:"user" yaml:"user"` } -func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) { +func createCA(domainName string) ([]byte, *rsa.PrivateKey, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return &Certificate{}, fmt.Errorf("failed to generate private key: %w", err) + return nil, nil, fmt.Errorf("failed to generate private key: %w", err) } subj := pkix.Name{ @@ -75,7 +76,7 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { - return &Certificate{}, fmt.Errorf("failed to generate serial number: %w", err) + return nil, nil, fmt.Errorf("failed to generate serial number: %w", err) } template := x509.Certificate{ @@ -107,22 +108,31 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) { crtBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { - return &Certificate{}, fmt.Errorf("failed to create certificate: %w", err) + return nil, nil, fmt.Errorf("failed to create certificate: %w", err) } - crtPem := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: crtBytes, - }) + return crtBytes, priv, nil +} - keyPem := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(priv), - }) +func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) { + crtBytes, priv, err := createCA(domainName) + if err != nil { + return nil, err + } + + crtPem, err := getPemBuffer(crtBytes, "CERTIFICATE") + if err != nil { + return nil, err + } + + keyPem, err := getPemBuffer(x509.MarshalPKCS1PrivateKey(priv), "RSA PRIVATE KEY") + if err != nil { + return nil, err + } ca := &Certificate{ - Certificate: string(crtPem), - PrivateKey: string(keyPem), + Certificate: crtPem.String(), + PrivateKey: keyPem.String(), } if path != "" { @@ -187,6 +197,108 @@ func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) { return csr, nil } +func CreateClientCertificate(path, csrStr string, ca *Certificate, user string) (*Certificate, error) { + csrPem, _ := pem.Decode([]byte(csrStr)) + if csrPem == nil || csrPem.Type != "CERTIFICATE REQUEST" { + return &Certificate{}, fmt.Errorf("failed to decode certificate request") + } + + csr, err := x509.ParseCertificateRequest(csrPem.Bytes) + if err != nil { + return &Certificate{}, fmt.Errorf("failed to parse certificate request: %w", err) + } + + caCrtPem, _ := pem.Decode([]byte(ca.Certificate)) + if caCrtPem == nil || caCrtPem.Type != "CERTIFICATE" { + return &Certificate{}, fmt.Errorf("failed to decode certificate") + } + + caCrt, err := x509.ParseCertificate(caCrtPem.Bytes) + if err != nil { + return &Certificate{}, fmt.Errorf("failed to parse certificate: %w", err) + } + + caPrivPem, _ := pem.Decode([]byte(ca.PrivateKey)) + if caPrivPem == nil || caPrivPem.Type != "RSA PRIVATE KEY" { + return &Certificate{}, fmt.Errorf("failed to decode private key") + } + + caPriv, err := x509.ParsePKCS1PrivateKey(caPrivPem.Bytes) + if err != nil { + return &Certificate{}, fmt.Errorf("failed to parse private key: %w", err) + } + + now := time.Now() + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return &Certificate{}, fmt.Errorf("failed to generate serial number: %w", err) + } + + template := x509.Certificate{ + Subject: csr.Subject, + NotBefore: now, + NotAfter: now.Add(24 * time.Hour * 9999), + SerialNumber: serialNumber, + IsCA: false, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageContentCommitment | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment, + } + + crtPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return &Certificate{}, fmt.Errorf("failed to generate private key: %w", err) + } + + // TODO: is this the correct Subject Key Identifier? + pubHash := sha1.Sum(crtPrivKey.PublicKey.N.Bytes()) + template.SubjectKeyId = pubHash[:] + + crtPrivPem := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(crtPrivKey), + }) + + clientCrtBytes, err := x509.CreateCertificate(rand.Reader, &template, caCrt, &crtPrivKey.PublicKey, caPriv) + if err != nil { + return &Certificate{}, fmt.Errorf("failed to create certificate: %w", err) + } + + clientCrtPem := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: clientCrtBytes, + }) + + clientCert := &Certificate{ + Certificate: string(clientCrtPem), + PrivateKey: string(crtPrivPem), + } + + // if path != "" { + // err = createCsrFiles(path, csr) + // if err != nil { + // return csr, err + // } + // } + + return clientCert, nil +} + +func getPemBuffer(b []byte, t string) (*bytes.Buffer, error) { + // var certBytes []byte + pemBuffer := bytes.NewBuffer([]byte{}) + err := pem.Encode(pemBuffer, &pem.Block{ + Type: t, + Bytes: b, + }) + if err != nil { + return &bytes.Buffer{}, fmt.Errorf("failed to write \"%s\" pem buffer of type: %w", t, err) + } + + return pemBuffer, nil +} + func createCAFiles(path string, ca *Certificate) error { path = filepath.Clean(path) err := os.MkdirAll(path, os.ModePerm)