File: //proc/self/root/lib/python3/dist-packages/josepy/jwk.py
"""JSON Web Key."""
import abc
import json
import logging
import math
from typing import (
    Any,
    Callable,
    Dict,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
)
import cryptography.exceptions
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, rsa
import josepy.util
from josepy import errors, json_util, util
logger = logging.getLogger(__name__)
class JWK(json_util.TypedJSONObjectWithFields, metaclass=abc.ABCMeta):
    """JSON Web Key."""
    type_field_name = "kty"
    TYPES: Dict[str, Type["JWK"]] = {}
    cryptography_key_types: Tuple[Type[Any], ...] = ()
    """Subclasses should override."""
    required: Sequence[str] = NotImplemented
    """Required members of public key's representation as defined by JWK/JWA."""
    _thumbprint_json_dumps_params: Dict[str, Union[Optional[int], Sequence[str], bool]] = {
        # "no whitespace or line breaks before or after any syntactic
        # elements"
        "indent": None,
        "separators": (",", ":"),
        # "members ordered lexicographically by the Unicode [UNICODE]
        # code points of the member names"
        "sort_keys": True,
    }
    key: Any
    def thumbprint(
        self, hash_function: Callable[[], hashes.HashAlgorithm] = hashes.SHA256
    ) -> bytes:
        """Compute JWK Thumbprint.
        https://tools.ietf.org/html/rfc7638
        :returns: bytes
        """
        digest = hashes.Hash(hash_function(), backend=default_backend())
        digest.update(
            json.dumps(
                {k: v for k, v in self.to_json().items() if k in self.required},
                **self._thumbprint_json_dumps_params,  # type: ignore[arg-type]
            ).encode()
        )
        return digest.finalize()
    @abc.abstractmethod
    def public_key(self) -> "JWK":  # pragma: no cover
        """Generate JWK with public key.
        For symmetric cryptosystems, this would return ``self``.
        """
        raise NotImplementedError()
    @classmethod
    def _load_cryptography_key(
        cls, data: bytes, password: Optional[bytes] = None, backend: Optional[Any] = None
    ) -> Any:
        backend = default_backend() if backend is None else backend
        exceptions = {}
        # private key?
        loader_private: Any
        for loader_private in (
            serialization.load_pem_private_key,
            serialization.load_der_private_key,
        ):
            try:
                return loader_private(data, password, backend)
            except (ValueError, TypeError, cryptography.exceptions.UnsupportedAlgorithm) as error:
                exceptions[str(loader_private)] = error
        # public key?
        loader_public: Any
        for loader_public in (serialization.load_pem_public_key, serialization.load_der_public_key):
            try:
                return loader_public(data, backend)
            except (ValueError, cryptography.exceptions.UnsupportedAlgorithm) as error:
                exceptions[str(loader_public)] = error
        # no luck
        raise errors.Error("Unable to deserialize key: {0}".format(exceptions))
    @classmethod
    def load(
        cls, data: bytes, password: Optional[bytes] = None, backend: Optional[Any] = None
    ) -> "JWK":
        """Load serialized key as JWK.
        :param str data: Public or private key serialized as PEM or DER.
        :param str password: Optional password.
        :param backend: A `.PEMSerializationBackend` and
            `.DERSerializationBackend` provider.
        :raises errors.Error: if unable to deserialize, or unsupported
            JWK algorithm
        :returns: JWK of an appropriate type.
        :rtype: `JWK`
        """
        try:
            key = cls._load_cryptography_key(data, password, backend)
        except errors.Error as error:
            logger.debug("Loading symmetric key, asymmetric failed: %s", error)
            return JWKOct(key=data)
        if cls.typ is not NotImplemented and not isinstance(key, cls.cryptography_key_types):
            raise errors.Error(
                "Unable to deserialize {0} into {1}".format(key.__class__, cls.__class__)
            )
        for jwk_cls in cls.TYPES.values():
            if isinstance(key, jwk_cls.cryptography_key_types):
                return jwk_cls(key=key)
        raise errors.Error("Unsupported algorithm: {0}".format(key.__class__))
@JWK.register
class JWKOct(JWK):
    """Symmetric JWK."""
    typ = "oct"
    __slots__ = ("key",)
    required = ("k", JWK.type_field_name)
    key: bytes
    def fields_to_partial_json(self) -> Dict[str, str]:
        # TODO: An "alg" member SHOULD also be present to identify the
        # algorithm intended to be used with the key, unless the
        # application uses another means or convention to determine
        # the algorithm used.
        return {"k": json_util.encode_b64jose(self.key)}
    @classmethod
    def fields_from_json(cls, jobj: Mapping[str, Any]) -> "JWKOct":
        return cls(key=json_util.decode_b64jose(jobj["k"]))
    def public_key(self) -> "JWKOct":
        return self
@JWK.register
class JWKRSA(JWK):
    """RSA JWK.
    :ivar key: :class:`~cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey`
        or :class:`~cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey` wrapped
        in :class:`~josepy.util.ComparableRSAKey`
    """
    typ = "RSA"
    cryptography_key_types = (rsa.RSAPublicKey, rsa.RSAPrivateKey)
    __slots__ = ("key",)
    required = ("e", JWK.type_field_name, "n")
    key: josepy.util.ComparableRSAKey
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        if "key" in kwargs and not isinstance(kwargs["key"], util.ComparableRSAKey):
            kwargs["key"] = util.ComparableRSAKey(kwargs["key"])
        super().__init__(*args, **kwargs)
    @classmethod
    def _encode_param(cls, data: int) -> str:
        """Encode Base64urlUInt.
        :type data: long
        :rtype: unicode
        """
        length = max(data.bit_length(), 8)  # decoding 0
        length = math.ceil(length / 8)
        return json_util.encode_b64jose(data.to_bytes(byteorder="big", length=length))
    @classmethod
    def _decode_param(cls, data: str) -> int:
        """Decode Base64urlUInt."""
        try:
            binary = json_util.decode_b64jose(data)
            if not binary:
                raise errors.DeserializationError()
            return int.from_bytes(binary, byteorder="big")
        except ValueError:  # invalid literal for long() with base 16
            raise errors.DeserializationError()
    def public_key(self) -> "JWKRSA":
        return type(self)(key=self.key.public_key())
    @classmethod
    def fields_from_json(cls, jobj: Mapping[str, Any]) -> "JWKRSA":
        n, e = (cls._decode_param(jobj[x]) for x in ("n", "e"))
        public_numbers = rsa.RSAPublicNumbers(e=e, n=n)
        # public key
        if "d" not in jobj:
            return cls(key=public_numbers.public_key(default_backend()))
        # private key
        d = cls._decode_param(jobj["d"])
        if (
            "p" in jobj
            or "q" in jobj
            or "dp" in jobj
            or "dq" in jobj
            or "qi" in jobj
            or "oth" in jobj
        ):
            # "If the producer includes any of the other private
            # key parameters, then all of the others MUST be
            # present, with the exception of "oth", which MUST
            # only be present when more than two prime factors
            # were used."
            (
                p,
                q,
                dp,
                dq,
                qi,
            ) = all_params = tuple(jobj.get(x) for x in ("p", "q", "dp", "dq", "qi"))
            if tuple(param for param in all_params if param is None):
                raise errors.Error("Some private parameters are missing: {0}".format(all_params))
            p, q, dp, dq, qi = tuple(cls._decode_param(str(x)) for x in all_params)
            # TODO: check for oth
        else:
            # cryptography>=0.8
            p, q = rsa.rsa_recover_prime_factors(n, e, d)
            dp = rsa.rsa_crt_dmp1(d, p)
            dq = rsa.rsa_crt_dmq1(d, q)
            qi = rsa.rsa_crt_iqmp(p, q)
        key = rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, public_numbers).private_key(
            default_backend()
        )
        return cls(key=key)
    def fields_to_partial_json(self) -> Dict[str, Any]:
        if isinstance(self.key._wrapped, rsa.RSAPublicKey):
            numbers = self.key.public_numbers()
            params = {
                "n": numbers.n,
                "e": numbers.e,
            }
        else:  # rsa.RSAPrivateKey
            private = self.key.private_numbers()
            public = self.key.public_key().public_numbers()
            params = {
                "n": public.n,
                "e": public.e,
                "d": private.d,
                "p": private.p,
                "q": private.q,
                "dp": private.dmp1,
                "dq": private.dmq1,
                "qi": private.iqmp,
            }
        return {key: self._encode_param(value) for key, value in params.items()}
@JWK.register
class JWKEC(JWK):
    """EC JWK.
    :ivar key: :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey`
        or :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey` wrapped
        in :class:`~josepy.util.ComparableECKey`
    """
    typ = "EC"
    __slots__ = ("key",)
    cryptography_key_types = (ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey)
    required = ("crv", JWK.type_field_name, "x", "y")
    key: josepy.util.ComparableECKey
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        if "key" in kwargs and not isinstance(kwargs["key"], util.ComparableECKey):
            kwargs["key"] = util.ComparableECKey(kwargs["key"])
        super().__init__(*args, **kwargs)
    @classmethod
    def _encode_param(cls, data: int, length: int) -> str:
        """Encode Base64urlUInt.
        :type data: long
        :type key_size: long
        :rtype: unicode
        """
        return json_util.encode_b64jose(data.to_bytes(byteorder="big", length=length))
    @classmethod
    def _decode_param(cls, data: str, name: str, valid_length: int) -> int:
        """Decode Base64urlUInt."""
        try:
            binary = json_util.decode_b64jose(data)
            if len(binary) != valid_length:
                raise errors.DeserializationError(
                    f'Expected parameter "{name}" to be {valid_length} bytes '
                    f"after base64-decoding; got {len(binary)} bytes instead"
                )
            return int.from_bytes(binary, byteorder="big")
        except ValueError:  # invalid literal for long() with base 16
            raise errors.DeserializationError()
    @classmethod
    def _curve_name_to_crv(cls, curve_name: str) -> str:
        if curve_name == "secp256r1":
            return "P-256"
        if curve_name == "secp384r1":
            return "P-384"
        if curve_name == "secp521r1":
            return "P-521"
        raise errors.SerializationError()
    @classmethod
    def _crv_to_curve(cls, crv: str) -> ec.EllipticCurve:
        # crv is case-sensitive
        if crv == "P-256":
            return ec.SECP256R1()
        if crv == "P-384":
            return ec.SECP384R1()
        if crv == "P-521":
            return ec.SECP521R1()
        raise errors.DeserializationError()
    @classmethod
    def expected_length_for_curve(cls, curve: ec.EllipticCurve) -> int:
        if isinstance(curve, ec.SECP256R1):
            return 32
        elif isinstance(curve, ec.SECP384R1):
            return 48
        elif isinstance(curve, ec.SECP521R1):
            return 66
        raise ValueError(f"Unexpected curve: {curve}")
    def fields_to_partial_json(self) -> Dict[str, Any]:
        params = {}
        if isinstance(self.key._wrapped, ec.EllipticCurvePublicKey):
            public = self.key.public_numbers()
        elif isinstance(self.key._wrapped, ec.EllipticCurvePrivateKey):
            private = self.key.private_numbers()
            public = self.key.public_key().public_numbers()
            params["d"] = private.private_value
        else:
            raise errors.SerializationError(
                "Supplied key is neither of type EllipticCurvePublicKey "
                "nor EllipticCurvePrivateKey"
            )
        params["x"] = public.x
        params["y"] = public.y
        params = {
            key: self._encode_param(value, self.expected_length_for_curve(public.curve))
            for key, value in params.items()
        }
        params["crv"] = self._curve_name_to_crv(public.curve.name)
        return params
    @classmethod
    def fields_from_json(cls, jobj: Mapping[str, Any]) -> "JWKEC":
        curve = cls._crv_to_curve(jobj["crv"])
        expected_length = cls.expected_length_for_curve(curve)
        x, y = (cls._decode_param(jobj[n], n, expected_length) for n in ("x", "y"))
        public_numbers = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve)
        # private key
        if "d" not in jobj:
            return cls(key=public_numbers.public_key(default_backend()))
        # private key
        d = cls._decode_param(jobj["d"], "d", expected_length)
        key = ec.EllipticCurvePrivateNumbers(d, public_numbers).private_key(default_backend())
        return cls(key=key)
    def public_key(self) -> "JWKEC":
        # Unlike RSAPrivateKey, EllipticCurvePrivateKey does not contain public_key()
        if hasattr(self.key, "public_key"):
            key = self.key.public_key()
        else:
            key = self.key.public_numbers().public_key(default_backend())
        return type(self)(key=key)