diff --git a/airframe-http/.jvm/src/test/scala/wvlet/airframe/http/router/HttpRequestMapperTest.scala b/airframe-http/.jvm/src/test/scala/wvlet/airframe/http/router/HttpRequestMapperTest.scala index 32ffa5a74a..06e0893118 100644 --- a/airframe-http/.jvm/src/test/scala/wvlet/airframe/http/router/HttpRequestMapperTest.scala +++ b/airframe-http/.jvm/src/test/scala/wvlet/airframe/http/router/HttpRequestMapperTest.scala @@ -41,7 +41,7 @@ object HttpRequestMapperTest extends AirSpec { def rpc5(p1: Option[String]): Unit = {} def rpc6(p1: Option[NestedRequest]): Unit = {} def rpc7( - request: HttpMessage.Request, + request: Request, context: HttpContext[Request, Response, Future], req: HttpRequest[Request] ): Unit = {} @@ -64,8 +64,10 @@ object HttpRequestMapperTest extends AirSpec { def endpoint4(p1: Option[Seq[String]]): Unit = {} } - private val api = new MyApi {} - private val router = Router.add[MyApi].add[MyApi2] + private val api = new MyApi {} + private val router = Router + .add[MyApi] + .add[MyApi2] private val mockContext = HttpContext.mockContext private def mapArgs( diff --git a/airframe-surface/src/main/scala-3/wvlet/airframe/surface/CompileTimeSurfaceFactory.scala b/airframe-surface/src/main/scala-3/wvlet/airframe/surface/CompileTimeSurfaceFactory.scala index c399040ef3..68e252e6c2 100644 --- a/airframe-surface/src/main/scala-3/wvlet/airframe/surface/CompileTimeSurfaceFactory.scala +++ b/airframe-surface/src/main/scala-3/wvlet/airframe/surface/CompileTimeSurfaceFactory.scala @@ -1,5 +1,7 @@ package wvlet.airframe.surface -import scala.quoted._ +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.immutable.ListMap +import scala.quoted.* private[surface] object CompileTimeSurfaceFactory { @@ -76,27 +78,36 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { surfaceOf(TypeRepr.of(using tpe)) } - private val seen = scala.collection.mutable.Set[TypeRepr]() - private val memo = scala.collection.mutable.Map[TypeRepr, Expr[Surface]]() - private val lazySurface = scala.collection.mutable.Set[TypeRepr]() + private var observedSurfaceCount = new AtomicInteger(0) + private var seen = ListMap[TypeRepr, Int]() + private val memo = scala.collection.mutable.Map[TypeRepr, Expr[Surface]]() + private val lazySurface = scala.collection.mutable.Set[TypeRepr]() - private def surfaceOf(t: TypeRepr): Expr[Surface] = { - if (surfaceToVar.contains(t)) { - // println(s"==== ${t} is already cached") - Ref(surfaceToVar(t)).asExprOf[Surface] + private def surfaceOf(t: TypeRepr, useVarRef: Boolean = true): Expr[Surface] = { + def buildLazySurface: Expr[Surface] = { + '{ LazySurface(${ clsOf(t) }, ${ Expr(fullTypeNameOf(t)) }) } + } + + if (useVarRef && surfaceToVar.contains(t)) { + if (lazySurface.contains(t)) { + buildLazySurface + } else { + Ref(surfaceToVar(t)).asExprOf[Surface] + } } else if (seen.contains(t)) { if (memo.contains(t)) { memo(t) } else { lazySurface += t - '{ LazySurface(${ clsOf(t) }, ${ Expr(fullTypeNameOf(t)) }) } + buildLazySurface } } else { - seen += t + seen += t -> observedSurfaceCount.getAndIncrement() // For debugging // println(s"[${typeNameOf(t)}]\n ${t}\nfull type name: ${fullTypeNameOf(t)}\nclass: ${t.getClass}") val generator = factory.andThen { expr => if (!lazySurface.contains(t)) { + // Generate the surface code without using the cache expr } else { // Need to cache the recursive Surface to be referenced in a LazySurface @@ -115,16 +126,12 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { } else { fullTypeNameOf(t) } + val key = Literal(StringConstant(cacheKey)).asExprOf[String] '{ - val key = ${ - Expr(cacheKey) + if (!wvlet.airframe.surface.surfaceCache.contains(${ key })) { + wvlet.airframe.surface.surfaceCache += ${ key } -> ${ expr } } - if (!wvlet.airframe.surface.surfaceCache.contains(key)) { - wvlet.airframe.surface.surfaceCache += key -> ${ - expr - } - } - wvlet.airframe.surface.surfaceCache(key) + wvlet.airframe.surface.surfaceCache.apply(${ key }) } } } @@ -386,9 +393,8 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { // args(i+1) val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(index))) index += 1 - // args(i+1).asInstanceOf[A] - // TODO: Cast primitive values to target types - Select.unique(extracted, "asInstanceOf").appliedToType(a.tpe) + // classOf[A].cast(args(i+1)) + clsCast(extracted, a.tpe) } Apply(prev, argExtractors.toList) } @@ -397,9 +403,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { } ) val expr = '{ - new wvlet.airframe.surface.ObjectFactory { - override def newInstance(args: Seq[Any]): Any = { ${ newClassFn.asExprOf[Seq[Any] => Any] }(args) } - } + ObjectFactory.newFactory(${ newClassFn.asExprOf[Seq[Any] => Any] }) } expr } @@ -442,7 +446,9 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { ), rhsFn = (sym: Symbol, paramRefs: List[Tree]) => { val strVarRef = paramRefs(1).asExprOf[String].asTerm - Select.unique(Apply(m, List(strVarRef)), "asInstanceOf").appliedToType(TypeRepr.of[Option[Any]]) + val expr = Select.unique(Apply(m, List(strVarRef)), "asInstanceOf").appliedToType(TypeRepr.of[Option[Any]]) + expr.changeOwner(sym) + } ) '{ @@ -539,7 +545,8 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { case other => other } - a.appliedTo(resolvedTypeArgs) + // Need to use the base type of the applied type to replace the type parameters + a.tycon.appliedTo(resolvedTypeArgs) case TypeRef(_, name) if typeArgTable.contains(name) => typeArgTable(name) case other => @@ -582,7 +589,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { clsOf(arg.tpe.dealias) } val isConstructor = t.typeSymbol.primaryConstructor == method - val constructorRef = '{ + val constructorRef: Expr[MethodRef] = '{ MethodRef( owner = ${ clsOf(t) }, name = ${ Expr(methodName) }, @@ -618,20 +625,40 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { // println(s"${paramName} ${paramIsAccessible}") val accessor: Expr[Option[Any => Any]] = if (method.isClassConstructor && paramIsAccessible) { + // MethodParameter.accessor[(owner type), (parameter type]] + val accessorMethod: Symbol = TypeRepr.of[MethodParameter.type].typeSymbol.methodMember("accessor").head + val objRef = Ref(TypeRepr.of[MethodParameter].typeSymbol.companionModule) + + def resolveType(tpe: TypeRepr): TypeRepr = tpe match { + case b: TypeBounds => + TypeRepr.of[Any] + case _ => + tpe + } + + val t1 = resolveType(t) + val t2 = resolveType(paramType) + + val typedAccessor = objRef.select(accessorMethod).appliedToTypes(List(t1, t2)) + val methodCall = typedAccessor.appliedToArgs(List(Literal(ClassOfConstant(t1)))) + val lambda = Lambda( owner = Symbol.spliceOwner, - tpe = MethodType(List("x"))(_ => List(TypeRepr.of[Any]), _ => TypeRepr.of[Any]), + tpe = MethodType(List("x"))(_ => List(t1), _ => t2), rhsFn = (sym, params) => { val x = params.head.asInstanceOf[Term] - val expr = Select.unique(Select.unique(x, "asInstanceOf").appliedToType(t), paramName) + val expr = Select.unique(x, paramName) expr.changeOwner(sym) } ) - // println(t.typeSymbol) - // println(paramType.typeSymbol.flags.show) - // println(lambda.show) - // println(lambda.show(using Printer.TreeStructure)) - '{ Some(${ lambda.asExprOf[Any => Any] }) } + val accMethod = methodCall.appliedToArgs(List(lambda)) + // println(s"=== ${accMethod.show}") + + // Generate code like : + // {{{ + // MethodParameter.accessor[t1, t2](classOf[t1]){(x:t1) => x.(field name) } + // }}} + '{ Some(${ accMethod.asExprOf[Any => Any] }) } } else { '{ None } } @@ -645,7 +672,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { tpe = MethodType(List("x"))(_ => List(TypeRepr.of[Any]), _ => TypeRepr.of[Any]), rhsFn = (sym, params) => { val x = params.head.asInstanceOf[Term] - val expr = Select.unique(x, "asInstanceOf").appliedToType(t).select(m) + val expr = clsCast(x, t).select(m) expr.changeOwner(sym) } ) @@ -683,35 +710,53 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { } // To reduce the byte code size, we need to memoize the generated surface bound to a variable - private val surfaceToVar = scala.collection.mutable.Map[TypeRepr, Symbol]() + private var surfaceToVar = ListMap.empty[TypeRepr, Symbol] private def methodsOf(t: TypeRepr): Expr[Seq[MethodSurface]] = { // Run just for collecting known surfaces. seen variable will be updated methodsOfInternal(t) - var count = 0 - // Bind the observed surfaces to local variables __s0, __s1, ... - seen.foreach { s => - // Update the cache so that the next call of surfaceOf method will use the local varaible reference - surfaceToVar += s -> Symbol.newVal( - Symbol.spliceOwner, - s"__s${count}", - TypeRepr.of[Surface], - Flags.EmptyFlags, - Symbol.noSymbol - ) - count += 1 - } - val surfaceDefs: List[ValDef] = surfaceToVar.map { x => - val sym = x._2 - ValDef(sym, Some(memo(x._1).asTerm)) - }.toList + // Create a var def table for replacing surfaceOf[xxx] to __s0, __s1, ... + var surfaceVarCount = 0 + seen + // Exclude primitive type surface + .toSeq + // Exclude primitive surfaces as it is already defined in Primitive object + .filterNot(x => primitiveTypeFactory.isDefinedAt(x._1)) + .sortBy(_._2) + .reverse + .map { case (tpe, order) => + // Update the cache so that the next call of surfaceOf method will use the local varaible reference + surfaceToVar += tpe -> Symbol.newVal( + Symbol.spliceOwner, + // Use alphabetically ordered variable names + f"__s${surfaceVarCount}%03X", + TypeRepr.of[Surface], + if (lazySurface.contains(tpe)) { + // If the surface itself is lazy, we need to eagerly initialize it to update the surface cache + Flags.EmptyFlags + } else { + // Use lazy val to avoid forward reference error + Flags.Lazy + }, + Symbol.noSymbol + ) + surfaceVarCount += 1 + } - // Clear method observation cache + // Clear surface cache + memo.clear() + seen = ListMap.empty seenMethodParent.clear() + val surfaceDefs: List[ValDef] = surfaceToVar.toSeq.map { case (tpe, sym) => + ValDef(sym, Some(surfaceOf(tpe, useVarRef = false).asTerm)) + }.toList + /** - * Generate a code like this: {{ val __s0 = Surface.of[A] val __s1 = Surface.of[B] ... + * Generate a code like this: + * + * {{ lazy val __s000 = Surface.of[A]; lazy val __s001 = Surface.of[B] ... * * ClassMethodSurface( .... ) }} */ @@ -758,6 +803,10 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { } } + private def clsCast(term: Term, t: TypeRepr): Term = { + Select.unique(term, "asInstanceOf").appliedToType(t) + } + private def createMethodCaller( objectType: TypeRepr, m: Symbol, @@ -779,13 +828,12 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) { rhsFn = (sym, params) => { val x = params(0).asInstanceOf[Term] val args = params(1).asInstanceOf[Term] - val expr = Select.unique(x, "asInstanceOf").appliedToType(objectType).select(m) + val expr = clsCast(x, objectType).select(m) val argList = methodArgs.zipWithIndex.collect { // If the arg is implicit, no need to explicitly bind it case (arg, i) if !arg.isImplicit => - // args(i).asInstanceOf[ArgType] val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(i))) - Select.unique(extracted, "asInstanceOf").appliedToType(arg.tpe) + clsCast(extracted, arg.tpe) } if (argList.isEmpty) { val newExpr = m.tree match { diff --git a/airframe-surface/src/main/scala/wvlet/airframe/surface/Surfaces.scala b/airframe-surface/src/main/scala/wvlet/airframe/surface/Surfaces.scala index 10f4bff15b..7aa00594a9 100644 --- a/airframe-surface/src/main/scala/wvlet/airframe/surface/Surfaces.scala +++ b/airframe-surface/src/main/scala/wvlet/airframe/surface/Surfaces.scala @@ -62,6 +62,18 @@ trait ObjectFactory extends Serializable { def newInstance(args: Seq[Any]): Any } +object ObjectFactory { + + /** + * Used internally for creating a new ObjectFactory instance from a given generic function + * @param f + * @return + */ + def newFactory(f: Seq[Any] => Any): ObjectFactory = new ObjectFactory { + override def newInstance(args: Seq[Any]): Any = f(args) + } +} + case class MethodRef(owner: Class[_], name: String, paramTypes: Seq[Class[_]], isConstructor: Boolean) trait MethodParameter extends Parameter { @@ -75,6 +87,12 @@ trait MethodParameter extends Parameter { def getMethodArgDefaultValue(methodOwner: Any): Option[Any] = getDefaultValue } +object MethodParameter { + def accessor[A, B](cl: Class[A])(body: A => B): Any => B = { (x: Any) => + body(cl.cast(x)) + } +} + trait MethodSurface extends ParameterBase { def mod: Int def owner: Surface diff --git a/airframe-surface/src/test/scala/wvlet/airframe/surface/RecursiveMethodParamTest.scala b/airframe-surface/src/test/scala/wvlet/airframe/surface/RecursiveMethodParamTest.scala new file mode 100644 index 0000000000..e3b3bd225e --- /dev/null +++ b/airframe-surface/src/test/scala/wvlet/airframe/surface/RecursiveMethodParamTest.scala @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.surface + +object RecursiveMethodParamTest { + case class Node(parent: Option[Node]) + + trait MyRecursiveApi { + def find(node: Node): Unit = {} + } +} + +class RecursiveMethodParamTest extends munit.FunSuite { + import RecursiveMethodParamTest._ + + // .... + test("Compile method surfaces with recursive method param") { + Surface.methodsOf[MyRecursiveApi] + } +} diff --git a/build.sbt b/build.sbt index 2d7120415e..4732137edb 100644 --- a/build.sbt +++ b/build.sbt @@ -109,6 +109,8 @@ val buildSettings = Seq[Setting[_]]( scalacOptions ++= Seq( "-feature", "-deprecation" + // Use this for debugging Macros + // "-Xcheck-macros" ) ++ { if (scalaVersion.value.startsWith("3.")) { Seq.empty