From 81b96d693db9434dcffe07bfa08aac566721532d Mon Sep 17 00:00:00 2001 From: msosnicki Date: Fri, 8 Nov 2024 12:58:46 +0100 Subject: [PATCH] Fix incorrect behavior in lenient tagged union decoders --- .../json/internals/SchemaVisitorJCodec.scala | 137 +++++++----------- .../json/SchemaVisitorJCodecTests.scala | 38 +++++ 2 files changed, 88 insertions(+), 87 deletions(-) diff --git a/modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala b/modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala index 354b72e81..709dda972 100644 --- a/modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala +++ b/modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala @@ -979,27 +979,7 @@ private[smithy4s] class SchemaVisitorJCodec( private def taggedUnion[U]( alternatives: Vector[Alt[U, _]] )(dispatch: Alt.Dispatcher[U]): JCodec[U] = - new JCodec[U] { - val expecting: String = "tagged-union" - - override def canBeKey: Boolean = false - - def jsonLabel[A](alt: Alt[U, A]): String = - alt.hints.get(JsonName) match { - case None => alt.label - case Some(x) => x.value - } - - private[this] val handlerMap = - new util.HashMap[String, (Cursor, JsonReader) => U] { - def handler[A](alt: Alt[U, A]) = { - val codec = apply(alt.schema) - (cursor: Cursor, reader: JsonReader) => - alt.inject(cursor.decode(codec, reader)) - } - - alternatives.foreach(alt => put(jsonLabel(alt), handler(alt))) - } + new TaggedUnionJCodec[U](alternatives)(dispatch) { def decodeValue(cursor: Cursor, in: JsonReader): U = if (in.isNextToken('{')) { @@ -1020,59 +1000,65 @@ private[smithy4s] class SchemaVisitorJCodec( } } } else in.decodeError("Expected JSON object") + } - val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] { - def apply[A](label: String, instance: Schema[A]): Writer[A] = { - val jsonLabel = - instance.hints.get(JsonName).map(_.value).getOrElse(label) - val jcodecA = instance.compile(self) - a => - out => { - out.writeObjectStart() - out.writeKey(jsonLabel) - jcodecA.encodeValue(a, out) - out.writeObjectEnd() - } - } - } - val writer = dispatch.compile(precompiler) + private abstract class TaggedUnionJCodec[U](alternatives: Vector[Alt[U, _]])( + dispatch: Alt.Dispatcher[U] + ) extends JCodec[U] { - def encodeValue(u: U, out: JsonWriter): Unit = { - writer(u)(out) + val expecting = "tagged-union" + + override def canBeKey: Boolean = false + + def jsonLabel[A](alt: Alt[U, A]): String = + alt.hints.get(JsonName) match { + case None => alt.label + case Some(x) => x.value } - def decodeKey(in: JsonReader): U = - in.decodeError("Cannot use coproducts as keys") + protected val handlerMap = + new util.HashMap[String, (Cursor, JsonReader) => U] { + def handler[A](alt: Alt[U, A]) = { + val codec = apply(alt.schema) + (cursor: Cursor, reader: JsonReader) => + alt.inject(cursor.decode(codec, reader)) + } - def encodeKey(u: U, out: JsonWriter): Unit = - out.encodeError("Cannot use coproducts as keys") - } + alternatives.foreach(alt => put(jsonLabel(alt), handler(alt))) + } - private def lenientTaggedUnion[U]( - alternatives: Vector[Alt[U, _]] - )(dispatch: Alt.Dispatcher[U]): JCodec[U] = - new JCodec[U] { - val expecting: String = "tagged-union" + protected val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] { + def apply[A](label: String, instance: Schema[A]): Writer[A] = { + val jsonLabel = + instance.hints.get(JsonName).map(_.value).getOrElse(label) + val jcodecA = instance.compile(self) + a => + out => { + out.writeObjectStart() + out.writeKey(jsonLabel) + jcodecA.encodeValue(a, out) + out.writeObjectEnd() + } + } + } + protected val writer = dispatch.compile(precompiler) - override def canBeKey: Boolean = false + def encodeValue(u: U, out: JsonWriter): Unit = { + writer(u)(out) + } - def jsonLabel[A](alt: Alt[U, A]): String = - alt.hints.get(JsonName) match { - case None => alt.label - case Some(x) => x.value - } + def decodeKey(in: JsonReader): U = + in.decodeError("Cannot use coproducts as keys") - private[this] val handlerMap = - new util.HashMap[String, (Cursor, JsonReader) => U] { - def handler[A](alt: Alt[U, A]) = { - val codec = apply(alt.schema) - (cursor: Cursor, reader: JsonReader) => - alt.inject(cursor.decode(codec, reader)) - } + def encodeKey(u: U, out: JsonWriter): Unit = + out.encodeError("Cannot use coproducts as keys") - alternatives.foreach(alt => put(jsonLabel(alt), handler(alt))) - } + } + private def lenientTaggedUnion[U]( + alternatives: Vector[Alt[U, _]] + )(dispatch: Alt.Dispatcher[U]): JCodec[U] = + new TaggedUnionJCodec[U](alternatives)(dispatch) { def decodeValue(cursor: Cursor, in: JsonReader): U = { var result: U = null.asInstanceOf[U] if (in.isNextToken('{')) { @@ -1080,6 +1066,7 @@ private[smithy4s] class SchemaVisitorJCodec( in.rollbackToken() while ({ val key = in.readKeyAsString() + cursor.push(key) val handler = handlerMap.get(key) if (handler eq null) in.skip() else if (in.isNextToken('n')) { @@ -1103,31 +1090,7 @@ private[smithy4s] class SchemaVisitorJCodec( } } else in.decodeError("Expected JSON object") } - val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] { - def apply[A](label: String, instance: Schema[A]): Writer[A] = { - val jsonLabel = - instance.hints.get(JsonName).map(_.value).getOrElse(label) - val jcodecA = instance.compile(self) - a => - out => { - out.writeObjectStart() - out.writeKey(jsonLabel) - jcodecA.encodeValue(a, out) - out.writeObjectEnd() - } - } - } - val writer = dispatch.compile(precompiler) - def encodeValue(u: U, out: JsonWriter): Unit = { - writer(u)(out) - } - - def decodeKey(in: JsonReader): U = - in.decodeError("Cannot use coproducts as keys") - - def encodeKey(u: U, out: JsonWriter): Unit = - out.encodeError("Cannot use coproducts as keys") } private def untaggedUnion[U]( diff --git a/modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala b/modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala index a32aaabdf..c599c54d4 100644 --- a/modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala +++ b/modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala @@ -406,6 +406,44 @@ class SchemaVisitorJCodecTests() extends FunSuite { expect.same(readFromString[Either[Int, String]](json2), Left(1)) } + test("Lenient and regular unions have the same error messages") { + val json = """|{ + | "left" : {"foo": "b"} + |} + |""".stripMargin + + val schema = Schema.either( + Schema + .struct[String]( + Schema.string + .required[String]("bar", identity) + )(identity), + Schema + .struct[String]( + Schema.string + .required[String]("baz", identity) + )(identity) + ) + + val regularCodec = + JsoniterCodecCompilerImpl.defaultJsoniterCodecCompiler.fromSchema(schema) + val lenientCodec = + JsoniterCodecCompilerImpl.defaultJsoniterCodecCompiler.withLenientTaggedUnionDecoding + .fromSchema(schema) + + def decodeCheck(codec: JsonCodec[Either[String, String]]) = + expect.same( + Try( + readFromString[Either[String, String]](json)(codec) + ).toEither.left.map(_.getMessage), + Left("Missing required field (path: .left.bar)") + ) + + decodeCheck(regularCodec) + decodeCheck(lenientCodec) + + } + test("Untagged union are encoded / decoded") { val oneJ = """ {"three":"three_value"}""" val twoJ = """ {"four":4}"""