diff --git a/src/protocols/DisconnectionOrigin.ts b/src/protocols/DisconnectionOrigin.ts index 6d0530bfb..05600c67d 100644 --- a/src/protocols/DisconnectionOrigin.ts +++ b/src/protocols/DisconnectionOrigin.ts @@ -2,4 +2,5 @@ export enum DisconnectionOrigin { WEBSOCKET_AUTH_RENEWAL = "websocket/auth-renewal", USER_CONNECTION_CLOSED = "user/connection-closed", NETWORK_ERROR = "network/error", + PAYLOAD_MAX_SIZE_EXCEEDED = "payload/max-size-exceeded", } diff --git a/src/protocols/WebSocket.ts b/src/protocols/WebSocket.ts index 8e52fcbe6..06690f73d 100644 --- a/src/protocols/WebSocket.ts +++ b/src/protocols/WebSocket.ts @@ -6,6 +6,7 @@ import { JSONObject } from "../types"; import { RequestPayload } from "../types/RequestPayload"; import HttpProtocol from "./Http"; import { DisconnectionOrigin } from "./DisconnectionOrigin"; +import { parseSize } from "../utils/parseSize"; /** * WebSocket protocol used to connect to a Kuzzle server. @@ -15,6 +16,7 @@ export default class WebSocketProtocol extends BaseProtocolRealtime { private options: any; private client: any; private lasturl: any; + private maxPayloadSize: number; private ping: any; private waitForPong: boolean; private pingIntervalId: ReturnType; @@ -77,6 +79,7 @@ export default class WebSocketProtocol extends BaseProtocolRealtime { typeof options.pingInterval === "number" ? options.pingInterval : 2000; this.client = null; this.lasturl = null; + this.maxPayloadSize = null; } /** @@ -113,11 +116,13 @@ export default class WebSocketProtocol extends BaseProtocolRealtime { }); } - this.client.onopen = () => { + this.client.onopen = async () => { this.clientConnected(); this.setupPingPong(); + await this.getMaxPayloadSize(); + return resolve(); }; @@ -238,6 +243,20 @@ export default class WebSocketProtocol extends BaseProtocolRealtime { * @param {Object} payload */ send(request: RequestPayload, options: JSONObject = {}) { + if ( + this.maxPayloadSize !== null && + Buffer.byteLength(JSON.stringify(request), "utf8") > this.maxPayloadSize + ) { + const error: any = new Error( + `Payload size exceeded the maximum allowed by the server ${this.maxPayloadSize} bytes` + ); + + this.emit("networkError", { error }); + this.clientDisconnected(DisconnectionOrigin.PAYLOAD_MAX_SIZE_EXCEEDED); + + return; + } + if (!this.client || this.client.readyState !== this.client.OPEN) { return; } @@ -342,7 +361,7 @@ export default class WebSocketProtocol extends BaseProtocolRealtime { // If we were waiting for a pong that never occured before the next ping cycle we throw an error if (this.waitForPong) { const error: any = new Error( - "Kuzzle does'nt respond to ping. Connection lost." + "Kuzzle doesn't respond to ping. Connection lost." ); error.status = 503; @@ -359,4 +378,37 @@ export default class WebSocketProtocol extends BaseProtocolRealtime { } }, this._pingInterval); } + /** + * Get the maximum payload size allowed by the server + * Stores the value in `this.maxPayloadSize` + **/ + async getMaxPayloadSize() { + return new Promise((resolve, reject) => { + const originalOnMessage = this.client.onmessage; + this.client.onmessage = (payload) => { + try { + const data = JSON.parse(payload.data || payload); + + // Check if the message corresponds to the `getMaxPayloadSize` response + if ( + data.result && + data.result.server && + data.result.server.maxRequestSize + ) { + this.maxPayloadSize = parseSize(data.result.server.maxRequestSize); + + // Restore the original `onmessage` handler + this.client.onmessage = originalOnMessage; + resolve(this.maxPayloadSize); + return; + } + } catch (error) { + reject(error); + } + }; + + // Send the request + this.send({ action: "getConfig", controller: "server" }); + }); + } } diff --git a/src/utils/parseSize.js b/src/utils/parseSize.js new file mode 100644 index 000000000..bd5fd2551 --- /dev/null +++ b/src/utils/parseSize.js @@ -0,0 +1,39 @@ +function parseSize(size) { + const units = { + b: 1, // Bytes + kb: 1024, // Kilobytes + mb: 1024 * 1024, // Megabytes + // eslint-disable-next-line sort-keys + gb: 1024 * 1024 * 1024, // Gigabytes + }; + + if (typeof size !== "string") { + throw new Error( + `Invalid size input: expected a string, got ${typeof size}` + ); + } + + // Extract numeric value and unit + const match = /^(\d+(?:\.\d+)?)(b|kb|mb|gb)?$/i.exec(size.trim()); + if (!match) { + throw new Error( + `Invalid size format: "${size}". Expected formats like "2MB", "500KB", "3GB", etc.` + ); + } + + const value = parseFloat(match[1]); // Get the numeric part + const unit = (match[2] || "b").toLowerCase(); // Default to bytes if no unit is provided + + if (!units[unit]) { + throw new Error( + `Unknown size unit: "${unit}". Allowed units are B, KB, MB, GB.` + ); + } + + // Convert to bytes + return value * units[unit]; +} + +module.exports = { + parseSize, +};