Skip to content

Commit f327623

Browse files
committed
update
1 parent 5187eb1 commit f327623

2 files changed

Lines changed: 199 additions & 12 deletions

File tree

parquet-avro/src/main/java/org/apache/parquet/avro/AvroSchemaConverter.java

Lines changed: 89 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import java.util.Collections;
5252
import java.util.HashMap;
5353
import java.util.HashSet;
54+
import java.util.IdentityHashMap;
5455
import java.util.List;
5556
import java.util.Map;
5657
import java.util.Optional;
@@ -82,12 +83,26 @@ public class AvroSchemaConverter {
8283

8384
public static final String ADD_LIST_ELEMENT_RECORDS = "parquet.avro.add-list-element-records";
8485
private static final boolean ADD_LIST_ELEMENT_RECORDS_DEFAULT = true;
86+
public static final String AVRO_MAX_RECURSION = "parquet.avro.max-recursion";
87+
private static final int AVRO_MAX_RECURSION_DEFAULT = 10;
8588

8689
private final boolean assumeRepeatedIsListElement;
8790
private final boolean writeOldListStructure;
8891
private final boolean writeParquetUUID;
8992
private final boolean readInt96AsFixed;
9093
private final Set<String> pathsToInt96;
94+
private final int maxRecursion;
95+
96+
/**
97+
* Sets the maximum recursion depth for recursive schemas.
98+
*
99+
* @param config The hadoop configuration to be updated.
100+
* @param maxRecursion The maximum recursion depth schemas are allowed to go before terminating
101+
* with an UnsupportedOperationException instead of their actual schema.
102+
*/
103+
public static void setMaxRecursion(Configuration config, int maxRecursion) {
104+
config.setInt(AVRO_MAX_RECURSION, maxRecursion);
105+
}
91106

92107
public AvroSchemaConverter() {
93108
this(ADD_LIST_ELEMENT_RECORDS_DEFAULT, READ_INT96_AS_FIXED_DEFAULT);
@@ -106,6 +121,7 @@ public AvroSchemaConverter() {
106121
this.writeParquetUUID = WRITE_PARQUET_UUID_DEFAULT;
107122
this.readInt96AsFixed = readInt96AsFixed;
108123
this.pathsToInt96 = Collections.emptySet();
124+
this.maxRecursion = AVRO_MAX_RECURSION_DEFAULT;
109125
}
110126

111127
public AvroSchemaConverter(Configuration conf) {
@@ -118,6 +134,7 @@ public AvroSchemaConverter(ParquetConfiguration conf) {
118134
this.writeParquetUUID = conf.getBoolean(WRITE_PARQUET_UUID, WRITE_PARQUET_UUID_DEFAULT);
119135
this.readInt96AsFixed = conf.getBoolean(READ_INT96_AS_FIXED, READ_INT96_AS_FIXED_DEFAULT);
120136
this.pathsToInt96 = new HashSet<>(Arrays.asList(conf.getStrings(WRITE_FIXED_AS_INT96, new String[0])));
137+
this.maxRecursion = conf.getInt(AVRO_MAX_RECURSION, AVRO_MAX_RECURSION_DEFAULT);
121138
}
122139

123140
/**
@@ -150,16 +167,23 @@ public MessageType convert(Schema avroSchema) {
150167
if (!avroSchema.getType().equals(Schema.Type.RECORD)) {
151168
throw new IllegalArgumentException("Avro schema must be a record.");
152169
}
153-
return new MessageType(avroSchema.getFullName(), convertFields(avroSchema.getFields(), ""));
170+
return new MessageType(
171+
avroSchema.getFullName(),
172+
convertFields(avroSchema.getFields(), "", new IdentityHashMap<Schema, Integer>()));
154173
}
155174

156175
private List<Type> convertFields(List<Schema.Field> fields, String schemaPath) {
176+
return convertFields(fields, schemaPath, new IdentityHashMap<Schema, Integer>());
177+
}
178+
179+
private List<Type> convertFields(
180+
List<Schema.Field> fields, String schemaPath, IdentityHashMap<Schema, Integer> seenSchemas) {
157181
List<Type> types = new ArrayList<Type>();
158182
for (Schema.Field field : fields) {
159183
if (field.schema().getType().equals(Schema.Type.NULL)) {
160184
continue; // Avro nulls are not encoded, unless they are null unions
161185
}
162-
types.add(convertField(field, appendPath(schemaPath, field.name())));
186+
types.add(convertField(field, appendPath(schemaPath, field.name()), seenSchemas));
163187
}
164188
return types;
165189
}
@@ -168,11 +192,37 @@ private Type convertField(String fieldName, Schema schema, String schemaPath) {
168192
return convertField(fieldName, schema, Type.Repetition.REQUIRED, schemaPath);
169193
}
170194

195+
private Type convertField(
196+
String fieldName, Schema schema, String schemaPath, IdentityHashMap<Schema, Integer> seenSchemas) {
197+
return convertField(fieldName, schema, Type.Repetition.REQUIRED, schemaPath, seenSchemas);
198+
}
199+
171200
@SuppressWarnings("deprecation")
172201
private Type convertField(String fieldName, Schema schema, Type.Repetition repetition, String schemaPath) {
173-
Types.PrimitiveBuilder<PrimitiveType> builder;
202+
return convertField(fieldName, schema, repetition, schemaPath, new IdentityHashMap<Schema, Integer>());
203+
}
204+
205+
@SuppressWarnings("deprecation")
206+
private Type convertField(
207+
String fieldName,
208+
Schema schema,
209+
Type.Repetition repetition,
210+
String schemaPath,
211+
IdentityHashMap<Schema, Integer> seenSchemas) {
174212
Schema.Type type = schema.getType();
175213
LogicalType logicalType = schema.getLogicalType();
214+
215+
if (type.equals(Schema.Type.RECORD) || type.equals(Schema.Type.ENUM) || type.equals(Schema.Type.FIXED)) {
216+
Integer depth = seenSchemas.get(schema);
217+
if (depth != null && depth >= maxRecursion) {
218+
throw new UnsupportedOperationException("Recursive Avro schemas are not supported by parquet-avro: "
219+
+ schema.getFullName() + " (exceeded maximum recursion depth of " + maxRecursion + ")");
220+
}
221+
seenSchemas = new IdentityHashMap<>(seenSchemas);
222+
seenSchemas.put(schema, depth == null ? 1 : depth + 1);
223+
}
224+
225+
Types.PrimitiveBuilder<PrimitiveType> builder;
176226
if (type.equals(Schema.Type.BOOLEAN)) {
177227
builder = Types.primitive(BOOLEAN, repetition);
178228
} else if (type.equals(Schema.Type.INT)) {
@@ -195,21 +245,24 @@ private Type convertField(String fieldName, Schema schema, Type.Repetition repet
195245
builder = Types.primitive(BINARY, repetition).as(stringType());
196246
}
197247
} else if (type.equals(Schema.Type.RECORD)) {
198-
return new GroupType(repetition, fieldName, convertFields(schema.getFields(), schemaPath));
248+
return new GroupType(repetition, fieldName, convertFields(schema.getFields(), schemaPath, seenSchemas));
199249
} else if (type.equals(Schema.Type.ENUM)) {
200250
builder = Types.primitive(BINARY, repetition).as(enumType());
201251
} else if (type.equals(Schema.Type.ARRAY)) {
202252
if (writeOldListStructure) {
203253
return ConversionPatterns.listType(
204-
repetition, fieldName, convertField("array", schema.getElementType(), REPEATED, schemaPath));
254+
repetition,
255+
fieldName,
256+
convertField("array", schema.getElementType(), REPEATED, schemaPath, seenSchemas));
205257
} else {
206258
return ConversionPatterns.listOfElements(
207259
repetition,
208260
fieldName,
209-
convertField(AvroWriteSupport.LIST_ELEMENT_NAME, schema.getElementType(), schemaPath));
261+
convertField(
262+
AvroWriteSupport.LIST_ELEMENT_NAME, schema.getElementType(), schemaPath, seenSchemas));
210263
}
211264
} else if (type.equals(Schema.Type.MAP)) {
212-
Type valType = convertField("value", schema.getValueType(), schemaPath);
265+
Type valType = convertField("value", schema.getValueType(), schemaPath, seenSchemas);
213266
// avro map key type is always string
214267
return ConversionPatterns.stringKeyMapType(repetition, fieldName, valType);
215268
} else if (type.equals(Schema.Type.FIXED)) {
@@ -223,7 +276,7 @@ private Type convertField(String fieldName, Schema schema, Type.Repetition repet
223276
builder = Types.primitive(FIXED_LEN_BYTE_ARRAY, repetition).length(schema.getFixedSize());
224277
}
225278
} else if (type.equals(Schema.Type.UNION)) {
226-
return convertUnion(fieldName, schema, repetition, schemaPath);
279+
return convertUnion(fieldName, schema, repetition, schemaPath, seenSchemas);
227280
} else {
228281
throw new UnsupportedOperationException("Cannot convert Avro type " + type);
229282
}
@@ -246,6 +299,15 @@ private Type convertField(String fieldName, Schema schema, Type.Repetition repet
246299
}
247300

248301
private Type convertUnion(String fieldName, Schema schema, Type.Repetition repetition, String schemaPath) {
302+
return convertUnion(fieldName, schema, repetition, schemaPath, new IdentityHashMap<Schema, Integer>());
303+
}
304+
305+
private Type convertUnion(
306+
String fieldName,
307+
Schema schema,
308+
Type.Repetition repetition,
309+
String schemaPath,
310+
IdentityHashMap<Schema, Integer> seenSchemas) {
249311
List<Schema> nonNullSchemas = new ArrayList<Schema>(schema.getTypes().size());
250312
// Found any schemas in the union? Required for the edge case, where the union contains only a single type.
251313
boolean foundNullSchema = false;
@@ -267,20 +329,31 @@ private Type convertUnion(String fieldName, Schema schema, Type.Repetition repet
267329

268330
case 1:
269331
return foundNullSchema
270-
? convertField(fieldName, nonNullSchemas.get(0), repetition, schemaPath)
271-
: convertUnionToGroupType(fieldName, repetition, nonNullSchemas, schemaPath);
332+
? convertField(fieldName, nonNullSchemas.get(0), repetition, schemaPath, seenSchemas)
333+
: convertUnionToGroupType(fieldName, repetition, nonNullSchemas, schemaPath, seenSchemas);
272334

273335
default: // complex union type
274-
return convertUnionToGroupType(fieldName, repetition, nonNullSchemas, schemaPath);
336+
return convertUnionToGroupType(fieldName, repetition, nonNullSchemas, schemaPath, seenSchemas);
275337
}
276338
}
277339

278340
private Type convertUnionToGroupType(
279341
String fieldName, Type.Repetition repetition, List<Schema> nonNullSchemas, String schemaPath) {
342+
return convertUnionToGroupType(
343+
fieldName, repetition, nonNullSchemas, schemaPath, new IdentityHashMap<Schema, Integer>());
344+
}
345+
346+
private Type convertUnionToGroupType(
347+
String fieldName,
348+
Type.Repetition repetition,
349+
List<Schema> nonNullSchemas,
350+
String schemaPath,
351+
IdentityHashMap<Schema, Integer> seenSchemas) {
280352
List<Type> unionTypes = new ArrayList<Type>(nonNullSchemas.size());
281353
int index = 0;
282354
for (Schema childSchema : nonNullSchemas) {
283-
unionTypes.add(convertField("member" + index++, childSchema, Type.Repetition.OPTIONAL, schemaPath));
355+
unionTypes.add(
356+
convertField("member" + index++, childSchema, Type.Repetition.OPTIONAL, schemaPath, seenSchemas));
284357
}
285358
return new GroupType(repetition, fieldName, unionTypes);
286359
}
@@ -289,6 +362,10 @@ private Type convertField(Schema.Field field, String schemaPath) {
289362
return convertField(field.name(), field.schema(), schemaPath);
290363
}
291364

365+
private Type convertField(Schema.Field field, String schemaPath, IdentityHashMap<Schema, Integer> seenSchemas) {
366+
return convertField(field.name(), field.schema(), schemaPath, seenSchemas);
367+
}
368+
292369
public Schema convert(MessageType parquetSchema) {
293370
return convertFields(parquetSchema.getName(), parquetSchema.getFields(), new HashMap<>());
294371
}

parquet-avro/src/test/java/org/apache/parquet/avro/TestAvroSchemaConverter.java

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,116 @@ public void testAvroFixed12AsParquetInt96Type() throws Exception {
965965
() -> new AvroSchemaConverter(conf).convert(schema));
966966
}
967967

968+
@Test
969+
public void testRecursiveSchemaThrowsException() {
970+
String recursiveSchemaJson = "{"
971+
+ "\"type\": \"record\", \"name\": \"Node\", \"fields\": ["
972+
+ " {\"name\": \"value\", \"type\": \"int\"},"
973+
+ " {\"name\": \"children\", \"type\": ["
974+
+ " \"null\", {"
975+
+ " \"type\": \"array\", \"items\": [\"null\", \"Node\"]"
976+
+ " }"
977+
+ " ], \"default\": null}"
978+
+ "]}";
979+
980+
Schema recursiveSchema = new Schema.Parser().parse(recursiveSchemaJson);
981+
982+
assertThrows(
983+
"Recursive Avro schema should throw UnsupportedOperationException",
984+
UnsupportedOperationException.class,
985+
() -> new AvroSchemaConverter().convert(recursiveSchema));
986+
}
987+
988+
@Test
989+
public void testRecursiveSchemaFromGitHubIssue() {
990+
String issueSchemaJson = "{"
991+
+ "\"type\": \"record\", \"name\": \"ObjXX\", \"fields\": ["
992+
+ " {\"name\": \"id\", \"type\": [\"null\", \"long\"], \"default\": null},"
993+
+ " {\"name\": \"struct_add_list\", \"type\": [\"null\", {"
994+
+ " \"type\": \"array\", \"items\": [\"null\", {"
995+
+ " \"type\": \"record\", \"name\": \"ObjStructAdd\", \"fields\": ["
996+
+ " {\"name\": \"name\", \"type\": [\"null\", \"string\"], \"default\": null},"
997+
+ " {\"name\": \"fld_list\", \"type\": [\"null\", {"
998+
+ " \"type\": \"array\", \"items\": [\"null\", {"
999+
+ " \"type\": \"record\", \"name\": \"ObjStructAddFld\", \"fields\": ["
1000+
+ " {\"name\": \"name\", \"type\": [\"null\", \"string\"], \"default\": null},"
1001+
+ " {\"name\": \"ref_val\", \"type\": [\"null\", \"ObjStructAdd\"], \"default\": null}"
1002+
+ " ]"
1003+
+ " }]"
1004+
+ " }], \"default\": null}"
1005+
+ " ]"
1006+
+ " }]"
1007+
+ " }], \"default\": null},"
1008+
+ " {\"name\": \"kafka_timestamp\", \"type\": {\"type\": \"long\", \"logicalType\": \"timestamp-millis\"}}"
1009+
+ "]}";
1010+
1011+
Schema issueSchema = new Schema.Parser().parse(issueSchemaJson);
1012+
1013+
assertThrows(
1014+
"Schema should throw UnsupportedOperationException",
1015+
UnsupportedOperationException.class,
1016+
() -> new AvroSchemaConverter().convert(issueSchema));
1017+
}
1018+
1019+
@Test
1020+
public void testRecursiveSchemaErrorMessage() {
1021+
String recursiveSchemaJson = "{"
1022+
+ "\"type\": \"record\", \"name\": \"TestRecord\", \"fields\": ["
1023+
+ " {\"name\": \"self\", \"type\": [\"null\", \"TestRecord\"], \"default\": null}"
1024+
+ "]}";
1025+
1026+
Schema recursiveSchema = new Schema.Parser().parse(recursiveSchemaJson);
1027+
1028+
try {
1029+
new AvroSchemaConverter().convert(recursiveSchema);
1030+
Assert.fail("Expected UnsupportedOperationException");
1031+
} catch (UnsupportedOperationException e) {
1032+
String message = e.getMessage();
1033+
Assert.assertTrue(
1034+
"Error message should mention recursion",
1035+
message.contains("Recursive Avro schemas are not supported"));
1036+
Assert.assertTrue("Error message should mention schema name", message.contains("TestRecord"));
1037+
Assert.assertTrue(
1038+
"Error message should mention max recursion depth", message.contains("maximum recursion depth"));
1039+
}
1040+
}
1041+
1042+
@Test
1043+
public void testConfigurableMaxRecursion() {
1044+
String recursiveSchemaJson = "{"
1045+
+ "\"type\": \"record\", \"name\": \"Node\", \"fields\": ["
1046+
+ " {\"name\": \"child\", \"type\": [\"null\", \"Node\"], \"default\": null}"
1047+
+ "]}";
1048+
1049+
Schema recursiveSchema = new Schema.Parser().parse(recursiveSchemaJson);
1050+
Configuration conf = new Configuration();
1051+
1052+
AvroSchemaConverter.setMaxRecursion(conf, 1);
1053+
assertThrows(
1054+
"Should fail with max recursion 1",
1055+
UnsupportedOperationException.class,
1056+
() -> new AvroSchemaConverter(conf).convert(recursiveSchema));
1057+
1058+
AvroSchemaConverter.setMaxRecursion(conf, 5);
1059+
assertThrows(
1060+
"Should fail with max recursion 5",
1061+
UnsupportedOperationException.class,
1062+
() -> new AvroSchemaConverter(conf).convert(recursiveSchema));
1063+
}
1064+
1065+
@Test
1066+
public void testDeeplyNestedNonRecursiveSchema() {
1067+
Schema level3 = record("Level3", field("value", primitive(STRING)));
1068+
Schema level2 = record("Level2", field("level3", level3));
1069+
Schema level1 = record("Level1", field("level2", level2));
1070+
Schema rootSchema = record("Root", field("level1", level1));
1071+
1072+
AvroSchemaConverter converter = new AvroSchemaConverter();
1073+
MessageType result = converter.convert(rootSchema);
1074+
Assert.assertNotNull("Non-recursive deep schema should convert successfully", result);
1075+
Assert.assertEquals("Root schema name should be preserved", "Root", result.getName());
1076+
}
1077+
9681078
public static Schema optional(Schema original) {
9691079
return Schema.createUnion(Lists.newArrayList(Schema.create(Schema.Type.NULL), original));
9701080
}

0 commit comments

Comments
 (0)