Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 115 additions & 48 deletions src/rust/src/backend/dsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// 2.0, and the BSD License. See the LICENSE file in the root of this repository
// for complete details.

use std::cmp::Ordering;

use pyo3::types::PyAnyMethods;

use crate::backend::utils;
Expand Down Expand Up @@ -38,18 +40,22 @@ struct DsaParameters {

pub(crate) fn private_key_from_pkey(
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Private>,
) -> DsaPrivateKey {
DsaPrivateKey {
) -> CryptographyResult<DsaPrivateKey> {
let dsa = pkey.dsa()?;
check_dsa_private_key(dsa.p(), dsa.q(), dsa.g(), dsa.pub_key(), dsa.priv_key())?;
Ok(DsaPrivateKey {
pkey: pkey.to_owned(),
}
})
}

pub(crate) fn public_key_from_pkey(
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Public>,
) -> DsaPublicKey {
DsaPublicKey {
) -> CryptographyResult<DsaPublicKey> {
let dsa = pkey.dsa()?;
check_dsa_public_key(dsa.p(), dsa.q(), dsa.g(), dsa.pub_key())?;
Ok(DsaPublicKey {
pkey: pkey.to_owned(),
}
})
}

#[pyo3::pyfunction]
Expand Down Expand Up @@ -270,67 +276,128 @@ fn check_dsa_parameters(
py: pyo3::Python<'_>,
parameters: &DsaParameterNumbers,
) -> CryptographyResult<()> {
if ![1024, 2048, 3072, 4096].contains(
&parameters
.p
.bind(py)
.call_method0("bit_length")?
.extract::<usize>()?,
) {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err(
"p must be exactly 1024, 2048, 3072, or 4096 bits long",
),
let p = py_int_to_signed_bn(py, parameters.p.bind(py))?;
let q = py_int_to_signed_bn(py, parameters.q.bind(py))?;
let g = py_int_to_signed_bn(py, parameters.g.bind(py))?;

check_dsa_parameters_bignums(&p, &q, &g)
}

fn dsa_value_error(message: &'static str) -> CryptographyError {
CryptographyError::from(pyo3::exceptions::PyValueError::new_err(message))
}

fn py_int_to_signed_bn(
py: pyo3::Python<'_>,
value: &pyo3::Bound<'_, pyo3::PyAny>,
) -> CryptographyResult<openssl::bn::BigNum> {
let negative = value.lt(0)?;
let magnitude = value.call_method0(pyo3::intern!(py, "__abs__"))?;
let mut bn = utils::py_int_to_bn(py, &magnitude)?;
bn.set_negative(negative);
Ok(bn)
}

fn check_dsa_parameters_bignums(
p: &openssl::bn::BigNumRef,
q: &openssl::bn::BigNumRef,
g: &openssl::bn::BigNumRef,
) -> CryptographyResult<()> {
if p.is_negative() || ![1024, 2048, 3072, 4096].contains(&(p.num_bits() as usize)) {
return Err(dsa_value_error(
"p must be exactly 1024, 2048, 3072, or 4096 bits long",
));
}

if ![160, 224, 256].contains(
&parameters
.q
.bind(py)
.call_method0("bit_length")?
.extract::<usize>()?,
) {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("q must be exactly 160, 224, or 256 bits long"),
if q.is_negative() || ![160, 224, 256].contains(&(q.num_bits() as usize)) {
return Err(dsa_value_error(
"q must be exactly 160, 224, or 256 bits long",
));
}

if parameters.g.bind(py).le(1)? || parameters.g.bind(py).ge(parameters.p.bind(py))? {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("g, p don't satisfy 1 < g < p."),
));
let one = openssl::bn::BigNum::from_u32(1)?;
if g.cmp(&one) != Ordering::Greater || g.cmp(p) != Ordering::Less {
return Err(dsa_value_error("g, p don't satisfy 1 < g < p."));
}

Ok(())
}

fn check_dsa_private_numbers(
py: pyo3::Python<'_>,
numbers: &DsaPrivateNumbers,
fn check_dsa_public_key(
p: &openssl::bn::BigNumRef,
q: &openssl::bn::BigNumRef,
g: &openssl::bn::BigNumRef,
y: &openssl::bn::BigNumRef,
) -> CryptographyResult<()> {
let params = numbers.public_numbers.get().parameter_numbers.get();
check_dsa_parameters(py, params)?;
check_dsa_parameters_bignums(p, q, g)?;

if numbers.x.bind(py).le(0)? || numbers.x.bind(py).ge(params.q.bind(py))? {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("x must be > 0 and < q."),
));
let one = openssl::bn::BigNum::from_u32(1)?;
if y.cmp(&one) != Ordering::Greater || y.cmp(p) != Ordering::Less {
return Err(dsa_value_error("y must be > 1 and < p."));
}

if (**numbers.public_numbers.get().y.bind(py)).ne(params
.g
.bind(py)
.pow(numbers.x.bind(py), Some(params.p.bind(py)))?)?
{
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("y must be equal to (g ** x % p)."),
));
let mut ctx = openssl::bn::BigNumContext::new()?;
let mut result = openssl::bn::BigNum::new()?;
result.mod_exp(y, q, p, &mut ctx)?;
if result.cmp(&one) != Ordering::Equal {
return Err(dsa_value_error("y ** q mod p must be 1."));
}

Ok(())
}

fn check_dsa_private_key(
p: &openssl::bn::BigNumRef,
q: &openssl::bn::BigNumRef,
g: &openssl::bn::BigNumRef,
y: &openssl::bn::BigNumRef,
x: &openssl::bn::BigNumRef,
) -> CryptographyResult<()> {
check_dsa_public_key(p, q, g, y)?;

let zero = openssl::bn::BigNum::from_u32(0)?;
if x.cmp(&zero) != Ordering::Greater || x.cmp(q) != Ordering::Less {
return Err(dsa_value_error("x must be > 0 and < q."));
}

let mut ctx = openssl::bn::BigNumContext::new()?;
let mut expected_y = openssl::bn::BigNum::new()?;
expected_y.mod_exp(g, x, p, &mut ctx)?;
if y.cmp(&expected_y) != Ordering::Equal {
return Err(dsa_value_error("y must be equal to (g ** x % p)."));
}

Ok(())
}

fn check_dsa_public_numbers(
py: pyo3::Python<'_>,
numbers: &DsaPublicNumbers,
) -> CryptographyResult<()> {
let params = numbers.parameter_numbers.get();
let p = py_int_to_signed_bn(py, params.p.bind(py))?;
let q = py_int_to_signed_bn(py, params.q.bind(py))?;
let g = py_int_to_signed_bn(py, params.g.bind(py))?;
let y = py_int_to_signed_bn(py, numbers.y.bind(py))?;

check_dsa_public_key(&p, &q, &g, &y)
}

fn check_dsa_private_numbers(
py: pyo3::Python<'_>,
numbers: &DsaPrivateNumbers,
) -> CryptographyResult<()> {
let public_numbers = numbers.public_numbers.get();
let params = public_numbers.parameter_numbers.get();
let p = py_int_to_signed_bn(py, params.p.bind(py))?;
let q = py_int_to_signed_bn(py, params.q.bind(py))?;
let g = py_int_to_signed_bn(py, params.g.bind(py))?;
let y = py_int_to_signed_bn(py, public_numbers.y.bind(py))?;
let x = py_int_to_signed_bn(py, numbers.x.bind(py))?;

check_dsa_private_key(&p, &q, &g, &y, &x)
}

#[pyo3::pyclass(
frozen,
module = "cryptography.hazmat.primitives.asymmetric.dsa",
Expand Down Expand Up @@ -440,7 +507,7 @@ impl DsaPublicNumbers {

let parameter_numbers = self.parameter_numbers.get();

check_dsa_parameters(py, parameter_numbers)?;
check_dsa_public_numbers(py, self)?;

let dsa = openssl::dsa::Dsa::from_public_components(
utils::py_int_to_bn(py, parameter_numbers.p.bind(py))?,
Expand Down
4 changes: 2 additions & 2 deletions src/rust/src/backend/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ fn private_key_from_pkey<'p>(
openssl::pkey::Id::ED448 => Ok(crate::backend::ed448::private_key_from_pkey(pkey)
.into_pyobject(py)?
.into_any()),
openssl::pkey::Id::DSA => Ok(crate::backend::dsa::private_key_from_pkey(pkey)
openssl::pkey::Id::DSA => Ok(crate::backend::dsa::private_key_from_pkey(pkey)?
.into_pyobject(py)?
.into_any()),
openssl::pkey::Id::DH => Ok(crate::backend::dh::private_key_from_pkey(pkey)
Expand Down Expand Up @@ -363,7 +363,7 @@ fn public_key_from_pkey<'p>(
.into_pyobject(py)?
.into_any()),

openssl::pkey::Id::DSA => Ok(crate::backend::dsa::public_key_from_pkey(pkey)
openssl::pkey::Id::DSA => Ok(crate::backend::dsa::public_key_from_pkey(pkey)?
.into_pyobject(py)?
.into_any()),
openssl::pkey::Id::DH => Ok(crate::backend::dh::public_key_from_pkey(pkey)
Expand Down
74 changes: 59 additions & 15 deletions tests/hazmat/primitives/test_dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,24 @@ def test_invalid_parameters_values(self, p, q, g, backend):
DSA_KEY_1024.public_numbers.y,
2**200,
),
(
DSA_KEY_1024.public_numbers.parameter_numbers.p,
DSA_KEY_1024.public_numbers.parameter_numbers.q,
DSA_KEY_1024.public_numbers.parameter_numbers.g,
-1,
DSA_KEY_1024.x,
),
(
DSA_KEY_1024.public_numbers.parameter_numbers.p,
DSA_KEY_1024.public_numbers.parameter_numbers.q,
DSA_KEY_1024.public_numbers.parameter_numbers.g,
pow(
DSA_KEY_1024.public_numbers.parameter_numbers.g,
DSA_KEY_1024.x + 1,
DSA_KEY_1024.public_numbers.parameter_numbers.p,
),
DSA_KEY_1024.x,
),
(
DSA_KEY_1024.public_numbers.parameter_numbers.p,
DSA_KEY_1024.public_numbers.parameter_numbers.q,
Expand Down Expand Up @@ -352,6 +370,30 @@ def test_invalid_dsa_private_key_arguments(self, p, q, g, y, x, backend):
2**1200,
DSA_KEY_1024.public_numbers.y,
),
(
DSA_KEY_1024.public_numbers.parameter_numbers.p,
DSA_KEY_1024.public_numbers.parameter_numbers.q,
DSA_KEY_1024.public_numbers.parameter_numbers.g,
-1,
),
(
DSA_KEY_1024.public_numbers.parameter_numbers.p,
DSA_KEY_1024.public_numbers.parameter_numbers.q,
DSA_KEY_1024.public_numbers.parameter_numbers.g,
1,
),
(
DSA_KEY_1024.public_numbers.parameter_numbers.p,
DSA_KEY_1024.public_numbers.parameter_numbers.q,
DSA_KEY_1024.public_numbers.parameter_numbers.g,
DSA_KEY_1024.public_numbers.parameter_numbers.p + 1,
),
(
DSA_KEY_1024.public_numbers.parameter_numbers.p,
DSA_KEY_1024.public_numbers.parameter_numbers.q,
DSA_KEY_1024.public_numbers.parameter_numbers.g,
DSA_KEY_1024.public_numbers.parameter_numbers.p - 1,
),
],
)
def test_invalid_dsa_public_key_arguments(self, p, q, g, y, backend):
Expand Down Expand Up @@ -478,12 +520,17 @@ def test_dsa_verification(self, backend, subtests):
backend, algorithm, vector["p"], vector["q"], vector["g"]
)

public_key = dsa.DSAPublicNumbers(
parameter_numbers=dsa.DSAParameterNumbers(
vector["p"], vector["q"], vector["g"]
),
y=vector["y"],
).public_key(backend)
try:
public_key = dsa.DSAPublicNumbers(
parameter_numbers=dsa.DSAParameterNumbers(
vector["p"], vector["q"], vector["g"]
),
y=vector["y"],
).public_key(backend)
except ValueError:
assert vector["result"] == "F"
continue

sig = encode_dss_signature(vector["r"], vector["s"])

if vector["result"] == "F":
Expand Down Expand Up @@ -610,16 +657,13 @@ def test_prehashed_digest_mismatch(self, backend):
skip_message="Requires OpenSSL 3.0.9+, LibreSSL, BoringSSL, or AWS-LC",
)
def test_nilpotent(self):
key = load_vectors_from_file(
os.path.join("asymmetric", "DSA", "custom", "nilpotent.pem"),
lambda pemfile: serialization.load_pem_private_key(
pemfile.read().encode(), password=None
),
)
assert isinstance(key, dsa.DSAPrivateKey)

with pytest.raises(ValueError):
key.sign(b"anything", hashes.SHA256())
load_vectors_from_file(
os.path.join("asymmetric", "DSA", "custom", "nilpotent.pem"),
lambda pemfile: serialization.load_pem_private_key(
pemfile.read().encode(), password=None
),
)


class TestDSANumbers:
Expand Down
Loading