Skip to content

Commit

Permalink
Improve Scala 3 code generation
Browse files Browse the repository at this point in the history
  • Loading branch information
eed3si9n committed Aug 24, 2024
1 parent e242798 commit 9c71e25
Show file tree
Hide file tree
Showing 18 changed files with 238 additions and 104 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ lazy val scalaxbPlugin = (project in file("sbt-scalaxb"))
case "2.12" => "1.5.8" // set minimum sbt version
}
}
scriptedSbt := sbtVersion.value
scriptedLaunchOpts := { scriptedLaunchOpts.value ++
Seq("-Xmx1024M", "-Dplugin.version=" + version.value)
}
Expand Down
10 changes: 5 additions & 5 deletions cli/src/main/resources/scalaxb.scala.template
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ object DataRecord extends XMLStandardTypes {
}

case class ElemName(namespace: Option[String], name: String) {
var node: scala.xml.Node = _
var node: scala.xml.Node = null
def text = node.text
def nil = Helper.isNil(node)
def nilOption: Option[ElemName] = if (nil) None else Some(this)
Expand Down Expand Up @@ -799,11 +799,11 @@ trait ElemNameParser[A] extends AnyElemNameParser with XMLFormat[A] with CanWrit
}

private def parserErrorMsg(msg: String, next: scala.util.parsing.input.Reader[Elem], stack: List[ElemName]): String =
if (msg contains "parser error ") msg
if (msg.contains("parser error ")) msg
else "parser error \"" + msg + "\" while parsing " + stack.reverse.mkString("/", "/", "/") + next.pos.longString

private def parserErrorMsg(msg: String, node: scala.xml.Node): String =
if (msg contains "parser error ") msg
if (msg.contains("parser error ")) msg
else "parser error \"" + msg + "\" while parsing " + node.toString

def parser(node: scala.xml.Node, stack: List[ElemName]): Parser[A]
Expand Down Expand Up @@ -963,11 +963,11 @@ object Helper {
}

def splitQName(value: String, scope: scala.xml.NamespaceBinding): (Option[String], String) =
if (value startsWith "{") {
if (value.startsWith("{")) {
val qname = javax.xml.namespace.QName.valueOf(value)
(nullOrEmpty(qname.getNamespaceURI), qname.getLocalPart)
}
else if (value contains ':') {
else if (value.contains(':')) {
val prefix = value.dropRight(value.length - value.indexOf(':'))
val localPart = value.drop(value.indexOf(':') + 1)
(nullOrEmpty(scope.getURI(prefix)), localPart)
Expand Down
21 changes: 21 additions & 0 deletions cli/src/main/scala/scalaxb/compiler/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,25 @@ case class Config(items: Map[String, ConfigEntry]) {
def enumNameMaxLength: Int = (get[EnumNameMaxLength] getOrElse defaultEnumNameMaxLength).value
def useLists: Boolean = values contains UseLists
def jaxbPackage = get[JaxbPackage] getOrElse defaultJaxbPackage
def targetScalaVersion: String = get[TargetScalaVersion].getOrElse(defaultTargetScalaVersion).value
def targetScalaPartialVersion: Option[(Long, Long)] =
partialVersion(targetScalaVersion)
private val longPattern = """\d{1,19}"""
private val PartialVersion = raw"""($longPattern)\.($longPattern)(?:\..+)?""".r
private def partialVersion(s: String): Option[(Long, Long)] =
s match {
case PartialVersion(major, minor) => Some((major.toLong, minor.toLong))
case _ => None
}
def isScala3Plus: Boolean = targetScalaPartialVersion match {
case Some((major, _)) => major >= 3
case _ => false
}
def isScala3_4Plus: Boolean = targetScalaPartialVersion match {
case Some((3, minor)) => minor >= 4
case Some((major, _)) => major > 3
case _ => false
}

private def get[A <: ConfigEntry: ClassTag]: Option[A] =
items.get(implicitly[ClassTag[A]].runtimeClass.getName).asInstanceOf[Option[A]]
Expand Down Expand Up @@ -118,6 +137,7 @@ object Config {
val defaultSymbolEncodingStrategy = SymbolEncoding.Legacy151
val defaultEnumNameMaxLength = EnumNameMaxLength(50)
val defaultJaxbPackage = JaxbPackage.Javax
val defaultTargetScalaVersion = TargetScalaVersion("2.13.14")

val default = Config(
Vector(defaultPackageNames, defaultOpOutputWrapperPostfix, defaultOutdir,
Expand Down Expand Up @@ -169,6 +189,7 @@ object ConfigEntry {
case class EnumNameMaxLength(value: Int) extends ConfigEntry
case object UseLists extends ConfigEntry
case object GenerateMapK extends ConfigEntry
case class TargetScalaVersion(value: String) extends ConfigEntry

sealed abstract class HttpClientStyle extends ConfigEntry with Product with Serializable {
final override def name: String = classOf[HttpClientStyle].getName
Expand Down
94 changes: 71 additions & 23 deletions cli/src/main/scala/scalaxb/compiler/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ trait Module {
val output = evTo.newInstance(pkg, part + ".scala")
val out = evTo.toWriter(output)
try {
printNodes(snippet.definition, out)
printNodes(snippet.definition, out, config)
} finally {
out.flush()
out.close()
Expand All @@ -336,7 +336,7 @@ trait Module {
}))
val protocolNodes = generateProtocol(Snippet(snippets.toSeq: _*), cs.context, config2)
try {
printNodes(protocolNodes, out)
printNodes(protocolNodes, out, config2)
} finally {
out.flush()
out.close()
Expand All @@ -355,19 +355,65 @@ trait Module {
else Nil)
}

def generateFromResource[To](packageName: Option[String], fileName: String, resourcePath: String, substitution: Option[(String, String)] = None)
(implicit evTo: CanBeWriter[To]) = {
def substituteMany(subs: (String, String)*): String => String =
(s: String) => subs.toSeq.foldLeft(s)((acc, x) => acc.replaceAll(x._1, x._2))

private val scala3VarArgSub = """\:\s_\*""" -> "*"
private val scala3With = """scalaxb\.(\S+)\swith\sscalaxb\.(\S+)""" -> """scalaxb.$1 & scalaxb.$2"""

def printNodes(nodes: Seq[Node], out: PrintWriter, config: Config): Unit = {
val subs = ListBuffer.empty[(String, String)]
if (config.isScala3_4Plus) {
subs += scala3VarArgSub
subs += scala3With
printNodes(nodes, out, substituteMany(subs.toList: _*))
} else {
printNodes(nodes, out)
}
}

def generateFromResource[To](
packageName: Option[String],
fileName: String,
resourcePath: String,
config: Config,
subs0: (String, String)*
)(implicit evTo: CanBeWriter[To]): To = {
val output = implicitly[CanBeWriter[To]].newInstance(packageName, fileName)
val out = implicitly[CanBeWriter[To]].toWriter(output)
try {
printFromResource(resourcePath, out, substitution)
val subs = ListBuffer.empty[(String, String)]
subs ++= subs0
if (config.isScala3_4Plus) {
subs += scala3VarArgSub
}
val transform = substituteMany(subs: _*)
printFromResource(resourcePath, out, transform)
} finally {
out.flush()
out.close()
}
output
}

def generateBaseRuntimeFiles[To](cntxt: Context, config: Config)(implicit evTo: CanBeWriter[To]): List[To] = {
val subs = ListBuffer.empty[(String, String)]
subs += "%%JAXB_PACKAGE%%" -> config.jaxbPackage.packageName
if (config.isScala3Plus) {
subs += """CanWriteXML\[_\]""" -> "CanWriteXML[?]"
}

List(
generateFromResource[To](
Some("scalaxb"),
"scalaxb.scala",
"/scalaxb.scala.template",
config,
subs: _*,
),
)
}

def generateRuntimeFiles[To](context: Context, config: Config)(implicit evTo: CanBeWriter[To]): List[To]

// returns a seq of package name, snippet, and file name part tuple
Expand Down Expand Up @@ -432,17 +478,17 @@ trait Module {
def parse(location: URI, in: Reader): Schema
= parse(toImportable(location, readerToRawSchema(in)), buildContext)

def printNodes(nodes: Seq[Node], out: PrintWriter): Unit = {
import scala.xml._
def printNodes(nodes: Seq[Node], out: PrintWriter, transform: String => String = identity): Unit = {
import scala.xml.{ transform => _, _ }

def printNode(n: Node): Unit = n match {
case Text(s) => out.print(s)
case Text(s) => out.print(transform(s))
case EntityRef("lt") => out.print('<')
case EntityRef("gt") => out.print('>')
case EntityRef("amp") => out.print('&')
case atom: Atom[_] => out.print(atom.text)
case atom: Atom[_] => out.print(transform(atom.text))
case elem: Elem =>
printNodes(elem.child, out)
printNodes(elem.child, out, transform)
if (elem.text != "") {
if (elem.text.contains(newline)) out.println("")
out.println("")
Expand All @@ -454,25 +500,27 @@ trait Module {
for (node <- nodes) { printNode(node) }
}

def printFromResource(source: String, out: PrintWriter, substitution: Option[(String, String)] = None): Unit = {
def printFromResource(source: String, out: PrintWriter, transform: String => String = identity): Unit = {
val in = getClass.getResourceAsStream(source)
val reader = new java.io.BufferedReader(new java.io.InputStreamReader(in))
var line: Option[String] = None
line = Option[String](reader.readLine)
while (line != None) {
(line, substitution) match {
case (Some(l), Some((target, replacement))) => out.println(l.replace(target, replacement))
case (Some(l), None) => out.println(l)
case _ => // do nothing
}
try {
val reader = new java.io.BufferedReader(new java.io.InputStreamReader(in))
var line: Option[String] = None
line = Option[String](reader.readLine)
while (line != None) {
line match {
case Some(l) => out.println(transform(l))
case _ => // do nothing
}
line = Option[String](reader.readLine)
}
} finally {
in.close()
}
in.close
out.flush
}

def copyFileFromResource(source: String, dest: File, substitution: Option[(String, String)] = None) =
printFromResource(source, new java.io.PrintWriter(new java.io.FileWriter(dest)), substitution)
def copyFileFromResource(source: String, dest: File, transform: String => String = identity) =
printFromResource(source, new java.io.PrintWriter(new java.io.FileWriter(dest)), transform)

def mergeSnippets(snippets: Seq[Snippet]) =
Snippet(snippets flatMap {_.definition},
Expand Down
67 changes: 34 additions & 33 deletions cli/src/main/scala/scalaxb/compiler/wsdl11/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,12 @@ class Driver extends Module { driver =>
}
}

def generateDispatchFromResource[To](style: HttpClientStyle, filePrefix: String)(implicit evTo: CanBeWriter[To]): List[To] = {
def generateDispatchFromResource[To](style: HttpClientStyle, filePrefix: String, config: Config)(implicit evTo: CanBeWriter[To]): List[To] = {
def gen(asyncSuffix: String) = generateFromResource[To](
Some("scalaxb"),
s"httpclients_dispatch${asyncSuffix}.scala",
s"${filePrefix}${asyncSuffix}.scala.template"
s"${filePrefix}${asyncSuffix}.scala.template",
config,
)
style match {
case HttpClientStyle.Sync => gen("") :: Nil
Expand All @@ -212,83 +213,83 @@ class Driver extends Module { driver =>
}

def generateRuntimeFiles[To](cntxt: Context, config: Config)(implicit evTo: CanBeWriter[To]): List[To] =
generateBaseRuntimeFiles[To](cntxt, config) ++
List(
generateFromResource[To](Some("scalaxb"), "scalaxb.scala", "/scalaxb.scala.template", Some("%%JAXB_PACKAGE%%" -> config.jaxbPackage.packageName)),
(config.httpClientStyle match {
case HttpClientStyle.Sync => generateFromResource[To](Some("scalaxb"), "httpclients.scala", "/httpclients.scala.template")
case HttpClientStyle.Future => generateFromResource[To](Some("scalaxb"), "httpclients_async.scala", "/httpclients_async.scala.template")
case HttpClientStyle.Tagless => generateFromResource[To](Some("scalaxb"), "httpclients_tagless_final.scala", "/httpclients_tagless_final.scala.template")
case HttpClientStyle.Sync => generateFromResource[To](Some("scalaxb"), "httpclients.scala", "/httpclients.scala.template", config)
case HttpClientStyle.Future => generateFromResource[To](Some("scalaxb"), "httpclients_async.scala", "/httpclients_async.scala.template", config)
case HttpClientStyle.Tagless => generateFromResource[To](Some("scalaxb"), "httpclients_tagless_final.scala", "/httpclients_tagless_final.scala.template", config)
})) ++
(if (config.generateDispatchAs) List(generateFromResource[To](Some("dispatch.as"), "dispatch_as_scalaxb.scala",
"/dispatch_as_scalaxb.scala.template"))
"/dispatch_as_scalaxb.scala.template", config))
else Nil) ++
(if (config.generateVisitor) List(generateFromResource[To](Some("scalaxb"), "Visitor.scala", "/visitor.scala.template"))
(if (config.generateVisitor) List(generateFromResource[To](Some("scalaxb"), "Visitor.scala", "/visitor.scala.template", config))
else Nil) ++
(if (config.generateDispatchClient) (config.dispatchVersion, config.httpClientStyle) match {
case (VersionPattern(0, minor, _), style) if minor < 10 =>
generateDispatchFromResource(style, "/httpclients_dispatch_classic")
generateDispatchFromResource(style, "/httpclients_dispatch_classic", config)

case (VersionPattern(0, 10 | 11, 0), style) =>
generateDispatchFromResource(style, "/httpclients_dispatch0100")
generateDispatchFromResource(style, "/httpclients_dispatch0100", config)

case (VersionPattern(0, 11, 1 | 2), style) =>
generateDispatchFromResource(style, "/httpclients_dispatch0111")
generateDispatchFromResource(style, "/httpclients_dispatch0111", config)

case (VersionPattern(0, 11, 3 | 4), style) =>
generateDispatchFromResource(style, "/httpclients_dispatch0113")
generateDispatchFromResource(style, "/httpclients_dispatch0113", config)

case (VersionPattern(0, 12, 0 | 1), style) => // 0.12.1 does not have artifact in maven central
// 0.12.[0, 1] is using same template as 0.11.3+
generateDispatchFromResource(style, "/httpclients_dispatch0113")
generateDispatchFromResource(style, "/httpclients_dispatch0113", config)

case (VersionPattern(0, 12, _), style) =>
generateDispatchFromResource(style, "/httpclients_dispatch0122")
generateDispatchFromResource(style, "/httpclients_dispatch0122", config)

case (VersionPattern(0, 13, _), style) =>
generateDispatchFromResource(style, "/httpclients_dispatch0130")
generateDispatchFromResource(style, "/httpclients_dispatch0130", config)

case (VersionPattern(0, 14, _), style) =>
// Same as 0.13.x
generateDispatchFromResource(style, "/httpclients_dispatch0130")
generateDispatchFromResource(style, "/httpclients_dispatch0130", config)
case (VersionPattern(1, 0 | 1, _), style) =>
// Same as 0.13.x
generateDispatchFromResource(style, "/httpclients_dispatch0130")
generateDispatchFromResource(style, "/httpclients_dispatch0130", config)
} else Nil) ++
(if (config.generateGigahorseClient) (config.gigahorseVersion, config.httpClientStyle) match {
case (VersionPattern(x, y, _), HttpClientStyle.Sync) if (x.toInt == 0) && (y.toInt <= 5) =>
generateFromResource[To](Some("scalaxb"), "httpclients_gigahorse.scala",
"/httpclients_gigahorse02.scala.template", Some("%%BACKEND%%" -> config.gigahorseBackend)) :: Nil
"/httpclients_gigahorse02.scala.template", config, "%%BACKEND%%" -> config.gigahorseBackend) :: Nil
case (VersionPattern(x, y, _), HttpClientStyle.Future) if (x.toInt == 0) && (y.toInt <= 5) =>
generateFromResource[To](Some("scalaxb"), "httpclients_gigahorse_async.scala",
"/httpclients_gigahorse02_async.scala.template", Some("%%BACKEND%%" -> config.gigahorseBackend)) :: Nil
case _ => Nil
"/httpclients_gigahorse02_async.scala.template", config, "%%BACKEND%%" -> config.gigahorseBackend) :: Nil
case _ => Nil
} else Nil) ++
(if (config.generateHttp4sClient && config.httpClientStyle == HttpClientStyle.Tagless) config.http4sVersion match {
case VersionPattern(0,21, _) => List(generateFromResource[To](Some("scalaxb"), "httpclients_http4s.scala", "/httpclients_http4s_0_21.scala.template"))
case VersionPattern(0,22, _) => List(generateFromResource[To](Some("scalaxb"), "httpclients_http4s.scala", "/httpclients_http4s_0_22.scala.template"))
case VersionPattern(0,23, _) => List(generateFromResource[To](Some("scalaxb"), "httpclients_http4s.scala", "/httpclients_http4s_0_23.scala.template"))
case VersionPattern(0,21, _) => List(generateFromResource[To](Some("scalaxb"), "httpclients_http4s.scala", "/httpclients_http4s_0_21.scala.template", config))
case VersionPattern(0,22, _) => List(generateFromResource[To](Some("scalaxb"), "httpclients_http4s.scala", "/httpclients_http4s_0_22.scala.template", config))
case VersionPattern(0,23, _) => List(generateFromResource[To](Some("scalaxb"), "httpclients_http4s.scala", "/httpclients_http4s_0_23.scala.template", config))
case _ => sys.error(s"Unsupported http4s version ${config.http4sVersion}"); Nil
}
else Nil) ++
(if (cntxt.soap11) List(
(config.httpClientStyle match {
case HttpClientStyle.Sync => generateFromResource[To](Some("scalaxb"), "soap11.scala", "/soap11.scala.template")
case HttpClientStyle.Future => generateFromResource[To](Some("scalaxb"), "soap11_async.scala", "/soap11_async.scala.template")
case HttpClientStyle.Tagless => generateFromResource[To](Some("scalaxb"), "soap11_tagless.scala", "/soap11_tagless.scala.template")
case HttpClientStyle.Sync => generateFromResource[To](Some("scalaxb"), "soap11.scala", "/soap11.scala.template", config)
case HttpClientStyle.Future => generateFromResource[To](Some("scalaxb"), "soap11_async.scala", "/soap11_async.scala.template", config)
case HttpClientStyle.Tagless => generateFromResource[To](Some("scalaxb"), "soap11_tagless.scala", "/soap11_tagless.scala.template", config)
}),
generateFromResource[To](Some("soapenvelope11"), "soapenvelope11.scala",
"/soapenvelope11.scala.template"),
"/soapenvelope11.scala.template", config),
generateFromResource[To](Some("soapenvelope11"), "soapenvelope11_xmlprotocol.scala",
"/soapenvelope11_xmlprotocol.scala.template"))
"/soapenvelope11_xmlprotocol.scala.template", config))
else Nil) ++
(if (cntxt.soap12) List(
(config.httpClientStyle match {
case HttpClientStyle.Sync => generateFromResource[To](Some("scalaxb"), "soap12.scala", "/soap12.scala.template")
case HttpClientStyle.Future => generateFromResource[To](Some("scalaxb"), "soap12_async.scala", "/soap12_async.scala.template")
case HttpClientStyle.Tagless => generateFromResource[To](Some("scalaxb"), "soap12_tagless.scala", "/soap12_tagless.scala.template")
case HttpClientStyle.Sync => generateFromResource[To](Some("scalaxb"), "soap12.scala", "/soap12.scala.template", config)
case HttpClientStyle.Future => generateFromResource[To](Some("scalaxb"), "soap12_async.scala", "/soap12_async.scala.template", config)
case HttpClientStyle.Tagless => generateFromResource[To](Some("scalaxb"), "soap12_tagless.scala", "/soap12_tagless.scala.template", config)
}),
generateFromResource[To](Some("soapenvelope12"), "soapenvelope12.scala", "/soapenvelope12.scala.template"),
generateFromResource[To](Some("soapenvelope12"), "soapenvelope12_xmlprotocol.scala", "/soapenvelope12_xmlprotocol.scala.template"))
generateFromResource[To](Some("soapenvelope12"), "soapenvelope12.scala", "/soapenvelope12.scala.template", config),
generateFromResource[To](Some("soapenvelope12"), "soapenvelope12_xmlprotocol.scala", "/soapenvelope12_xmlprotocol.scala.template", config))
else Nil)
}

Expand Down
Loading

0 comments on commit 9c71e25

Please sign in to comment.