From 28d81b76770523ad26de679e64e328511a8c9ac8 Mon Sep 17 00:00:00 2001 From: vanjaftn <92813097+vanjaftn@users.noreply.github.com> Date: Sat, 2 Dec 2023 15:16:22 +0100 Subject: [PATCH] Refactor AggregationResponse (#368) --- .../zio/elasticsearch/HttpExecutorSpec.scala | 6 +- .../response/AggregationResponse.scala | 324 +++++++----------- 2 files changed, 121 insertions(+), 209 deletions(-) diff --git a/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala b/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala index 0b5adc254..a99b97cf4 100644 --- a/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala +++ b/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala @@ -467,7 +467,7 @@ object HttpExecutorSpec extends IntegrationSpec { Executor.execute(ElasticRequest.createIndex(firstSearchIndex)), Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie ), - test("aggregate using terms aggregation") { + test("aggregate using terms aggregation with max aggregation as a sub aggregation") { checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument) { (firstDocumentId, firstDocument, secondDocumentId, secondDocument) => for { @@ -481,7 +481,9 @@ object HttpExecutorSpec extends IntegrationSpec { .refreshTrue ) aggregation = - termsAggregation(name = "aggregationString", field = TestDocument.stringField.keyword) + termsAggregation(name = "aggregationString", field = TestDocument.stringField.keyword).withSubAgg( + maxAggregation("subAggregation", TestDocument.intField) + ) aggsRes <- Executor .execute(ElasticRequest.aggregate(selectors = firstSearchIndex, aggregation = aggregation)) diff --git a/modules/library/src/main/scala/zio/elasticsearch/executor/response/AggregationResponse.scala b/modules/library/src/main/scala/zio/elasticsearch/executor/response/AggregationResponse.scala index 0066add8c..12497006e 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/executor/response/AggregationResponse.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/executor/response/AggregationResponse.scala @@ -118,6 +118,112 @@ private[elasticsearch] object AvgAggregationResponse { implicit val decoder: JsonDecoder[AvgAggregationResponse] = DeriveJsonDecoder.gen[AvgAggregationResponse] } +private[elasticsearch] case class BucketDecoder(fields: Chunk[(String, Json)]) extends JsonDecoderOps { + val allFields: Map[String, Any] = fields.flatMap { case (field, data) => + field match { + case "key" => + Some(field -> data.toString.replaceAll("\"", "")) + case "doc_count" => + Some(field -> data.unsafeAs[Int]) + case _ => + val objFields = data.unsafeAs[Obj].fields.toMap + + (field: @unchecked) match { + case str if str.contains("weighted_avg#") => + Some(field -> WeightedAvgAggregationResponse(value = objFields("value").unsafeAs[Double])) + case str if str.contains("avg#") => + Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double])) + case str if str.contains("cardinality#") => + Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int])) + case str if str.contains("extended_stats#") => + Some( + field -> ExtendedStatsAggregationResponse( + count = objFields("count").unsafeAs[Int], + min = objFields("min").unsafeAs[Double], + max = objFields("max").unsafeAs[Double], + avg = objFields("avg").unsafeAs[Double], + sum = objFields("sum").unsafeAs[Double], + sumOfSquares = objFields("sum_of_squares").unsafeAs[Double], + variance = objFields("variance").unsafeAs[Double], + variancePopulation = objFields("variance_population").unsafeAs[Double], + varianceSampling = objFields("variance_sampling").unsafeAs[Double], + stdDeviation = objFields("std_deviation").unsafeAs[Double], + stdDeviationPopulation = objFields("std_deviation_population").unsafeAs[Double], + stdDeviationSampling = objFields("std_deviation_sampling").unsafeAs[Double], + stdDeviationBoundsResponse = objFields("std_deviation_sampling").unsafeAs[StdDeviationBoundsResponse]( + StdDeviationBoundsResponse.decoder + ) + ) + ) + case str if str.contains("filter#") => + Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder)) + case str if str.contains("max#") => + Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double])) + case str if str.contains("min#") => + Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double])) + case str if str.contains("missing#") => + Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int])) + case str if str.contains("percentile_ranks#") => + Some( + field -> PercentileRanksAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]) + ) + case str if str.contains("percentiles#") => + Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]])) + case str if str.contains("stats#") => + Some( + field -> StatsAggregationResponse( + count = objFields("count").unsafeAs[Int], + min = objFields("min").unsafeAs[Double], + max = objFields("max").unsafeAs[Double], + avg = objFields("avg").unsafeAs[Double], + sum = objFields("sum").unsafeAs[Double] + ) + ) + case str if str.contains("sum#") => + Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double])) + case str if str.contains("terms#") => + Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder)) + case str if str.contains("value_count#") => + Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int])) + } + } + }.toMap + + val subAggs: Map[String, AggregationResponse] = allFields.collect { + case (field, data) if field != "doc_count" && field != "key" => + (field: @unchecked) match { + case str if str.contains("weighted_avg#") => + (field.split("#")(1), data.asInstanceOf[WeightedAvgAggregationResponse]) + case str if str.contains("avg#") => + (field.split("#")(1), data.asInstanceOf[AvgAggregationResponse]) + case str if str.contains("cardinality#") => + (field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse]) + case str if str.contains("extended_stats#") => + (field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse]) + case str if str.contains("filter#") => + (field.split("#")(1), data.asInstanceOf[FilterAggregationResponse]) + case str if str.contains("max#") => + (field.split("#")(1), data.asInstanceOf[MaxAggregationResponse]) + case str if str.contains("min#") => + (field.split("#")(1), data.asInstanceOf[MinAggregationResponse]) + case str if str.contains("missing#") => + (field.split("#")(1), data.asInstanceOf[MissingAggregationResponse]) + case str if str.contains("percentile_ranks#") => + (field.split("#")(1), data.asInstanceOf[PercentileRanksAggregationResponse]) + case str if str.contains("percentiles#") => + (field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse]) + case str if str.contains("stats#") => + (field.split("#")(1), data.asInstanceOf[StatsAggregationResponse]) + case str if str.contains("sum#") => + (field.split("#")(1), data.asInstanceOf[SumAggregationResponse]) + case str if str.contains("terms#") => + (field.split("#")(1), data.asInstanceOf[TermsAggregationResponse]) + case str if str.contains("value_count#") => + (field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse]) + } + } +} + private[elasticsearch] final case class CardinalityAggregationResponse(value: Int) extends AggregationResponse private[elasticsearch] object CardinalityAggregationResponse { @@ -159,110 +265,13 @@ private[elasticsearch] final case class FilterAggregationResponse( subAggregations: Option[Map[String, AggregationResponse]] = None ) extends AggregationResponse -private[elasticsearch] object FilterAggregationResponse extends JsonDecoderOps { +private[elasticsearch] object FilterAggregationResponse { implicit val decoder: JsonDecoder[FilterAggregationResponse] = Obj.decoder.mapOrFail { case Obj(fields) => - val allFields = fields.flatMap { case (field, data) => - field match { - case "doc_count" => - Some(field -> data.unsafeAs[Int]) - case _ => - val objFields = data.unsafeAs[Obj].fields.toMap - - (field: @unchecked) match { - case str if str.contains("weighted_avg#") => - Some(field -> WeightedAvgAggregationResponse(value = objFields("value").unsafeAs[Double])) - case str if str.contains("avg#") => - Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double])) - case str if str.contains("cardinality#") => - Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int])) - case str if str.contains("extended_stats#") => - Some( - field -> ExtendedStatsAggregationResponse( - count = objFields("count").unsafeAs[Int], - min = objFields("min").unsafeAs[Double], - max = objFields("max").unsafeAs[Double], - avg = objFields("avg").unsafeAs[Double], - sum = objFields("sum").unsafeAs[Double], - sumOfSquares = objFields("sum_of_squares").unsafeAs[Double], - variance = objFields("variance").unsafeAs[Double], - variancePopulation = objFields("variance_population").unsafeAs[Double], - varianceSampling = objFields("variance_sampling").unsafeAs[Double], - stdDeviation = objFields("std_deviation").unsafeAs[Double], - stdDeviationPopulation = objFields("std_deviation_population").unsafeAs[Double], - stdDeviationSampling = objFields("std_deviation_sampling").unsafeAs[Double], - stdDeviationBoundsResponse = objFields("std_deviation_sampling").unsafeAs[StdDeviationBoundsResponse]( - StdDeviationBoundsResponse.decoder - ) - ) - ) - case str if str.contains("filter#") => - Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder)) - case str if str.contains("max#") => - Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double])) - case str if str.contains("min#") => - Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double])) - case str if str.contains("missing#") => - Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int])) - case str if str.contains("percentile_ranks#") => - Some( - field -> PercentileRanksAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]) - ) - case str if str.contains("percentiles#") => - Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]])) - case str if str.contains("stats#") => - Some( - field -> StatsAggregationResponse( - count = objFields("count").unsafeAs[Int], - min = objFields("min").unsafeAs[Double], - max = objFields("max").unsafeAs[Double], - avg = objFields("avg").unsafeAs[Double], - sum = objFields("sum").unsafeAs[Double] - ) - ) - case str if str.contains("sum#") => - Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double])) - case str if str.contains("terms#") => - Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder)) - case str if str.contains("value_count#") => - Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int])) - } - } - }.toMap + val bucketDecoder = BucketDecoder(fields) + val allFields = bucketDecoder.allFields + val docCount = allFields("doc_count").asInstanceOf[Int] + val subAggs = bucketDecoder.subAggs - val docCount = allFields("doc_count").asInstanceOf[Int] - val subAggs = allFields.collect { - case (field, data) if field != "doc_count" => - (field: @unchecked) match { - case str if str.contains("weighted_avg#") => - (field.split("#")(1), data.asInstanceOf[WeightedAvgAggregationResponse]) - case str if str.contains("avg#") => - (field.split("#")(1), data.asInstanceOf[AvgAggregationResponse]) - case str if str.contains("cardinality#") => - (field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse]) - case str if str.contains("extended_stats#") => - (field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse]) - case str if str.contains("filter#") => - (field.split("#")(1), data.asInstanceOf[FilterAggregationResponse]) - case str if str.contains("max#") => - (field.split("#")(1), data.asInstanceOf[MaxAggregationResponse]) - case str if str.contains("min#") => - (field.split("#")(1), data.asInstanceOf[MinAggregationResponse]) - case str if str.contains("missing#") => - (field.split("#")(1), data.asInstanceOf[MissingAggregationResponse]) - case str if str.contains("percentile_ranks#") => - (field.split("#")(1), data.asInstanceOf[PercentileRanksAggregationResponse]) - case str if str.contains("percentiles#") => - (field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse]) - case str if str.contains("stats#") => - (field.split("#")(1), data.asInstanceOf[StatsAggregationResponse]) - case str if str.contains("sum#") => - (field.split("#")(1), data.asInstanceOf[SumAggregationResponse]) - case str if str.contains("terms#") => - (field.split("#")(1), data.asInstanceOf[TermsAggregationResponse]) - case str if str.contains("value_count#") => - (field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse]) - } - } Right(FilterAggregationResponse.apply(docCount, Option(subAggs).filter(_.nonEmpty))) } } @@ -366,113 +375,14 @@ private[elasticsearch] final case class TermsAggregationBucket( subAggregations: Option[Map[String, AggregationResponse]] = None ) extends AggregationBucket -private[elasticsearch] object TermsAggregationBucket extends JsonDecoderOps { +private[elasticsearch] object TermsAggregationBucket { implicit val decoder: JsonDecoder[TermsAggregationBucket] = Obj.decoder.mapOrFail { case Obj(fields) => - val allFields = fields.flatMap { case (field, data) => - field match { - case "key" => - Some(field -> data.toString.replaceAll("\"", "")) - case "doc_count" => - Some(field -> data.unsafeAs[Int]) - case _ => - val objFields = data.unsafeAs[Obj].fields.toMap - - (field: @unchecked) match { - case str if str.contains("weighted_avg#") => - Some(field -> WeightedAvgAggregationResponse(value = objFields("value").unsafeAs[Double])) - case str if str.contains("avg#") => - Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double])) - case str if str.contains("cardinality#") => - Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int])) - case str if str.contains("extended_stats#") => - Some( - field -> ExtendedStatsAggregationResponse( - count = objFields("count").unsafeAs[Int], - min = objFields("min").unsafeAs[Double], - max = objFields("max").unsafeAs[Double], - avg = objFields("avg").unsafeAs[Double], - sum = objFields("sum").unsafeAs[Double], - sumOfSquares = objFields("sum_of_squares").unsafeAs[Double], - variance = objFields("variance").unsafeAs[Double], - variancePopulation = objFields("variance_population").unsafeAs[Double], - varianceSampling = objFields("variance_sampling").unsafeAs[Double], - stdDeviation = objFields("std_deviation").unsafeAs[Double], - stdDeviationPopulation = objFields("std_deviation_population").unsafeAs[Double], - stdDeviationSampling = objFields("std_deviation_sampling").unsafeAs[Double], - stdDeviationBoundsResponse = objFields("std_deviation_sampling").unsafeAs[StdDeviationBoundsResponse]( - StdDeviationBoundsResponse.decoder - ) - ) - ) - case str if str.contains("filter#") => - Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder)) - case str if str.contains("max#") => - Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double])) - case str if str.contains("min#") => - Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double])) - case str if str.contains("missing#") => - Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int])) - case str if str.contains("percentile_ranks#") => - Some( - field -> PercentileRanksAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]) - ) - case str if str.contains("percentiles#") => - Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]])) - case str if str.contains("stats#") => - Some( - field -> StatsAggregationResponse( - count = objFields("count").unsafeAs[Int], - min = objFields("min").unsafeAs[Double], - max = objFields("max").unsafeAs[Double], - avg = objFields("avg").unsafeAs[Double], - sum = objFields("sum").unsafeAs[Double] - ) - ) - case str if str.contains("sum#") => - Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double])) - case str if str.contains("terms#") => - Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder)) - case str if str.contains("value_count#") => - Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int])) - } - } - }.toMap + val bucketDecoder = BucketDecoder(fields) + val allFields = bucketDecoder.allFields + val docCount = allFields("doc_count").asInstanceOf[Int] + val key = allFields("key").asInstanceOf[String] + val subAggs = bucketDecoder.subAggs - val key = allFields("key").asInstanceOf[String] - val docCount = allFields("doc_count").asInstanceOf[Int] - val subAggs = allFields.collect { - case (field, data) if field != "key" && field != "doc_count" => - (field: @unchecked) match { - case str if str.contains("weighted_avg#") => - (field.split("#")(1), data.asInstanceOf[WeightedAvgAggregationResponse]) - case str if str.contains("avg#") => - (field.split("#")(1), data.asInstanceOf[AvgAggregationResponse]) - case str if str.contains("cardinality#") => - (field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse]) - case str if str.contains("extended_stats#") => - (field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse]) - case str if str.contains("filter#") => - (field.split("#")(1), data.asInstanceOf[FilterAggregationResponse]) - case str if str.contains("max#") => - (field.split("#")(1), data.asInstanceOf[MaxAggregationResponse]) - case str if str.contains("min#") => - (field.split("#")(1), data.asInstanceOf[MinAggregationResponse]) - case str if str.contains("missing#") => - (field.split("#")(1), data.asInstanceOf[MissingAggregationResponse]) - case str if str.contains("percentile_ranks#") => - (field.split("#")(1), data.asInstanceOf[PercentileRanksAggregationResponse]) - case str if str.contains("percentiles#") => - (field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse]) - case str if str.contains("stats#") => - (field.split("#")(1), data.asInstanceOf[StatsAggregationResponse]) - case str if str.contains("sum#") => - (field.split("#")(1), data.asInstanceOf[SumAggregationResponse]) - case str if str.contains("terms#") => - (field.split("#")(1), data.asInstanceOf[TermsAggregationResponse]) - case str if str.contains("value_count#") => - (field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse]) - } - } Right(TermsAggregationBucket.apply(key, docCount, Option(subAggs).filter(_.nonEmpty))) } }