Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions jdk/src/share/classes/sun/security/ssl/SSLSocketImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,8 @@ public void startHandshake() throws IOException {
if (!conContext.isNegotiated) {
readHandshakeRecord();
}
} catch (InterruptedIOException iioe) {
handleException(iioe);
} catch (IOException ioe) {
throw conContext.fatal(Alert.HANDSHAKE_FAILURE,
"Couldn't kickstart handshaking", ioe);
Expand Down Expand Up @@ -1313,12 +1315,11 @@ private int readHandshakeRecord() throws IOException {
}
} catch (SSLException ssle) {
throw ssle;
} catch (InterruptedIOException iioe) {
// don't change exception in case of timeouts or interrupts
throw iioe;
} catch (IOException ioe) {
if (!(ioe instanceof SSLException)) {
throw new SSLException("readHandshakeRecord", ioe);
} else {
throw ioe;
}
throw new SSLException("readHandshakeRecord", ioe);
}
}

Expand Down Expand Up @@ -1379,6 +1380,9 @@ private ByteBuffer readApplicationRecord(
}
} catch (SSLException ssle) {
throw ssle;
} catch (InterruptedIOException iioe) {
// don't change exception in case of timeouts or interrupts
throw iioe;
} catch (IOException ioe) {
if (!(ioe instanceof SSLException)) {
throw new SSLException("readApplicationRecord", ioe);
Expand Down
216 changes: 112 additions & 104 deletions jdk/src/share/classes/sun/security/ssl/SSLSocketInputRecord.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) 1996, 2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2020, Azul Systems, Inc. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand All @@ -26,6 +27,7 @@
package sun.security.ssl;

import java.io.EOFException;
import java.io.InterruptedIOException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
Expand All @@ -47,37 +49,31 @@
final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
private InputStream is = null;
private OutputStream os = null;
private final byte[] temporary = new byte[1024];
private final byte[] header = new byte[headerSize];
private int headerOff = 0;
// Cache for incomplete record body.
private ByteBuffer recordBody = ByteBuffer.allocate(1024);

private boolean formatVerified = false; // SSLv2 ruled out?

// Cache for incomplete handshake messages.
private ByteBuffer handshakeBuffer = null;

private boolean hasHeader = false; // Had read the record header

SSLSocketInputRecord(HandshakeHash handshakeHash) {
super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
}

@Override
int bytesInCompletePacket() throws IOException {
if (!hasHeader) {
// read exactly one record
try {
int really = read(is, temporary, 0, headerSize);
if (really < 0) {
// EOF: peer shut down incorrectly
return -1;
}
} catch (EOFException eofe) {
// The caller will handle EOF.
return -1;
}
hasHeader = true;
// read header
try {
readHeader();
} catch (EOFException eofe) {
// The caller will handle EOF.
return -1;
}

byte byteZero = temporary[0];
byte byteZero = header[0];
int len = 0;

/*
Expand All @@ -93,9 +89,9 @@ int bytesInCompletePacket() throws IOException {
* Last sanity check that it's not a wild record
*/
if (!ProtocolVersion.isNegotiable(
temporary[1], temporary[2], false)) {
header[1], header[2], false)) {
throw new SSLException("Unrecognized record version " +
ProtocolVersion.nameOf(temporary[1], temporary[2]) +
ProtocolVersion.nameOf(header[1], header[2]) +
" , plaintext connection?");
}

Expand All @@ -109,8 +105,8 @@ int bytesInCompletePacket() throws IOException {
/*
* One of the SSLv3/TLS message types.
*/
len = ((temporary[3] & 0xFF) << 8) +
(temporary[4] & 0xFF) + headerSize;
len = ((header[3] & 0xFF) << 8) +
(header[4] & 0xFF) + headerSize;
} else {
/*
* Must be SSLv2 or something unknown.
Expand All @@ -121,11 +117,11 @@ int bytesInCompletePacket() throws IOException {
*/
boolean isShort = ((byteZero & 0x80) != 0);

if (isShort && ((temporary[2] == 1) || (temporary[2] == 4))) {
if (isShort && ((header[2] == 1) || (header[2] == 4))) {
if (!ProtocolVersion.isNegotiable(
temporary[3], temporary[4], false)) {
header[3], header[4], false)) {
throw new SSLException("Unrecognized record version " +
ProtocolVersion.nameOf(temporary[3], temporary[4]) +
ProtocolVersion.nameOf(header[3], header[4]) +
" , plaintext connection?");
}

Expand All @@ -138,9 +134,9 @@ int bytesInCompletePacket() throws IOException {
//
// int mask = (isShort ? 0x7F : 0x3F);
// len = ((byteZero & mask) << 8) +
// (temporary[1] & 0xFF) + (isShort ? 2 : 3);
// (header[1] & 0xFF) + (isShort ? 2 : 3);
//
len = ((byteZero & 0x7F) << 8) + (temporary[1] & 0xFF) + 2;
len = ((byteZero & 0x7F) << 8) + (header[1] & 0xFF) + 2;
} else {
// Gobblygook!
throw new SSLException(
Expand All @@ -160,34 +156,41 @@ Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset,
return null;
}

if (!hasHeader) {
// read exactly one record
int really = read(is, temporary, 0, headerSize);
if (really < 0) {
throw new EOFException("SSL peer shut down incorrectly");
}
hasHeader = true;
}
// read header
readHeader();

Plaintext plaintext = null;
if (!formatVerified) {
formatVerified = true;
Plaintext[] plaintext = null;
boolean cleanInBuffer = true;
try {
if (!formatVerified) {
formatVerified = true;

/*
* The first record must either be a handshake record or an
* alert message. If it's not, it is either invalid or an
* SSLv2 message.
*/
if ((temporary[0] != ContentType.HANDSHAKE.id) &&
(temporary[0] != ContentType.ALERT.id)) {
hasHeader = false;
return handleUnknownRecord(temporary);
/*
* The first record must either be a handshake record or an
* alert message. If it's not, it is either invalid or an
* SSLv2 message.
*/
if ((header[0] != ContentType.HANDSHAKE.id) &&
(header[0] != ContentType.ALERT.id)) {
plaintext = handleUnknownRecord();
}
}
}

// The record header should has consumed.
hasHeader = false;
return decodeInputRecord(temporary);
// The record header should has consumed.
if (plaintext == null) {
plaintext = decodeInputRecord();
}
} catch(InterruptedIOException e) {
// do not clean header and recordBody in case of Socket Timeout
cleanInBuffer = false;
throw e;
} finally {
if (cleanInBuffer) {
headerOff = 0;
recordBody.clear();
}
}
return plaintext;
}

@Override
Expand All @@ -200,9 +203,7 @@ void setDeliverStream(OutputStream outputStream) {
this.os = outputStream;
}

// Note that destination may be null
private Plaintext[] decodeInputRecord(
byte[] header) throws IOException, BadPaddingException {
private Plaintext[] decodeInputRecord() throws IOException, BadPaddingException {
byte contentType = header[0]; // pos: 0
byte majorVersion = header[1]; // pos: 1
byte minorVersion = header[2]; // pos: 2
Expand All @@ -227,30 +228,27 @@ private Plaintext[] decodeInputRecord(
}

//
// Read a complete record.
// Read a complete record and store in the recordBody
// recordBody is used to cache incoming record and restore in case of
// read operation timedout
//
ByteBuffer destination = ByteBuffer.allocate(headerSize + contentLen);
int dstPos = destination.position();
destination.put(temporary, 0, headerSize);
while (contentLen > 0) {
int howmuch = Math.min(temporary.length, contentLen);
int really = read(is, temporary, 0, howmuch);
if (really < 0) {
throw new EOFException("SSL peer shut down incorrectly");
if (recordBody.position() == 0) {
if (recordBody.capacity() < contentLen) {
recordBody = ByteBuffer.allocate(contentLen);
}

destination.put(temporary, 0, howmuch);
contentLen -= howmuch;
recordBody.limit(contentLen);
} else {
contentLen = recordBody.remaining();
}
destination.flip();
destination.position(dstPos + headerSize);
readFully(contentLen);
recordBody.flip();

if (SSLLogger.isOn && SSLLogger.isOn("record")) {
SSLLogger.fine(
"READ: " +
ProtocolVersion.nameOf(majorVersion, minorVersion) +
" " + ContentType.nameOf(contentType) + ", length = " +
destination.remaining());
recordBody.remaining());
}

//
Expand All @@ -259,7 +257,7 @@ private Plaintext[] decodeInputRecord(
ByteBuffer fragment;
try {
Plaintext plaintext =
readCipher.decrypt(contentType, destination, null);
readCipher.decrypt(contentType, recordBody, null);
fragment = plaintext.fragment;
contentType = plaintext.contentType;
} catch (BadPaddingException bpe) {
Expand Down Expand Up @@ -368,8 +366,7 @@ private Plaintext[] decodeInputRecord(
};
}

private Plaintext[] handleUnknownRecord(
byte[] header) throws IOException, BadPaddingException {
private Plaintext[] handleUnknownRecord() throws IOException, BadPaddingException {
byte firstByte = header[0];
byte thirdByte = header[2];

Expand Down Expand Up @@ -411,32 +408,29 @@ private Plaintext[] handleUnknownRecord(
}

int msgLen = ((header[0] & 0x7F) << 8) | (header[1] & 0xFF);

ByteBuffer destination = ByteBuffer.allocate(headerSize + msgLen);
destination.put(temporary, 0, headerSize);
msgLen -= 3; // had read 3 bytes of content as header
while (msgLen > 0) {
int howmuch = Math.min(temporary.length, msgLen);
int really = read(is, temporary, 0, howmuch);
if (really < 0) {
throw new EOFException("SSL peer shut down incorrectly");
if (recordBody.position() == 0) {
if (recordBody.capacity() < (headerSize + msgLen)) {
recordBody = ByteBuffer.allocate(headerSize + msgLen);
}

destination.put(temporary, 0, howmuch);
msgLen -= howmuch;
recordBody.limit(headerSize + msgLen);
recordBody.put(header, 0, headerSize);
} else {
msgLen = recordBody.remaining();
}
destination.flip();
msgLen -= 3; // had read 3 bytes of content as header
readFully(msgLen);
recordBody.flip();

/*
* If we can map this into a V3 ClientHello, read and
* hash the rest of the V2 handshake, turn it into a
* V3 ClientHello message, and pass it up.
*/
destination.position(2); // exclude the header
handshakeHash.receive(destination);
destination.position(0);
recordBody.position(2); // exclude the header
handshakeHash.receive(recordBody);
recordBody.position(0);

ByteBuffer converted = convertToClientHello(destination);
ByteBuffer converted = convertToClientHello(recordBody);

if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
SSLLogger.fine(
Expand All @@ -456,28 +450,42 @@ private Plaintext[] handleUnknownRecord(
}
}

// Read the exact bytes of data, otherwise, return -1.
private static int read(InputStream is,
byte[] buffer, int offset, int len) throws IOException {
int n = 0;
while (n < len) {
int readLen = is.read(buffer, offset + n, len - n);
if (readLen < 0) {
if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
SSLLogger.fine("Raw read: EOF");
}
return -1;
// Read the exact bytes of data, otherwise, throw IOException.
private int readFully(int len) throws IOException {
int end = len + recordBody.position();
int off = recordBody.position();
try {
while (off < end) {
off += read(is, recordBody.array(), off, end - off);
}
} finally {
recordBody.position(off);
}
return len;
}

// Read SSE record header, otherwise, throw IOException.
private int readHeader() throws IOException {
while (headerOff < headerSize) {
headerOff += read(is, header, headerOff, headerSize - headerOff);
}
return headerSize;
}

private static int read(InputStream is, byte[] buf, int off, int len) throws IOException {
int readLen = is.read(buf, off, len);
if (readLen < 0) {
if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
ByteBuffer bb = ByteBuffer.wrap(buffer, offset + n, readLen);
SSLLogger.fine("Raw read", bb);
SSLLogger.fine("Raw read: EOF");
}

n += readLen;
throw new EOFException("SSL peer shut down incorrectly");
}

return n;
if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
ByteBuffer bb = ByteBuffer.wrap(buf, off, readLen);
SSLLogger.fine("Raw read", bb);
}
return readLen;
}

// Try to use up the input stream without impact the performance too much.
Expand Down
Loading