diff --git a/src/rust/src/backend/dsa.rs b/src/rust/src/backend/dsa.rs index c398c7faad6b..2717deeb6bbd 100644 --- a/src/rust/src/backend/dsa.rs +++ b/src/rust/src/backend/dsa.rs @@ -2,6 +2,8 @@ // 2.0, and the BSD License. See the LICENSE file in the root of this repository // for complete details. +use std::cmp::Ordering; + use pyo3::types::PyAnyMethods; use crate::backend::utils; @@ -38,18 +40,22 @@ struct DsaParameters { pub(crate) fn private_key_from_pkey( pkey: &openssl::pkey::PKeyRef, -) -> DsaPrivateKey { - DsaPrivateKey { +) -> CryptographyResult { + let dsa = pkey.dsa()?; + check_dsa_private_key(dsa.p(), dsa.q(), dsa.g(), dsa.pub_key(), dsa.priv_key())?; + Ok(DsaPrivateKey { pkey: pkey.to_owned(), - } + }) } pub(crate) fn public_key_from_pkey( pkey: &openssl::pkey::PKeyRef, -) -> DsaPublicKey { - DsaPublicKey { +) -> CryptographyResult { + let dsa = pkey.dsa()?; + check_dsa_public_key(dsa.p(), dsa.q(), dsa.g(), dsa.pub_key())?; + Ok(DsaPublicKey { pkey: pkey.to_owned(), - } + }) } #[pyo3::pyfunction] @@ -270,67 +276,128 @@ fn check_dsa_parameters( py: pyo3::Python<'_>, parameters: &DsaParameterNumbers, ) -> CryptographyResult<()> { - if ![1024, 2048, 3072, 4096].contains( - ¶meters - .p - .bind(py) - .call_method0("bit_length")? - .extract::()?, - ) { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err( - "p must be exactly 1024, 2048, 3072, or 4096 bits long", - ), + let p = py_int_to_signed_bn(py, parameters.p.bind(py))?; + let q = py_int_to_signed_bn(py, parameters.q.bind(py))?; + let g = py_int_to_signed_bn(py, parameters.g.bind(py))?; + + check_dsa_parameters_bignums(&p, &q, &g) +} + +fn dsa_value_error(message: &'static str) -> CryptographyError { + CryptographyError::from(pyo3::exceptions::PyValueError::new_err(message)) +} + +fn py_int_to_signed_bn( + py: pyo3::Python<'_>, + value: &pyo3::Bound<'_, pyo3::PyAny>, +) -> CryptographyResult { + let negative = value.lt(0)?; + let magnitude = value.call_method0(pyo3::intern!(py, "__abs__"))?; + let mut bn = utils::py_int_to_bn(py, &magnitude)?; + bn.set_negative(negative); + Ok(bn) +} + +fn check_dsa_parameters_bignums( + p: &openssl::bn::BigNumRef, + q: &openssl::bn::BigNumRef, + g: &openssl::bn::BigNumRef, +) -> CryptographyResult<()> { + if p.is_negative() || ![1024, 2048, 3072, 4096].contains(&(p.num_bits() as usize)) { + return Err(dsa_value_error( + "p must be exactly 1024, 2048, 3072, or 4096 bits long", )); } - if ![160, 224, 256].contains( - ¶meters - .q - .bind(py) - .call_method0("bit_length")? - .extract::()?, - ) { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("q must be exactly 160, 224, or 256 bits long"), + if q.is_negative() || ![160, 224, 256].contains(&(q.num_bits() as usize)) { + return Err(dsa_value_error( + "q must be exactly 160, 224, or 256 bits long", )); } - if parameters.g.bind(py).le(1)? || parameters.g.bind(py).ge(parameters.p.bind(py))? { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("g, p don't satisfy 1 < g < p."), - )); + let one = openssl::bn::BigNum::from_u32(1)?; + if g.cmp(&one) != Ordering::Greater || g.cmp(p) != Ordering::Less { + return Err(dsa_value_error("g, p don't satisfy 1 < g < p.")); } Ok(()) } -fn check_dsa_private_numbers( - py: pyo3::Python<'_>, - numbers: &DsaPrivateNumbers, +fn check_dsa_public_key( + p: &openssl::bn::BigNumRef, + q: &openssl::bn::BigNumRef, + g: &openssl::bn::BigNumRef, + y: &openssl::bn::BigNumRef, ) -> CryptographyResult<()> { - let params = numbers.public_numbers.get().parameter_numbers.get(); - check_dsa_parameters(py, params)?; + check_dsa_parameters_bignums(p, q, g)?; - if numbers.x.bind(py).le(0)? || numbers.x.bind(py).ge(params.q.bind(py))? { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("x must be > 0 and < q."), - )); + let one = openssl::bn::BigNum::from_u32(1)?; + if y.cmp(&one) != Ordering::Greater || y.cmp(p) != Ordering::Less { + return Err(dsa_value_error("y must be > 1 and < p.")); } - if (**numbers.public_numbers.get().y.bind(py)).ne(params - .g - .bind(py) - .pow(numbers.x.bind(py), Some(params.p.bind(py)))?)? - { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("y must be equal to (g ** x % p)."), - )); + let mut ctx = openssl::bn::BigNumContext::new()?; + let mut result = openssl::bn::BigNum::new()?; + result.mod_exp(y, q, p, &mut ctx)?; + if result.cmp(&one) != Ordering::Equal { + return Err(dsa_value_error("y ** q mod p must be 1.")); + } + + Ok(()) +} + +fn check_dsa_private_key( + p: &openssl::bn::BigNumRef, + q: &openssl::bn::BigNumRef, + g: &openssl::bn::BigNumRef, + y: &openssl::bn::BigNumRef, + x: &openssl::bn::BigNumRef, +) -> CryptographyResult<()> { + check_dsa_public_key(p, q, g, y)?; + + let zero = openssl::bn::BigNum::from_u32(0)?; + if x.cmp(&zero) != Ordering::Greater || x.cmp(q) != Ordering::Less { + return Err(dsa_value_error("x must be > 0 and < q.")); + } + + let mut ctx = openssl::bn::BigNumContext::new()?; + let mut expected_y = openssl::bn::BigNum::new()?; + expected_y.mod_exp(g, x, p, &mut ctx)?; + if y.cmp(&expected_y) != Ordering::Equal { + return Err(dsa_value_error("y must be equal to (g ** x % p).")); } Ok(()) } +fn check_dsa_public_numbers( + py: pyo3::Python<'_>, + numbers: &DsaPublicNumbers, +) -> CryptographyResult<()> { + let params = numbers.parameter_numbers.get(); + let p = py_int_to_signed_bn(py, params.p.bind(py))?; + let q = py_int_to_signed_bn(py, params.q.bind(py))?; + let g = py_int_to_signed_bn(py, params.g.bind(py))?; + let y = py_int_to_signed_bn(py, numbers.y.bind(py))?; + + check_dsa_public_key(&p, &q, &g, &y) +} + +fn check_dsa_private_numbers( + py: pyo3::Python<'_>, + numbers: &DsaPrivateNumbers, +) -> CryptographyResult<()> { + let public_numbers = numbers.public_numbers.get(); + let params = public_numbers.parameter_numbers.get(); + let p = py_int_to_signed_bn(py, params.p.bind(py))?; + let q = py_int_to_signed_bn(py, params.q.bind(py))?; + let g = py_int_to_signed_bn(py, params.g.bind(py))?; + let y = py_int_to_signed_bn(py, public_numbers.y.bind(py))?; + let x = py_int_to_signed_bn(py, numbers.x.bind(py))?; + + check_dsa_private_key(&p, &q, &g, &y, &x) +} + #[pyo3::pyclass( frozen, module = "cryptography.hazmat.primitives.asymmetric.dsa", @@ -440,7 +507,7 @@ impl DsaPublicNumbers { let parameter_numbers = self.parameter_numbers.get(); - check_dsa_parameters(py, parameter_numbers)?; + check_dsa_public_numbers(py, self)?; let dsa = openssl::dsa::Dsa::from_public_components( utils::py_int_to_bn(py, parameter_numbers.p.bind(py))?, diff --git a/src/rust/src/backend/keys.rs b/src/rust/src/backend/keys.rs index 5acebea690b1..8751f6a8fb6d 100644 --- a/src/rust/src/backend/keys.rs +++ b/src/rust/src/backend/keys.rs @@ -165,7 +165,7 @@ fn private_key_from_pkey<'p>( openssl::pkey::Id::ED448 => Ok(crate::backend::ed448::private_key_from_pkey(pkey) .into_pyobject(py)? .into_any()), - openssl::pkey::Id::DSA => Ok(crate::backend::dsa::private_key_from_pkey(pkey) + openssl::pkey::Id::DSA => Ok(crate::backend::dsa::private_key_from_pkey(pkey)? .into_pyobject(py)? .into_any()), openssl::pkey::Id::DH => Ok(crate::backend::dh::private_key_from_pkey(pkey) @@ -363,7 +363,7 @@ fn public_key_from_pkey<'p>( .into_pyobject(py)? .into_any()), - openssl::pkey::Id::DSA => Ok(crate::backend::dsa::public_key_from_pkey(pkey) + openssl::pkey::Id::DSA => Ok(crate::backend::dsa::public_key_from_pkey(pkey)? .into_pyobject(py)? .into_any()), openssl::pkey::Id::DH => Ok(crate::backend::dh::public_key_from_pkey(pkey) diff --git a/tests/hazmat/primitives/test_dsa.py b/tests/hazmat/primitives/test_dsa.py index 94e25eef8cd4..d910f0105ef4 100644 --- a/tests/hazmat/primitives/test_dsa.py +++ b/tests/hazmat/primitives/test_dsa.py @@ -270,6 +270,24 @@ def test_invalid_parameters_values(self, p, q, g, backend): DSA_KEY_1024.public_numbers.y, 2**200, ), + ( + DSA_KEY_1024.public_numbers.parameter_numbers.p, + DSA_KEY_1024.public_numbers.parameter_numbers.q, + DSA_KEY_1024.public_numbers.parameter_numbers.g, + -1, + DSA_KEY_1024.x, + ), + ( + DSA_KEY_1024.public_numbers.parameter_numbers.p, + DSA_KEY_1024.public_numbers.parameter_numbers.q, + DSA_KEY_1024.public_numbers.parameter_numbers.g, + pow( + DSA_KEY_1024.public_numbers.parameter_numbers.g, + DSA_KEY_1024.x + 1, + DSA_KEY_1024.public_numbers.parameter_numbers.p, + ), + DSA_KEY_1024.x, + ), ( DSA_KEY_1024.public_numbers.parameter_numbers.p, DSA_KEY_1024.public_numbers.parameter_numbers.q, @@ -352,6 +370,30 @@ def test_invalid_dsa_private_key_arguments(self, p, q, g, y, x, backend): 2**1200, DSA_KEY_1024.public_numbers.y, ), + ( + DSA_KEY_1024.public_numbers.parameter_numbers.p, + DSA_KEY_1024.public_numbers.parameter_numbers.q, + DSA_KEY_1024.public_numbers.parameter_numbers.g, + -1, + ), + ( + DSA_KEY_1024.public_numbers.parameter_numbers.p, + DSA_KEY_1024.public_numbers.parameter_numbers.q, + DSA_KEY_1024.public_numbers.parameter_numbers.g, + 1, + ), + ( + DSA_KEY_1024.public_numbers.parameter_numbers.p, + DSA_KEY_1024.public_numbers.parameter_numbers.q, + DSA_KEY_1024.public_numbers.parameter_numbers.g, + DSA_KEY_1024.public_numbers.parameter_numbers.p + 1, + ), + ( + DSA_KEY_1024.public_numbers.parameter_numbers.p, + DSA_KEY_1024.public_numbers.parameter_numbers.q, + DSA_KEY_1024.public_numbers.parameter_numbers.g, + DSA_KEY_1024.public_numbers.parameter_numbers.p - 1, + ), ], ) def test_invalid_dsa_public_key_arguments(self, p, q, g, y, backend): @@ -478,12 +520,17 @@ def test_dsa_verification(self, backend, subtests): backend, algorithm, vector["p"], vector["q"], vector["g"] ) - public_key = dsa.DSAPublicNumbers( - parameter_numbers=dsa.DSAParameterNumbers( - vector["p"], vector["q"], vector["g"] - ), - y=vector["y"], - ).public_key(backend) + try: + public_key = dsa.DSAPublicNumbers( + parameter_numbers=dsa.DSAParameterNumbers( + vector["p"], vector["q"], vector["g"] + ), + y=vector["y"], + ).public_key(backend) + except ValueError: + assert vector["result"] == "F" + continue + sig = encode_dss_signature(vector["r"], vector["s"]) if vector["result"] == "F": @@ -610,16 +657,13 @@ def test_prehashed_digest_mismatch(self, backend): skip_message="Requires OpenSSL 3.0.9+, LibreSSL, BoringSSL, or AWS-LC", ) def test_nilpotent(self): - key = load_vectors_from_file( - os.path.join("asymmetric", "DSA", "custom", "nilpotent.pem"), - lambda pemfile: serialization.load_pem_private_key( - pemfile.read().encode(), password=None - ), - ) - assert isinstance(key, dsa.DSAPrivateKey) - with pytest.raises(ValueError): - key.sign(b"anything", hashes.SHA256()) + load_vectors_from_file( + os.path.join("asymmetric", "DSA", "custom", "nilpotent.pem"), + lambda pemfile: serialization.load_pem_private_key( + pemfile.read().encode(), password=None + ), + ) class TestDSANumbers: