diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py index 6a44d4c97db0a..aa56692e8e165 100755 --- a/python/google/protobuf/internal/json_format_test.py +++ b/python/google/protobuf/internal/json_format_test.py @@ -909,6 +909,26 @@ def testParseEnumValue(self): 'for enum type protobuf_unittest.TestAllTypes.NestedEnum.', json_format.Parse, '{"optionalNestedEnum": 12345}', message) + def testParseUnknownEnumStringValueProto3(self): + message = json_format_proto3_pb2.TestMessage() + text = '{"enumValue": "UNKNOWN_STRING_VALUE"}' + json_format.Parse(text, message, ignore_unknown_fields=True) + # In proto3, there is no difference between the default value and 0. + self.assertEqual(message.enum_value, 0) + + def testParseUnknownEnumStringValueProto2(self): + message = json_format_pb2.TestNumbers() + text = '{"a": "UNKNOWN_STRING_VALUE"}' + json_format.Parse(text, message, ignore_unknown_fields=True) + # In proto2 we can see that the field was not set at all. + self.assertFalse(message.HasField("a")) + + def testParseUnknownEnumStringValueRepeatedProto3(self): + message = json_format_proto3_pb2.TestMessage() + text = '{"repeatedEnumValue": ["UNKNOWN_STRING_VALUE", "FOO", "BAR"]}' + json_format.Parse(text, message, ignore_unknown_fields=True) + self.assertEquals(len(message.repeated_enum_value), 2) + def testBytes(self): message = json_format_proto3_pb2.TestMessage() # Test url base64 diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py index 573cc0d3dec52..fd78185c32567 100644 --- a/python/google/protobuf/json_format.py +++ b/python/google/protobuf/json_format.py @@ -116,6 +116,39 @@ class ParseError(Error): """Thrown in case of parsing error.""" +class _UnknownEnumStringValueParseError(ParseError): + """Thrown if an unknown enum string value is encountered. This exception never leaks outside of the module.""" + + +class _MaybeSuppressUnknownEnumStringValueParseError(): + """ + Example usage: + + with _MaybeSuppressUnknownEnumStringValueParseError(True): + ... + + If should_suppress is True, the _UnknownEnumStringValueParseError will be ignored in the context body. + + The motivation for the context manager is to avoid a bigger refactor that would enable _ConvertScalarFieldValue to + signal to the caller that the field should be ignored. + + We want to avoid a bigger refactor because we are maintaining a fork and we want changes to be minimal to simplify + merging with upstream. + """ + def __init__(self, should_suppress): + self.should_suppress = should_suppress + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + # The return value from __exit__ indicates if any exception that occurred in the context body should be suppressed. + # We suppress _UnknownEnumStringValueParseError if should_suppress is set. + # See context manager docs: + # https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers + return self.should_suppress and exc_type == _UnknownEnumStringValueParseError + + def MessageToJson( message, including_default_value_fields=False, @@ -598,8 +631,9 @@ def _ConvertFieldValuePair(self, js, message): if item is None: raise ParseError('null is not allowed to be used as an element' ' in a repeated field.') - getattr(message, field.name).append( - _ConvertScalarFieldValue(item, field)) + with _MaybeSuppressUnknownEnumStringValueParseError(self.ignore_unknown_fields): + getattr(message, field.name).append( + _ConvertScalarFieldValue(item, field)) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: if field.is_extension: sub_message = message.Extensions[field] @@ -608,10 +642,11 @@ def _ConvertFieldValuePair(self, js, message): sub_message.SetInParent() self.ConvertMessage(value, sub_message) else: - if field.is_extension: - message.Extensions[field] = _ConvertScalarFieldValue(value, field) - else: - setattr(message, field.name, _ConvertScalarFieldValue(value, field)) + with _MaybeSuppressUnknownEnumStringValueParseError(self.ignore_unknown_fields): + if field.is_extension: + message.Extensions[field] = _ConvertScalarFieldValue(value, field) + else: + setattr(message, field.name, _ConvertScalarFieldValue(value, field)) except ParseError as e: if field and field.containing_oneof is None: raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) @@ -698,7 +733,8 @@ def _ConvertStructMessage(self, value, message): def _ConvertWrapperMessage(self, value, message): """Convert a JSON representation into Wrapper message.""" field = message.DESCRIPTOR.fields_by_name['value'] - setattr(message, 'value', _ConvertScalarFieldValue(value, field)) + with _MaybeSuppressUnknownEnumStringValueParseError(self.ignore_unknown_fields): + setattr(message, 'value', _ConvertScalarFieldValue(value, field)) def _ConvertMapFieldValue(self, value, message, field): """Convert map field value for a message map field. @@ -718,13 +754,15 @@ def _ConvertMapFieldValue(self, value, message, field): key_field = field.message_type.fields_by_name['key'] value_field = field.message_type.fields_by_name['value'] for key in value: - key_value = _ConvertScalarFieldValue(key, key_field, True) + with _MaybeSuppressUnknownEnumStringValueParseError(self.ignore_unknown_fields): + key_value = _ConvertScalarFieldValue(key, key_field, True) if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: self.ConvertMessage(value[key], getattr( message, field.name)[key_value]) else: - getattr(message, field.name)[key_value] = _ConvertScalarFieldValue( - value[key], value_field) + with _MaybeSuppressUnknownEnumStringValueParseError(self.ignore_unknown_fields): + getattr(message, field.name)[key_value] = _ConvertScalarFieldValue( + value[key], value_field) def _ConvertScalarFieldValue(value, field, require_str=False): @@ -740,6 +778,7 @@ def _ConvertScalarFieldValue(value, field, require_str=False): Raises: ParseError: In case of convert problems. + _UnknownEnumStringValueParseError: If unknown enum string value is encountered during parsing. """ if field.cpp_type in _INT_TYPES: return _ConvertInteger(value) @@ -770,7 +809,9 @@ def _ConvertScalarFieldValue(value, field, require_str=False): number = int(value) enum_value = field.enum_type.values_by_number.get(number, None) except ValueError: - raise ParseError('Invalid enum value {0} for enum type {1}.'.format( + # The ValueError will be raised by the conversion to int. + # That means that here we know that we have an unknown enum string value. + raise _UnknownEnumStringValueParseError('Invalid enum value {0} for enum type {1}.'.format( value, field.enum_type.full_name)) if enum_value is None: if field.file.syntax == 'proto3':