Skip to content

fix: update header validation raised errors #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Unreleased
- Use ``ECKey.binding.register_curve`` to register new supported curves.
- Use ``UnsupportedAlgorithmError`` instead of ``ValueError`` in JWS/JWE registry.
- Use ``MissingKeyTypeError`` and ``InvalidKeyIdError`` for errors in JWK.
- Use ``UnsupportedHeaderError``, ``MissingHeaderError``, and ``MissingCritHeaderError`` for header validation.
- Respect RFC6749 character set in error descriptions.

1.0.4
Expand Down
28 changes: 28 additions & 0 deletions src/joserfc/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,30 @@ class UnsupportedAlgorithmError(JoseError):
error = "unsupported_algorithm"


class UnsupportedHeaderError(JoseError):
error = "unsupported_header"


class MissingHeaderError(JoseError):
"""This error happens when the required header does not exist."""

error = "missing_header"

def __init__(self, key: str):
description = f"Missing '{key}' value in header"
super(MissingHeaderError, self).__init__(description=description)


class MissingCritHeaderError(JoseError):
"""This error happens when the critical header does not exist."""

error = "missing_crit_header"

def __init__(self, key: str):
description = f"Missing critical '{key}' value in header"
super(MissingCritHeaderError, self).__init__(description=description)


class MissingEncryptionError(JoseError):
error = "missing_encryption"
description = "Missing 'enc' value in header"
Expand Down Expand Up @@ -104,6 +128,10 @@ class InvalidCEKLengthError(JoseError):
error = "invalid_cek_length"
description = "Invalid 'cek' length"

def __init__(self, cek_size: int):
description = f"A key of size {cek_size} bits MUST be used"
super(InvalidCEKLengthError, self).__init__(description=description)


class InvalidClaimError(JoseError):
error = "invalid_claim"
Expand Down
11 changes: 8 additions & 3 deletions src/joserfc/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from __future__ import annotations
from typing import Any, Dict, Callable, Union
from .errors import (
MissingHeaderError,
MissingCritHeaderError,
UnsupportedHeaderError,
)

Header = Dict[str, Any]

Expand Down Expand Up @@ -174,13 +179,13 @@ def check_supported_header(registry: HeaderRegistryDict, header: Header) -> None
allowed_keys = set(registry.keys())
unsupported_keys = set(header.keys()) - allowed_keys
if unsupported_keys:
raise ValueError(f"Unsupported {unsupported_keys} in header")
raise UnsupportedHeaderError(f"Unsupported {unsupported_keys} in header")


def validate_registry_header(registry: HeaderRegistryDict, header: Header, check_required: bool = True) -> None:
for key, reg in registry.items():
if check_required and reg.required and key not in header:
raise ValueError(f"Required '{key}' is missing in header")
raise MissingHeaderError(key)
if key in header:
try:
reg.validate(header[key])
Expand All @@ -193,4 +198,4 @@ def check_crit_header(header: Header) -> None:
if "crit" in header:
for k in header["crit"]:
if k not in header:
raise ValueError(f"'{k}' is a critical header")
raise MissingCritHeaderError(k)
4 changes: 2 additions & 2 deletions src/joserfc/rfc7516/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _perform_decrypt(obj: EncryptionData, registry: JWERegistry) -> None:

cek = cek_set.pop()
if len(cek) * 8 != enc.cek_size: # pragma: no cover
raise InvalidCEKLengthError(f"A key of size {enc.cek_size} bits MUST be used")
raise InvalidCEKLengthError(enc.cek_size)

aad = json_b64encode(obj.protected)
if isinstance(obj, BaseJSONEncryption) and obj.aad:
Expand Down Expand Up @@ -181,7 +181,7 @@ def __pre_encrypt_direct_mode(alg: JWEAlgModel, enc: JWEEncModel, recipient: Rec
# let the CEK be the agreed upon key.
cek = alg.encrypt_agreed_upon_key(enc, recipient)
if len(cek) * 8 != enc.cek_size: # pragma: no cover
raise InvalidCEKLengthError(f"A key of size {enc.cek_size} bits MUST be used")
raise InvalidCEKLengthError(enc.cek_size)
else:
# 6. When Direct Encryption is employed, let the CEK be the shared
# symmetric key.
Expand Down
5 changes: 3 additions & 2 deletions tests/jwe/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
InvalidKeyLengthError,
DecodeError,
UnsupportedAlgorithmError,
UnsupportedHeaderError,
)
from tests.base import load_key

Expand Down Expand Up @@ -88,7 +89,7 @@ def test_invalid_key_length(self):
def test_extra_header(self):
key = OctKey.generate_key(256)
protected = {"alg": "dir", "enc": "A128CBC-HS256", "custom": "hi"}
self.assertRaises(ValueError, jwe.encrypt_compact, protected, b"i", key)
self.assertRaises(UnsupportedHeaderError, jwe.encrypt_compact, protected, b"i", key)

registry = jwe.JWERegistry(strict_check_header=False)
jwe.encrypt_compact(protected, b"i", key, registry=registry)
Expand All @@ -99,6 +100,6 @@ def test_extra_header(self):
def test_strict_check_header_with_more_header_registry(self):
key = load_key("ec-p256-private.pem")
protected = {"alg": "ECDH-ES", "enc": "A128CBC-HS256", "custom": "hi"}
self.assertRaises(ValueError, jwe.encrypt_compact, protected, b"i", key)
self.assertRaises(UnsupportedHeaderError, jwe.encrypt_compact, protected, b"i", key)
registry = jwe.JWERegistry(strict_check_header=False)
jwe.encrypt_compact(protected, b"i", key, registry=registry)
3 changes: 2 additions & 1 deletion tests/jws/test_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DecodeError,
MissingAlgorithmError,
UnsupportedAlgorithmError,
UnsupportedHeaderError,
)


Expand Down Expand Up @@ -68,7 +69,7 @@ def test_with_key_set(self):
def test_strict_check_header(self):
header = {"alg": "HS256", "custom": "hi"}
key = OctKey.import_key("secret")
self.assertRaises(ValueError, serialize_compact, header, b"hi", key)
self.assertRaises(UnsupportedHeaderError, serialize_compact, header, b"hi", key)

registry = JWSRegistry(strict_check_header=False)
serialize_compact(header, b"hi", key, registry=registry)
9 changes: 6 additions & 3 deletions tests/jws/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
UnsupportedKeyAlgorithmError,
UnsupportedKeyOperationError,
InvalidKeyTypeError,
MissingHeaderError,
MissingCritHeaderError,
UnsupportedHeaderError,
)
from joserfc.util import urlsafe_b64encode
from tests.base import load_key
Expand All @@ -17,7 +20,7 @@ class TestJWSErrors(TestCase):
def test_without_alg(self):
key = OctKey.import_key("secret")
# missing alg
self.assertRaises(ValueError, jws.serialize_compact, {"kid": "123"}, "i", key)
self.assertRaises(MissingHeaderError, jws.serialize_compact, {"kid": "123"}, "i", key)

def test_none_alg(self):
header = {"alg": "none"}
Expand Down Expand Up @@ -80,7 +83,7 @@ def test_crit_header(self):
key = OctKey.import_key("secret")
# missing kid header
self.assertRaises(
ValueError,
MissingCritHeaderError,
jws.serialize_compact,
header,
"i",
Expand All @@ -94,7 +97,7 @@ def test_extra_header(self):
header = {"alg": "HS256", "extra": "hi"}
key = OctKey.import_key("secret")
self.assertRaises(
ValueError,
UnsupportedHeaderError,
jws.serialize_compact,
header,
"i",
Expand Down
3 changes: 2 additions & 1 deletion tests/jwt/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from joserfc.errors import (
InvalidPayloadError,
MissingClaimError,
UnsupportedHeaderError,
)


Expand Down Expand Up @@ -52,7 +53,7 @@ def test_using_registry(self):
registry=jwe.JWERegistry(),
)
self.assertRaises(
ValueError,
UnsupportedHeaderError,
jwt.encode,
{"alg": "A128KW", "enc": "A128GCM"},
{"sub": "a"},
Expand Down