From 0bac3e09afd748e485e1b702dd94580798221789 Mon Sep 17 00:00:00 2001 From: Kieran Wallbanks Date: Fri, 13 Oct 2023 17:20:28 +0100 Subject: [PATCH] better logging and coroutine handling --- .../kyori/adventure/webui/jvm/Application.kt | 35 ++-- .../webui/jvm/minimessage/SocketTest.kt | 164 ---------------- .../preview/ServerStatusPreviewManager.kt | 182 ++++++++++++++++++ 3 files changed, 200 insertions(+), 181 deletions(-) delete mode 100644 src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/minimessage/SocketTest.kt create mode 100644 src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/minimessage/preview/ServerStatusPreviewManager.kt diff --git a/src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/Application.kt b/src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/Application.kt index 4ebea5b..0b0bd43 100644 --- a/src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/Application.kt +++ b/src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/Application.kt @@ -1,21 +1,21 @@ package net.kyori.adventure.webui.jvm -import io.ktor.http.* -import io.ktor.http.content.* -import io.ktor.network.selector.* -import io.ktor.network.sockets.* -import io.ktor.server.application.* -import io.ktor.server.plugins.cachingheaders.* -import io.ktor.server.plugins.compression.* -import io.ktor.server.routing.* -import io.ktor.server.websocket.* -import io.ktor.utils.io.* -import io.ktor.websocket.* -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking -import net.kyori.adventure.webui.jvm.minimessage.SocketTest -import okhttp3.internal.and +import io.ktor.http.CacheControl +import io.ktor.http.ContentType +import io.ktor.http.content.CachingOptions +import io.ktor.server.application.Application +import io.ktor.server.application.install +import io.ktor.server.application.log +import io.ktor.server.plugins.cachingheaders.CachingHeaders +import io.ktor.server.plugins.compression.Compression +import io.ktor.server.plugins.compression.deflate +import io.ktor.server.plugins.compression.gzip +import io.ktor.server.routing.routing +import io.ktor.server.websocket.WebSockets +import io.ktor.server.websocket.pingPeriod +import io.ktor.server.websocket.timeout +import io.ktor.websocket.WebSocketDeflateExtension +import net.kyori.adventure.webui.jvm.minimessage.preview.ServerStatusPreviewManager import java.time.Duration public fun Application.main() { @@ -50,7 +50,8 @@ public fun Application.main() { } } - SocketTest().main() + // Initialise the server status preview manager. + ServerStatusPreviewManager(this) } diff --git a/src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/minimessage/SocketTest.kt b/src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/minimessage/SocketTest.kt deleted file mode 100644 index a54e898..0000000 --- a/src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/minimessage/SocketTest.kt +++ /dev/null @@ -1,164 +0,0 @@ -package net.kyori.adventure.webui.jvm.minimessage - -import io.ktor.network.selector.* -import io.ktor.network.sockets.* -import io.ktor.utils.io.* -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.withContext -import net.kyori.adventure.text.minimessage.MiniMessage -import net.kyori.adventure.text.serializer.gson.GsonComponentSerializer -import net.kyori.adventure.text.serializer.legacy.LegacyComponentSerializer -import okhttp3.internal.and -import java.io.ByteArrayOutputStream -import java.io.DataOutputStream -import kotlin.text.Charsets.UTF_8 - -public class SocketTest { - - public fun main() { - // TODO - // 1. make this non blocking somehow, idk how kotlin works - // 2. add api/ui to store into some cache - // 3. parse server address to get stuff from the cache and return that in the status response - runBlocking { - val serverSocket = aSocket(SelectorManager(Dispatchers.IO)).tcp().bind("127.0.0.1", 9002) - println("Server is listening at ${serverSocket.localAddress}") - while (true) { - val socket = serverSocket.accept() - println("Accepted ${socket.remoteAddress}") - launch { - try { - val receiveChannel = socket.openReadChannel() - val sendChannel = socket.openWriteChannel(autoFlush = true) - - // handshake - val handshakePacket = receiveChannel.readMcPacket() - val protocolVersion = handshakePacket.readVarInt() - val serverAddress = handshakePacket.readUtf8String() - val serverPort = handshakePacket.readShort() - val nextState = handshakePacket.readVarInt() - - if (nextState != 1) { - // send kick - sendChannel.writeMcPacket(0) { - it.writeString( - GsonComponentSerializer.gson() - .serialize(MiniMessage.miniMessage().deserialize("You cant join here!")) - ) - } - } else { - // send status response - sendChannel.writeMcPacket(0) { - it.writeString( - """{ - "version": { - "name": "${ - LegacyComponentSerializer.legacySection() - .serialize(MiniMessage.miniMessage().deserialize("MiniMessage")) - }", - "protocol": 762 - }, - "description": ${ - GsonComponentSerializer.gson().serialize( - MiniMessage.miniMessage().deserialize("MiniMessage is cool!") - ) - } - }""".trimIndent() - ) - } - } - - sendChannel.close() - } catch (e: Exception) { - println(e) - } - - socket.close() - return@launch - } - } - } - } -} - -public suspend fun ByteWriteChannel.writeMcPacket(packetId: Int, consumer: (packet: DataOutputStream) -> Unit) { - val stream = ByteArrayOutputStream() - val packet = DataOutputStream(stream) - - consumer.invoke(packet) - - val data = stream.toByteArray() - writeVarInt(data.size + 1) - writeVarInt(packetId) - writeFully(data) -} - -public fun DataOutputStream.writeString(string: String) { - val bytes = string.toByteArray(UTF_8) - writeVarInt(bytes.size) - write(bytes) -} - -public fun DataOutputStream.writeVarInt(int: Int) { - var value = int - while (true) { - if ((value and 0x7F.inv()) == 0) { - writeByte(value) - return - } - - writeByte((value and 0x7F) or 0x80) - - value = value ushr 7 - } -} - -public suspend fun ByteWriteChannel.writeVarInt(int: Int) { - var value = int - while (true) { - if ((value and 0x7F.inv()) == 0) { - writeByte(value) - return - } - - writeByte((value and 0x7F) or 0x80) - - value = value ushr 7 - } -} - -public suspend fun ByteReadChannel.readMcPacket(): ByteReadChannel { - val length = readVarInt() - val packetId = readVarInt() - val data = ByteArray(length) - readFully(data, 0, length) - return ByteReadChannel(data) -} - -public suspend fun ByteReadChannel.readVarInt(): Int { - var value = 0 - var position = 0 - var currentByte: Byte - - while (true) { - currentByte = readByte() - value = value or ((currentByte and 0x7F) shl position) - - if ((currentByte and 0x80) == 0) break - - position += 7 - - if (position >= 32) throw RuntimeException("VarInt is too big") - } - - return value -} - -public suspend fun ByteReadChannel.readUtf8String(): String { - val length = readVarInt() - val data = ByteArray(length) - readFully(data, 0, length) - return String(data) -} diff --git a/src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/minimessage/preview/ServerStatusPreviewManager.kt b/src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/minimessage/preview/ServerStatusPreviewManager.kt new file mode 100644 index 0000000..20fc4f4 --- /dev/null +++ b/src/jvmMain/kotlin/net/kyori/adventure/webui/jvm/minimessage/preview/ServerStatusPreviewManager.kt @@ -0,0 +1,182 @@ +package net.kyori.adventure.webui.jvm.minimessage.preview + +import io.ktor.network.selector.SelectorManager +import io.ktor.network.sockets.aSocket +import io.ktor.network.sockets.openReadChannel +import io.ktor.network.sockets.openWriteChannel +import io.ktor.server.application.Application +import io.ktor.utils.io.ByteReadChannel +import io.ktor.utils.io.ByteWriteChannel +import io.ktor.utils.io.close +import io.ktor.utils.io.writeByte +import io.ktor.utils.io.writeFully +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.ensureActive +import kotlinx.coroutines.job +import kotlinx.coroutines.launch +import net.kyori.adventure.text.minimessage.MiniMessage +import net.kyori.adventure.text.serializer.gson.GsonComponentSerializer +import net.kyori.adventure.text.serializer.legacy.LegacyComponentSerializer +import okhttp3.internal.and +import org.slf4j.LoggerFactory +import java.io.ByteArrayOutputStream +import java.io.DataOutputStream +import kotlin.coroutines.CoroutineContext + +/** Manager class for previewing server status. */ +public class ServerStatusPreviewManager( + application: Application, +) : CoroutineScope { + + private val logger = LoggerFactory.getLogger(ServerStatusPreviewManager::class.java) + private val managerJob = SupervisorJob(application.coroutineContext.job) + override val coroutineContext: CoroutineContext = application.coroutineContext + managerJob + + init { + launch { + // Initialise the socket. + val serverSocket = aSocket(SelectorManager(Dispatchers.IO)).tcp().bind("127.0.0.1", 9002) + logger.info("Listening for pings at ${serverSocket.localAddress}") + + while (true) { + // Ensure we are active so that the socket is properly closed when the application ends. + ensureActive() + + val socket = serverSocket.accept() + logger.debug("Accepted socket connection from {}", socket.remoteAddress) + + launch { + try { + val receiveChannel = socket.openReadChannel() + val sendChannel = socket.openWriteChannel(autoFlush = true) + + // handshake + val handshakePacket = receiveChannel.readMcPacket() + val protocolVersion = handshakePacket.readVarInt() + val serverAddress = handshakePacket.readUtf8String() + val serverPort = handshakePacket.readShort() + val nextState = handshakePacket.readVarInt() + + if (nextState != 1) { + // send kick + sendChannel.writeMcPacket(0) { + it.writeString( + GsonComponentSerializer.gson() + .serialize(MiniMessage.miniMessage().deserialize("You cant join here!")) + ) + } + } else { + // send status response + sendChannel.writeMcPacket(0) { + it.writeString( + """{ + "version": { + "name": "${ + LegacyComponentSerializer.legacySection() + .serialize(MiniMessage.miniMessage().deserialize("MiniMessage")) + }", + "protocol": $protocolVersion + }, + "description": ${ + GsonComponentSerializer.gson().serialize( + MiniMessage.miniMessage().deserialize("MiniMessage is cool!") + ) + } + }""".trimIndent() + ) + } + } + + sendChannel.close() + } catch (e: Exception) { + logger.error("An unknown error occurred whilst responding to a ping from ${socket.remoteAddress}", e) + } + + socket.close() + } + } + } + } + + private suspend fun ByteWriteChannel.writeMcPacket(packetId: Int, consumer: (packet: DataOutputStream) -> Unit) { + val stream = ByteArrayOutputStream() + val packet = DataOutputStream(stream) + + consumer.invoke(packet) + + val data = stream.toByteArray() + writeVarInt(data.size + 1) + writeVarInt(packetId) + writeFully(data) + } + + private fun DataOutputStream.writeString(string: String) { + val bytes = string.toByteArray(Charsets.UTF_8) + writeVarInt(bytes.size) + write(bytes) + } + + private fun DataOutputStream.writeVarInt(int: Int) { + var value = int + while (true) { + if ((value and 0x7F.inv()) == 0) { + writeByte(value) + return + } + + writeByte((value and 0x7F) or 0x80) + + value = value ushr 7 + } + } + + private suspend fun ByteWriteChannel.writeVarInt(int: Int) { + var value = int + while (true) { + if ((value and 0x7F.inv()) == 0) { + writeByte(value) + return + } + + writeByte((value and 0x7F) or 0x80) + + value = value ushr 7 + } + } + + private suspend fun ByteReadChannel.readMcPacket(): ByteReadChannel { + val length = readVarInt() + val packetId = readVarInt() + val data = ByteArray(length) + readFully(data, 0, length) + return ByteReadChannel(data) + } + + private suspend fun ByteReadChannel.readVarInt(): Int { + var value = 0 + var position = 0 + var currentByte: Byte + + while (true) { + currentByte = readByte() + value = value or ((currentByte and 0x7F) shl position) + + if ((currentByte and 0x80) == 0) break + + position += 7 + + if (position >= 32) throw RuntimeException("VarInt is too big") + } + + return value + } + + private suspend fun ByteReadChannel.readUtf8String(): String { + val length = readVarInt() + val data = ByteArray(length) + readFully(data, 0, length) + return String(data) + } +}