Skip to content

Commit

Permalink
Fix incorrect behavior in lenient tagged union decoders
Browse files Browse the repository at this point in the history
  • Loading branch information
msosnicki committed Nov 8, 2024
1 parent 31e6188 commit 81b96d6
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 87 deletions.
137 changes: 50 additions & 87 deletions modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -979,27 +979,7 @@ private[smithy4s] class SchemaVisitorJCodec(
private def taggedUnion[U](
alternatives: Vector[Alt[U, _]]
)(dispatch: Alt.Dispatcher[U]): JCodec[U] =
new JCodec[U] {
val expecting: String = "tagged-union"

override def canBeKey: Boolean = false

def jsonLabel[A](alt: Alt[U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
}

private[this] val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](alt: Alt[U, A]) = {
val codec = apply(alt.schema)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
}

alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}
new TaggedUnionJCodec[U](alternatives)(dispatch) {

def decodeValue(cursor: Cursor, in: JsonReader): U =
if (in.isNextToken('{')) {
Expand All @@ -1020,66 +1000,73 @@ private[smithy4s] class SchemaVisitorJCodec(
}
}
} else in.decodeError("Expected JSON object")
}

val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance.compile(self)
a =>
out => {
out.writeObjectStart()
out.writeKey(jsonLabel)
jcodecA.encodeValue(a, out)
out.writeObjectEnd()
}
}
}
val writer = dispatch.compile(precompiler)
private abstract class TaggedUnionJCodec[U](alternatives: Vector[Alt[U, _]])(
dispatch: Alt.Dispatcher[U]
) extends JCodec[U] {

def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
val expecting = "tagged-union"

override def canBeKey: Boolean = false

def jsonLabel[A](alt: Alt[U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
}

def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")
protected val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](alt: Alt[U, A]) = {
val codec = apply(alt.schema)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
}

def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")
}
alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}

private def lenientTaggedUnion[U](
alternatives: Vector[Alt[U, _]]
)(dispatch: Alt.Dispatcher[U]): JCodec[U] =
new JCodec[U] {
val expecting: String = "tagged-union"
protected val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance.compile(self)
a =>
out => {
out.writeObjectStart()
out.writeKey(jsonLabel)
jcodecA.encodeValue(a, out)
out.writeObjectEnd()
}
}
}
protected val writer = dispatch.compile(precompiler)

override def canBeKey: Boolean = false
def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}

def jsonLabel[A](alt: Alt[U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
}
def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")

private[this] val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](alt: Alt[U, A]) = {
val codec = apply(alt.schema)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
}
def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")

alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}
}

private def lenientTaggedUnion[U](
alternatives: Vector[Alt[U, _]]
)(dispatch: Alt.Dispatcher[U]): JCodec[U] =
new TaggedUnionJCodec[U](alternatives)(dispatch) {
def decodeValue(cursor: Cursor, in: JsonReader): U = {
var result: U = null.asInstanceOf[U]
if (in.isNextToken('{')) {
if (!in.isNextToken('}')) {
in.rollbackToken()
while ({
val key = in.readKeyAsString()
cursor.push(key)
val handler = handlerMap.get(key)
if (handler eq null) in.skip()
else if (in.isNextToken('n')) {
Expand All @@ -1103,31 +1090,7 @@ private[smithy4s] class SchemaVisitorJCodec(
}
} else in.decodeError("Expected JSON object")
}
val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance.compile(self)
a =>
out => {
out.writeObjectStart()
out.writeKey(jsonLabel)
jcodecA.encodeValue(a, out)
out.writeObjectEnd()
}
}
}
val writer = dispatch.compile(precompiler)

def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}

def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")

def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")
}

private def untaggedUnion[U](
Expand Down
38 changes: 38 additions & 0 deletions modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,44 @@ class SchemaVisitorJCodecTests() extends FunSuite {
expect.same(readFromString[Either[Int, String]](json2), Left(1))
}

test("Lenient and regular unions have the same error messages") {
val json = """|{
| "left" : {"foo": "b"}
|}
|""".stripMargin

val schema = Schema.either(
Schema
.struct[String](
Schema.string
.required[String]("bar", identity)
)(identity),
Schema
.struct[String](
Schema.string
.required[String]("baz", identity)
)(identity)
)

val regularCodec =
JsoniterCodecCompilerImpl.defaultJsoniterCodecCompiler.fromSchema(schema)
val lenientCodec =
JsoniterCodecCompilerImpl.defaultJsoniterCodecCompiler.withLenientTaggedUnionDecoding
.fromSchema(schema)

def decodeCheck(codec: JsonCodec[Either[String, String]]) =
expect.same(
Try(
readFromString[Either[String, String]](json)(codec)
).toEither.left.map(_.getMessage),
Left("Missing required field (path: .left.bar)")
)

decodeCheck(regularCodec)
decodeCheck(lenientCodec)

}

test("Untagged union are encoded / decoded") {
val oneJ = """ {"three":"three_value"}"""
val twoJ = """ {"four":4}"""
Expand Down

0 comments on commit 81b96d6

Please sign in to comment.