Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@
import javax.net.ssl.SSLSession;
import javax.net.ssl.StandardConstants;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.Channel;
@@ -47,7 +47,9 @@
import java.util.function.Consumer;
import java.util.function.Function;

/** A server-side {@link TlsChannel}. */
/**
* A server-side {@link TlsChannel}.
*/
public class ServerTlsChannel implements TlsChannel {

private static final Logger LOGGER = Loggers.getLogger("connection.tls");
@@ -83,7 +85,7 @@ public SSLContext getSslContext(SniReader sniReader) throws IOException, EofExce
throw new TlsChannelCallbackException("SNI callback failed", e);
}
return chosenContext.orElseThrow(
() -> new SSLHandshakeException("No ssl context available for received SNI: " + nameOpt));
() -> new SSLHandshakeException("No ssl context available for received SNI: " + nameOpt));
}
}

@@ -111,12 +113,13 @@ private static SSLEngine defaultSSLEngineFactory(SSLContext sslContext) {
return engine;
}

/** Builder of {@link ServerTlsChannel} */
/**
* Builder of {@link ServerTlsChannel}
*/
public static class Builder extends TlsChannelBuilder<Builder> {

private final SslContextStrategy internalSslContextFactory;
private Function<SSLContext, SSLEngine> sslEngineFactory =
ServerTlsChannel::defaultSSLEngineFactory;
private Function<SSLContext, SSLEngine> sslEngineFactory = ServerTlsChannel::defaultSSLEngineFactory;

private Builder(ByteChannel underlying, SSLContext sslContext) {
super(underlying);
@@ -140,20 +143,20 @@ public Builder withEngineFactory(Function<SSLContext, SSLEngine> sslEngineFactor

public ServerTlsChannel build() {
return new ServerTlsChannel(
underlying,
internalSslContextFactory,
sslEngineFactory,
sessionInitCallback,
runTasks,
plainBufferAllocator,
encryptedBufferAllocator,
releaseBuffers,
waitForCloseConfirmation);
underlying,
internalSslContextFactory,
sslEngineFactory,
sessionInitCallback,
runTasks,
plainBufferAllocator,
encryptedBufferAllocator,
releaseBuffers,
waitForCloseConfirmation);
}
}

/**
* Create a new {@link Builder}, configured with a underlying {@link Channel} and a fixed {@link
* Create a new {@link Builder}, configured with an underlying {@link Channel} and a fixed {@link
* SSLContext}, which will be used to create the {@link SSLEngine}.
*
* @param underlying a reference to the underlying {@link ByteChannel}
@@ -165,16 +168,16 @@ public static Builder newBuilder(ByteChannel underlying, SSLContext sslContext)
}

/**
* Create a new {@link Builder}, configured with a underlying {@link Channel} and a custom {@link
* Create a new {@link Builder}, configured with an underlying {@link Channel} and a custom {@link
* SSLContext} factory, which will be used to create the context (in turn used to create the
* {@link SSLEngine}, as a function of the SNI received at the TLS connection start.
* {@link SSLEngine}), as a function of the SNI received at the TLS connection start.
*
* <p><b>Implementation note:</b><br>
* Due to limitations of {@link SSLEngine}, configuring a {@link ServerTlsChannel} to select the
* {@link SSLContext} based on the SNI value implies parsing the first TLS frame (ClientHello)
* independently of the SSLEngine.
*
* @param underlying a reference to the underlying {@link ByteChannel}
* @param underlying a reference to the underlying {@link ByteChannel}
* @param sslContextFactory a function from an optional SNI to the {@link SSLContext} to be used
* @return the new builder
* @see <a href="https://tools.ietf.org/html/rfc6066#section-3">Server Name Indication</a>
@@ -203,15 +206,15 @@ public static Builder newBuilder(ByteChannel underlying, SniSslContextFactory ss

// @formatter:off
private ServerTlsChannel(
ByteChannel underlying,
SslContextStrategy internalSslContextFactory,
Function<SSLContext, SSLEngine> engineFactory,
Consumer<SSLSession> sessionInitCallback,
boolean runTasks,
BufferAllocator plainBufAllocator,
BufferAllocator encryptedBufAllocator,
boolean releaseBuffers,
boolean waitForCloseConfirmation) {
ByteChannel underlying,
SslContextStrategy internalSslContextFactory,
Function<SSLContext, SSLEngine> engineFactory,
Consumer<SSLSession> sessionInitCallback,
boolean runTasks,
BufferAllocator plainBufAllocator,
BufferAllocator encryptedBufAllocator,
boolean releaseBuffers,
boolean waitForCloseConfirmation) {
this.underlying = underlying;
this.sslContextStrategy = internalSslContextFactory;
this.engineFactory = engineFactory;
@@ -221,8 +224,7 @@ private ServerTlsChannel(
this.encryptedBufAllocator = new TrackingAllocator(encryptedBufAllocator);
this.releaseBuffers = releaseBuffers;
this.waitForCloseConfirmation = waitForCloseConfirmation;
inEncrypted =
new BufferHolder(
inEncrypted = new BufferHolder(
"inEncrypted",
Optional.empty(),
encryptedBufAllocator,
@@ -242,7 +244,7 @@ public ByteChannel getUnderlying() {
/**
* Return the used {@link SSLContext}.
*
* @return if context if present, of null if the TLS connection as not been initializer, or the
* @return context if present, or null if the TLS connection as not been initializer, or the
* SNI not received yet.
*/
public SSLContext getSslContext() {
@@ -347,8 +349,12 @@ public void handshake() throws IOException {

@Override
public void close() throws IOException {
if (impl != null) impl.close();
if (inEncrypted != null) inEncrypted.dispose();
if (impl != null) {
impl.close();
}
if (inEncrypted != null) {
inEncrypted.dispose();
}
underlying.close();
}

@@ -370,8 +376,7 @@ private void initEngine() throws IOException, EofException {
LOGGER.trace("client threw exception in SSLEngine factory", e);
throw new TlsChannelCallbackException("SSLEngine creation callback failed", e);
}
impl =
new TlsChannelImpl(
impl = new TlsChannelImpl(
underlying,
underlying,
engine,
@@ -393,41 +398,33 @@ private void initEngine() throws IOException, EofException {
private Optional<SNIServerName> getServerNameIndication() throws IOException, EofException {
inEncrypted.prepare();
try {
int recordHeaderSize = readRecordHeaderSize();
while (inEncrypted.buffer.position() < recordHeaderSize) {
if (!inEncrypted.buffer.hasRemaining()) {
inEncrypted.enlarge();
// loop finishes using return statements
while (true) {
try {
inEncrypted.buffer.flip();
try {
Map<Integer, SNIServerName> serverNames = TlsExplorer.exploreTlsRecord(inEncrypted.buffer);
SNIServerName hostName = serverNames.get(StandardConstants.SNI_HOST_NAME);
if (hostName instanceof SNIHostName) {
return Optional.of(hostName);
} else {
return Optional.empty();
}
} finally {
inEncrypted.buffer.compact();
}
} catch (BufferUnderflowException e) {
if (!inEncrypted.buffer.hasRemaining()) {
inEncrypted.enlarge();
}
TlsChannelImpl.callChannelRead(underlying, inEncrypted.buffer); // IO block
}
TlsChannelImpl.readFromChannel(underlying, inEncrypted.buffer); // IO block
}
((Buffer) inEncrypted.buffer).flip();
Map<Integer, SNIServerName> serverNames = TlsExplorer.explore(inEncrypted.buffer);
inEncrypted.buffer.compact();
SNIServerName hostName = serverNames.get(StandardConstants.SNI_HOST_NAME);
if (hostName != null && hostName instanceof SNIHostName) {
SNIHostName sniHostName = (SNIHostName) hostName;
return Optional.of(sniHostName);
} else {
return Optional.empty();
}
} finally {
inEncrypted.release();
}
}

private int readRecordHeaderSize() throws IOException, EofException {
while (inEncrypted.buffer.position() < TlsExplorer.RECORD_HEADER_SIZE) {
if (!inEncrypted.buffer.hasRemaining()) {
throw new IllegalStateException("inEncrypted too small");
}
TlsChannelImpl.readFromChannel(underlying, inEncrypted.buffer); // IO block
}
((Buffer) inEncrypted.buffer).flip();
int recordHeaderSize = TlsExplorer.getRequiredSize(inEncrypted.buffer);
inEncrypted.buffer.compact();
return recordHeaderSize;
}

@Override
public boolean shutdown() throws IOException {
return impl != null && impl.shutdown();
@@ -442,4 +439,4 @@ public boolean shutdownReceived() {
public boolean shutdownSent() {
return impl != null && impl.shutdownSent();
}
}
}
Original file line number Diff line number Diff line change
@@ -29,11 +29,19 @@ public class ByteBufferSet {
public final int length;

public ByteBufferSet(ByteBuffer[] array, int offset, int length) {
if (array == null) throw new NullPointerException();
if (array.length < offset) throw new IndexOutOfBoundsException();
if (array.length < offset + length) throw new IndexOutOfBoundsException();
if (array == null) {
throw new NullPointerException();
}
if (array.length < offset) {
throw new IndexOutOfBoundsException();
}
if (array.length < offset + length) {
throw new IndexOutOfBoundsException();
}
for (int i = offset; i < offset + length; i++) {
if (array[i] == null) throw new NullPointerException();
if (array[i] == null) {
throw new NullPointerException();
}
}
this.array = array;
this.offset = offset;
@@ -56,10 +64,20 @@ public long remaining() {
return ret;
}

public long position() {
long ret = 0;
for (int i = offset; i < offset + length; i++) {
ret += array[i].position();
}
return ret;
}

public int putRemaining(ByteBuffer from) {
int totalBytes = 0;
for (int i = offset; i < offset + length; i++) {
if (!from.hasRemaining()) break;
if (!from.hasRemaining()) {
break;
}
ByteBuffer dstBuffer = array[i];
int bytes = Math.min(from.remaining(), dstBuffer.remaining());
ByteBufferUtil.copy(from, dstBuffer, bytes);
@@ -78,7 +96,9 @@ public ByteBufferSet put(ByteBuffer from, int length) {
int totalBytes = 0;
for (int i = offset; i < offset + this.length; i++) {
int pending = length - totalBytes;
if (pending == 0) break;
if (pending == 0) {
break;
}
int bytes = Math.min(pending, (int) remaining());
ByteBuffer dstBuffer = array[i];
ByteBufferUtil.copy(from, dstBuffer, bytes);
@@ -90,7 +110,9 @@ public ByteBufferSet put(ByteBuffer from, int length) {
public int getRemaining(ByteBuffer dst) {
int totalBytes = 0;
for (int i = offset; i < offset + length; i++) {
if (!dst.hasRemaining()) break;
if (!dst.hasRemaining()) {
break;
}
ByteBuffer srcBuffer = array[i];
int bytes = Math.min(dst.remaining(), srcBuffer.remaining());
ByteBufferUtil.copy(srcBuffer, dst, bytes);
@@ -109,7 +131,9 @@ public ByteBufferSet get(ByteBuffer dst, int length) {
int totalBytes = 0;
for (int i = offset; i < offset + this.length; i++) {
int pending = length - totalBytes;
if (pending == 0) break;
if (pending == 0) {
break;
}
ByteBuffer srcBuffer = array[i];
int bytes = Math.min(pending, srcBuffer.remaining());
ByteBufferUtil.copy(srcBuffer, dst, bytes);
@@ -124,19 +148,15 @@ public boolean hasRemaining() {

public boolean isReadOnly() {
for (int i = offset; i < offset + length; i++) {
if (array[i].isReadOnly()) return true;
if (array[i].isReadOnly()) {
return true;
}
}
return false;
}

@Override
public String toString() {
return "ByteBufferSet[array="
+ Arrays.toString(array)
+ ", offset="
+ offset
+ ", length="
+ length
+ "]";
return "ByteBufferSet[" + Arrays.toString(array) + ":" + offset + ":" + length + "]";
}
}
}
Original file line number Diff line number Diff line change
@@ -36,7 +36,6 @@
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ClosedChannelException;
@@ -48,44 +47,27 @@
import java.util.function.Consumer;

import static java.lang.String.format;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;

public class TlsChannelImpl implements ByteChannel {

private static final Logger LOGGER = Loggers.getLogger("connection.tls");

public static final int buffersInitialSize = 4096;

/** Official TLS max data size is 2^14 = 16k. Use 1024 more to account for the overhead */
/**
* Official TLS max data size is 2^14 = 16k. Use 1024 more to account for the overhead
*/
public static final int maxTlsPacketSize = 17 * 1024;

private static class UnwrapResult {
public final int bytesProduced;
public final HandshakeStatus lastHandshakeStatus;
public final boolean wasClosed;

public UnwrapResult(int bytesProduced, HandshakeStatus lastHandshakeStatus, boolean wasClosed) {
this.bytesProduced = bytesProduced;
this.lastHandshakeStatus = lastHandshakeStatus;
this.wasClosed = wasClosed;
}
}

private static class WrapResult {
public final int bytesConsumed;
public final HandshakeStatus lastHandshakeStatus;

public WrapResult(int bytesConsumed, HandshakeStatus lastHandshakeStatus) {
this.bytesConsumed = bytesConsumed;
this.lastHandshakeStatus = lastHandshakeStatus;
}
}

/** Used to signal EOF conditions from the underlying channel */
/**
* Used to signal EOF conditions from the underlying channel
*/
public static class EofException extends Exception {
private static final long serialVersionUID = -3859156713994602991L;

/** For efficiency, override this method to do nothing. */
/**
* For efficiency, override this method to do nothing.
*/
@Override
public Throwable fillInStackTrace() {
return this;
@@ -95,7 +77,7 @@ public Throwable fillInStackTrace() {
private final ReadableByteChannel readChannel;
private final WritableByteChannel writeChannel;
private final SSLEngine engine;
private BufferHolder inEncrypted;
private final BufferHolder inEncrypted;
private final Consumer<SSLSession> initSessionCallback;

private final boolean runTasks;
@@ -105,47 +87,42 @@ public Throwable fillInStackTrace() {

// @formatter:off
public TlsChannelImpl(
ReadableByteChannel readChannel,
WritableByteChannel writeChannel,
SSLEngine engine,
Optional<BufferHolder> inEncrypted,
Consumer<SSLSession> initSessionCallback,
boolean runTasks,
TrackingAllocator plainBufAllocator,
TrackingAllocator encryptedBufAllocator,
boolean releaseBuffers,
boolean waitForCloseConfirmation) {
ReadableByteChannel readChannel,
WritableByteChannel writeChannel,
SSLEngine engine,
Optional<BufferHolder> inEncrypted,
Consumer<SSLSession> initSessionCallback,
boolean runTasks,
TrackingAllocator plainBufAllocator,
TrackingAllocator encryptedBufAllocator,
boolean releaseBuffers,
boolean waitForCloseConfirmation) {
// @formatter:on
this.readChannel = readChannel;
this.writeChannel = writeChannel;
this.engine = engine;
this.inEncrypted =
inEncrypted.orElseGet(
() ->
new BufferHolder(
"inEncrypted",
Optional.empty(),
encryptedBufAllocator,
buffersInitialSize,
maxTlsPacketSize,
false /* plainData */,
releaseBuffers));
this.inEncrypted = inEncrypted.orElseGet(() -> new BufferHolder(
"inEncrypted",
Optional.empty(),
encryptedBufAllocator,
buffersInitialSize,
maxTlsPacketSize,
false /* plainData */,
releaseBuffers));
this.initSessionCallback = initSessionCallback;
this.runTasks = runTasks;
this.plainBufAllocator = plainBufAllocator;
this.encryptedBufAllocator = encryptedBufAllocator;
this.waitForCloseConfirmation = waitForCloseConfirmation;
inPlain =
new BufferHolder(
inPlain = new BufferHolder(
"inPlain",
Optional.empty(),
plainBufAllocator,
buffersInitialSize,
maxTlsPacketSize,
true /* plainData */,
releaseBuffers);
outEncrypted =
new BufferHolder(
outEncrypted = new BufferHolder(
"outEncrypted",
Optional.empty(),
encryptedBufAllocator,
@@ -159,7 +136,9 @@ public TlsChannelImpl(
private final Lock readLock = new ReentrantLock();
private final Lock writeLock = new ReentrantLock();

private volatile boolean negotiated = false;
private boolean handshakeStarted = false;

private volatile boolean handshakeCompleted = false;

/**
* Whether a IOException was received from the underlying channel or from the {@link SSLEngine}.
@@ -172,11 +151,27 @@ public TlsChannelImpl(
/** Whether a close_notify was already received. */
private volatile boolean shutdownReceived = false;

// decrypted data from inEncrypted
private BufferHolder inPlain;
/**
* Decrypted data from inEncrypted
*/
private final BufferHolder inPlain;

// contains data encrypted to send to the underlying channel
private BufferHolder outEncrypted;
/**
* Contains data encrypted to send to the underlying channel
*/
private final BufferHolder outEncrypted;

/**
* Reference to the current read buffer supplied by the client this field is only valid during a
* read operation. This field is used instead of {@link #inPlain} in order to avoid copying
* returned bytes when possible.
*/
private ByteBufferSet suppliedInPlain;

/**
* Bytes produced by the current read operation
*/
private int bytesToReturn;

/**
* Handshake wrap() method calls need a buffer to read from, even when they actually do not read
@@ -185,8 +180,7 @@ public TlsChannelImpl(
* <p>Note: standard SSLEngine is happy with no buffers, the empty buffer is here to make this
* work with Netty's OpenSSL's wrapper.
*/
private final ByteBufferSet dummyOut =
new ByteBufferSet(new ByteBuffer[] {ByteBuffer.allocate(0)});
private final ByteBufferSet dummyOut = new ByteBufferSet(new ByteBuffer[] {ByteBuffer.allocate(0)});

public Consumer<SSLSession> getSessionInitCallback() {
return initSessionCallback;
@@ -204,45 +198,53 @@ public TrackingAllocator getEncryptedBufferAllocator() {

public long read(ByteBufferSet dest) throws IOException {
checkReadBuffer(dest);
if (!dest.hasRemaining()) return 0;
if (!dest.hasRemaining()) {
return 0;
}
handshake();
readLock.lock();
try {
if (invalid || shutdownSent) {
throw new ClosedChannelException();
}
HandshakeStatus handshakeStatus = engine.getHandshakeStatus();
int bytesToReturn = inPlain.nullOrEmpty() ? 0 :inPlain.buffer.position();

long originalDestPosition = dest.position();
suppliedInPlain = dest;
bytesToReturn = inPlain.nullOrEmpty() ? 0 : inPlain.buffer.position();

while (true) {

// return bytes are soon as we have them
if (bytesToReturn > 0) {
if (inPlain.nullOrEmpty()) {
// if there is not in internal buffer, that means that the bytes must be in the supplied
// buffer
Util.assertTrue(dest.position() == originalDestPosition + bytesToReturn);
return bytesToReturn;
} else {
Util.assertTrue(inPlain.buffer.position() == bytesToReturn);
return transferPendingPlain(dest);
}
}

if (shutdownReceived) {
return -1;
}
Util.assertTrue(inPlain.nullOrEmpty());
switch (handshakeStatus) {
switch (engine.getHandshakeStatus()) {
case NEED_UNWRAP:
case NEED_WRAP:
bytesToReturn = handshake(Optional.of(dest), Optional.of(handshakeStatus));
handshakeStatus = NOT_HANDSHAKING;
writeAndHandshake();
break;
case NOT_HANDSHAKING:
case FINISHED:
UnwrapResult res = readAndUnwrap(Optional.of(dest));
if (res.wasClosed) {
readAndUnwrap();
if (shutdownReceived) {
return -1;
}
bytesToReturn = res.bytesProduced;
handshakeStatus = res.lastHandshakeStatus;
break;
case NEED_TASK:
handleTask();
handshakeStatus = engine.getHandshakeStatus();
break;
default:
// Unsupported stage eg: NEED_UNWRAP_AGAIN
@@ -252,20 +254,28 @@ public long read(ByteBufferSet dest) throws IOException {
} catch (EofException e) {
return -1;
} finally {
bytesToReturn = 0;
suppliedInPlain = null;
readLock.unlock();
}
}

private void handleTask() throws NeedsTaskException {
Runnable task = engine.getDelegatedTask();
if (runTasks) {
engine.getDelegatedTask().run();
LOGGER.trace("delegating in task: " + task);
task.run();
} else {
throw new NeedsTaskException(engine.getDelegatedTask());
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("task needed, throwing exception: " + task);
}
throw new NeedsTaskException(task);
}
}

/** Copies bytes from the internal input plain buffer to the supplied buffer. */
private int transferPendingPlain(ByteBufferSet dstBuffers) {
((Buffer) inPlain.buffer).flip(); // will read
inPlain.buffer.flip(); // will read
int bytes = dstBuffers.putRemaining(inPlain.buffer);
inPlain.buffer.compact(); // will write
boolean disposed = inPlain.release();
@@ -275,48 +285,62 @@ private int transferPendingPlain(ByteBufferSet dstBuffers) {
return bytes;
}

private UnwrapResult unwrapLoop(Optional<ByteBufferSet> dest, HandshakeStatus originalStatus)
throws SSLException {
ByteBufferSet effDest =
dest.orElseGet(
() -> {
inPlain.prepare();
return new ByteBufferSet(inPlain.buffer);
});
private SSLEngineResult unwrapLoop() throws SSLException {
ByteBufferSet effDest;
if (suppliedInPlain != null) {
effDest = suppliedInPlain;
} else {
inPlain.prepare();
effDest = new ByteBufferSet(inPlain.buffer);
}

while (true) {
Util.assertTrue(inPlain.nullOrEmpty());
SSLEngineResult result = callEngineUnwrap(effDest);
HandshakeStatus status = engine.getHandshakeStatus();

/*
* Note that data can be returned even in case of overflow, in that
* case, just return the data.
*/
if (result.bytesProduced() > 0
|| result.getStatus() == Status.BUFFER_UNDERFLOW
|| result.getStatus() == Status.CLOSED
|| result.getHandshakeStatus() != originalStatus) {
boolean wasClosed = result.getStatus() == Status.CLOSED;
return new UnwrapResult(result.bytesProduced(), result.getHandshakeStatus(), wasClosed);
if (result.bytesProduced() > 0) {
return result;
}
if (result.getStatus() == Status.CLOSED) {
return result;
}
if (result.getStatus() == Status.BUFFER_UNDERFLOW) {
return result;
}

if (result.getHandshakeStatus() == HandshakeStatus.FINISHED
|| status == HandshakeStatus.NEED_TASK
|| status == HandshakeStatus.NEED_WRAP) {
return result;
}
if (result.getStatus() == Status.BUFFER_OVERFLOW) {
if (dest.isPresent() && effDest == dest.get()) {
if (effDest == suppliedInPlain) {
/*
* The client-supplier buffer is not big enough. Use the
* internal inPlain buffer, also ensure that it is bigger
* internal inPlain buffer. Also ensure that it is bigger
* than the too-small supplied one.
*/
inPlain.prepare();
ensureInPlainCapacity(Math.min(((int) dest.get().remaining()) * 2, maxTlsPacketSize));
if (inPlain.buffer.capacity() <= suppliedInPlain.remaining()) {
inPlain.enlarge();
}
} else {
inPlain.enlarge();
}

// inPlain changed, re-create the wrapper
effDest = new ByteBufferSet(inPlain.buffer);
}
}
}

private SSLEngineResult callEngineUnwrap(ByteBufferSet dest) throws SSLException {
((Buffer) inEncrypted.buffer).flip();
inEncrypted.buffer.flip();
try {
SSLEngineResult result =
engine.unwrap(inEncrypted.buffer, dest.array, dest.offset, dest.length);
@@ -339,9 +363,9 @@ private SSLEngineResult callEngineUnwrap(ByteBufferSet dest) throws SSLException
}
}

private int readFromChannel() throws IOException, EofException {
private void readFromChannel() throws IOException, EofException {
try {
return readFromChannel(readChannel, inEncrypted.buffer);
callChannelRead(readChannel, inEncrypted.buffer);
} catch (WouldBlockException e) {
throw e;
} catch (IOException e) {
@@ -350,8 +374,8 @@ private int readFromChannel() throws IOException, EofException {
}
}

public static int readFromChannel(ReadableByteChannel readChannel, ByteBuffer buffer)
throws IOException, EofException {
public static void callChannelRead(ReadableByteChannel readChannel, ByteBuffer buffer)
throws IOException, EofException {
Util.assertTrue(buffer.hasRemaining());
LOGGER.trace("Reading from channel");
int c = readChannel.read(buffer); // IO block
@@ -364,7 +388,6 @@ public static int readFromChannel(ReadableByteChannel readChannel, ByteBuffer bu
if (c == 0) {
throw new NeedsReadException();
}
return c;
}

// write
@@ -389,27 +412,33 @@ public long write(ByteBufferSet source) throws IOException {

private long wrapAndWrite(ByteBufferSet source) throws IOException {
long bytesToConsume = source.remaining();
long bytesConsumed = 0;
outEncrypted.prepare();
try {
while (true) {
writeToChannel();
if (bytesConsumed == bytesToConsume) return bytesToConsume;
WrapResult res = wrapLoop(source);
bytesConsumed += res.bytesConsumed;
writeToChannel(); // IO block
if (source.remaining() == 0) {
return bytesToConsume;
}
SSLEngineResult result = wrapLoop(source);
if (result.getStatus() == Status.CLOSED) {
return bytesToConsume - source.remaining();
}
}
} finally {
outEncrypted.release();
}
}

private WrapResult wrapLoop(ByteBufferSet source) throws SSLException {
/**
* Returns last {@link HandshakeStatus} of the loop
*/
private SSLEngineResult wrapLoop(ByteBufferSet source) throws SSLException {
while (true) {
SSLEngineResult result = callEngineWrap(source);
switch (result.getStatus()) {
case OK:
case CLOSED:
return new WrapResult(result.bytesConsumed(), result.getHandshakeStatus());
return result;
case BUFFER_OVERFLOW:
Util.assertTrue(result.bytesConsumed() == 0);
outEncrypted.enlarge();
@@ -439,26 +468,14 @@ private SSLEngineResult callEngineWrap(ByteBufferSet source) throws SSLException
}
}

private void ensureInPlainCapacity(int newCapacity) {
if (inPlain.buffer.capacity() < newCapacity) {
if (LOGGER.isTraceEnabled()) {
LOGGER.trace(format(
"inPlain buffer too small, increasing from %s to %s",
inPlain.buffer.capacity(),
newCapacity));
}
inPlain.resize(newCapacity);
}
}

private void writeToChannel() throws IOException {
if (outEncrypted.buffer.position() == 0) {
return;
}
((Buffer) outEncrypted.buffer).flip();
outEncrypted.buffer.flip();
try {
try {
writeToChannel(writeChannel, outEncrypted.buffer);
callChannelWrite(writeChannel, outEncrypted.buffer); // IO block
} catch (WouldBlockException e) {
throw e;
} catch (IOException e) {
@@ -470,13 +487,12 @@ private void writeToChannel() throws IOException {
}
}

private static void writeToChannel(WritableByteChannel channel, ByteBuffer src)
throws IOException {
private static void callChannelWrite(WritableByteChannel channel, ByteBuffer src) throws IOException {
while (src.hasRemaining()) {
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("Writing to channel: " + src);
}
int c = channel.write(src);
int c = channel.write(src); // IO block
if (c == 0) {
/*
* If no bytesProduced were written, it means that the socket is
@@ -485,7 +501,7 @@ private static void writeToChannel(WritableByteChannel channel, ByteBuffer src)
throw new NeedsWriteException();
}
// blocking SocketChannels can write less than all the bytesProduced
// just before an error the loop forces the exception
// just before an error, the loop forces the exception
}
}

@@ -526,42 +542,57 @@ public void handshake() throws IOException {
}

private void doHandshake(boolean force) throws IOException, EofException {
if (!force && negotiated) return;
if (!force && handshakeCompleted) {
return;
}
initLock.lock();
try {
if (invalid || shutdownSent) throw new ClosedChannelException();
if (force || !negotiated) {
engine.beginHandshake();
LOGGER.trace("Called engine.beginHandshake()");
handshake(Optional.empty(), Optional.empty());
if (invalid || shutdownSent) {
throw new ClosedChannelException();
}
if (force || !handshakeCompleted) {

if (!handshakeStarted) {
LOGGER.trace("Called engine.beginHandshake()");
engine.beginHandshake();

// Some engines that do not support renegotiations may be sensitive to calling
// SSLEngine.beginHandshake() more than once. This guard prevents that.
// See: https://github.com/marianobarrios/tls-channel/issues/197
handshakeStarted = true;
}

writeAndHandshake();

if (engine.getSession().getProtocol().startsWith("DTLS")) {
throw new IllegalArgumentException("DTLS not supported");
}

handshakeCompleted = true;

// call client code
try {
initSessionCallback.accept(engine.getSession());
} catch (Exception e) {
LOGGER.trace("client code threw exception in session initialization callback", e);
throw new TlsChannelCallbackException("session initialization callback failed", e);
}
negotiated = true;
}
} finally {
initLock.unlock();
}
}

private int handshake(Optional<ByteBufferSet> dest, Optional<HandshakeStatus> handshakeStatus)
throws IOException, EofException {
private void writeAndHandshake() throws IOException, EofException {
readLock.lock();
try {
writeLock.lock();
try {
if (invalid || shutdownSent) {
throw new ClosedChannelException();
}
Util.assertTrue(inPlain.nullOrEmpty());
outEncrypted.prepare();
try {
writeToChannel(); // IO block
return handshakeLoop(dest, handshakeStatus);
handshakeLoop();
} finally {
outEncrypted.release();
}
@@ -573,58 +604,56 @@ private int handshake(Optional<ByteBufferSet> dest, Optional<HandshakeStatus> ha
}
}

private int handshakeLoop(Optional<ByteBufferSet> dest, Optional<HandshakeStatus> handshakeStatus)
throws IOException, EofException {
private void handshakeLoop() throws IOException, EofException {
Util.assertTrue(inPlain.nullOrEmpty());
HandshakeStatus status = handshakeStatus.orElseGet(() -> engine.getHandshakeStatus());
while (true) {
switch (status) {
switch (engine.getHandshakeStatus()) {
case NEED_WRAP:
Util.assertTrue(outEncrypted.nullOrEmpty());
WrapResult wrapResult = wrapLoop(dummyOut);
status = wrapResult.lastHandshakeStatus;
wrapLoop(dummyOut);
writeToChannel(); // IO block
break;
case NEED_UNWRAP:
UnwrapResult res = readAndUnwrap(dest);
status = res.lastHandshakeStatus;
if (res.bytesProduced > 0) return res.bytesProduced;
readAndUnwrap();
if (bytesToReturn > 0) {
return;
}
break;
case NOT_HANDSHAKING:
/*
* This should not really happen using SSLEngine, because
* handshaking ends with a FINISHED status. However, we accept
* this value to permit the use of a pass-through stub engine
* with no encryption.
*/
return 0;
return;
case NEED_TASK:
handleTask();
status = engine.getHandshakeStatus();
break;
case FINISHED:
return 0;
// this status is never returned by SSLEngine.getHandshakeStatus()
throw new IllegalStateException();
default:
// Unsupported stage eg: NEED_UNWRAP_AGAIN
return 0;
throw new IllegalStateException();
}
}
}

private UnwrapResult readAndUnwrap(Optional<ByteBufferSet> dest)
throws IOException, EofException {
private void readAndUnwrap() throws IOException, EofException {
// Save status before operation: use it to stop when status changes
HandshakeStatus orig = engine.getHandshakeStatus();
inEncrypted.prepare();
try {
while (true) {
Util.assertTrue(inPlain.nullOrEmpty());
UnwrapResult res = unwrapLoop(dest, orig);
if (res.bytesProduced > 0 || res.lastHandshakeStatus != orig || res.wasClosed) {
if (res.wasClosed) {
shutdownReceived = true;
}
return res;
SSLEngineResult result = unwrapLoop();
HandshakeStatus status = engine.getHandshakeStatus();
if (result.bytesProduced() > 0) {
bytesToReturn = result.bytesProduced();
return;
}
if (result.getStatus() == Status.CLOSED) {
shutdownReceived = true;
return;
}
if (result.getHandshakeStatus() == HandshakeStatus.FINISHED
|| status == HandshakeStatus.NEED_TASK
|| status == HandshakeStatus.NEED_WRAP) {
return;
}
if (!inEncrypted.buffer.hasRemaining()) {
inEncrypted.enlarge();
@@ -720,7 +749,7 @@ public boolean shutdown() throws IOException {
if (!shutdownReceived) {
try {
// IO block
readAndUnwrap(Optional.empty());
readAndUnwrap();
Util.assertTrue(shutdownReceived);
} catch (EofException e) {
throw new ClosedChannelException();
@@ -737,18 +766,9 @@ public boolean shutdown() throws IOException {
}

private void freeBuffers() {
if (inEncrypted != null) {
inEncrypted.dispose();
inEncrypted = null;
}
if (inPlain != null) {
inPlain.dispose();
inPlain = null;
}
if (outEncrypted != null) {
outEncrypted.dispose();
outEncrypted = null;
}
inEncrypted.dispose();
inPlain.dispose();
outEncrypted.dispose();
}

public boolean isOpen() {
Original file line number Diff line number Diff line change
@@ -29,158 +29,178 @@
import java.util.HashMap;
import java.util.Map;

/*
* Implement basic TLS parsing, just to read the SNI (as this is not done by
* {@link SSLEngine}.
*/
//** Implement basic TLS parsing, just to read the SNI. */
public final class TlsExplorer {

private TlsExplorer() {}

/** The header size of TLS/SSL records. */
public static final int RECORD_HEADER_SIZE = 5;

/**
* Returns the required number of bytesProduced in the {@code source} {@link ByteBuffer} necessary
* to explore SSL/TLS connection.
/*
* struct {
* uint8 major;
* uint8 minor;
* } ProtocolVersion;
*
* <p>This method tries to parse as few bytesProduced as possible from {@code source} byte buffer
* to get the length of an SSL/TLS record.
* enum {
* change_cipher_spec(20),
* alert(21),
* handshake(22),
* application_data(23),
* (255)
* } ContentType;
*
* @param source source buffer
* @return the required size
* struct {
* ContentType type;
* ProtocolVersion version;
* uint16 length;
* opaque fragment[TLSPlaintext.length];
* } TLSPlaintext;
*/
public static int getRequiredSize(ByteBuffer source) {
if (source.remaining() < RECORD_HEADER_SIZE) throw new BufferUnderflowException();
((Buffer) source).mark();
try {
byte firstByte = source.get();
source.get(); // second byte discarded
byte thirdByte = source.get();
if ((firstByte & 0x80) != 0 && thirdByte == 0x01) {
// looks like a V2ClientHello
return RECORD_HEADER_SIZE; // Only need the header fields
} else {
return (((source.get() & 0xFF) << 8) | (source.get() & 0xFF)) + 5;
}
} finally {
((Buffer) source).reset();
}
}
/** Explores a TLS record in search to the SNI. This method does not consume buffer. */
public static Map<Integer, SNIServerName> exploreTlsRecord(ByteBuffer input) throws SSLProtocolException {

public static Map<Integer, SNIServerName> explore(ByteBuffer source) throws SSLProtocolException {
if (source.remaining() < RECORD_HEADER_SIZE) throw new BufferUnderflowException();
((Buffer) source).mark();
input.mark();
try {
byte firstByte = source.get();
ignore(source, 1); // ignore second byte
byte thirdByte = source.get();
if ((firstByte & 0x80) != 0 && thirdByte == 0x01) {
// looks like a V2ClientHello
return new HashMap<>();
} else if (firstByte == 22) {
byte firstByte = input.get();
if (firstByte != 22) {
// 22: handshake record
return exploreTLSRecord(source, firstByte);
} else {
throw new SSLProtocolException("Not handshake record");
throw new SSLProtocolException("Not a handshake record");
}

ignore(input, 2); // ignore version

// Is there enough data for a full record?
int recordLength = getInt16(input);
if (recordLength > input.remaining()) {
throw new BufferUnderflowException();
}

return exploreHandshake(input, recordLength);
} finally {
((Buffer) source).reset();
input.reset();
}
}

/*
* struct { uint8 major; uint8 minor; } ProtocolVersion;
* enum {
* hello_request(0),
* client_hello(1),
* server_hello(2),
* certificate(11),
* server_key_exchange (12),
* certificate_request(13),
* server_hello_done(14),
* certificate_verify(15),
* client_key_exchange(16),
* finished(20),
* (255)
* } HandshakeType;
*
* enum { change_cipher_spec(20), alert(21), handshake(22),
* application_data(23), (255) } ContentType;
*
* struct { ContentType type; ProtocolVersion version; uint16 length; opaque
* fragment[TLSPlaintext.length]; } TLSPlaintext;
*/
private static Map<Integer, SNIServerName> exploreTLSRecord(ByteBuffer input, byte firstByte)
throws SSLProtocolException {
// Is it a handshake message?
if (firstByte != 22) // 22: handshake record
throw new SSLProtocolException("Not handshake record");
// Is there enough data for a full record?
int recordLength = getInt16(input);
if (recordLength > input.remaining()) throw new BufferUnderflowException();
return exploreHandshake(input, recordLength);
}

/*
* enum { hello_request(0), client_hello(1), server_hello(2),
* certificate(11), server_key_exchange (12), certificate_request(13),
* server_hello_done(14), certificate_verify(15), client_key_exchange(16),
* finished(20) (255) } HandshakeType;
*
* struct { HandshakeType msg_type; uint24 length; select (HandshakeType) {
* case hello_request: HelloRequest; case client_hello: ClientHello; case
* server_hello: ServerHello; case certificate: Certificate; case
* server_key_exchange: ServerKeyExchange; case certificate_request:
* CertificateRequest; case server_hello_done: ServerHelloDone; case
* certificate_verify: CertificateVerify; case client_key_exchange:
* ClientKeyExchange; case finished: Finished; } body; } Handshake;
* struct {
* HandshakeType msg_type;
* uint24 length;
* select (HandshakeType) {
* case hello_request: HelloRequest;
* case client_hello: ClientHello;
* case server_hello: ServerHello;
* case certificate: Certificate;
* case server_key_exchange: ServerKeyExchange;
* case certificate_request: CertificateRequest;
* case server_hello_done: ServerHelloDone;
* case certificate_verify: CertificateVerify;
* case client_key_exchange: ClientKeyExchange;
* case finished: Finished;
* } body;
* } Handshake;
*/
private static Map<Integer, SNIServerName> exploreHandshake(ByteBuffer input, int recordLength)
throws SSLProtocolException {
// What is the handshake type?
throws SSLProtocolException {
byte handshakeType = input.get();
if (handshakeType != 0x01) // 0x01: client_hello message
throw new SSLProtocolException("Not initial handshaking");
if (handshakeType != 0x01) {
// 0x01: client_hello message
throw new SSLProtocolException("Not an initial handshaking");
}

// What is the handshake body length?
int handshakeLength = getInt24(input);

// Theoretically, a single handshake message might span multiple
// records, but in practice this does not occur.
if (handshakeLength > recordLength - 4) // 4: handshake header size
throw new SSLProtocolException("Handshake message spans multiple records");
((Buffer) input).limit(handshakeLength + input.position());
if (handshakeLength > recordLength - 4) {
// 4: handshake header size
throw new SSLProtocolException("Handshake message spans multiple records");
}
input.limit(handshakeLength + input.position());

return exploreClientHello(input);
}

/*
* struct { uint32 gmt_unix_time; opaque random_bytes[28]; } Random;
* struct {
* uint32 gmt_unix_time;
* opaque random_bytes[28];
* } Random;
*
* opaque SessionID<0..32>;
*
* uint8 CipherSuite[2];
*
* enum { null(0), (255) } CompressionMethod;
* enum {
* null(0),
* (255)
* } CompressionMethod;
*
* struct { ProtocolVersion client_version; Random random; SessionID
* session_id; CipherSuite cipher_suites<2..2^16-2>; CompressionMethod
* compression_methods<1..2^8-1>; select (extensions_present) { case false:
* struct {}; case true: Extension extensions<0..2^16-1>; }; } ClientHello;
* struct {
* ProtocolVersion client_version;
* Random random;
* SessionID session_id;
* CipherSuite cipher_suites<2..2^16-2>;
* CompressionMethod compression_methods<1..2^8-1>;
* select (extensions_present) {
* case false: struct {};
* case true: Extension extensions<0..2^16-1>;
* };
* } ClientHello;
*/
private static Map<Integer, SNIServerName> exploreClientHello(ByteBuffer input)
throws SSLProtocolException {
private static Map<Integer, SNIServerName> exploreClientHello(ByteBuffer input) throws SSLProtocolException {
ignore(input, 2); // ignore version
ignore(input, 32); // ignore random; 32: the length of Random
ignoreByteVector8(input); // ignore session id
ignoreByteVector16(input); // ignore cipher_suites
ignoreByteVector8(input); // ignore compression methods
if (input.remaining() > 0) return exploreExtensions(input);
else return new HashMap<>();
if (input.hasRemaining()) {
return exploreExtensions(input);
} else {
return new HashMap<>();
}
}

/*
* struct { ExtensionType extension_type; opaque extension_data<0..2^16-1>;
* struct {
* ExtensionType extension_type;
* opaque extension_data<0..2^16-1>;
* } Extension;
*
* enum { server_name(0), max_fragment_length(1), client_certificate_url(2),
* trusted_ca_keys(3), truncated_hmac(4), status_request(5), (65535) }
* enum {
* server_name(0),
* max_fragment_length(1),
* client_certificate_url(2),
* trusted_ca_keys(3),
* truncated_hmac(4),
* status_request(5),
* (65535)
* }
* ExtensionType;
*/
private static Map<Integer, SNIServerName> exploreExtensions(ByteBuffer input)
throws SSLProtocolException {
private static Map<Integer, SNIServerName> exploreExtensions(ByteBuffer input) throws SSLProtocolException {
int length = getInt16(input); // length of extensions
while (length > 0) {
int extType = getInt16(input); // extension type
int extLen = getInt16(input); // length of extension data
if (extType == 0x00) { // 0x00: type of server name indication
if (extType == 0x00) {
// 0x00: type of server name indication
return exploreSNIExt(input, extLen);
} else { // ignore other extensions
} else {
// ignore other extensions
ignore(input, extLen);
}
length -= extLen + 4;
@@ -189,49 +209,65 @@ private static Map<Integer, SNIServerName> exploreExtensions(ByteBuffer input)
}

/*
* struct { NameType name_type; select (name_type) { case host_name:
* HostName; } name; } ServerName;
* struct {
* NameType name_type;
* select (name_type) {
* case host_name: HostName;
* } name;
* } ServerName;
*
* enum { host_name(0), (255) } NameType;
* enum {
* host_name(0),
* (255)
* } NameType;
*
* opaque HostName<1..2^16-1>;
*
* struct { ServerName server_name_list<1..2^16-1> } ServerNameList;
* struct {
* ServerName server_name_list<1..2^16-1>
* } ServerNameList;
*/
private static Map<Integer, SNIServerName> exploreSNIExt(ByteBuffer input, int extLen)
throws SSLProtocolException {
private static Map<Integer, SNIServerName> exploreSNIExt(ByteBuffer input, int extLen) throws SSLProtocolException {
Map<Integer, SNIServerName> sniMap = new HashMap<>();
int remains = extLen;
if (extLen >= 2) { // "server_name" extension in ClientHello
if (extLen >= 2) {
// "server_name" extension in ClientHello
int listLen = getInt16(input); // length of server_name_list
if (listLen == 0 || listLen + 2 != extLen)
if (listLen == 0 || listLen + 2 != extLen) {
throw new SSLProtocolException("Invalid server name indication extension");
}
remains -= 2; // 2: the length field of server_name_list
while (remains > 0) {
int code = getInt8(input); // name_type
int snLen = getInt16(input); // length field of server name
if (snLen > remains)
if (snLen > remains) {
throw new SSLProtocolException("Not enough data to fill declared vector size");
}
byte[] encoded = new byte[snLen];
input.get(encoded);
SNIServerName serverName;
if (code == StandardConstants.SNI_HOST_NAME) {
if (encoded.length == 0)
if (encoded.length == 0) {
throw new SSLProtocolException("Empty HostName in server name indication");
}
serverName = new SNIHostName(encoded);
} else {
serverName = new UnknownServerName(code, encoded);
}
// check for duplicated server name type
if (sniMap.put(serverName.getType(), serverName) != null)
if (sniMap.put(serverName.getType(), serverName) != null) {
throw new SSLProtocolException("Duplicated server name of type " + serverName.getType());
}
remains -= encoded.length + 3; // NameType: 1 byte; HostName;
// length: 2 bytesProduced
}
} else if (extLen == 0) { // "server_name" extension in ServerHello
} else if (extLen == 0) {
// "server_name" extension in ServerHello
throw new SSLProtocolException("Not server name indication extension in client");
}
if (remains != 0) throw new SSLProtocolException("Invalid server name indication extension");
if (remains != 0) {
throw new SSLProtocolException("Invalid server name indication extension");
}
return sniMap;
}

@@ -257,7 +293,10 @@ private static void ignoreByteVector16(ByteBuffer input) {

private static void ignore(ByteBuffer input, int length) {
if (length != 0) {
((Buffer) input).position(input.position() + length);
if (input.remaining() < length) {
throw new BufferUnderflowException();
}
input.position(input.position() + length);
}
}

@@ -267,4 +306,4 @@ private static class UnknownServerName extends SNIServerName {
super(code, encoded);
}
}
}
}