Skip to content

Commit

Permalink
Rewrite the pyOpenSSL implementation in terms of the cryptography one
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Oct 25, 2024
1 parent 134004e commit 84250de
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 102 deletions.
89 changes: 3 additions & 86 deletions src/service_identity/pyopenssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,17 @@
from __future__ import annotations

import contextlib
import warnings

from typing import Sequence

from pyasn1.codec.der.decoder import decode
from pyasn1.type.char import IA5String
from pyasn1.type.univ import ObjectIdentifier
from pyasn1_modules.rfc2459 import GeneralNames

from .exceptions import CertificateError
from .cryptography import extract_patterns
from .hazmat import (
DNS_ID,
CertificatePattern,
DNSPattern,
IPAddress_ID,
IPAddressPattern,
SRVPattern,
URIPattern,
verify_service_identity,
)


with contextlib.suppress(ImportError):
# We only use it for docstrings -- `if TYPE_CHECKING`` does not work.
from OpenSSL.crypto import X509
from OpenSSL.SSL import Connection


Expand Down Expand Up @@ -62,7 +48,7 @@ def verify_hostname(connection: Connection, hostname: str) -> None:
"""
verify_service_identity(
cert_patterns=extract_patterns(
connection.get_peer_certificate() # type:ignore[arg-type]
connection.get_peer_certificate().to_cryptography() # type:ignore[union-attr]
),
obligatory_ids=[DNS_ID(hostname)],
optional_ids=[],
Expand Down Expand Up @@ -98,77 +84,8 @@ def verify_ip_address(connection: Connection, ip_address: str) -> None:
"""
verify_service_identity(
cert_patterns=extract_patterns(
connection.get_peer_certificate() # type:ignore[arg-type]
connection.get_peer_certificate().to_cryptography() # type:ignore[union-attr]
),
obligatory_ids=[IPAddress_ID(ip_address)],
optional_ids=[],
)


ID_ON_DNS_SRV = ObjectIdentifier("1.3.6.1.5.5.7.8.7") # id_on_dnsSRV


def extract_patterns(cert: X509) -> Sequence[CertificatePattern]:
"""
Extract all valid ID patterns from a certificate for service verification.
Args:
cert: The certificate to be dissected.
Returns:
List of IDs.
.. versionchanged:: 23.1.0
``commonName`` is not used as a fallback anymore.
"""
ids: list[CertificatePattern] = []
for i in range(cert.get_extension_count()):
ext = cert.get_extension(i)
if ext.get_short_name() == b"subjectAltName":
names, _ = decode(ext.get_data(), asn1Spec=GeneralNames())
for n in names:
name_string = n.getName()
if name_string == "dNSName":
ids.append(
DNSPattern.from_bytes(n.getComponent().asOctets())
)
elif name_string == "iPAddress":
ids.append(
IPAddressPattern.from_bytes(
n.getComponent().asOctets()
)
)
elif name_string == "uniformResourceIdentifier":
ids.append(
URIPattern.from_bytes(n.getComponent().asOctets())
)
elif name_string == "otherName":
comp = n.getComponent()
oid = comp.getComponentByPosition(0)
if oid == ID_ON_DNS_SRV:
srv, _ = decode(comp.getComponentByPosition(1))
if isinstance(srv, IA5String):
ids.append(SRVPattern.from_bytes(srv.asOctets()))
else: # pragma: no cover
msg = "Unexpected certificate content."
raise CertificateError(msg)
else: # pragma: no cover
pass
else: # pragma: no cover
pass

return ids


def extract_ids(cert: X509) -> Sequence[CertificatePattern]:
"""
Deprecated and never public API. Use :func:`extract_patterns` instead.
.. deprecated:: 23.1.0
"""
warnings.warn(
category=DeprecationWarning,
message="`extract_ids()` is deprecated, please use `extract_patterns()`.",
stacklevel=2,
)
return extract_patterns(cert)
16 changes: 0 additions & 16 deletions tests/test_pyopenssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
URIPattern,
)
from service_identity.pyopenssl import (
extract_ids,
extract_patterns,
verify_hostname,
verify_ip_address,
Expand Down Expand Up @@ -144,18 +143,3 @@ def test_ip(self):
IPAddressPattern(pattern=ipaddress.IPv4Address("2.2.2.2")),
IPAddressPattern(pattern=ipaddress.IPv6Address("2a00:1c38::53")),
] == rv

def test_extract_ids_deprecated(self):
"""
`extract_ids` raises a DeprecationWarning with correct stacklevel.
"""
with pytest.deprecated_call() as wr:
extract_ids(CERT_EVERYTHING)

w = wr.pop()

assert (
"`extract_ids()` is deprecated, please use `extract_patterns()`."
== w.message.args[0]
)
assert __file__ == w.filename

0 comments on commit 84250de

Please sign in to comment.