Skip to content

fix(event_source): fix decode headers with signed bytes #6878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions aws_lambda_powertools/shared/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,19 @@ def sanitize_xray_segment_name(name: str) -> str:
def get_tracer_id() -> str | None:
xray_trace_id = os.getenv(constants.XRAY_TRACE_ID_ENV)
return xray_trace_id.split(";")[0].replace("Root=", "") if xray_trace_id else None


def decode_header_bytes(byte_list):
"""
Decode a list of byte values that might be signed.
If any negative values exist, handle them as signed bytes.
Otherwise use the normal bytes construction.
"""
has_negative = any(b < 0 for b in byte_list)

if not has_negative:
# Use normal bytes construction if all values are positive
return bytes(byte_list)
# Convert signed bytes to unsigned (0-255 range)
unsigned_bytes = [(b & 0xFF) for b in byte_list]
return bytes(unsigned_bytes)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.shared.functions import decode_header_bytes
from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict, DictWrapper

if TYPE_CHECKING:
Expand Down Expand Up @@ -110,7 +111,7 @@ def headers(self) -> list[dict[str, list[int]]]:
@cached_property
def decoded_headers(self) -> dict[str, bytes]:
"""Decodes the headers as a single dictionary."""
return CaseInsensitiveDict((k, bytes(v)) for chunk in self.headers for k, v in chunk.items())
return CaseInsensitiveDict((k, decode_header_bytes(v)) for chunk in self.headers for k, v in chunk.items())


class KafkaEventBase(DictWrapper):
Expand Down
5 changes: 4 additions & 1 deletion aws_lambda_powertools/utilities/kafka/consumer_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.shared.functions import decode_header_bytes
from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict
from aws_lambda_powertools.utilities.data_classes.kafka_event import KafkaEventBase, KafkaEventRecordBase
from aws_lambda_powertools.utilities.kafka.deserializer.deserializer import get_deserializer
Expand Down Expand Up @@ -115,7 +116,9 @@ def original_headers(self) -> list[dict[str, list[int]]]:
@cached_property
def headers(self) -> dict[str, bytes]:
"""Decodes the headers as a single dictionary."""
return CaseInsensitiveDict((k, bytes(v)) for chunk in self.original_headers for k, v in chunk.items())
return CaseInsensitiveDict(
(k, decode_header_bytes(v)) for chunk in self.original_headers for k, v in chunk.items()
)


class ConsumerRecords(KafkaEventBase):
Expand Down
10 changes: 4 additions & 6 deletions aws_lambda_powertools/utilities/parser/models/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel, field_validator

from aws_lambda_powertools.shared.functions import base64_decode, bytes_to_string
from aws_lambda_powertools.shared.functions import base64_decode, bytes_to_string, decode_header_bytes

SERVERS_DELIMITER = ","

Expand All @@ -28,9 +28,7 @@ class KafkaRecordModel(BaseModel):
# key is optional; only decode if not None
@field_validator("key", mode="before")
def decode_key(cls, value):
if value is not None:
return base64_decode(value)
return value
return base64_decode(value) if value is not None else value

@field_validator("value", mode="before")
def data_base64_decode(cls, value):
Expand All @@ -41,7 +39,7 @@ def data_base64_decode(cls, value):
def decode_headers_list(cls, value):
for header in value:
for key, values in header.items():
header[key] = bytes(values)
header[key] = decode_header_bytes(values)
return value


Expand All @@ -51,7 +49,7 @@ class KafkaBaseEventModel(BaseModel):

@field_validator("bootstrapServers", mode="before")
def split_servers(cls, value):
return None if not value else value.split(SERVERS_DELIMITER)
return value.split(SERVERS_DELIMITER) if value else None


class KafkaSelfManagedEventModel(KafkaBaseEventModel):
Expand Down
22 changes: 22 additions & 0 deletions tests/events/kafkaEventMsk.json
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,28 @@
"dataFormat": "AVRO",
"schemaId": "1234"
}
},
{
"topic":"mymessage-with-unsigned",
"partition":0,
"offset":15,
"timestamp":1545084650987,
"timestampType":"CREATE_TIME",
"key": null,
"value":"eyJrZXkiOiJ2YWx1ZSJ9",
"headers":[
{
"headerKey":[104, 101, 108, 108, 111, 45, 119, 111, 114, 108, 100, 45, -61, -85]
}
],
"valueSchemaMetadata": {
"dataFormat": "AVRO",
"schemaId": "1234"
},
"keySchemaMetadata": {
"dataFormat": "AVRO",
"schemaId": "1234"
}
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_kafka_msk_event():
assert parsed_event.decoded_bootstrap_servers == bootstrap_servers_list

records = list(parsed_event.records)
assert len(records) == 3
assert len(records) == 4
record = records[0]
raw_record = raw_event["records"]["mytopic-0"][0]
assert record.topic == raw_record["topic"]
Expand All @@ -40,9 +40,10 @@ def test_kafka_msk_event():
assert record.value_schema_metadata.schema_id == raw_record["valueSchemaMetadata"]["schemaId"]

assert parsed_event.record == records[0]
for i in range(1, 3):
for i in range(1, 4):
record = records[i]
assert record.key is None
assert record.decoded_headers is not None


def test_kafka_self_managed_event():
Expand Down Expand Up @@ -90,5 +91,5 @@ def test_kafka_record_property_with_stopiteration_error():
# WHEN calling record property thrice
# THEN raise StopIteration
with pytest.raises(StopIteration):
for _ in range(4):
for _ in range(5):
assert parsed_event.record.topic is not None
6 changes: 3 additions & 3 deletions tests/unit/parser/_pydantic/test_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_kafka_msk_event_with_envelope():
)
for i in range(3):
assert parsed_event[i].key == "value"
assert len(parsed_event) == 3
assert len(parsed_event) == 4


def test_kafka_self_managed_event_with_envelope():
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_kafka_msk_event():
assert parsed_event.eventSourceArn == raw_event["eventSourceArn"]

records = list(parsed_event.records["mytopic-0"])
assert len(records) == 3
assert len(records) == 4
record: KafkaRecordModel = records[0]
raw_record = raw_event["records"]["mytopic-0"][0]
assert record.topic == raw_record["topic"]
Expand All @@ -88,6 +88,6 @@ def test_kafka_msk_event():
assert record.keySchemaMetadata.schemaId == "1234"
assert record.valueSchemaMetadata.dataFormat == "AVRO"
assert record.valueSchemaMetadata.schemaId == "1234"
for i in range(1, 3):
for i in range(1, 4):
record: KafkaRecordModel = records[i]
assert record.key is None
Loading