Skip to content

Commit 1ef3f01

Browse files
protobuf-github-botshaod2
authored andcommitted
Internal pure python fixes
PiperOrigin-RevId: 733441339
1 parent 69cca9b commit 1ef3f01

File tree

4 files changed

+86
-29
lines changed

4 files changed

+86
-29
lines changed

python/google/protobuf/internal/decoder.py

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,10 @@ def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
184184
clear_if_default=False):
185185
if is_packed:
186186
local_DecodeVarint = _DecodeVarint
187-
def DecodePackedField(buffer, pos, end, message, field_dict):
187+
def DecodePackedField(
188+
buffer, pos, end, message, field_dict, current_depth=0
189+
):
190+
del current_depth # unused
188191
value = field_dict.get(key)
189192
if value is None:
190193
value = field_dict.setdefault(key, new_default(message))
@@ -199,11 +202,15 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
199202
del value[-1] # Discard corrupt value.
200203
raise _DecodeError('Packed element was truncated.')
201204
return pos
205+
202206
return DecodePackedField
203207
elif is_repeated:
204208
tag_bytes = encoder.TagBytes(field_number, wire_type)
205209
tag_len = len(tag_bytes)
206-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
210+
def DecodeRepeatedField(
211+
buffer, pos, end, message, field_dict, current_depth=0
212+
):
213+
del current_depth # unused
207214
value = field_dict.get(key)
208215
if value is None:
209216
value = field_dict.setdefault(key, new_default(message))
@@ -218,9 +225,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
218225
if new_pos > end:
219226
raise _DecodeError('Truncated message.')
220227
return new_pos
228+
221229
return DecodeRepeatedField
222230
else:
223-
def DecodeField(buffer, pos, end, message, field_dict):
231+
232+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
233+
del current_depth # unused
224234
(new_value, pos) = decode_value(buffer, pos)
225235
if pos > end:
226236
raise _DecodeError('Truncated message.')
@@ -229,6 +239,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
229239
else:
230240
field_dict[key] = new_value
231241
return pos
242+
232243
return DecodeField
233244

234245
return SpecificDecoder
@@ -364,7 +375,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
364375
enum_type = key.enum_type
365376
if is_packed:
366377
local_DecodeVarint = _DecodeVarint
367-
def DecodePackedField(buffer, pos, end, message, field_dict):
378+
def DecodePackedField(
379+
buffer, pos, end, message, field_dict, current_depth=0
380+
):
368381
"""Decode serialized packed enum to its value and a new position.
369382
370383
Args:
@@ -377,6 +390,7 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
377390
Returns:
378391
int, new position in serialized data.
379392
"""
393+
del current_depth # unused
380394
value = field_dict.get(key)
381395
if value is None:
382396
value = field_dict.setdefault(key, new_default(message))
@@ -407,11 +421,14 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
407421
# pylint: enable=protected-access
408422
raise _DecodeError('Packed element was truncated.')
409423
return pos
424+
410425
return DecodePackedField
411426
elif is_repeated:
412427
tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
413428
tag_len = len(tag_bytes)
414-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
429+
def DecodeRepeatedField(
430+
buffer, pos, end, message, field_dict, current_depth=0
431+
):
415432
"""Decode serialized repeated enum to its value and a new position.
416433
417434
Args:
@@ -424,6 +441,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
424441
Returns:
425442
int, new position in serialized data.
426443
"""
444+
del current_depth # unused
427445
value = field_dict.get(key)
428446
if value is None:
429447
value = field_dict.setdefault(key, new_default(message))
@@ -446,9 +464,11 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
446464
if new_pos > end:
447465
raise _DecodeError('Truncated message.')
448466
return new_pos
467+
449468
return DecodeRepeatedField
450469
else:
451-
def DecodeField(buffer, pos, end, message, field_dict):
470+
471+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
452472
"""Decode serialized repeated enum to its value and a new position.
453473
454474
Args:
@@ -461,6 +481,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
461481
Returns:
462482
int, new position in serialized data.
463483
"""
484+
del current_depth # unused
464485
value_start_pos = pos
465486
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
466487
if pos > end:
@@ -480,6 +501,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
480501
(tag_bytes, buffer[value_start_pos:pos].tobytes()))
481502
# pylint: enable=protected-access
482503
return pos
504+
483505
return DecodeField
484506

485507

@@ -538,7 +560,10 @@ def _ConvertToUnicode(memview):
538560
tag_bytes = encoder.TagBytes(field_number,
539561
wire_format.WIRETYPE_LENGTH_DELIMITED)
540562
tag_len = len(tag_bytes)
541-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
563+
def DecodeRepeatedField(
564+
buffer, pos, end, message, field_dict, current_depth=0
565+
):
566+
del current_depth # unused
542567
value = field_dict.get(key)
543568
if value is None:
544569
value = field_dict.setdefault(key, new_default(message))
@@ -553,9 +578,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
553578
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
554579
# Prediction failed. Return.
555580
return new_pos
581+
556582
return DecodeRepeatedField
557583
else:
558-
def DecodeField(buffer, pos, end, message, field_dict):
584+
585+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
586+
del current_depth # unused
559587
(size, pos) = local_DecodeVarint(buffer, pos)
560588
new_pos = pos + size
561589
if new_pos > end:
@@ -565,6 +593,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
565593
else:
566594
field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
567595
return new_pos
596+
568597
return DecodeField
569598

570599

@@ -579,7 +608,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
579608
tag_bytes = encoder.TagBytes(field_number,
580609
wire_format.WIRETYPE_LENGTH_DELIMITED)
581610
tag_len = len(tag_bytes)
582-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
611+
def DecodeRepeatedField(
612+
buffer, pos, end, message, field_dict, current_depth=0
613+
):
614+
del current_depth # unused
583615
value = field_dict.get(key)
584616
if value is None:
585617
value = field_dict.setdefault(key, new_default(message))
@@ -594,9 +626,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
594626
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
595627
# Prediction failed. Return.
596628
return new_pos
629+
597630
return DecodeRepeatedField
598631
else:
599-
def DecodeField(buffer, pos, end, message, field_dict):
632+
633+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
634+
del current_depth # unused
600635
(size, pos) = local_DecodeVarint(buffer, pos)
601636
new_pos = pos + size
602637
if new_pos > end:
@@ -606,6 +641,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
606641
else:
607642
field_dict[key] = buffer[pos:new_pos].tobytes()
608643
return new_pos
644+
609645
return DecodeField
610646

611647

@@ -621,7 +657,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
621657
tag_bytes = encoder.TagBytes(field_number,
622658
wire_format.WIRETYPE_START_GROUP)
623659
tag_len = len(tag_bytes)
624-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
660+
def DecodeRepeatedField(
661+
buffer, pos, end, message, field_dict, current_depth=0
662+
):
625663
value = field_dict.get(key)
626664
if value is None:
627665
value = field_dict.setdefault(key, new_default(message))
@@ -630,7 +668,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
630668
if value is None:
631669
value = field_dict.setdefault(key, new_default(message))
632670
# Read sub-message.
633-
pos = value.add()._InternalParse(buffer, pos, end)
671+
pos = value.add()._InternalParse(buffer, pos, end, current_depth)
634672
# Read end tag.
635673
new_pos = pos+end_tag_len
636674
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -640,19 +678,22 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
640678
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
641679
# Prediction failed. Return.
642680
return new_pos
681+
643682
return DecodeRepeatedField
644683
else:
645-
def DecodeField(buffer, pos, end, message, field_dict):
684+
685+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
646686
value = field_dict.get(key)
647687
if value is None:
648688
value = field_dict.setdefault(key, new_default(message))
649689
# Read sub-message.
650-
pos = value._InternalParse(buffer, pos, end)
690+
pos = value._InternalParse(buffer, pos, end, current_depth)
651691
# Read end tag.
652692
new_pos = pos+end_tag_len
653693
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
654694
raise _DecodeError('Missing group end tag.')
655695
return new_pos
696+
656697
return DecodeField
657698

658699

@@ -666,7 +707,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
666707
tag_bytes = encoder.TagBytes(field_number,
667708
wire_format.WIRETYPE_LENGTH_DELIMITED)
668709
tag_len = len(tag_bytes)
669-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
710+
def DecodeRepeatedField(
711+
buffer, pos, end, message, field_dict, current_depth=0
712+
):
670713
value = field_dict.get(key)
671714
if value is None:
672715
value = field_dict.setdefault(key, new_default(message))
@@ -677,7 +720,10 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
677720
if new_pos > end:
678721
raise _DecodeError('Truncated message.')
679722
# Read sub-message.
680-
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
723+
if (
724+
value.add()._InternalParse(buffer, pos, new_pos, current_depth)
725+
!= new_pos
726+
):
681727
# The only reason _InternalParse would return early is if it
682728
# encountered an end-group tag.
683729
raise _DecodeError('Unexpected end-group tag.')
@@ -686,9 +732,11 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
686732
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
687733
# Prediction failed. Return.
688734
return new_pos
735+
689736
return DecodeRepeatedField
690737
else:
691-
def DecodeField(buffer, pos, end, message, field_dict):
738+
739+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
692740
value = field_dict.get(key)
693741
if value is None:
694742
value = field_dict.setdefault(key, new_default(message))
@@ -698,11 +746,12 @@ def DecodeField(buffer, pos, end, message, field_dict):
698746
if new_pos > end:
699747
raise _DecodeError('Truncated message.')
700748
# Read sub-message.
701-
if value._InternalParse(buffer, pos, new_pos) != new_pos:
749+
if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
702750
# The only reason _InternalParse would return early is if it encountered
703751
# an end-group tag.
704752
raise _DecodeError('Unexpected end-group tag.')
705753
return new_pos
754+
706755
return DecodeField
707756

708757

@@ -851,7 +900,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
851900
# Can't read _concrete_class yet; might not be initialized.
852901
message_type = field_descriptor.message_type
853902

854-
def DecodeMap(buffer, pos, end, message, field_dict):
903+
def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0):
904+
del current_depth # Unused.
855905
submsg = message_type._concrete_class()
856906
value = field_dict.get(key)
857907
if value is None:
@@ -934,7 +984,7 @@ def _SkipGroup(buffer, pos, end):
934984
pos = new_pos
935985

936986

937-
def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
987+
def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
938988
"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
939989

940990
unknown_field_set = containers.UnknownFieldSet()
@@ -944,14 +994,16 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
944994
field_number, wire_type = wire_format.UnpackTag(tag)
945995
if wire_type == wire_format.WIRETYPE_END_GROUP:
946996
break
947-
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
997+
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type, current_depth)
948998
# pylint: disable=protected-access
949999
unknown_field_set._add(field_number, wire_type, data)
9501000

9511001
return (unknown_field_set, pos)
9521002

9531003

954-
def _DecodeUnknownField(buffer, pos, wire_type):
1004+
def _DecodeUnknownField(
1005+
buffer, pos, wire_type, current_depth=0
1006+
):
9551007
"""Decode a unknown field. Returns the UnknownField and new position."""
9561008

9571009
if wire_type == wire_format.WIRETYPE_VARINT:
@@ -965,7 +1017,7 @@ def _DecodeUnknownField(buffer, pos, wire_type):
9651017
data = buffer[pos:pos+size].tobytes()
9661018
pos += size
9671019
elif wire_type == wire_format.WIRETYPE_START_GROUP:
968-
(data, pos) = _DecodeUnknownFieldSet(buffer, pos)
1020+
(data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
9691021
elif wire_type == wire_format.WIRETYPE_END_GROUP:
9701022
return (0, -1)
9711023
else:

python/google/protobuf/internal/message_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
cmp = lambda x, y: (x > y) - (x < y)
3131

3232
from google.protobuf.internal import api_implementation # pylint: disable=g-import-not-at-top
33+
from google.protobuf.internal import decoder
3334
from google.protobuf.internal import encoder
3435
from google.protobuf.internal import enum_type_wrapper
3536
from google.protobuf.internal import more_extensions_pb2

python/google/protobuf/internal/python_message.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,7 @@ def MergeFromString(self, serialized):
11971197
fields_by_tag = cls._fields_by_tag
11981198
message_set_decoders_by_tag = cls._message_set_decoders_by_tag
11991199

1200-
def InternalParse(self, buffer, pos, end):
1200+
def InternalParse(self, buffer, pos, end, current_depth=0):
12011201
"""Create a message from serialized bytes.
12021202
12031203
Args:
@@ -1247,10 +1247,13 @@ def InternalParse(self, buffer, pos, end):
12471247
else:
12481248
_MaybeAddDecoder(cls, field_des)
12491249
field_decoder = field_des._decoders[is_packed]
1250-
pos = field_decoder(buffer, new_pos, end, self, field_dict)
1250+
pos = field_decoder(
1251+
buffer, new_pos, end, self, field_dict, current_depth
1252+
)
12511253
if field_des.containing_oneof:
12521254
self._UpdateOneofState(field_des)
12531255
return pos
1256+
12541257
cls._InternalParse = InternalParse
12551258

12561259

python/google/protobuf/internal/self_recursive.proto

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55
// license that can be found in the LICENSE file or at
66
// https://developers.google.com/open-source/licenses/bsd
77

8-
syntax = "proto2";
8+
edition = "2023";
99

1010
package google.protobuf.python.internal;
1111

1212
message SelfRecursive {
13-
optional SelfRecursive sub = 1;
13+
SelfRecursive sub = 1;
14+
int32 i = 2;
1415
}
1516

1617
message IndirectRecursive {
17-
optional IntermediateRecursive intermediate = 1;
18+
IntermediateRecursive intermediate = 1;
1819
}
1920

2021
message IntermediateRecursive {
21-
optional IndirectRecursive indirect = 1;
22+
IndirectRecursive indirect = 1;
2223
}

0 commit comments

Comments
 (0)