Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import java.util.zip.Deflater;
import java.util.zip.Inflater;

import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException;

public final class PerMessageDeflateExtension implements WebSocketExtension {

private static final byte[] TAIL = new byte[]{0x00, 0x00, (byte) 0xFF, (byte) 0xFF};
Expand Down Expand Up @@ -76,6 +78,14 @@ public boolean usesRsv1() {

@Override
public ByteBuffer decode(final WebSocketFrameType type, final boolean fin, final ByteBuffer payload) throws WebSocketException {
return decode(type, fin, payload, 0L);
}

@Override
public ByteBuffer decode(final WebSocketFrameType type,
final boolean fin,
final ByteBuffer payload,
final long maxOutputSize) throws WebSocketException {
if (!isDataFrame(type) && type != WebSocketFrameType.CONTINUATION) {
throw new WebSocketException("Unsupported frame type for permessage-deflate: " + type);
}
Expand All @@ -94,14 +104,23 @@ public ByteBuffer decode(final WebSocketFrameType type, final boolean fin, final
inflater.setInput(withTail);
final ByteArrayOutputStream out = new ByteArrayOutputStream(Math.max(128, input.length));
final byte[] buffer = new byte[Math.min(16384, Math.max(1024, input.length * 2))];
long produced = 0L;
try {
while (!inflater.needsInput()) {
final int count = inflater.inflate(buffer);
if (count == 0 && inflater.needsInput()) {
break;
}
// Enforce the decoded size cap during inflation, not after, so a small
// compressed payload cannot expand into a huge buffer before we react.
if (maxOutputSize > 0L && produced + count > maxOutputSize) {
throw new WebSocketProtocolException(1009, "Message too big");
}
out.write(buffer, 0, count);
produced += count;
}
} catch (final WebSocketProtocolException wspe) {
throw wspe;
} catch (final Exception ex) {
throw new WebSocketException("Unable to inflate payload", ex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ default ByteBuffer decode(
return payload;
}

/**
* Decode a frame payload, aborting as soon as the produced output exceeds
* {@code maxOutputSize}. A non-positive limit means no limit. Implementations
* that may expand input (e.g. permessage-deflate) MUST honour the limit during
* the expansion step, not only after it, to prevent decompression-bomb attacks.
*
* @since 5.7
*/
default ByteBuffer decode(
final WebSocketFrameType type,
final boolean fin,
final ByteBuffer payload,
final long maxOutputSize) throws WebSocketException {
return decode(type, fin, payload);
}

default ByteBuffer encode(
final WebSocketFrameType type,
final boolean fin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,22 +112,23 @@ WebSocketFrame readFrame() throws IOException {
payload[i] = (byte) (payload[i] ^ maskKey[i % 4]);
}
ByteBuffer data = ByteBuffer.wrap(payload);
final long maxOutputSize = config.getMaxMessageSize();
if (rsv1 && rsv1Extension != null) {
data = rsv1Extension.decode(type, fin, data);
data = rsv1Extension.decode(type, fin, data, maxOutputSize);
continuationCompressed = !fin && (type == WebSocketFrameType.TEXT || type == WebSocketFrameType.BINARY);
} else if (type == WebSocketFrameType.CONTINUATION && continuationCompressed && rsv1Extension != null) {
data = rsv1Extension.decode(type, fin, data);
data = rsv1Extension.decode(type, fin, data, maxOutputSize);
if (fin) {
continuationCompressed = false;
}
} else if (type == WebSocketFrameType.CONTINUATION && fin) {
continuationCompressed = false;
}
if (rsv2 && rsv2Extension != null) {
data = rsv2Extension.decode(type, fin, data);
data = rsv2Extension.decode(type, fin, data, maxOutputSize);
}
if (rsv3 && rsv3Extension != null) {
data = rsv3Extension.decode(type, fin, data);
data = rsv3Extension.decode(type, fin, data, maxOutputSize);
}
return new WebSocketFrame(fin, false, false, false, type, data);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.apache.hc.core5.websocket.WebSocketHandler;
import org.apache.hc.core5.websocket.WebSocketHandshake;
import org.apache.hc.core5.websocket.WebSocketSession;
import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException;

final class WebSocketH2ServerExchangeHandler implements AsyncServerExchangeHandler {

Expand Down Expand Up @@ -160,6 +161,13 @@ public void handleRequest(
try {
handler.onOpen(session);
new WebSocketServerProcessor(session, handler, config.getMaxMessageSize()).process();
} catch (final WebSocketProtocolException ex) {
handler.onError(session, ex);
try {
session.close(ex.closeCode, ex.getMessage());
} catch (final IOException ignore) {
// ignore
}
} catch (final Exception ex) {
handler.onError(session, ex);
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@
*/
package org.apache.hc.core5.websocket;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.zip.Deflater;

import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException;
import org.junit.jupiter.api.Test;

class PerMessageDeflateExtensionTest {
Expand All @@ -56,6 +60,42 @@ void decodesFragmentedMessage() throws Exception {
assertEquals("fragmented message", WebSocketSession.decodeText(ByteBuffer.wrap(joined.toByteArray())));
}

@Test
void decodeWithinLimitSucceeds() throws Exception {
final byte[] plain = "hello world hello world hello world".getBytes(StandardCharsets.UTF_8);
final byte[] compressed = deflateWithSyncFlush(plain);

final PerMessageDeflateExtension ext = new PerMessageDeflateExtension();
final ByteBuffer out = ext.decode(WebSocketFrameType.TEXT, true, ByteBuffer.wrap(compressed), plain.length + 16L);

assertArrayEquals(plain, toBytes(out));
}

@Test
void decodeInflationBombIsRejectedDuringInflate() {
final byte[] plain = new byte[64 * 1024];
Arrays.fill(plain, (byte) 'A');
final byte[] compressed = deflateWithSyncFlush(plain);

final PerMessageDeflateExtension ext = new PerMessageDeflateExtension();
final WebSocketProtocolException ex = assertThrows(WebSocketProtocolException.class,
() -> ext.decode(WebSocketFrameType.BINARY, true, ByteBuffer.wrap(compressed), 1024L));
assertEquals(1009, ex.closeCode);
assertEquals("Message too big", ex.getMessage());
}

@Test
void decodeZeroLimitMeansUnlimited() throws Exception {
final byte[] plain = new byte[8 * 1024];
Arrays.fill(plain, (byte) 'B');
final byte[] compressed = deflateWithSyncFlush(plain);

final PerMessageDeflateExtension ext = new PerMessageDeflateExtension();
final ByteBuffer out = ext.decode(WebSocketFrameType.BINARY, true, ByteBuffer.wrap(compressed), 0L);

assertArrayEquals(plain, toBytes(out));
}

private static byte[] deflateWithSyncFlush(final byte[] input) {
final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
deflater.setInput(input);
Expand Down
Loading