From 3144e1856a627bf0658bf9f3acfbb003aade3746 Mon Sep 17 00:00:00 2001 From: Maksim Kazantsev Date: Wed, 20 May 2026 11:34:50 +0000 Subject: [PATCH] Pull request 2657: AGDNS-3720-add-tls-config-provider Squashed commit of the following: commit 1748706d70718ae68d64cb0b26d30be5c3635a8d Author: Maksim Kazantsev Date: Wed May 20 12:21:56 2026 +0300 all: imp docs; commit 90f314adeadd167765a0a86493877f042f4b9805 Author: Maksim Kazantsev Date: Tue May 19 20:02:09 2026 +0300 home: imp code; commit 76265a91fd138ee344acc644bc3a8cfbb0c458f9 Author: Maksim Kazantsev Date: Tue May 19 19:39:35 2026 +0300 all: add tls config provider; imp tests; --- internal/aghtest/interface.go | 26 +++++++ internal/aghtls/configprovider.go | 39 ++++++++++ internal/dnsforward/config.go | 21 +++-- internal/dnsforward/configvalidator.go | 9 ++- internal/home/config.go | 4 +- internal/home/controlupdate.go | 16 ++-- internal/home/dns.go | 69 +++++++++-------- internal/home/home.go | 14 ++-- internal/home/testdata/cert.pem | 14 ++++ internal/home/testdata/key.pem | 16 ++++ internal/home/tls.go | 89 +++++++++++---------- internal/home/tls_internal_test.go | 103 ++++++++++++------------- 12 files changed, 272 insertions(+), 148 deletions(-) create mode 100644 internal/aghtls/configprovider.go create mode 100644 internal/home/testdata/cert.pem create mode 100644 internal/home/testdata/key.pem diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index 4219321f25c..d2c40c2a01e 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -2,6 +2,8 @@ package aghtest import ( "context" + "crypto/tls" + "crypto/x509" "net/http" "net/netip" "time" @@ -9,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/AdGuardHome/internal/aghtls" nextagh "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/AdGuardHome/internal/whois" @@ -198,3 +201,26 @@ var _ aghhttp.Registrar = (*Registrar)(nil) func (m *Registrar) Register(method, path string, h http.HandlerFunc) { m.OnRegister(method, path, h) } + +// TLSConfigProvider is a fake [aghtls.TLSConfigProvider] implementation for +// tests. +// TODO(m.kazantsev): Use in tests. +type TLSConfigProvider struct { + OnTLSConfig func() (conf *tls.Config) + OnRootCAs func() (cert *x509.CertPool) +} + +// type check +var _ aghtls.TLSConfigProvider = (*TLSConfigProvider)(nil) + +// TLSConfig implements the [aghtls.TLSConfigProvider] interface for +// *TLSConfigProvider. +func (t *TLSConfigProvider) TLSConfig() (conf *tls.Config) { + return t.OnTLSConfig() +} + +// RootCAs implements the [aghtls.TLSConfigProvider] interface for +// *TLSConfigProvider. +func (t *TLSConfigProvider) RootCAs() (pool *x509.CertPool) { + return t.OnRootCAs() +} diff --git a/internal/aghtls/configprovider.go b/internal/aghtls/configprovider.go new file mode 100644 index 00000000000..77e58399b80 --- /dev/null +++ b/internal/aghtls/configprovider.go @@ -0,0 +1,39 @@ +package aghtls + +import ( + "crypto/tls" + "crypto/x509" +) + +// TLSConfigProvider provides TLS configuration to consumers. Implementations +// must be safe for concurrent use. +// +// TODO(m.kazantsev): Merge with the Manager interface. +// TODO(m.kazantsev): Add at least one real implementation. +type TLSConfigProvider interface { + // TLSConfig returns a clone of the current TLS configuration. conf + // provides its certificates via GetConfigForClient method. + TLSConfig() (conf *tls.Config) + + // RootCAs returns the current root CA pool. + RootCAs() (root *x509.CertPool) +} + +// type check +var _ TLSConfigProvider = EmptyTLSConfigProvider{} + +// EmptyTLSConfigProvider is the implementation of the [TLSConfigProvider] +// interface that does nothing. +type EmptyTLSConfigProvider struct{} + +// TLSConfig implements the [TLSConfigProvider] interface for +// EmptyTLSConfigProvider. It always returns nil. +func (EmptyTLSConfigProvider) TLSConfig() (conf *tls.Config) { + return nil +} + +// RootCAs implements the [TLSConfigProvider] interface for +// EmptyTLSConfigProvider. It always returns nil. +func (EmptyTLSConfigProvider) RootCAs() (root *x509.CertPool) { + return nil +} diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 46cc45d584a..83f5c55ce15 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -544,7 +544,12 @@ func (conf *ServerConfig) loadUpstreams( upstreams = stringutil.SplitTrimmed(string(data), "\n") - l.DebugContext(ctx, "got upstreams", "number", len(upstreams), "filename", conf.UpstreamDNSFileName) + l.DebugContext( + ctx, + "got upstreams", + "number", len(upstreams), + "filename", conf.UpstreamDNSFileName, + ) return stringutil.FilterOut(upstreams, aghnet.IsCommentOrEmpty), nil } @@ -652,7 +657,10 @@ func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) // ourAddrsSet returns an addrPortSet that contains all the configured listening // addresses. l must not be nil. -func (conf *ServerConfig) ourAddrsSet(ctx context.Context, l *slog.Logger) (m addrPortSet, err error) { +func (conf *ServerConfig) ourAddrsSet( + ctx context.Context, + l *slog.Logger, +) (m addrPortSet, err error) { addrs, unspecPorts := conf.collectDNSAddrs() switch { case addrs.Len() == 0: @@ -781,8 +789,9 @@ func anyNameMatches(dnsNames []string, sni string) (ok bool) { return false } -// Called by 'tls' package when Client Hello is received -// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake. +// onGetCertificate is called by [tls] package when Client Hello is received. If +// the server name (from SNI) supplied by client is incorrect - we terminate the +// ongoing TLS handshake. func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { if s.conf.TLSConf.StrictSNICheck && !anyNameMatches(s.dnsNames, ch.ServerName) { // TODO(s.chzhen): Pass context. @@ -798,8 +807,8 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er return s.conf.TLSConf.Cert, nil } -// preparePlain prepares the plain-DNS configuration for the DNS proxy. -// preparePlain assumes that prepareTLS has already been called. +// preparePlain prepares the plain-DNS configuration for the DNS proxy. The +// method assumes that prepareTLS has already been called. func (s *Server) preparePlain(ctx context.Context, proxyConf *proxy.Config) (err error) { if s.conf.ServePlainDNS { proxyConf.UDPListenAddr = s.conf.UDPListenAddrs diff --git a/internal/dnsforward/configvalidator.go b/internal/dnsforward/configvalidator.go index 4a362384f86..24abfcae7c4 100644 --- a/internal/dnsforward/configvalidator.go +++ b/internal/dnsforward/configvalidator.go @@ -90,7 +90,12 @@ func newUpstreamConfigValidator( // collectErrResults parses err and returns parsing results containing the // original upstream configuration line and the corresponding error. err can be // nil. l must not be nil. -func collectErrResults(ctx context.Context, l *slog.Logger, lines []string, err error) (results []*parseResult) { +func collectErrResults( + ctx context.Context, + l *slog.Logger, + lines []string, + err error, +) (results []*parseResult) { if err == nil { return nil } @@ -132,7 +137,7 @@ func collectErrResults(ctx context.Context, l *slog.Logger, lines []string, err } // insertConfResults parses conf and inserts the upstream result into results. -// It can insert multiple results as well as none. +// It can insert multiple results as well as none. conf must not be nil. func insertConfResults(conf *proxy.UpstreamConfig, results map[string]*upstreamResult) { insertListResults(conf.Upstreams, results, false) diff --git a/internal/home/config.go b/internal/home/config.go index 7f633c5070a..ed4e6008fdc 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -884,8 +884,8 @@ func (c *configuration) write( } if tlsMgr != nil { - tlsConf := tlsMgr.config() - config.TLS = *tlsConf + extTLSConf := tlsMgr.extendedTLSConfig() + config.TLS = *extTLSConf } if globalContext.stats != nil { diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index 96d844007b5..cb14ab62c93 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -178,6 +178,10 @@ type versionResponse struct { Disabled bool `json:"disabled"` } +// maxPrivilegedPort is the maximum port number. This only applies to Unix, as +// on Windows, [aghnet.CanBindPrivilegedPorts] always returns `true`, `nil`. +const maxPrivilegedPort = 1024 + // setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually // allowed to perform an automatic update by the OS. l and tlsMgr must not be // nil. @@ -191,9 +195,9 @@ func (vr *versionResponse) setAllowedToAutoUpdate( } canUpdate := true - if tlsConfUsesPrivilegedPorts(tlsMgr.config()) || - config.HTTPConfig.Address.Port() < 1024 || - config.DNS.Port < 1024 { + if tlsConfUsesPrivilegedPorts(tlsMgr.extendedTLSConfig()) || + config.HTTPConfig.Address.Port() < maxPrivilegedPort || + config.DNS.Port < maxPrivilegedPort { canUpdate, err = aghnet.CanBindPrivilegedPorts(ctx, l) if err != nil { return fmt.Errorf("checking ability to bind privileged ports: %w", err) @@ -206,9 +210,11 @@ func (vr *versionResponse) setAllowedToAutoUpdate( } // tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration -// indicates that privileged ports are used. +// indicates that privileged ports are used. c must be valid func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) { - return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024) + return c.Enabled && (c.PortHTTPS < maxPrivilegedPort || + c.PortDNSOverTLS < maxPrivilegedPort || + c.PortDNSOverQUIC < maxPrivilegedPort) } // finishUpdate completes an update procedure. It is intended to be used as a diff --git a/internal/home/dns.go b/internal/home/dns.go index dc776088d78..fbe0d388901 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -170,7 +170,7 @@ func initDNSServer( dnsConf, err := newServerConfig( &config.DNS, config.Clients.Sources, - tlsMgr.config(), + tlsMgr.extendedTLSConfig(), config.HTTPConfig.DoH, tlsMgr, httpReg, @@ -212,7 +212,8 @@ func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) { } } -func isRunning() bool { +// isRunning checks whether the DNS server is running. +func isRunning() (ok bool) { return globalContext.dnsServer != nil && globalContext.dnsServer.IsRunning() } @@ -262,7 +263,7 @@ func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) { func newServerConfig( dnsConf *dnsConfig, clientSrcConf *clientSourcesConfig, - tlsConf *tlsConfigSettings, + extTLSConf *tlsConfigSettings, dohConf *doHConfig, tlsMgr *tlsManager, httpReg aghhttp.Registrar, @@ -274,7 +275,7 @@ func newServerConfig( fwdConf := dnsConf.Config fwdConf.ClientsContainer = clientsContainer - intTLSConf, err := newDNSTLSConfig(tlsConf, hosts, dohConf.InsecureEnabled) + intTLSConf, err := newDNSTLSConfig(extTLSConf, hosts, dohConf.InsecureEnabled) if err != nil { return nil, fmt.Errorf("constructing tls config: %w", err) } @@ -322,19 +323,19 @@ func newServerConfig( } // newDNSTLSConfig converts values from the configuration file into the internal -// TLS settings for the DNS server. conf must not be nil. +// TLS settings for the DNS server. extTLSConf must not be nil. func newDNSTLSConfig( - conf *tlsConfigSettings, + extTLSConf *tlsConfigSettings, addrs []netip.Addr, allowUnencryptedDoH bool, ) (dnsConf *dnsforward.TLSConfig, err error) { - if !conf.Enabled { + if !extTLSConf.Enabled { return &dnsforward.TLSConfig{}, nil } // TODO(e.burkov): Add tracking for DNSCrypt configuration file changes to // the [aghtls.Manager]. - dnsCryptConf, err := newDNSCryptConfig(conf, addrs) + dnsCryptConf, err := newDNSCryptConfig(extTLSConf, addrs) if err != nil { // Don't wrap the error, because it's informative enough as is. return nil, err @@ -342,23 +343,23 @@ func newDNSTLSConfig( dnsConf = &dnsforward.TLSConfig{ DNSCryptConf: dnsCryptConf, - ServerName: conf.ServerName, - StrictSNICheck: conf.StrictSNICheck, + ServerName: extTLSConf.ServerName, + StrictSNICheck: extTLSConf.StrictSNICheck, } - if conf.PortHTTPS != 0 { - dnsConf.HTTPSListenAddrs = ipsToAddrPorts(addrs, conf.PortHTTPS) + if extTLSConf.PortHTTPS != 0 { + dnsConf.HTTPSListenAddrs = ipsToAddrPorts(addrs, extTLSConf.PortHTTPS) } - if conf.PortDNSOverTLS != 0 { - dnsConf.TLSListenAddrs = ipsToTCPAddrs(addrs, conf.PortDNSOverTLS) + if extTLSConf.PortDNSOverTLS != 0 { + dnsConf.TLSListenAddrs = ipsToTCPAddrs(addrs, extTLSConf.PortDNSOverTLS) } - if conf.PortDNSOverQUIC != 0 { - dnsConf.QUICListenAddrs = ipsToUDPAddrs(addrs, conf.PortDNSOverQUIC) + if extTLSConf.PortDNSOverQUIC != 0 { + dnsConf.QUICListenAddrs = ipsToUDPAddrs(addrs, extTLSConf.PortDNSOverQUIC) } - cert, err := tls.X509KeyPair(conf.CertificateChainData, conf.PrivateKeyData) + cert, err := tls.X509KeyPair(extTLSConf.CertificateChainData, extTLSConf.PrivateKeyData) if err != nil { err = fmt.Errorf("parsing tls key pair: %w", err) if allowUnencryptedDoH || dnsCryptConf != nil { @@ -378,20 +379,20 @@ func newDNSTLSConfig( } // newDNSCryptConfig converts values from the configuration file into the -// internal DNSCrypt settings for the DNS server. conf must not be nil. +// internal DNSCrypt settings for the DNS server. extTLSConf must not be nil. func newDNSCryptConfig( - conf *tlsConfigSettings, + extTLSConf *tlsConfigSettings, addrs []netip.Addr, ) (dnsCryptConf *dnsforward.DNSCryptConfig, err error) { - if conf.PortDNSCrypt == 0 { + if extTLSConf.PortDNSCrypt == 0 { return nil, nil } - if conf.DNSCryptConfigFile == "" { + if extTLSConf.DNSCryptConfigFile == "" { return nil, fmt.Errorf("dnscrypt_config_file: %w", errors.ErrEmptyValue) } - f, err := os.Open(conf.DNSCryptConfigFile) + f, err := os.Open(extTLSConf.DNSCryptConfigFile) if err != nil { return nil, fmt.Errorf("opening dnscrypt config: %w", err) } @@ -410,8 +411,8 @@ func newDNSCryptConfig( return &dnsforward.DNSCryptConfig{ ResolverCert: cert, - UDPListenAddrs: ipsToUDPAddrs(addrs, conf.PortDNSCrypt), - TCPListenAddrs: ipsToTCPAddrs(addrs, conf.PortDNSCrypt), + UDPListenAddrs: ipsToUDPAddrs(addrs, extTLSConf.PortDNSCrypt), + TCPListenAddrs: ipsToTCPAddrs(addrs, extTLSConf.PortDNSCrypt), ProviderName: rc.ProviderName, }, nil } @@ -426,16 +427,16 @@ type dnsEncryption struct { // getDNSEncryption returns the TLS encryption addresses that AdGuard Home // listens on. tlsMgr must not be nil. func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) { - tlsConf := tlsMgr.config() + extTLSConf := tlsMgr.extendedTLSConfig() - if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 { + if !extTLSConf.Enabled || extTLSConf.ServerName == "" { return dnsEncryption{} } - hostname := tlsConf.ServerName - if tlsConf.PortHTTPS != 0 { + hostname := extTLSConf.ServerName + if extTLSConf.PortHTTPS != 0 { addr := hostname - if p := tlsConf.PortHTTPS; p != defaultPortHTTPS { + if p := extTLSConf.PortHTTPS; p != defaultPortHTTPS { addr = netutil.JoinHostPort(addr, p) } @@ -446,14 +447,14 @@ func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) { }).String() } - if p := tlsConf.PortDNSOverTLS; p != 0 { + if p := extTLSConf.PortDNSOverTLS; p != 0 { de.tls = (&url.URL{ Scheme: "tls", Host: netutil.JoinHostPort(hostname, p), }).String() } - if p := tlsConf.PortDNSOverQUIC; p != 0 { + if p := extTLSConf.PortDNSOverQUIC; p != 0 { de.quic = (&url.URL{ Scheme: "quic", Host: netutil.JoinHostPort(hostname, p), @@ -463,7 +464,9 @@ func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) { return de } -func startDNSServer() error { +// startDNSServer starts the DNS server, clients container, filters, stats and +// the query log. +func startDNSServer() (err error) { config.RLock() defer config.RUnlock() @@ -475,7 +478,7 @@ func startDNSServer() error { // TODO(s.chzhen): Pass context. ctx := context.TODO() - err := globalContext.clients.Start(ctx) + err = globalContext.clients.Start(ctx) if err != nil { return fmt.Errorf("starting clients container: %w", err) } diff --git a/internal/home/home.go b/internal/home/home.go index e60d4c195f4..b99caf0b996 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -1269,22 +1269,24 @@ func printWebAddrs(ctx context.Context, l *slog.Logger, proto, addr string, port } // printHTTPAddresses prints the IP addresses which user can use to access the -// admin interface. proto is either schemeHTTP or schemeHTTPS. +// admin interface. proto is either [urlutil.SchemeHTTPS] or +// [urlutil.SchemeHTTP]. l must not be nil. If proto is [urlutil.SchemeHTTPS], +// then tlsMgr must not be nil. // // TODO(s.chzhen): Implement separate functions for HTTP and HTTPS. func printHTTPAddresses(ctx context.Context, l *slog.Logger, proto string, tlsMgr *tlsManager) { - var tlsConf *tlsConfigSettings + var extTLSConf *tlsConfigSettings if tlsMgr != nil { - tlsConf = tlsMgr.config() + extTLSConf = tlsMgr.extendedTLSConfig() } port := config.HTTPConfig.Address.Port() if proto == urlutil.SchemeHTTPS { - port = tlsConf.PortHTTPS + port = extTLSConf.PortHTTPS } - if proto == urlutil.SchemeHTTPS && tlsConf.ServerName != "" { - printWebAddrs(ctx, l, proto, tlsConf.ServerName, tlsConf.PortHTTPS) + if proto == urlutil.SchemeHTTPS && extTLSConf.ServerName != "" { + printWebAddrs(ctx, l, proto, extTLSConf.ServerName, extTLSConf.PortHTTPS) return } diff --git a/internal/home/testdata/cert.pem b/internal/home/testdata/cert.pem new file mode 100644 index 00000000000..c1970967648 --- /dev/null +++ b/internal/home/testdata/cert.pem @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV +BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3 +MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV +MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB +gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW +QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6 +i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV +HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4 +eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8 +LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna +b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq +Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ== +-----END CERTIFICATE----- diff --git a/internal/home/testdata/key.pem b/internal/home/testdata/key.pem new file mode 100644 index 00000000000..0094ccb4c07 --- /dev/null +++ b/internal/home/testdata/key.pem @@ -0,0 +1,16 @@ +-----BEGIN PRIVATE KEY----- +MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p +aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV +FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX +xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP +QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU +QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL +MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72 +9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg +a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj +FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg +An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp +O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG +O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU +kXS9jgARhhiWXJrk +-----END PRIVATE KEY----- diff --git a/internal/home/tls.go b/internal/home/tls.go index ffabd18d907..ac5071fbd89 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -36,7 +36,7 @@ type tlsManager struct { // logger is used for logging the operation of the TLS Manager. logger *slog.Logger - // mu protects status, certLastMod, conf, and servePlainDNS. + // mu protects status, certLastMod, extTLSConf, and servePlainDNS. mu *sync.Mutex // status is the current status of the configuration. It is never nil. @@ -54,8 +54,12 @@ type tlsManager struct { // Resolve it. web *webAPI - // conf contains the TLS configuration settings. It must not be nil. - conf *tlsConfigSettings + // extTLSConf contains extended TLS configuration settings. It must not be + // nil. + // TODO(m.kazantsev): Add a field of a type of [*tls.Config] which will + // represent the TLS settings. This is why these settings are called + // 'extended'. + extTLSConf *tlsConfigSettings // confModifier is used to update the global configuration. confModifier agh.ConfigModifier @@ -111,7 +115,7 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, httpReg: conf.httpReg, manager: conf.manager, status: &tlsConfigStatus{}, - conf: &conf.tlsSettings, + extTLSConf: &conf.tlsSettings, servePlainDNS: conf.servePlainDNS, } @@ -133,21 +137,21 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, m.mu.Lock() defer m.mu.Unlock() - if !m.conf.Enabled { + if !m.extTLSConf.Enabled { return m, nil } err = m.manager.Set(ctx, aghtls.TLSPair{ - CertPath: m.conf.CertificatePath, - KeyPath: m.conf.PrivateKeyPath, + CertPath: m.extTLSConf.CertificatePath, + KeyPath: m.extTLSConf.PrivateKeyPath, }) if err != nil { m.logger.ErrorContext(ctx, "setting tls files", slogutil.KeyError, err) } - err = m.loadTLSConfig(ctx, m.conf, m.status) + err = m.loadTLSConfig(ctx, m.extTLSConf, m.status) if err != nil { - m.conf.Enabled = false + m.extTLSConf.Enabled = false return m, err } @@ -166,22 +170,22 @@ func (m *tlsManager) setWebAPI(webAPI *webAPI) { m.web = webAPI } -// config returns a deep copy of the stored TLS configuration. -func (m *tlsManager) config() (conf *tlsConfigSettings) { +// extendedTLSConfig returns a deep copy of the stored TLS configuration. +func (m *tlsManager) extendedTLSConfig() (extTLSConf *tlsConfigSettings) { m.mu.Lock() defer m.mu.Unlock() - return m.conf.clone() + return m.extTLSConf.clone() } // setCertFileTime sets [tlsManager.certLastMod] from the certificate. If there // are errors, setCertFileTime logs them. m.mu is expected to be locked. func (m *tlsManager) setCertFileTime(ctx context.Context) { - if len(m.conf.CertificatePath) == 0 { + if len(m.extTLSConf.CertificatePath) == 0 { return } - fi, err := os.Stat(m.conf.CertificatePath) + fi, err := os.Stat(m.extTLSConf.CertificatePath) if err != nil { m.logger.ErrorContext(ctx, "looking up certificate path", slogutil.KeyError, err) @@ -203,7 +207,7 @@ func (m *tlsManager) start(ctx context.Context) { // The background context is used because the TLSConfigChanged wraps context // with timeout on its own and shuts down the server, which handles current // request. - m.web.tlsConfigChanged(context.Background(), m.conf) + m.web.tlsConfigChanged(context.Background(), m.extTLSConf) go m.handleCertFileChange(ctx) } @@ -235,7 +239,7 @@ func (m *tlsManager) reload(ctx context.Context) { m.mu.Lock() defer m.mu.Unlock() - tlsConfPtr := m.conf + tlsConfPtr := m.extTLSConf if !tlsConfPtr.Enabled || len(tlsConfPtr.CertificatePath) == 0 { return @@ -267,7 +271,7 @@ func (m *tlsManager) reload(ctx context.Context) { return } - m.conf = &tlsConf + m.extTLSConf = &tlsConf m.status = status m.certLastMod = fi.ModTime().UTC() @@ -280,7 +284,7 @@ func (m *tlsManager) reload(ctx context.Context) { // The background context is used because the TLSConfigChanged wraps context // with timeout on its own and shuts down the server, which handles current // request. - m.web.tlsConfigChanged(context.Background(), m.conf) + m.web.tlsConfigChanged(context.Background(), m.extTLSConf) } // reconfigureDNSServer updates the DNS server configuration using the stored @@ -289,7 +293,7 @@ func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) { newConf, err := newServerConfig( &config.DNS, config.Clients.Sources, - m.conf, + m.extTLSConf, config.HTTPConfig.DoH, m, m.httpReg, @@ -314,7 +318,7 @@ func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) { // set in status.WarningValidation. func (m *tlsManager) loadTLSConfig( ctx context.Context, - tlsConf *tlsConfigSettings, + extTLSConf *tlsConfigSettings, status *tlsConfigStatus, ) (err error) { defer func() { @@ -327,13 +331,13 @@ func (m *tlsManager) loadTLSConfig( } }() - err = loadCertificateChainData(tlsConf) + err = loadCertificateChainData(extTLSConf) if err != nil { // Don't wrap the error, because it's informative enough as is. return err } - err = loadPrivateKeyData(tlsConf) + err = loadPrivateKeyData(extTLSConf) if err != nil { // Don't wrap the error, because it's informative enough as is. return err @@ -342,9 +346,9 @@ func (m *tlsManager) loadTLSConfig( err = m.validateCertificates( ctx, status, - tlsConf.CertificateChainData, - tlsConf.PrivateKeyData, - tlsConf.ServerName, + extTLSConf.CertificateChainData, + extTLSConf.PrivateKeyData, + extTLSConf.ServerName, ) return errors.Annotate(err, "validating certificate pair: %w") @@ -353,15 +357,15 @@ func (m *tlsManager) loadTLSConfig( // loadCertificateChainData loads PEM-encoded certificates chain data to the // TLS configuration. tlsConf must be not nil. tlsConf.CertificateChainData // struct field will be modified in case tlsConfig.CertificatePath is not an -// empty string. -func loadCertificateChainData(tlsConf *tlsConfigSettings) (err error) { - tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain) - if tlsConf.CertificatePath != "" { - if tlsConf.CertificateChain != "" { +// empty string. extTLSConf must not be nil. +func loadCertificateChainData(extTLSConf *tlsConfigSettings) (err error) { + extTLSConf.CertificateChainData = []byte(extTLSConf.CertificateChain) + if extTLSConf.CertificatePath != "" { + if extTLSConf.CertificateChain != "" { return errors.Error("certificate data and file can't be set together") } - tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath) + extTLSConf.CertificateChainData, err = os.ReadFile(extTLSConf.CertificatePath) if err != nil { return fmt.Errorf("reading cert file: %w", err) } @@ -373,14 +377,15 @@ func loadCertificateChainData(tlsConf *tlsConfigSettings) (err error) { // loadPrivateKeyData loads PEM-encoded private key data to the TLS // configuration. tlsConf must be not nil. tlsConf.PrivateKeyData struct field // will be modified in case tlsConfig.PrivateKeyPath is not an empty string. -func loadPrivateKeyData(tlsConf *tlsConfigSettings) (err error) { - tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey) - if tlsConf.PrivateKeyPath != "" { - if tlsConf.PrivateKey != "" { +// extTLSConf must not be nil. +func loadPrivateKeyData(extTLSConf *tlsConfigSettings) (err error) { + extTLSConf.PrivateKeyData = []byte(extTLSConf.PrivateKey) + if extTLSConf.PrivateKeyPath != "" { + if extTLSConf.PrivateKey != "" { return errors.Error("private key data and file can't be set together") } - tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath) + extTLSConf.PrivateKeyData, err = os.ReadFile(extTLSConf.PrivateKeyPath) if err != nil { return fmt.Errorf("reading key file: %w", err) } @@ -460,7 +465,7 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) { m.mu.Lock() defer m.mu.Unlock() - tlsConf = m.conf.clone() + tlsConf = m.extTLSConf.clone() servePlainDNS = m.servePlainDNS }() @@ -494,7 +499,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { defer m.mu.Unlock() if setts.PrivateKeySaved { - setts.PrivateKey = m.conf.PrivateKey + setts.PrivateKey = m.extTLSConf.PrivateKey } if err = m.validateTLSSettings(setts); err != nil { @@ -525,14 +530,14 @@ func (m *tlsManager) setConfig( status *tlsConfigStatus, servePlain aghalg.NullBool, ) (restartHTTPS bool) { - if !m.conf.setPrivateFieldsAndCompare(&newConf) { + if !m.extTLSConf.setPrivateFieldsAndCompare(&newConf) { m.logger.InfoContext(ctx, "config has changed, restarting https server") restartHTTPS = true } else { m.logger.InfoContext(ctx, "config has not changed") } - m.conf = &newConf + m.extTLSConf = &newConf m.status = status @@ -587,10 +592,10 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) defer m.mu.Unlock() if req.PrivateKeySaved { - req.PrivateKey = m.conf.PrivateKey + req.PrivateKey = m.extTLSConf.PrivateKey } - req.StrictSNICheck = m.conf.StrictSNICheck + req.StrictSNICheck = m.extTLSConf.StrictSNICheck if err = m.validateTLSSettings(req); err != nil { aghhttp.ErrorAndLog(ctx, m.logger, r, w, http.StatusBadRequest, "%s", err) diff --git a/internal/home/tls_internal_test.go b/internal/home/tls_internal_test.go index 161d43f9410..0cb219fe4e7 100644 --- a/internal/home/tls_internal_test.go +++ b/internal/home/tls_internal_test.go @@ -25,44 +25,18 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// TODO(s.chzhen): Consider moving to testdata. -var testCertChainData = []byte(`-----BEGIN CERTIFICATE----- -MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV -BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3 -MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV -MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB -gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW -QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6 -i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV -HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4 -eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8 -LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna -b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq -Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ== ------END CERTIFICATE-----`) - -var testPrivateKeyData = []byte(`-----BEGIN PRIVATE KEY----- -MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p -aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV -FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX -xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP -QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU -QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL -MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72 -9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg -a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj -FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg -An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp -O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG -O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU -kXS9jgARhhiWXJrk ------END PRIVATE KEY-----`) +// Paths to the test TLS-related data. +const ( + testCertificatePath = "./testdata/cert.pem" + testPrivateKeyPath = "./testdata/key.pem" +) func TestValidateCertificates(t *testing.T) { ctx := testutil.ContextWithTimeout(t, testTimeout) @@ -92,6 +66,10 @@ func TestValidateCertificates(t *testing.T) { t.Run("valid", func(t *testing.T) { status := &tlsConfigStatus{} + + testCertChainData := requireReadFile(t, testCertificatePath) + testPrivateKeyData := requireReadFile(t, testPrivateKeyPath) + err = m.validateCertificates(ctx, status, testCertChainData, testPrivateKeyData, "") assert.Error(t, err) @@ -213,7 +191,7 @@ func newCertAndKey(tb testing.TB, n int64) (certDER []byte, key *rsa.PrivateKey) } // writeCertAndKey is a helper function that writes certificate and key to -// specified paths. +// specified paths. key must not be nil. func writeCertAndKey( tb testing.TB, certDER []byte, @@ -310,8 +288,8 @@ func TestTLSManager_Reload(t *testing.T) { web := newTestWeb(t, &webConfig{}) m.setWebAPI(web) - conf := m.config() - assertCertSerialNumber(t, conf, snBefore) + extTLSConf := m.extendedTLSConfig() + assertCertSerialNumber(t, extTLSConf, snBefore) certDER, key = newCertAndKey(t, snAfter) writeCertAndKey(t, certDER, certPath, key, keyPath) @@ -324,8 +302,8 @@ func TestTLSManager_Reload(t *testing.T) { return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout)) }) - conf = m.config() - assertCertSerialNumber(t, conf, snAfter) + extTLSConf = m.extendedTLSConfig() + assertCertSerialNumber(t, extTLSConf, snAfter) } func TestTLSManager_HandleTLSStatus(t *testing.T) { @@ -334,13 +312,16 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) { err error ) + testCertChain := requireReadFile(t, testCertificatePath) + testPrivateKeyData := requireReadFile(t, testPrivateKeyPath) + m, err := newTLSManager(ctx, &tlsManagerConfig{ logger: testLogger, confModifier: agh.EmptyConfigModifier{}, manager: aghtls.EmptyManager{}, tlsSettings: tlsConfigSettings{ Enabled: true, - CertificateChain: string(testCertChainData), + CertificateChain: string(testCertChain), PrivateKey: string(testPrivateKeyData), }, servePlainDNS: false, @@ -355,7 +336,7 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) { err = json.NewDecoder(w.Body).Decode(res) require.NoError(t, err) - wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChainData) + wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChain) assert.True(t, res.Enabled) assert.Equal(t, wantCertificateChain, res.CertificateChain) assert.True(t, res.PrivateKeySaved) @@ -470,9 +451,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) { confModifier: agh.EmptyConfigModifier{}, manager: aghtls.EmptyManager{}, tlsSettings: tlsConfigSettings{ - Enabled: true, - CertificateChain: string(testCertChainData), - PrivateKey: string(testPrivateKeyData), + Enabled: true, + CertificatePath: testCertificatePath, + PrivateKeyPath: testPrivateKeyPath, }, servePlainDNS: false, }) @@ -483,9 +464,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) { setts := &tlsConfigSettingsExt{ tlsConfigSettings: tlsConfigSettings{ - Enabled: true, - CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData), - PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData), + Enabled: true, + CertificatePath: testCertificatePath, + PrivateKeyPath: testPrivateKeyPath, }, } @@ -500,6 +481,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) { err = json.NewDecoder(w.Body).Decode(res) require.NoError(t, err) + testCertChainData := requireReadFile(t, testCertificatePath) + testPrivateKeyData := requireReadFile(t, testPrivateKeyPath) + cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) require.NoError(t, err) @@ -541,7 +525,7 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { }) require.NoError(t, err) - config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")} + config.DNS.BindHosts = []netip.Addr{netutil.IPv4Localhost()} config.DNS.Port = 0 const wantSerialNumber int64 = 1 @@ -571,16 +555,16 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { web := newTestWeb(t, &webConfig{}) m.setWebAPI(web) - conf := m.config() - assertCertSerialNumber(t, conf, wantSerialNumber) + extTLSConf := m.extendedTLSConfig() + assertCertSerialNumber(t, extTLSConf, wantSerialNumber) // Prepare a request with the new TLS configuration. setts := &tlsConfigSettingsExt{ tlsConfigSettings: tlsConfigSettings{ - Enabled: true, - PortHTTPS: 4433, - CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData), - PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData), + Enabled: true, + PortHTTPS: 4433, + CertificatePath: testCertificatePath, + PrivateKeyPath: testPrivateKeyPath, }, } @@ -606,6 +590,9 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { err = json.NewDecoder(w.Body).Decode(res) require.NoError(t, err) + testCertChainData := requireReadFile(t, testCertificatePath) + testPrivateKeyData := requireReadFile(t, testPrivateKeyPath) + cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) require.NoError(t, err) @@ -629,3 +616,15 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { return true }, testTimeout, testTimeout/10) } + +// requireReadFile reads the file at the specified path and returns its content. +// +// TODO(m.kazantsev): Move to golibs/testutil. +func requireReadFile(tb testing.TB, path string) (data []byte) { + tb.Helper() + + data, err := os.ReadFile(path) + require.NoError(tb, err) + + return data +}