5151import java .util .Collections ;
5252import java .util .HashMap ;
5353import java .util .HashSet ;
54+ import java .util .IdentityHashMap ;
5455import java .util .List ;
5556import java .util .Map ;
5657import java .util .Optional ;
@@ -82,12 +83,26 @@ public class AvroSchemaConverter {
8283
8384 public static final String ADD_LIST_ELEMENT_RECORDS = "parquet.avro.add-list-element-records" ;
8485 private static final boolean ADD_LIST_ELEMENT_RECORDS_DEFAULT = true ;
86+ public static final String AVRO_MAX_RECURSION = "parquet.avro.max-recursion" ;
87+ private static final int AVRO_MAX_RECURSION_DEFAULT = 10 ;
8588
8689 private final boolean assumeRepeatedIsListElement ;
8790 private final boolean writeOldListStructure ;
8891 private final boolean writeParquetUUID ;
8992 private final boolean readInt96AsFixed ;
9093 private final Set <String > pathsToInt96 ;
94+ private final int maxRecursion ;
95+
96+ /**
97+ * Sets the maximum recursion depth for recursive schemas.
98+ *
99+ * @param config The hadoop configuration to be updated.
100+ * @param maxRecursion The maximum recursion depth schemas are allowed to go before terminating
101+ * with an UnsupportedOperationException instead of their actual schema.
102+ */
103+ public static void setMaxRecursion (Configuration config , int maxRecursion ) {
104+ config .setInt (AVRO_MAX_RECURSION , maxRecursion );
105+ }
91106
92107 public AvroSchemaConverter () {
93108 this (ADD_LIST_ELEMENT_RECORDS_DEFAULT , READ_INT96_AS_FIXED_DEFAULT );
@@ -106,6 +121,7 @@ public AvroSchemaConverter() {
106121 this .writeParquetUUID = WRITE_PARQUET_UUID_DEFAULT ;
107122 this .readInt96AsFixed = readInt96AsFixed ;
108123 this .pathsToInt96 = Collections .emptySet ();
124+ this .maxRecursion = AVRO_MAX_RECURSION_DEFAULT ;
109125 }
110126
111127 public AvroSchemaConverter (Configuration conf ) {
@@ -118,6 +134,7 @@ public AvroSchemaConverter(ParquetConfiguration conf) {
118134 this .writeParquetUUID = conf .getBoolean (WRITE_PARQUET_UUID , WRITE_PARQUET_UUID_DEFAULT );
119135 this .readInt96AsFixed = conf .getBoolean (READ_INT96_AS_FIXED , READ_INT96_AS_FIXED_DEFAULT );
120136 this .pathsToInt96 = new HashSet <>(Arrays .asList (conf .getStrings (WRITE_FIXED_AS_INT96 , new String [0 ])));
137+ this .maxRecursion = conf .getInt (AVRO_MAX_RECURSION , AVRO_MAX_RECURSION_DEFAULT );
121138 }
122139
123140 /**
@@ -150,16 +167,23 @@ public MessageType convert(Schema avroSchema) {
150167 if (!avroSchema .getType ().equals (Schema .Type .RECORD )) {
151168 throw new IllegalArgumentException ("Avro schema must be a record." );
152169 }
153- return new MessageType (avroSchema .getFullName (), convertFields (avroSchema .getFields (), "" ));
170+ return new MessageType (
171+ avroSchema .getFullName (),
172+ convertFields (avroSchema .getFields (), "" , new IdentityHashMap <Schema , Integer >()));
154173 }
155174
156175 private List <Type > convertFields (List <Schema .Field > fields , String schemaPath ) {
176+ return convertFields (fields , schemaPath , new IdentityHashMap <Schema , Integer >());
177+ }
178+
179+ private List <Type > convertFields (
180+ List <Schema .Field > fields , String schemaPath , IdentityHashMap <Schema , Integer > seenSchemas ) {
157181 List <Type > types = new ArrayList <Type >();
158182 for (Schema .Field field : fields ) {
159183 if (field .schema ().getType ().equals (Schema .Type .NULL )) {
160184 continue ; // Avro nulls are not encoded, unless they are null unions
161185 }
162- types .add (convertField (field , appendPath (schemaPath , field .name ())));
186+ types .add (convertField (field , appendPath (schemaPath , field .name ()), seenSchemas ));
163187 }
164188 return types ;
165189 }
@@ -168,11 +192,37 @@ private Type convertField(String fieldName, Schema schema, String schemaPath) {
168192 return convertField (fieldName , schema , Type .Repetition .REQUIRED , schemaPath );
169193 }
170194
195+ private Type convertField (
196+ String fieldName , Schema schema , String schemaPath , IdentityHashMap <Schema , Integer > seenSchemas ) {
197+ return convertField (fieldName , schema , Type .Repetition .REQUIRED , schemaPath , seenSchemas );
198+ }
199+
171200 @ SuppressWarnings ("deprecation" )
172201 private Type convertField (String fieldName , Schema schema , Type .Repetition repetition , String schemaPath ) {
173- Types .PrimitiveBuilder <PrimitiveType > builder ;
202+ return convertField (fieldName , schema , repetition , schemaPath , new IdentityHashMap <Schema , Integer >());
203+ }
204+
205+ @ SuppressWarnings ("deprecation" )
206+ private Type convertField (
207+ String fieldName ,
208+ Schema schema ,
209+ Type .Repetition repetition ,
210+ String schemaPath ,
211+ IdentityHashMap <Schema , Integer > seenSchemas ) {
174212 Schema .Type type = schema .getType ();
175213 LogicalType logicalType = schema .getLogicalType ();
214+
215+ if (type .equals (Schema .Type .RECORD ) || type .equals (Schema .Type .ENUM ) || type .equals (Schema .Type .FIXED )) {
216+ Integer depth = seenSchemas .get (schema );
217+ if (depth != null && depth >= maxRecursion ) {
218+ throw new UnsupportedOperationException ("Recursive Avro schemas are not supported by parquet-avro: "
219+ + schema .getFullName () + " (exceeded maximum recursion depth of " + maxRecursion + ")" );
220+ }
221+ seenSchemas = new IdentityHashMap <>(seenSchemas );
222+ seenSchemas .put (schema , depth == null ? 1 : depth + 1 );
223+ }
224+
225+ Types .PrimitiveBuilder <PrimitiveType > builder ;
176226 if (type .equals (Schema .Type .BOOLEAN )) {
177227 builder = Types .primitive (BOOLEAN , repetition );
178228 } else if (type .equals (Schema .Type .INT )) {
@@ -195,21 +245,24 @@ private Type convertField(String fieldName, Schema schema, Type.Repetition repet
195245 builder = Types .primitive (BINARY , repetition ).as (stringType ());
196246 }
197247 } else if (type .equals (Schema .Type .RECORD )) {
198- return new GroupType (repetition , fieldName , convertFields (schema .getFields (), schemaPath ));
248+ return new GroupType (repetition , fieldName , convertFields (schema .getFields (), schemaPath , seenSchemas ));
199249 } else if (type .equals (Schema .Type .ENUM )) {
200250 builder = Types .primitive (BINARY , repetition ).as (enumType ());
201251 } else if (type .equals (Schema .Type .ARRAY )) {
202252 if (writeOldListStructure ) {
203253 return ConversionPatterns .listType (
204- repetition , fieldName , convertField ("array" , schema .getElementType (), REPEATED , schemaPath ));
254+ repetition ,
255+ fieldName ,
256+ convertField ("array" , schema .getElementType (), REPEATED , schemaPath , seenSchemas ));
205257 } else {
206258 return ConversionPatterns .listOfElements (
207259 repetition ,
208260 fieldName ,
209- convertField (AvroWriteSupport .LIST_ELEMENT_NAME , schema .getElementType (), schemaPath ));
261+ convertField (
262+ AvroWriteSupport .LIST_ELEMENT_NAME , schema .getElementType (), schemaPath , seenSchemas ));
210263 }
211264 } else if (type .equals (Schema .Type .MAP )) {
212- Type valType = convertField ("value" , schema .getValueType (), schemaPath );
265+ Type valType = convertField ("value" , schema .getValueType (), schemaPath , seenSchemas );
213266 // avro map key type is always string
214267 return ConversionPatterns .stringKeyMapType (repetition , fieldName , valType );
215268 } else if (type .equals (Schema .Type .FIXED )) {
@@ -223,7 +276,7 @@ private Type convertField(String fieldName, Schema schema, Type.Repetition repet
223276 builder = Types .primitive (FIXED_LEN_BYTE_ARRAY , repetition ).length (schema .getFixedSize ());
224277 }
225278 } else if (type .equals (Schema .Type .UNION )) {
226- return convertUnion (fieldName , schema , repetition , schemaPath );
279+ return convertUnion (fieldName , schema , repetition , schemaPath , seenSchemas );
227280 } else {
228281 throw new UnsupportedOperationException ("Cannot convert Avro type " + type );
229282 }
@@ -246,6 +299,15 @@ private Type convertField(String fieldName, Schema schema, Type.Repetition repet
246299 }
247300
248301 private Type convertUnion (String fieldName , Schema schema , Type .Repetition repetition , String schemaPath ) {
302+ return convertUnion (fieldName , schema , repetition , schemaPath , new IdentityHashMap <Schema , Integer >());
303+ }
304+
305+ private Type convertUnion (
306+ String fieldName ,
307+ Schema schema ,
308+ Type .Repetition repetition ,
309+ String schemaPath ,
310+ IdentityHashMap <Schema , Integer > seenSchemas ) {
249311 List <Schema > nonNullSchemas = new ArrayList <Schema >(schema .getTypes ().size ());
250312 // Found any schemas in the union? Required for the edge case, where the union contains only a single type.
251313 boolean foundNullSchema = false ;
@@ -267,20 +329,31 @@ private Type convertUnion(String fieldName, Schema schema, Type.Repetition repet
267329
268330 case 1 :
269331 return foundNullSchema
270- ? convertField (fieldName , nonNullSchemas .get (0 ), repetition , schemaPath )
271- : convertUnionToGroupType (fieldName , repetition , nonNullSchemas , schemaPath );
332+ ? convertField (fieldName , nonNullSchemas .get (0 ), repetition , schemaPath , seenSchemas )
333+ : convertUnionToGroupType (fieldName , repetition , nonNullSchemas , schemaPath , seenSchemas );
272334
273335 default : // complex union type
274- return convertUnionToGroupType (fieldName , repetition , nonNullSchemas , schemaPath );
336+ return convertUnionToGroupType (fieldName , repetition , nonNullSchemas , schemaPath , seenSchemas );
275337 }
276338 }
277339
278340 private Type convertUnionToGroupType (
279341 String fieldName , Type .Repetition repetition , List <Schema > nonNullSchemas , String schemaPath ) {
342+ return convertUnionToGroupType (
343+ fieldName , repetition , nonNullSchemas , schemaPath , new IdentityHashMap <Schema , Integer >());
344+ }
345+
346+ private Type convertUnionToGroupType (
347+ String fieldName ,
348+ Type .Repetition repetition ,
349+ List <Schema > nonNullSchemas ,
350+ String schemaPath ,
351+ IdentityHashMap <Schema , Integer > seenSchemas ) {
280352 List <Type > unionTypes = new ArrayList <Type >(nonNullSchemas .size ());
281353 int index = 0 ;
282354 for (Schema childSchema : nonNullSchemas ) {
283- unionTypes .add (convertField ("member" + index ++, childSchema , Type .Repetition .OPTIONAL , schemaPath ));
355+ unionTypes .add (
356+ convertField ("member" + index ++, childSchema , Type .Repetition .OPTIONAL , schemaPath , seenSchemas ));
284357 }
285358 return new GroupType (repetition , fieldName , unionTypes );
286359 }
@@ -289,6 +362,10 @@ private Type convertField(Schema.Field field, String schemaPath) {
289362 return convertField (field .name (), field .schema (), schemaPath );
290363 }
291364
365+ private Type convertField (Schema .Field field , String schemaPath , IdentityHashMap <Schema , Integer > seenSchemas ) {
366+ return convertField (field .name (), field .schema (), schemaPath , seenSchemas );
367+ }
368+
292369 public Schema convert (MessageType parquetSchema ) {
293370 return convertFields (parquetSchema .getName (), parquetSchema .getFields (), new HashMap <>());
294371 }
0 commit comments