diff --git a/config.go b/config.go index 4bf70e1..0f1cfb6 100644 --- a/config.go +++ b/config.go @@ -10,27 +10,29 @@ import ( // SignConfig contains additional configuration for the signer. type SignConfig struct { - signAlg bool - signCreated bool - fakeCreated int64 - expires int64 - nonce string - tag string - keyID *string - maxBodySize int64 + signAlg bool + signCreated bool + fakeCreated int64 + expires int64 + expiresAfter int64 + nonce string + tag string + keyID *string + maxBodySize int64 schemeFromRequest func(*http.Request) string } // NewSignConfig generates a default configuration. func NewSignConfig() *SignConfig { return &SignConfig{ - signAlg: true, - signCreated: true, - fakeCreated: 0, - expires: 0, - nonce: "", - tag: "", // we disallow an empty tag - keyID: nil, + signAlg: true, + signCreated: true, + fakeCreated: 0, + expires: 0, + expiresAfter: 0, + nonce: "", + tag: "", // we disallow an empty tag + keyID: nil, } } @@ -60,6 +62,14 @@ func (c *SignConfig) SetExpires(expires int64) *SignConfig { return c } +// SetExpiresAfter sets the "expires" parameter to createdTime + delay (seconds). +// Use this for a relative validity window instead of an absolute timestamp. +// Default: 0 (do not add the parameter). +func (c *SignConfig) SetExpiresAfter(delay int64) *SignConfig { + c.expiresAfter = delay + return c +} + // SetNonce adds a "nonce" string parameter whose content should be unique per signed message. // Default: empty string (do not add the parameter). func (c *SignConfig) SetNonce(nonce string) *SignConfig { @@ -99,17 +109,17 @@ func (c *SignConfig) SetSchemeFromRequest(f func(*http.Request) string) *SignCon // VerifyConfig contains additional configuration for the verifier. type VerifyConfig struct { - verifyCreated bool - notNewerThan time.Duration - notOlderThan time.Duration - allowedAlgs []string - rejectExpired bool - keyID *string - dateWithin time.Duration - allowedTags []string - maxBodySize int64 + verifyCreated bool + notNewerThan time.Duration + notOlderThan time.Duration + allowedAlgs []string + rejectExpired bool + keyID *string + dateWithin time.Duration + allowedTags []string + maxBodySize int64 schemeFromRequest func(*http.Request) string - nonceValidator func(string) error + nonceValidator func(string) error } // SetNonceValidator sets a callback to validate the nonce parameter during verification. diff --git a/handler_test.go b/handler_test.go index 5feac04..18d3ce5 100644 --- a/handler_test.go +++ b/handler_test.go @@ -59,13 +59,15 @@ func Test_WrapHandler(t *testing.T) { // test various failures func TestWrapHandlerServerSigns(t *testing.T) { - serverSignsTestCase := func(t *testing.T, nilSigner, dontSignResponse, earlyExpires, noSigner, badKey, badAlgs, verifyRequest bool) { + serverSignsTestCase := func(t *testing.T, nilSigner, dontSignResponse, earlyExpires, earlyExpiresAfter, noSigner, badKey, badAlgs, verifyRequest bool) { // Callback to let the server locate its signing key and configuration var signConfig *SignConfig - if !earlyExpires { - signConfig = NewSignConfig() - } else { + if earlyExpires { signConfig = NewSignConfig().SetExpires(2000) + } else if earlyExpiresAfter { + signConfig = NewSignConfig().SetExpiresAfter(1).setFakeCreated(1000) + } else { + signConfig = NewSignConfig() } fetchSigner := func(res http.Response, r *http.Request) (string, *Signer) { sigName := "sig1" @@ -128,29 +130,33 @@ func TestWrapHandlerServerSigns(t *testing.T) { } } nilSigner := func(t *testing.T) { - serverSignsTestCase(t, true, false, false, false, false, false, false) + serverSignsTestCase(t, true, false, false, false, false, false, false, false) } dontSignResponse := func(t *testing.T) { - serverSignsTestCase(t, false, true, false, false, false, false, false) + serverSignsTestCase(t, false, true, false, false, false, false, false, false) } earlyExpires := func(t *testing.T) { - serverSignsTestCase(t, false, false, true, false, false, false, false) + serverSignsTestCase(t, false, false, true, false, false, false, false, false) + } + earlyExpiresAfter := func(t *testing.T) { + serverSignsTestCase(t, false, false, false, true, false, false, false, false) } noSigner := func(t *testing.T) { - serverSignsTestCase(t, false, false, false, true, false, false, false) + serverSignsTestCase(t, false, false, false, false, true, false, false, false) } badKey := func(t *testing.T) { - serverSignsTestCase(t, false, false, false, false, true, false, false) + serverSignsTestCase(t, false, false, false, false, false, true, false, false) } badAlgs := func(t *testing.T) { - serverSignsTestCase(t, false, false, false, false, false, true, false) + serverSignsTestCase(t, false, false, false, false, false, false, true, false) } failVerify := func(t *testing.T) { - serverSignsTestCase(t, false, false, false, false, false, false, true) + serverSignsTestCase(t, false, false, false, false, false, false, false, true) } t.Run("nil Signer", nilSigner) t.Run("don't sign response", dontSignResponse) t.Run("early expires field", earlyExpires) + t.Run("early expires after field", earlyExpiresAfter) t.Run("bad fetch Signer", noSigner) t.Run("wrong verification key", badKey) t.Run("failed algorithm check", badAlgs) diff --git a/signatures.go b/signatures.go index f90b83d..b3fc3f0 100644 --- a/signatures.go +++ b/signatures.go @@ -279,6 +279,12 @@ func generateSigParams(config *SignConfig, alg string, foreignSigner interface{} if config.expires != 0 { p.Add("expires", config.expires) } + if config.expiresAfter != 0 { + if config.expires != 0 { + return "", fmt.Errorf("cannot use both expires and expiresAfter") + } + p.Add("expires", config.expiresAfter+createdTime) + } if config.nonce != "" { qNonce, err := quotedString(config.nonce) if err != nil { diff --git a/signatures_test.go b/signatures_test.go index 2576ed1..ec16d5f 100644 --- a/signatures_test.go +++ b/signatures_test.go @@ -981,6 +981,36 @@ func TestMessageSignAndVerifyResponseHMAC(t *testing.T) { } } +func TestExpiresAfterCalculation(t *testing.T) { + fields := Headers("@status", "date", "content-type") + signatureName := "sigres" + key, _ := base64.StdEncoding.DecodeString("uzvJfB4u3N0Jy4T7NZ75MDVcr8zSTInedJtkgcu46YW4XByzNJjxBdtjUkdJPBtbmHhIDi6pcl8jsasjlTMtDQ==") + config := NewSignConfig().SetExpiresAfter(60).setFakeCreated(1000).SetKeyID("test-shared-secret") + signer, _ := NewHMACSHA256Signer(key, config, fields) + res := readResponse(httpres2) + sigInput, _, err := SignResponse(signatureName, *signer, res, nil) + if err != nil { + t.Fatalf("SignResponse failed: %s", err) + } + // expires should be fakeCreated + expiresAfter = 1000 + 60 = 1060 + if !strings.Contains(sigInput, "expires=1060") { + t.Errorf("expected expires=1060 in signature input, got: %s", sigInput) + } +} + +func TestExpiresAndExpiresAfterConflict(t *testing.T) { + fields := Headers("@status", "date", "content-type") + signatureName := "sigres" + key, _ := base64.StdEncoding.DecodeString("uzvJfB4u3N0Jy4T7NZ75MDVcr8zSTInedJtkgcu46YW4XByzNJjxBdtjUkdJPBtbmHhIDi6pcl8jsasjlTMtDQ==") + config := NewSignConfig().SetExpires(2000).SetExpiresAfter(60).SetKeyID("test-shared-secret") + signer, _ := NewHMACSHA256Signer(key, config, fields) + res := readResponse(httpres2) + _, _, err := SignResponse(signatureName, *signer, res, nil) + if err == nil { + t.Errorf("expected error when both SetExpires and SetExpiresAfter are set") + } +} + func TestSignAndVerifyRSAPSS(t *testing.T) { config := NewSignConfig().SignAlg(false).setFakeCreated(1618884475).SetKeyID("test-key-rsa-pss") fields := Headers("@authority", "date", "content-type")