diff --git a/ktor-network/jvm/src/io/ktor/network/selector/SelectorManager.kt b/ktor-network/jvm/src/io/ktor/network/selector/SelectorManager.kt index b813ae36dd0..fad8c7aab29 100644 --- a/ktor-network/jvm/src/io/ktor/network/selector/SelectorManager.kt +++ b/ktor-network/jvm/src/io/ktor/network/selector/SelectorManager.kt @@ -77,5 +77,9 @@ public actual enum class SelectInterest(public val flag: Int) { public val flags: IntArray = values().map { it.flag }.toIntArray() public val size: Int = values().size + + private val byFlag = values().associateBy { it.flag } + + public fun byValue(value: Int): SelectInterest = byFlag[value] ?: error("Unknown SelectInterest value: $value") } } diff --git a/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Attachment.kt b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Attachment.kt new file mode 100644 index 00000000000..1e7e33d14ff --- /dev/null +++ b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Attachment.kt @@ -0,0 +1,63 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.network.selector.eventgroup + +import kotlinx.coroutines.suspendCancellableCoroutine +import java.nio.channels.SelectionKey + +internal inline val SelectionKey.attachment get() = attachment() as Attachment + +/** + * Attachment for SelectionKey + * It contains task for each interest and allows to run them and resume the continuation + */ +internal class Attachment { + private var acceptTask: Task? = null + private var readTask: Task? = null + private var writeTask: Task? = null + private var connectTask: Task? = null + + suspend fun runTask(interest: Int, task: suspend () -> T): T { + return suspendCancellableCoroutine { + @Suppress("UNCHECKED_CAST") + setContinuationByInterest(interest, Task(it.toResumableCancellable(), task) as Task) + } + } + + suspend fun runTaskAndResumeContinuation(key: SelectionKey) { + when { + key.isAcceptable -> acceptTask.runAndResume(SelectionKey.OP_ACCEPT) + key.isReadable -> readTask.runAndResume(SelectionKey.OP_READ) + key.isWritable -> writeTask.runAndResume(SelectionKey.OP_WRITE) + key.isConnectable -> connectTask.runAndResume(SelectionKey.OP_CONNECT) + } + } + + private suspend fun Task?.runAndResume(interest: Int) { + val task = this ?: return + setContinuationByInterest(interest, null) + task.runAndResume() + } + + private fun setContinuationByInterest(interest: Int, task: Task?) { + when (interest) { + SelectionKey.OP_ACCEPT -> acceptTask = task + SelectionKey.OP_READ -> readTask = task + SelectionKey.OP_WRITE -> writeTask = task + SelectionKey.OP_CONNECT -> connectTask = task + } + } + + fun cancel(cause: Throwable? = null) { + acceptTask.cancel(cause) + readTask.cancel(cause) + writeTask.cancel(cause) + connectTask.cancel(cause) + } + + private fun Task<*>?.cancel(cause: Throwable? = null) { + this?.continuation?.cancel(cause) + } +} diff --git a/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Connection.kt b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Connection.kt new file mode 100644 index 00000000000..b1f35af8db3 --- /dev/null +++ b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Connection.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.network.selector.eventgroup + +import java.nio.channels.SocketChannel + +/** + * Allows to perform read and write operations on the socket channel, + * which will be submitted as tasks to the event loop and will be suspended until + * they will be executed in the context of the event loop + */ +internal interface Connection { + val channel: SocketChannel + + suspend fun performRead(body: suspend (SocketChannel) -> T): T + + suspend fun performWrite(body: suspend (SocketChannel) -> T): T + + fun close() +} diff --git a/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/EventGroupContext.kt b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/EventGroupContext.kt new file mode 100644 index 00000000000..0e763038224 --- /dev/null +++ b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/EventGroupContext.kt @@ -0,0 +1,28 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.network.selector.eventgroup + +import io.ktor.utils.io.* +import kotlin.coroutines.* + +private val MAX_THREADS by lazy { + Runtime.getRuntime().availableProcessors() + .minus(2) + .coerceAtLeast(1) +} + +@InternalAPI +public class EventGroupContext( + public val parallelism: Int, +) : CoroutineContext.Element { + override val key: CoroutineContext.Key<*> = Key + + public companion object Key : CoroutineContext.Key +} + +@InternalAPI +internal fun CoroutineContext.eventGroupParallelism(): Int { + return get(EventGroupContext.Key)?.parallelism ?: MAX_THREADS +} diff --git a/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/EventGroupSelectorManager.kt b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/EventGroupSelectorManager.kt new file mode 100644 index 00000000000..168cbd38a4e --- /dev/null +++ b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/EventGroupSelectorManager.kt @@ -0,0 +1,109 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.network.selector.eventgroup + +import io.ktor.network.selector.* +import io.ktor.utils.io.* +import kotlinx.atomicfu.* +import kotlinx.coroutines.* +import java.nio.channels.* +import java.nio.channels.spi.* +import kotlin.coroutines.* + +@OptIn(InternalAPI::class) +public class EventGroupSelectorManager(context: CoroutineContext) : SelectorManager { + public val group: EventGroup = EventGroup(context.eventGroupParallelism()) + + override val coroutineContext: CoroutineContext = context + CoroutineName("eventgroup") + + override val provider: SelectorProvider = SelectorProvider.provider() + + override fun notifyClosed(selectable: Selectable) { + // whatever + } + + override suspend fun select(selectable: Selectable, interest: SelectInterest) { + error("no select in eventgroup") + } + + override fun close() { + group.close() + } +} + +public class EventGroup(private val maxLoops: Int) { + private val acceptLoop = Eventloop() + private val loopIndex = atomic(0) + private val loops = mutableListOf() + + init { + acceptLoop.run() + + repeat(maxLoops - 1) { + val next = Eventloop().apply { run() } + loops.add(next) + } + } + + private fun registerAcceptKey(channel: Selectable) = acceptLoop.runOnLoop { + acceptLoop.addInterest(channel, SelectionKey.OP_ACCEPT) + } + + internal fun registerChannel(channel: ServerSocketChannel): RegisteredServerChannel { + val selectableChannel = SelectableBase(channel) + val key = registerAcceptKey(selectableChannel) + + return RegisteredServerChannelImpl(channel, key) + } + + private inner class RegisteredServerChannelImpl( + override val channel: ServerSocketChannel, + private val key: CompletableDeferred, + ) : RegisteredServerChannel { + override suspend fun acceptConnection(configure: (SocketChannel) -> Unit): ConnectionImpl { + val result = key.await().attachment.runTask(SelectionKey.OP_ACCEPT) { + channel.accept().apply { + configureBlocking(false) + configure(this) + } + } + + val nextLoopIndex = loopIndex.getAndIncrement() % (maxLoops - 1) + + return ConnectionImpl(result, loops[nextLoopIndex]) + } + } + + private class ConnectionImpl( + override val channel: SocketChannel, + val loop: Eventloop, + ) : Connection { + private val selectable = SelectableBase(channel) + + override suspend fun performRead(body: suspend (SocketChannel) -> T): T { + return runTask(SelectionKey.OP_READ) { body(channel) } + } + + override suspend fun performWrite(body: suspend (SocketChannel) -> T): T { + return runTask(SelectionKey.OP_WRITE) { body(channel) } + } + + override fun close() { + channel.close() + } + + private suspend fun runTask(interest: Int, body: suspend () -> T): T { + val key = loop.addInterest(selectable, interest) + return key.attachment.runTask(interest, body).also { + loop.deleteInterest(selectable, interest) + } + } + } + + public fun close() { + acceptLoop.close(null) + loops.forEach { it.close(null) } + } +} diff --git a/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Eventloop.kt b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Eventloop.kt new file mode 100644 index 00000000000..015f95a9472 --- /dev/null +++ b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Eventloop.kt @@ -0,0 +1,104 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.network.selector.eventgroup + +import io.ktor.network.selector.* +import kotlinx.coroutines.* +import java.nio.channels.* + +internal class Eventloop { + val scope = newThreadContext(nThreads = 1).wrapInScope() + + fun run(): Job { + return scope.launch { runLoop() } + } + + private val taskQueue = ArrayDeque>() + + private val selector = Selector.open() + + fun close(cause: Throwable?) { + taskQueue.forEach { it.continuation.cancel(cause) } + selector.close() + } + + private suspend fun runLoop() { + while (true) { + runAllPendingTasks() + + val n = selector.select(SELECTOR_TIMEOUT_MILLIS) + yield() + + if (n == 0) { + continue + } + + val selectionKeys = selector.selectedKeys().iterator() + while (selectionKeys.hasNext()) { + val key = selectionKeys.next() + selectionKeys.remove() + + try { + if (!key.isValid) continue + key.attachment.runTaskAndResumeContinuation(key) + } catch (e: Throwable) { + key.channel().close() + key.attachment.cancel(e) + } + } + } + } + + private suspend fun runAllPendingTasks() { + repeat(taskQueue.size) { + taskQueue.removeFirst().runAndResume() + } + } + + internal fun runOnLoop(body: suspend () -> T): CompletableDeferred { + val result = CompletableDeferred() + taskQueue.addLast(Task(result.toResumableCancellable(), body)) + return result + } + + internal fun addInterest(selectable: Selectable, interest: Int): SelectionKey { + val channel = selectable.channel + val key = channel.keyFor(selector) + selectable.interestOp(SelectInterest.byValue(interest), true) + val ops = selectable.interestedOps + + if (key == null) { + if (ops != 0) { + channel.register(selector, ops, Attachment()) + } + } else { + if (key.interestOps() != ops) { + key.interestOps(ops) + } + } + return key + } + + internal fun deleteInterest(selectable: Selectable, interest: Int) { + val channel = selectable.channel + val key = channel.keyFor(selector) + selectable.interestOp(SelectInterest.byValue(interest), false) + val ops = selectable.interestedOps + + if (key == null) { + if (ops != 0) { + channel.register(selector, ops, Attachment()) + } + } else { + if (key.interestOps() != ops) { + key.interestOps(ops) + } + } + } + + companion object { + private const val SELECTOR_TIMEOUT_MILLIS = 20L + } +} diff --git a/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/RegisteredServerChannel.kt b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/RegisteredServerChannel.kt new file mode 100644 index 00000000000..a584e6c61cb --- /dev/null +++ b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/RegisteredServerChannel.kt @@ -0,0 +1,21 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.network.selector.eventgroup + +import java.net.Socket +import java.nio.channels.ServerSocketChannel +import java.nio.channels.SocketChannel + +/** + * Represents a server channel registered to an event loop with OP_ACCEPT interest + */ +internal interface RegisteredServerChannel { + val channel: ServerSocketChannel + + /** + * Allows to accept connections on the server socket channel + */ + suspend fun acceptConnection(configure: (SocketChannel) -> Unit = {}): Connection +} diff --git a/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/ServerConnectionBasedSocket.kt b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/ServerConnectionBasedSocket.kt new file mode 100644 index 00000000000..ce945acd02a --- /dev/null +++ b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/ServerConnectionBasedSocket.kt @@ -0,0 +1,427 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.network.selector.eventgroup + +import io.ktor.network.selector.* +import io.ktor.network.sockets.* +import io.ktor.network.util.* +import io.ktor.utils.io.* +import io.ktor.utils.io.ByteChannel +import io.ktor.utils.io.pool.* +import kotlinx.coroutines.* +import java.nio.* +import java.nio.channels.* +import java.util.concurrent.atomic.* +import kotlin.coroutines.* + +internal class ServerConnectionBasedSocket( + connection: Connection, + selector: SelectorManager, + socketOptions: SocketOptions.TCPClientSocketOptions? = null +) : TSocketImpl(connection, connection.channel, selector, pool = null, socketOptions = socketOptions), + Socket { + init { + require(!channel.isBlocking) { "Channel need to be configured as non-blocking." } + } + + override val localAddress: SocketAddress + get() { + val localAddress = if (java7NetworkApisAvailable) { + channel.localAddress + } else { + channel.socket().localSocketAddress + } + return localAddress?.toSocketAddress() + ?: throw IllegalStateException("Channel is not yet bound") + } + + override val remoteAddress: SocketAddress + get() { + val remoteAddress = if (java7NetworkApisAvailable) { + channel.remoteAddress + } else { + channel.socket().remoteSocketAddress + } + return remoteAddress?.toSocketAddress() + ?: throw IllegalStateException("Channel is not yet connected") + } +} + +internal abstract class TSocketImpl( + val connection: Connection, + override val channel: S, + val selector: SelectorManager, + val pool: ObjectPool?, + private val socketOptions: SocketOptions.TCPClientSocketOptions? = null +) : ReadWriteSocket, SelectableBase(channel), CoroutineScope + where S : java.nio.channels.ByteChannel, S : SelectableChannel { + + private val closeFlag = AtomicBoolean() + + @Suppress("DEPRECATION") + private val readerJob = AtomicReference() + + @Suppress("DEPRECATION") + private val writerJob = AtomicReference() + + override val socketContext: CompletableJob = Job() + + override val coroutineContext: CoroutineContext + get() = socketContext + + // NOTE: it is important here to use different versions of attachForReadingImpl + // because it is not always valid to use channel's internal buffer for NIO read/write: + // at least UDP datagram reading MUST use bigger byte buffer otherwise datagram could be truncated + // that will cause broken data + // however it is not the case for attachForWriting this is why we use direct writing in any case + + @Suppress("DEPRECATION") + final override fun attachForReading(channel: io.ktor.utils.io.ByteChannel): WriterJob { + return attachFor("reading", channel, writerJob) { + if (pool != null) { + attachForReadingImplE(channel, connection, pool, socketOptions) + } else { + attachForReadingDirectImplE(channel, connection, socketOptions) + } + } + } + + @Suppress("DEPRECATION") + final override fun attachForWriting(channel: io.ktor.utils.io.ByteChannel): ReaderJob { + return attachFor("writing", channel, readerJob) { + attachForWritingDirectImplE(channel, connection, socketOptions) + } + } + + override fun dispose() { + close() + } + + override fun close() { + if (!closeFlag.compareAndSet(false, true)) return + + readerJob.get()?.channel?.close() + writerJob.get()?.cancel() + checkChannels() + } + + @Suppress("DEPRECATION") + private fun attachFor( + name: String, + channel: io.ktor.utils.io.ByteChannel, + ref: AtomicReference, + producer: () -> J + ): J { + if (closeFlag.get()) { + val e = ClosedChannelException() + channel.close(e) + throw e + } + + val j = producer() + + if (!ref.compareAndSet(null, j)) { + val e = IllegalStateException("$name channel has already been set") + j.cancel() + throw e + } + if (closeFlag.get()) { + val e = ClosedChannelException() + j.cancel() + channel.close(e) + throw e + } + + channel.attachJob(j) + + j.invokeOnCompletion { + checkChannels() + } + + return j + } + + private fun actualClose(): Throwable? { + return try { + channel.close() + super.close() + null + } catch (cause: Throwable) { + cause + } finally { + selector.notifyClosed(this) + } + } + + private fun checkChannels() { + if (closeFlag.get() && readerJob.completedOrNotStarted && writerJob.completedOrNotStarted) { + val e1 = readerJob.exception + val e2 = writerJob.exception + val e3 = actualClose() + + val combined = combine(combine(e1, e2), e3) + + if (combined == null) socketContext.complete() else socketContext.completeExceptionally(combined) + } + } + + private fun combine(e1: Throwable?, e2: Throwable?): Throwable? = when { + e1 == null -> e2 + e2 == null -> e1 + e1 === e2 -> e1 + else -> { + e1.addSuppressed(e2) + e1 + } + } + + private val AtomicReference.completedOrNotStarted: Boolean + get() = get().let { it == null || it.isCompleted } + + @OptIn(InternalCoroutinesApi::class) + private val AtomicReference.exception: Throwable? + get() = get()?.takeIf { it.isCancelled } + ?.getCancellationException()?.cause // TODO it should be completable deferred or provide its own exception +} + +@Suppress("DEPRECATION") +internal fun CoroutineScope.attachForReadingImplE( + channel: ByteChannel, + connection: Connection, + pool: ObjectPool, + socketOptions: SocketOptions.TCPClientSocketOptions? = null +): WriterJob { + val buffer = pool.borrow() + return writer(Dispatchers.Unconfined + CoroutineName("cio-from-nio-reader"), channel) { + try { + val timeout = if (socketOptions?.socketTimeout != null) { + createTimeout("reading", socketOptions.socketTimeout) { + channel.close(SocketTimeoutException()) + } + } else { + null + } + + while (true) { + var rc = 0 + + timeout.withTimeout { + do { + rc = connection.readToE(buffer) + if (rc == 0) { + channel.flush() + } + } while (rc == 0) + } + + if (rc == -1) { + channel.close() + break + } else { + buffer.flip() + channel.writeFully(buffer) + buffer.clear() + } + } + timeout?.finish() + } finally { + pool.recycle(buffer) + try { + if (java7NetworkApisAvailable) { + connection.channel.shutdownInput() + } else { + connection.channel.socket().shutdownInput() + } + } catch (ignore: ClosedChannelException) { + } + } + } +} + +@Suppress("DEPRECATION") +internal fun CoroutineScope.attachForReadingDirectImplE( + channel: ByteChannel, + connection: Connection, + socketOptions: SocketOptions.TCPClientSocketOptions? = null +): WriterJob = writer(Dispatchers.Unconfined + CoroutineName("cio-from-nio-reader"), channel) { + try { + val timeout = if (socketOptions?.socketTimeout != null) { + createTimeout("reading-direct", socketOptions.socketTimeout) { + channel.close(SocketTimeoutException()) + } + } else { + null + } + + while (!channel.isClosedForWrite) { + timeout.withTimeout { + val rc = channel.readFromE(connection) + + if (rc == -1) { + channel.close() + return@withTimeout + } + + if (rc > 0) return@withTimeout + + channel.flush() + + while (true) { + if (channel.readFromE(connection) != 0) break + } + } + } + + timeout?.finish() + channel.closedCause?.let { throw it } + channel.close() + } finally { + try { + if (java7NetworkApisAvailable) { + connection.channel.shutdownInput() + } else { + connection.channel.socket().shutdownInput() + } + } catch (ignore: ClosedChannelException) { + } + } +} + +private suspend fun ByteWriteChannel.readFromE(connection: Connection): Int { + var count = 0 + connection.performRead { channel -> + write { buffer -> + count = channel.read(buffer) + } + } + + return count +} + +private suspend fun Connection.readToE(receivedRequest: ByteBuffer): Int { + return performRead { + it.read(receivedRequest) + } +} + + + +//_-------------------------------------------------------- +@Suppress("DEPRECATION") +internal fun CoroutineScope.attachForWritingImplE( + channel: ByteChannel, + nioChannel: WritableByteChannel, + selectable: Selectable, + selector: SelectorManager, + pool: ObjectPool, + socketOptions: SocketOptions.TCPClientSocketOptions? = null +): ReaderJob { + val buffer = pool.borrow() + + return reader(Dispatchers.Unconfined + CoroutineName("cio-to-nio-writer"), channel) { + try { + val timeout = if (socketOptions?.socketTimeout != null) { + createTimeout("writing", socketOptions.socketTimeout) { + channel.close(SocketTimeoutException()) + } + } else { + null + } + + while (true) { + buffer.clear() + if (channel.readAvailable(buffer) == -1) { + break + } + buffer.flip() + + while (buffer.hasRemaining()) { + var rc: Int + + timeout.withTimeout { + do { + rc = nioChannel.write(buffer) + if (rc == 0) { + selectable.interestOp(SelectInterest.WRITE, true) + selector.select(selectable, SelectInterest.WRITE) + } + } while (buffer.hasRemaining() && rc == 0) + } + + selectable.interestOp(SelectInterest.WRITE, false) + } + } + timeout?.finish() + } finally { + pool.recycle(buffer) + if (nioChannel is SocketChannel) { + try { + if (java7NetworkApisAvailable) { + nioChannel.shutdownOutput() + } else { + nioChannel.socket().shutdownOutput() + } + } catch (ignore: ClosedChannelException) { + } + } + } + } +} + +@Suppress("DEPRECATION") +internal fun CoroutineScope.attachForWritingDirectImplE( + channel: ByteChannel, + connection: Connection, + socketOptions: SocketOptions.TCPClientSocketOptions? = null +): ReaderJob = reader(Dispatchers.Unconfined + CoroutineName("cio-to-nio-writer"), channel) { + try { + @Suppress("DEPRECATION") + channel.lookAheadSuspend { + val timeout = if (socketOptions?.socketTimeout != null) { + createTimeout("writing-direct", socketOptions.socketTimeout) { + channel.close(SocketTimeoutException()) + } + } else { + null + } + + while (true) { + val buffer = request(0, 1) + if (buffer == null) { + if (!awaitAtLeast(1)) break + continue + } + + while (buffer.hasRemaining()) { + var rc = 0 + + timeout.withTimeout { + do { + rc = connection.writeFromE(buffer) + } while (buffer.hasRemaining() && rc == 0) + } + + consumed(rc) + } + } + timeout?.finish() + } + } finally { + try { + if (java7NetworkApisAvailable) { + connection.channel.shutdownOutput() + } else { + connection.channel.socket().shutdownOutput() + } + } catch (ignore: ClosedChannelException) { + } + } +} + +private suspend fun Connection.writeFromE(receivedRequest: ByteBuffer): Int { + return performWrite { + it.write(receivedRequest) + } +} diff --git a/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Task.kt b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Task.kt new file mode 100644 index 00000000000..bdbfef3702b --- /dev/null +++ b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/Task.kt @@ -0,0 +1,62 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.network.selector.eventgroup + +import kotlinx.coroutines.* +import kotlinx.coroutines.CancellationException +import kotlin.coroutines.cancellation.* +import kotlin.coroutines.resume + +/** + * A task for the event loop + * + * It contains a runnable, which perform an i/o operation, + * and a continuation, that will be resumed with the result + */ +internal data class Task( + val continuation: ResumableCancellable, + val runnable: suspend () -> T, +) { + suspend fun runAndResume() { + try { + val result = runnable.invoke() + continuation.resume(result) + } catch (e: Throwable) { + continuation.cancel(e) + } + } +} + +internal fun CancellableContinuation.toResumableCancellable(): ResumableCancellable { + return object : ResumableCancellable { + override fun resume(value: T) { + this@toResumableCancellable.resume(value) + } + + override fun cancel(cause: Throwable?) { + this@toResumableCancellable.cancel(cause) + } + } +} + +internal fun CompletableDeferred.toResumableCancellable(): ResumableCancellable { + return object : ResumableCancellable { + override fun resume(value: T) { + this@toResumableCancellable.complete(value) + } + + override fun cancel(cause: Throwable?) { + val realCause = if (cause is CancellationException) cause else CancellationException(cause) + this@toResumableCancellable.cancel(realCause) + } + } +} + + +internal interface ResumableCancellable { + fun resume(value: T) + + fun cancel(cause: Throwable?) +} diff --git a/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/utils.kt b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/utils.kt new file mode 100644 index 00000000000..38cef22c179 --- /dev/null +++ b/ktor-network/jvm/src/io/ktor/network/selector/eventgroup/utils.kt @@ -0,0 +1,20 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.network.selector.eventgroup + +import kotlinx.coroutines.* +import java.util.concurrent.* +import kotlin.coroutines.* + +internal fun newThreadContext(nThreads: Int): CoroutineContext { + val pool = when (nThreads) { + 1 -> Executors.newSingleThreadExecutor() + else -> Executors.newFixedThreadPool(nThreads) + } + + return pool.asCoroutineDispatcher() +} + +internal fun CoroutineContext.wrapInScope() = CoroutineScope(this) diff --git a/ktor-network/jvm/src/io/ktor/network/sockets/ConnectUtilsJvm.kt b/ktor-network/jvm/src/io/ktor/network/sockets/ConnectUtilsJvm.kt index 95827ae803f..9506716084c 100644 --- a/ktor-network/jvm/src/io/ktor/network/sockets/ConnectUtilsJvm.kt +++ b/ktor-network/jvm/src/io/ktor/network/sockets/ConnectUtilsJvm.kt @@ -5,6 +5,7 @@ package io.ktor.network.sockets import io.ktor.network.selector.* +import io.ktor.network.selector.eventgroup.EventGroupSelectorManager import java.net.* import java.nio.channels.* import java.nio.channels.spi.* @@ -27,10 +28,14 @@ internal actual fun bind( localAddress: SocketAddress?, socketOptions: SocketOptions.AcceptorOptions ): ServerSocket = selector.buildOrClose({ openServerSocketChannelFor(localAddress) }) { + require(selector is EventGroupSelectorManager) + if (localAddress is InetSocketAddress) assignOptions(socketOptions) nonBlocking() - ServerSocketImpl(this, selector).apply { + val registered = selector.group.registerChannel(this) + + ServerSocketImpl(registered, selector).apply { if (java7NetworkApisAvailable) { channel.bind(localAddress?.toJavaAddress(), socketOptions.backlogSize) } else { diff --git a/ktor-network/jvm/src/io/ktor/network/sockets/ServerSocketImpl.kt b/ktor-network/jvm/src/io/ktor/network/sockets/ServerSocketImpl.kt index b2cfa22987b..80af5bca740 100644 --- a/ktor-network/jvm/src/io/ktor/network/sockets/ServerSocketImpl.kt +++ b/ktor-network/jvm/src/io/ktor/network/sockets/ServerSocketImpl.kt @@ -5,15 +5,19 @@ package io.ktor.network.sockets import io.ktor.network.selector.* +import io.ktor.network.selector.eventgroup.* +import io.ktor.network.selector.eventgroup.Connection import kotlinx.coroutines.* import java.net.* import java.nio.channels.* @Suppress("BlockingMethodInNonBlockingContext") internal class ServerSocketImpl( - override val channel: ServerSocketChannel, - val selector: SelectorManager -) : ServerSocket, Selectable by SelectableBase(channel) { + private val registeredServerChannel: RegisteredServerChannel, + private val selector: SelectorManager, +) : ServerSocket, Selectable by SelectableBase(registeredServerChannel.channel) { + override val channel: ServerSocketChannel get() = registeredServerChannel.channel + init { require(!channel.isBlocking) { "Channel need to be configured as non-blocking." } } @@ -31,29 +35,19 @@ internal class ServerSocketImpl( } override suspend fun accept(): Socket { - channel.accept()?.let { return accepted(it) } - return acceptSuspend() - } - - private suspend fun acceptSuspend(): Socket { - while (true) { - interestOp(SelectInterest.ACCEPT, true) - selector.select(this, SelectInterest.ACCEPT) - channel.accept()?.let { return accepted(it) } - } + return registeredServerChannel.acceptConnection { nioChannel -> + if (localAddress is InetSocketAddress) { + if (java7NetworkApisAvailable) { + nioChannel.setOption(StandardSocketOptions.TCP_NODELAY, true) + } else { + nioChannel.socket().tcpNoDelay = true + } + } + }.toSocket() } - private fun accepted(nioChannel: SocketChannel): Socket { - interestOp(SelectInterest.ACCEPT, false) - nioChannel.configureBlocking(false) - if (localAddress is InetSocketAddress) { - if (java7NetworkApisAvailable) { - nioChannel.setOption(StandardSocketOptions.TCP_NODELAY, true) - } else { - nioChannel.socket().tcpNoDelay = true - } - } - return SocketImpl(nioChannel, selector) + private fun Connection.toSocket(): Socket { + return ServerConnectionBasedSocket(this, selector) } override fun close() {