diff --git a/src/main/scala/com/evolution/resourcepool/ResourcePool.scala b/src/main/scala/com/evolution/resourcepool/ResourcePool.scala index d2908ac..ad9bce8 100644 --- a/src/main/scala/com/evolution/resourcepool/ResourcePool.scala +++ b/src/main/scala/com/evolution/resourcepool/ResourcePool.scala @@ -112,18 +112,19 @@ object ResourcePool { } yield { new ResourcePool[F, A] { def get = { - for { - partition <- ref.modify { a => + ref + .modify { a => val b = a + 1 ( if (b < length) b else 0, a ) } - result <- values - .apply(partition) - .get - } yield result + .flatMap { partition => + values + .apply(partition) + .get + } } } } @@ -151,7 +152,9 @@ object ResourcePool { def now = Clock[F].realTime(TimeUnit.MILLISECONDS).map { _.millis} - final case class Entry(value: A, release: F[Unit], timestamp: FiniteDuration) + final case class Entry(value: A, release: F[Unit], timestamp: FiniteDuration) { + def renew: F[Entry] = now.map { now => copy(timestamp = now) } + } sealed trait State @@ -237,116 +240,115 @@ object ResourcePool { } for { - ref <- Resource - .make { - Ref[F].of { State.empty } - } { ref => - 0.tailRecM { count => - ref - .access - .flatMap { - case (state: State.Allocated, set) => - Deferred - .apply[F, Either[Throwable, Unit]] - .flatMap { released => - - def apply(allocated: Set[Id], releasing: Set[Id], tasks: Tasks)(effect: => F[Unit]) = { - set - .apply { State.Released(allocated = allocated, releasing = releasing, tasks, released) } - .flatMap { - case true => - for { - result <- { - if (allocated.isEmpty && releasing.isEmpty) { - // the pool is empty now, we can safely release it + ref <- Resource.make { + Ref[F].of { State.empty } + } { ref => + 0.tailRecM { count => + ref + .access + .flatMap { + case (state: State.Allocated, set) => + Deferred + .apply[F, Either[Throwable, Unit]] + .flatMap { released => + + def apply(allocated: Set[Id], releasing: Set[Id], tasks: Tasks)(effect: => F[Unit]) = { + set + .apply { State.Released(allocated = allocated, releasing = releasing, tasks, released) } + .flatMap { + case true => + for { + result <- { + if (allocated.isEmpty && releasing.isEmpty) { + // the pool is empty now, we can safely release it + released + .complete(().asRight) + .handleError { _ => () } + } else { + // the pool will be released elsewhere when all resources in `allocated` or + // `releasing` get released + effect.productR { released - .complete(().asRight) - .handleError { _ => () } - } else { - // the pool will be released elsewhere when all resources in `allocated` or - // `releasing` get released - effect.productR { - released - .get - .rethrow - } + .get + .rethrow } } - } yield { - result.asRight[Int] } - case false => - (count + 1) - .asLeft[Unit] - .pure[F] - } - .uncancelable - } + } yield { + result.asRight[Int] + } + case false => + (count + 1) + .asLeft[Unit] + .pure[F] + } + .uncancelable + } - state.stage match { - case stage: State.Allocated.Stage.Free => - // glue `release` functions of all free resources together - val (entries, releasing, release) = stage - .ids - .foldLeft((state.entries, state.releasing, ().pure[F])) { - case ((entries, releasing, release), id) => - entries - .get(id) - .fold { + state.stage match { + case stage: State.Allocated.Stage.Free => + // glue `release` functions of all free resources together + val (entries, releasing, release) = stage + .ids + .foldLeft((state.entries, state.releasing, ().pure[F])) { + case ((entries, releasing, release), id) => + entries + .get(id) + .fold { + (entries, releasing, release) + } { + case Some(entry) => + (entries - id, releasing + id, release.productR(entry.release)) + case None => (entries, releasing, release) - } { - case Some(entry) => - (entries - id, releasing + id, release.productR(entry.release)) - case None => - (entries, releasing, release) - } - } + } + } + + apply( + allocated = entries.keySet, + releasing = releasing, + Queue.empty + ) { + release + } + case stage: State.Allocated.Stage.Busy => + if (discardTasksOnRelease) { apply( - allocated = entries.keySet, - releasing = releasing, + allocated = state.entries.keySet, + releasing = state.releasing, Queue.empty ) { - release + stage + .tasks + .foldMapM { task => + task + .complete(ReleasedError.asLeft) + .void + } } - - case stage: State.Allocated.Stage.Busy => - if (discardTasksOnRelease) { - apply( - allocated = state.entries.keySet, - releasing = state.releasing, - Queue.empty - ) { - stage - .tasks - .foldMapM { task => - task - .complete(ReleasedError.asLeft) - .void - } - } - } else { - apply( - allocated = state.entries.keySet, - releasing = state.releasing, - stage.tasks - ) { - ().pure[F] - } + } else { + apply( + allocated = state.entries.keySet, + releasing = state.releasing, + stage.tasks + ) { + ().pure[F] } - } + } } + } - case (state: State.Released, _) => - state - .released - .get - .rethrow - .map { _.asRight[Int] } - } - } + case (state: State.Released, _) => + state + .released + .get + .rethrow + .map { _.asRight[Int] } + } } - _ <- Concurrent[F].background { + } + _ <- Concurrent[F].background { val interval = expireAfter / 10 for { _ <- Timer[F].sleep(expireAfter) @@ -419,11 +421,149 @@ object ResourcePool { new ResourcePool[F, A] { def get = { + def entryAdd(id: Id, entry: Entry) = { + 0.tailRecM { count => + ref + .access + .flatMap { + case (state: State.Allocated, set) => + set + .apply { state.copy(entries = state.entries.updated(id, entry.some)) } + .map { + case true => ().asRight[Int] + case false => (count + 1).asLeft[Unit] + } + case (_: State.Released, _) => + () + .asRight[Int] + .pure[F] + } + } + } + + def entryRemove(id: Id, error: Throwable) = { + ref + .modify { + case state: State.Allocated => + + val entries = state.entries - id + + def stateOf(stage: State.Allocated.Stage) = { + state.copy( + entries = entries, + stage = stage) + } + + if (entries.isEmpty) { + state.stage match { + case stage: State.Allocated.Stage.Free => + ( + stateOf(stage), + ().pure[F] + ) + case stage: State.Allocated.Stage.Busy => + ( + stateOf(State.Allocated.Stage.free(List.empty)), + stage + .tasks + .foldMapM { task => + task + .complete(error.asLeft) + .void + } + ) + } + } else { + ( + stateOf(state.stage), + ().pure[F] + ) + } + + case state: State.Released => + + val allocated = state.allocated - id + + def stateOf(tasks: Tasks) = { + state.copy( + allocated = allocated, + tasks = tasks) + } + + if (allocated.isEmpty) { + ( + stateOf(Queue.empty), + state + .tasks + .foldMapM { task => + task + .complete(error.asLeft) + .void + } + .productR { + if (state.releasing.isEmpty) { + state + .released + .complete(().asRight) + .handleError { _ => () } + } else { + ().pure[F] + } + } + ) + } else { + ( + stateOf(state.tasks), + ().pure[F] + ) + } + } + .flatten + } + + def entryRelease(id: Id, release: Release) = { + for { + result <- release.attempt + result <- ref + .modify { + case state: State.Allocated => + ( + state.copy(releasing = state.releasing - id), + ().pure[F] + ) + + case state: State.Released => + val releasing = state.releasing - id + ( + state.copy(releasing = releasing), + result match { + case Right(a) => + if (releasing.isEmpty && state.allocated.isEmpty) { + // this was the last resource in a pool, + // we can release the pool itself now + state + .released + .complete(a.asRight) + .handleError { _ => () } + } else { + ().pure[F] + } + case Left(error) => + state + .released + .complete(error.asLeft) + .handleError { _ => () } + } + ) + } + .flatten + } yield result + } + def releaseOf(id: Id, entry: Entry): Release = { for { - timestamp <- now - entry <- entry.copy(timestamp = timestamp).pure[F] - result <- ref + entry <- entry.renew + result <- ref .modify { case state: State.Allocated => @@ -480,6 +620,28 @@ object ResourcePool { } yield result } + def removeTask(task: Task) = { + ref.update { + case state: State.Allocated => + state.stage match { + case _: State.Allocated.Stage.Free => + state + case stage: State.Allocated.Stage.Busy => + state.copy( + stage = stage.copy( + tasks = stage + .tasks + .filter { _ ne task })) + } + + case state: State.Released => + state.copy(tasks = + state + .tasks + .filter { _ ne task }) + } + } + 0 .tailRecM { count => ref @@ -492,8 +654,7 @@ object ResourcePool { .apply(state) .flatMap { case true => - effect - .map { _.asRight[Int] } + effect.map { _.asRight[Int] } case false => (count + 1) .asLeft[X] @@ -511,27 +672,7 @@ object ResourcePool { } { task .get - .onCancel { - ref.update { - case state: State.Allocated => - state.stage match { - case _: State.Allocated.Stage.Free => - state - case stage: State.Allocated.Stage.Busy => - state.copy( - stage = stage.copy( - tasks = stage - .tasks - .filter { _ ne task })) - } - - case state: State.Released => - state.copy(tasks = - state - .tasks - .filter { _ ne task }) - } - } + .onCancel { removeTask(task) } .rethrow .map { case (id, entry) => (entry.value, releaseOf(id, entry)) @@ -554,21 +695,18 @@ object ResourcePool { } { entry => entry.fold { IllegalStateError(s"entry is not defined, id: $id").raiseError[F, Either[Int, F[Result]]] - } { entry0 => - now.flatMap { timestamp => - val entry = entry0.copy(timestamp = timestamp) - apply { - state.copy( - stage = stage.copy(ids), - entries = state.entries.updated( - id, - entry0 - .copy(timestamp = timestamp) - .some)) - } { - (entry0.value, releaseOf(id, entry)).pure[F].pure[F] + } { entry => + entry + .renew + .flatMap { entry => + apply { + state.copy( + stage = stage.copy(ids), + entries = state.entries.updated(id, entry.some)) + } { + (entry.value, releaseOf(id, entry)).pure[F].pure[F] + } } - } } } @@ -580,153 +718,32 @@ object ResourcePool { val id = state.id apply { state.copy( - id = id + 1, + id = id + 1, entries = state.entries.updated(id, none)) } { resource .apply(id.toString) .allocated + .onCancel { entryRemove(id, CancelledError) } .attempt .flatMap { case Right((value, release)) => // resource was allocated for { - timestamp <- now - entry = Entry( - value = value, - release = { - val result = for { - result <- release.attempt - result <- ref - .modify { - case state: State.Allocated => - ( - state.copy(releasing = state.releasing - id), - ().pure[F] - ) - - case state: State.Released => - val releasing = state.releasing - id - ( - state.copy(releasing = releasing), - result match { - case Right(a) => - if (releasing.isEmpty && state.allocated.isEmpty) { - // this was the last resource in a pool, - // we can release the pool itself now - state - .released - .complete(a.asRight) - .handleError { _ => () } - } else { - ().pure[F] - } - case Left(a) => - state - .released - .complete(a.asLeft) - .handleError { _ => () } - } - ) - } - .flatten - .uncancelable - } yield result - result - .start - .void - }, - timestamp = timestamp) - _ <- ref - .access - .flatMap { - case (state: State.Allocated, set) => - set - .apply { state.copy(entries = state.entries.updated(id, entry.some)) } - .map { - case true => ().asRight[Int] - case false => (count + 1).asLeft[Unit] - } - case (_: State.Released, _) => - () - .asRight[Int] - .pure[F] - } + now <- now + entry = Entry( + value = value, + release = entryRelease(id, release) + .start + .void, + timestamp = now) + _ <- entryAdd(id, entry) } yield { (value, releaseOf(id, entry)).pure[F] } case Left(a) => // resource failed to allocate - ref - .modify { - case state: State.Allocated => - - val entries = state.entries - id - - def stateOf(stage: State.Allocated.Stage) = { - state.copy( - entries = entries, - stage = stage) - } - - if (entries.isEmpty) { - state.stage match { - case stage: State.Allocated.Stage.Free => - ( - stateOf(stage), - ().pure[F] - ) - case stage: State.Allocated.Stage.Busy => - ( - stateOf(State.Allocated.Stage.free(List.empty)), - stage - .tasks - .foldMapM { _.complete(a.asLeft) } - ) - } - } else { - ( - stateOf(stage), - ().pure[F] - ) - } - - case state: State.Released => - - val allocated = state.allocated - id - - def stateOf(tasks: Tasks) = { - state.copy( - allocated = allocated, - tasks = tasks) - } - - if (allocated.isEmpty) { - ( - stateOf(Queue.empty), - state - .tasks - .foldMapM { _.complete(a.asLeft) } - .productR { - if (state.releasing.isEmpty) { - state - .released - .complete(().asRight) - .handleError { _ => () } - } else { - ().pure[F] - } - } - ) - } else { - ( - stateOf(state.tasks), - ().pure[F] - ) - } - } - .flatten - .productR { a.raiseError[F, F[Result]] } + entryRemove(id, a).productR { a.raiseError[F, F[Result]] } } } } else { @@ -741,7 +758,7 @@ object ResourcePool { case (_: State.Released, _) => ReleasedError.raiseError[F, Either[Int, F[Result]]] - } + } } .flatten } @@ -762,6 +779,8 @@ object ResourcePool { final case object ReleasedError extends RuntimeException("released") with NoStackTrace + final case object CancelledError extends RuntimeException("cancelled") with NoStackTrace + final case class IllegalStateError(msg: String) extends RuntimeException(msg) with NoStackTrace diff --git a/src/test/scala/com/evolution/resourcepool/ResourcePoolTest.scala b/src/test/scala/com/evolution/resourcepool/ResourcePoolTest.scala index e7acde7..32cfb47 100644 --- a/src/test/scala/com/evolution/resourcepool/ResourcePoolTest.scala +++ b/src/test/scala/com/evolution/resourcepool/ResourcePoolTest.scala @@ -10,6 +10,7 @@ import org.scalatest.funsuite.AsyncFunSuite import org.scalatest.matchers.should.Matchers import scala.concurrent.TimeoutException + import scala.concurrent.duration._ import scala.util.control.NoStackTrace @@ -337,13 +338,19 @@ class ResourcePoolTest extends AsyncFunSuite with Matchers { ) } yield { for { - fiber0 <- pool.resource.use { _.pure[IO] }.start + fiber0 <- pool + .resource + .use { _.pure[IO] } + .start result <- fiber0 .join .timeout(10.millis) .attempt _ <- IO { result should matchPattern { case Left(_: TimeoutException) => } } - fiber1 <- pool.resource.use { _.pure[IO] }.start + fiber1 <- pool + .resource + .use { _.pure[IO] } + .start result <- fiber1 .join .timeout(10.millis) @@ -361,37 +368,68 @@ class ResourcePoolTest extends AsyncFunSuite with Matchers { .run() } - test("cancel `get` while allocating resource") { + ignore("cancel `get` while allocating resource") { + IO + .never + .void + .toResource + .toResourcePool( + maxSize = 1, + expireAfter = 1.day, + ) + .use { pool => + for { + fiber0 <- pool + .resource + .use { _.pure[IO] } + .start + result <- fiber0 + .join + .timeout(10.millis) + .attempt + _ <- IO { result should matchPattern { case Left(_: TimeoutException) => } } + fiber1 <- pool + .resource + .use { _.pure[IO] } + .start + _ <- fiber0.cancel + result <- fiber1 + .join + .attempt + _ <- IO { result shouldEqual ResourcePool.CancelledError.asLeft } + } yield {} + } + .run() + } + + ignore("cancel `get` while allocating resource, maxSize = 2") { val result = for { - deferred0 <- Deferred[IO, Unit].toResource - deferred1 <- Deferred[IO, Unit].toResource - pool <- deferred0 - .complete(()) - .productR { deferred1.get } + ref <- Ref[IO] + .of(IO.never.void.some) + .toResource + pool <- ref + .getAndSet(none) + .flatMap { _.foldA } .toResource .toResourcePool( - maxSize = 1, - expireAfter = 1.day) + maxSize = 2, + expireAfter = 1.day, + ) } yield { for { fiber0 <- pool .resource - .use { _ => IO.never } + .use { _.pure[IO] } .start result <- fiber0 .join .timeout(10.millis) .attempt _ <- IO { result should matchPattern { case Left(_: TimeoutException) => } } - fiber1 <- pool + _ <- pool .resource .use { _.pure[IO] } - .start - _ <- deferred0.get - fiber2 <- fiber0.cancel.start - _ <- deferred1.complete(()) - _ <- fiber1.join - _ <- fiber2.join + _ <- fiber0.cancel } yield {} } result @@ -435,6 +473,44 @@ class ResourcePoolTest extends AsyncFunSuite with Matchers { .run() } + ignore("cancel `resource` while waiting in queue") { + val result = for { + deferred0 <- Deferred[IO, Unit].toResource + pool <- () + .pure[Resource[IO, *]] + .toResourcePool( + maxSize = 1, + expireAfter = 1.day) + } yield { + for { + fiber0 <- pool + .resource + .use { _ => + deferred0 + .complete(()) + .productR { IO.never } + } + .start + _ <- deferred0.get + fiber1 <- pool + .resource + .use { _ => IO.never } + .start + result <- fiber1 + .join + .timeout(10.millis) + .attempt + _ <- IO { result should matchPattern { case Left(_: TimeoutException) => } } + _ <- fiber1.cancel + _ <- fiber0.cancel + } yield {} + } + result + .use(identity) + .run() + } + + test("cancel `resource.use") { val result = for { deferred0 <- Deferred[IO, Unit]