diff --git a/src/main/java/org/apache/commons/io/serialization/ValidatingObjectInputStream.java b/src/main/java/org/apache/commons/io/serialization/ValidatingObjectInputStream.java index 495d02677b6..d3adb95cd16 100644 --- a/src/main/java/org/apache/commons/io/serialization/ValidatingObjectInputStream.java +++ b/src/main/java/org/apache/commons/io/serialization/ValidatingObjectInputStream.java @@ -510,7 +510,11 @@ public ValidatingObjectInputStream reject(final String... patterns) { @Override protected Class resolveClass(final ObjectStreamClass osc) throws IOException, ClassNotFoundException { checkClassName(osc.getName()); - return super.resolveClass(osc); + final Class result = super.resolveClass(osc); + for (final Class interfaceName : result.getInterfaces()) { + checkClassName(interfaceName.getName()); + } + return result; } /** diff --git a/src/test/java/org/apache/commons/io/serialization/ValidatingObjectInputStreamInterfaceTest.java b/src/test/java/org/apache/commons/io/serialization/ValidatingObjectInputStreamInterfaceTest.java new file mode 100644 index 00000000000..db4e09a6a6d --- /dev/null +++ b/src/test/java/org/apache/commons/io/serialization/ValidatingObjectInputStreamInterfaceTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.commons.io.serialization; + +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.io.IOException; +import java.io.InvalidClassException; +import java.io.Serializable; + +import org.apache.commons.lang3.SerializationUtils; +import org.junit.jupiter.api.Test; + +/** + * Tests {@link ValidatingObjectInputStream}. + */ +class ValidatingObjectInputStreamInterfaceTest { + + public interface IFoo extends Serializable { + + void foo(); + } + + public static class MockObject implements IFoo { + + private static final long serialVersionUID = 1L; + + @Override + public void foo() { + // empty + } + } + + abstract static class AbtractFoo implements IFoo { + + private static final long serialVersionUID = 1L; + + } + + static class FooImpl extends AbtractFoo { + + private static final long serialVersionUID = 1L; + + @Override + public void foo() { + // empty + } + + } + + @Test + void testAcceptAll() throws IOException, ClassNotFoundException { + final MockObject object = new MockObject(); + final byte[] serialized = SerializationUtils.serialize(object); + final Class ifaceClass = IFoo.class; + // @formatter:off + try (ValidatingObjectInputStream vois = ValidatingObjectInputStream.builder() + .setByteArray(serialized) + .accept("*") + .get()) { + // @formatter:on + assertInstanceOf(ifaceClass, vois.readObject()); + } + } + + @Test + void testAcceptAbstractClass() throws IOException, ClassNotFoundException { + final FooImpl object = new FooImpl(); + final byte[] serialized = SerializationUtils.serialize(object); + final Class ifaceClass = IFoo.class; + // @formatter:off + try (ValidatingObjectInputStream vois = ValidatingObjectInputStream.builder() + .setByteArray(serialized) + .accept(IFoo.class) + .accept(AbtractFoo.class) + .accept(FooImpl.class) + .get()) { + // @formatter:on + assertInstanceOf(ifaceClass, vois.readObject()); + } + } + + @Test + void testAcceptInterface() throws IOException, ClassNotFoundException { + final MockObject object = new MockObject(); + final byte[] serialized = SerializationUtils.serialize(object); + final Class ifaceClass = IFoo.class; + // @formatter:off + try (ValidatingObjectInputStream vois = ValidatingObjectInputStream.builder() + .setByteArray(serialized) + .accept(ifaceClass) // not a feature + .get()) { + // @formatter:on + // not a feature + // assertInstanceOf(ifaceClass, vois.readObject()); + assertThrows(InvalidClassException.class, vois::readObject); + } + } + + @Test + void testRejectInterface() throws IOException, ClassNotFoundException { + final MockObject object = new MockObject(); + final byte[] serialized = SerializationUtils.serialize(object); + final Class ifaceClass = IFoo.class; + // @formatter:off + try (ValidatingObjectInputStream vois = ValidatingObjectInputStream.builder() + .setByteArray(serialized) + .accept("*") + .reject(ifaceClass) + .get()) { + // @formatter:on + assertThrows(InvalidClassException.class, vois::readObject); + } + } +}