Skip to content

Commit

Permalink
Merge pull request #2 from SaadAissa/decodeFields
Browse files Browse the repository at this point in the history
Decode fields
  • Loading branch information
SaadAissa authored Jun 25, 2024
2 parents 01d744e + ec47589 commit df373c1
Show file tree
Hide file tree
Showing 17 changed files with 736 additions and 43 deletions.
104 changes: 86 additions & 18 deletions src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import tastyquery.jdk.ClasspathLoaders

import java.nio.file.Path
import scala.util.matching.Regex
import tastyquery.Exceptions.NonMethodReferenceException

object BinaryDecoder:
def apply(classEntries: Seq[Path])(using ThrowOrWarn): BinaryDecoder =
Expand Down Expand Up @@ -107,6 +108,75 @@ class BinaryDecoder(using Context, ThrowOrWarn):
candidates.singleOrThrow(method)
end decode

def decode(field: binary.Field): DecodedField =
val decodedClass = decode(field.declaringClass)
decode(decodedClass, field)

def decode(decodedClass: DecodedClass, field: binary.Field): DecodedField =
def tryDecode(f: PartialFunction[binary.Field, Seq[DecodedField]]): Seq[DecodedField] =
f.applyOrElse(field, _ => Seq.empty[DecodedField])

extension (xs: Seq[DecodedField])
def orTryDecode(f: PartialFunction[binary.Field, Seq[DecodedField]]): Seq[DecodedField] =
if xs.nonEmpty then xs else f.applyOrElse(field, _ => Seq.empty[DecodedField])
val decodedFields =
tryDecode {
case Patterns.LazyVal(name) =>
for
owner <- decodedClass.classSymbol.toSeq ++ decodedClass.linearization.filter(_.isTrait)
sym <- owner.declarations.collect {
case sym: TermSymbol if sym.nameStr == name && sym.isModuleOrLazyVal => sym
}
yield DecodedField.ValDef(decodedClass, sym)
case Patterns.Module() =>
decodedClass.classSymbol.flatMap(_.moduleValue).map(DecodedField.ModuleVal(decodedClass, _)).toSeq
case Patterns.Offset(nbr) =>
Seq(DecodedField.LazyValOffset(decodedClass, nbr, defn.LongType))
case Patterns.OuterField() =>
decodedClass.symbolOpt
.flatMap(_.outerClass)
.map(outerClass => DecodedField.Outer(decodedClass, outerClass.selfType))
.toSeq
case Patterns.SerialVersionUID() =>
Seq(DecodedField.SerialVersionUID(decodedClass, defn.LongType))
case Patterns.LazyValBitmap(name) =>
Seq(DecodedField.LazyValBitmap(decodedClass, defn.BooleanType, name))
case Patterns.AnyValCapture() =>
for
classSym <- decodedClass.symbolOpt.toSeq
outerClass <- classSym.outerClass.toSeq
if outerClass.isSubClass(defn.AnyValClass)
sym <- outerClass.declarations.collect {
case sym: TermSymbol if sym.isVal && !sym.isMethod => sym
}
yield DecodedField.Capture(decodedClass, sym)
case Patterns.Capture(names) =>
decodedClass.symbolOpt.toSeq
.flatMap(CaptureCollector.collectCaptures)
.filter { captureSym =>
names.exists {
case Patterns.LazyVal(name) => name == captureSym.nameStr
case name => name == captureSym.nameStr
}
}
.map(DecodedField.Capture(decodedClass, _))

case _ if field.isStatic && decodedClass.isJava =>
for
owner <- decodedClass.companionClassSymbol.toSeq
sym <- owner.declarations.collect { case sym: TermSymbol if sym.nameStr == field.name => sym }
yield DecodedField.ValDef(decodedClass, sym)
}.orTryDecode { case _ =>
for
owner <- withCompanionIfExtendsJavaLangEnum(decodedClass) ++ decodedClass.linearization.filter(_.isTrait)
sym <- owner.declarations.collect {
case sym: TermSymbol if matchTargetName(field, sym) && !sym.isMethod => sym
}
yield DecodedField.ValDef(decodedClass, sym)
}
decodedFields.singleOrThrow(field)
end decode

private def reduceAmbiguityOnClasses(syms: Seq[DecodedClass]): Seq[DecodedClass] =
if syms.size > 1 then
val reduced = syms.filterNot(sym => syms.exists(enclose(sym, _)))
Expand Down Expand Up @@ -476,13 +546,8 @@ class BinaryDecoder(using Context, ThrowOrWarn):
.map(target => DecodedMethod.TraitStaticForwarder(decode(decodedClass, target)))

private def decodeOuter(decodedClass: DecodedClass): Option[DecodedMethod.OuterAccessor] =
def outerClass(sym: Symbol): Option[ClassSymbol] =
sym.owner match
case null => None
case owner if owner.isClass => Some(owner.asClass)
case owner => outerClass(owner)
decodedClass.symbolOpt
.flatMap(outerClass)
.flatMap(_.outerClass)
.map(outerClass => DecodedMethod.OuterAccessor(decodedClass, outerClass.thisType))

private def decodeTraitInitializer(
Expand Down Expand Up @@ -616,11 +681,18 @@ class BinaryDecoder(using Context, ThrowOrWarn):
DecodedMethod.MixinForwarder(decodedClass, staticForwarder.target)
}

private def withCompanionIfExtendsAnyVal(cls: ClassSymbol): Seq[ClassSymbol] =
cls.companionClass match
case Some(companionClass) if companionClass.isSubClass(defn.AnyValClass) =>
Seq(cls, companionClass)
case _ => Seq(cls)
private def withCompanionIfExtendsAnyVal(decodedClass: DecodedClass): Seq[Symbol] = decodedClass match
case classDef: DecodedClass.ClassDef =>
Seq(classDef.symbol) ++ classDef.symbol.companionClass.filter(_.isSubClass(defn.AnyValClass))
case _: DecodedClass.SyntheticCompanionClass => Seq.empty
case anonFun: DecodedClass.SAMOrPartialFunction => Seq(anonFun.symbol)
case inlined: DecodedClass.InlinedClass => withCompanionIfExtendsAnyVal(inlined.underlying)

private def withCompanionIfExtendsJavaLangEnum(decodedClass: DecodedClass): Seq[ClassSymbol] =
decodedClass.classSymbol.toSeq.flatMap { cls =>
if cls.isSubClass(defn.javaLangEnumClass) then Seq(cls) ++ cls.companionClass
else Seq(cls)
}

private def decodeAdaptedAnonFun(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] =
if method.instructions.nonEmpty then
Expand Down Expand Up @@ -786,13 +858,6 @@ class BinaryDecoder(using Context, ThrowOrWarn):
private def collectLiftedTrees[S](decodedClass: DecodedClass, method: binary.Method)(
matcher: PartialFunction[LiftedTree[?], LiftedTree[S]]
): Seq[LiftedTree[S]] =
def withCompanionIfExtendsAnyVal(decodedClass: DecodedClass): Seq[Symbol] = decodedClass match
case classDef: DecodedClass.ClassDef =>
Seq(classDef.symbol) ++ classDef.symbol.companionClass.filter(_.isSubClass(defn.AnyValClass))
case _: DecodedClass.SyntheticCompanionClass => Seq.empty
case anonFun: DecodedClass.SAMOrPartialFunction => Seq(anonFun.symbol)
case inlined: DecodedClass.InlinedClass => withCompanionIfExtendsAnyVal(inlined.underlying)

val owners = withCompanionIfExtendsAnyVal(decodedClass)
val sourceLines =
if owners.size == 2 && method.allParameters.exists(p => p.name.matches("\\$this\\$\\d+")) then
Expand Down Expand Up @@ -823,6 +888,9 @@ class BinaryDecoder(using Context, ThrowOrWarn):
private def matchTargetName(method: binary.Method, symbol: TermSymbol): Boolean =
method.unexpandedDecodedNames.map(_.stripSuffix("$")).contains(symbol.targetNameStr)

private def matchTargetName(field: binary.Field, symbol: TermSymbol): Boolean =
field.unexpandedDecodedNames.map(_.stripSuffix("$")).contains(symbol.targetNameStr)

private case class SourceParams(
declaredParamNames: Seq[UnsignedTermName],
declaredParamTypes: Seq[Type],
Expand Down
32 changes: 32 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,35 @@ object DecodedMethod:
override def toString: String =
if underlying.isInstanceOf[InlinedMethod] then underlying.toString
else s"$underlying (inlined)"

sealed trait DecodedField extends DecodedSymbol:
def owner: DecodedClass
override def symbolOpt: Option[TermSymbol] = None
def declaredType: TypeOrMethodic

object DecodedField:
final class ValDef(val owner: DecodedClass, val symbol: TermSymbol) extends DecodedField:
def declaredType: TypeOrMethodic = symbol.declaredType
override def symbolOpt: Option[TermSymbol] = Some(symbol)
override def toString: String = s"ValDef($owner, ${symbol.showBasic})"

final class ModuleVal(val owner: DecodedClass, val symbol: TermSymbol) extends DecodedField:
def declaredType: TypeOrMethodic = symbol.declaredType
override def symbolOpt: Option[TermSymbol] = Some(symbol)
override def toString: String = s"ModuleVal($owner, ${symbol.showBasic})"

final class LazyValOffset(val owner: DecodedClass, val index: Int, val declaredType: Type) extends DecodedField:
override def toString: String = s"LazyValOffset($owner, $index)"

final class Outer(val owner: DecodedClass, val declaredType: Type) extends DecodedField:
override def toString: String = s"Outer($owner, ${declaredType.showBasic})"

final class SerialVersionUID(val owner: DecodedClass, val declaredType: Type) extends DecodedField:
override def toString: String = s"SerialVersionUID($owner)"

final class Capture(val owner: DecodedClass, val symbol: TermSymbol) extends DecodedField:
def declaredType: TypeOrMethodic = symbol.declaredType
override def toString: String = s"Capture($owner, ${symbol.showBasic})"

final class LazyValBitmap(val owner: DecodedClass, val declaredType: Type, val name: String) extends DecodedField:
override def toString: String = s"LazyValBitmap($owner, , ${declaredType.showBasic})"
19 changes: 19 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import tastyquery.Types.*
import scala.annotation.tailrec

class StackTraceFormatter(using ThrowOrWarn):
def format(field: DecodedField): String =
val typeAscription = field.declaredType match
case tpe: Type => ": " + format(tpe)
case tpe => format(tpe)
formatOwner(field).dot(formatName(field)) + typeAscription

def format(cls: DecodedClass): String =
cls match
case cls: DecodedClass.ClassDef => formatQualifiedName(cls.symbol)
Expand Down Expand Up @@ -60,6 +66,19 @@ class StackTraceFormatter(using ThrowOrWarn):
case method: DecodedMethod.SAMOrPartialFunctionConstructor => format(method.owner)
case method: DecodedMethod.InlinedMethod => formatOwner(method.underlying)

private def formatOwner(field: DecodedField): String =
format(field.owner)

private def formatName(field: DecodedField): String =
field match
case field: DecodedField.ValDef => formatName(field.symbol)
case field: DecodedField.ModuleVal => ""
case field: DecodedField.LazyValOffset => "<offset " + field.index + ">"
case field: DecodedField.Outer => "<outer>"
case field: DecodedField.SerialVersionUID => "<serialVersionUID>"
case field: DecodedField.Capture => formatName(field.symbol).dot("<capture>")
case field: DecodedField.LazyValBitmap => field.name.dot("<lazy val bitmap>")

private def formatName(method: DecodedMethod): String =
method match
case method: DecodedMethod.ValOrDefDef => formatName(method.symbol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ trait ClassType extends Type:
def declaredField(name: String): Option[Field]
def declaredMethod(name: String, descriptor: String): Option[Method]
def declaredMethods: Seq[Method]
def declaredFields: Seq[Field]
def classLoader: BinaryClassLoader

def isObject = name.endsWith("$")
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/ch/epfl/scala/decoder/binary/Field.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ package ch.epfl.scala.decoder.binary
trait Field extends Symbol:
def declaringClass: ClassType
def `type`: Type
def isStatic: Boolean
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package ch.epfl.scala.decoder.internal

import tastyquery.Trees.*
import scala.collection.mutable
import tastyquery.Symbols.*
import tastyquery.Traversers.*
import tastyquery.Contexts.*
import tastyquery.SourcePosition
import tastyquery.Types.*
import tastyquery.Traversers
import ch.epfl.scala.decoder.ThrowOrWarn
import scala.languageFeature.postfixOps

object CaptureCollector:
def collectCaptures(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn): Set[TermSymbol] =
val collector = CaptureCollector(cls)
collector.traverse(cls.tree)
collector.capture.toSet

class CaptureCollector(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn) extends TreeTraverser:
val capture: mutable.Set[TermSymbol] = mutable.Set.empty
val alreadySeen: mutable.Set[Symbol] = mutable.Set.empty

def loopCollect(symbol: Symbol)(collect: => Unit): Unit =
if !alreadySeen.contains(symbol) then
alreadySeen += symbol
collect
override def traverse(tree: Tree): Unit =
tree match
case _: TypeTree => ()
case ident: Ident =>
for sym <- ident.safeSymbol.collect { case sym: TermSymbol => sym } do
// check that sym is local
// and check that no owners of sym is cls
if !alreadySeen.contains(sym) then
if sym.isLocal then
if !ownersIsCls(sym) then capture += sym
if sym.isMethod || sym.isLazyVal then loopCollect(sym)(sym.tree.foreach(traverse))
else if sym.isModuleVal then loopCollect(sym)(sym.moduleClass.flatMap(_.tree).foreach(traverse))
case _ => super.traverse(tree)

def ownersIsCls(sym: Symbol): Boolean =
sym.owner match
case owner: Symbol =>
if owner == cls then true
else ownersIsCls(owner)
case null => false
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Definitions(using ctx: Context):
val PartialFunctionClass = scalaPackage.getDecl(typeName("PartialFunction")).get.asClass
val AbstractPartialFunctionClass = scalaRuntimePackage.getDecl(typeName("AbstractPartialFunction")).get.asClass
val SerializableClass = javaIoPackage.getDecl(typeName("Serializable")).get.asClass
val javaLangEnumClass = javaLangPackage.getDecl(typeName("Enum")).get.asClass

val SerializedLambdaType: Type = TypeRef(javaLangInvokePackage.packageRef, typeName("SerializedLambda"))
val DeserializeLambdaType = MethodType(List(SimpleName("arg0")), List(SerializedLambdaType), ObjectType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import tastyquery.Contexts.*
import tastyquery.SourcePosition
import tastyquery.Types.*
import tastyquery.Traversers
import tastyquery.Exceptions.NonMethodReferenceException
import ch.epfl.scala.decoder.ThrowOrWarn

/**
Expand Down
39 changes: 39 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,45 @@ object Patterns:
"(.+)\\$i\\d+".r.unapplySeq(xs(0)).map(_(0)).getOrElse(xs(0))
}

object LazyVal:
def unapply(field: binary.Field): Option[String] = unapply(field.decodedName)

def unapply(name: String): Option[String] =
"""(.*)\$lzy\d+""".r.unapplySeq(name).map(xs => xs(0).stripSuffix("$"))

object Module:
def unapply(field: binary.Field): Boolean = field.name == "MODULE$"

object Offset:
def unapply(field: binary.Field): Option[Int] =
"""OFFSET\$(?:_m_)?(\d+)""".r.unapplySeq(field.name).map(xs => xs(0).toInt)

object OuterField:
def unapply(field: binary.Field): Boolean = field.name == "$outer"

object SerialVersionUID:
def unapply(field: binary.Field): Boolean = field.name == "serialVersionUID"

object AnyValCapture:
def unapply(field: binary.Field): Boolean =
field.name.matches("\\$this\\$\\d+")

object Capture:
def unapply(field: binary.Field): Option[Seq[String]] =
field.extractFromDecodedNames("(.+)\\$\\d+".r)(xs => xs(0))

object LazyValBitmap:
def unapply(field: binary.Field): Option[String] =
"(.+)bitmap\\$\\d+".r.unapplySeq(field.decodedName).map(xs => xs(0))

extension (field: binary.Field)
private def extractFromDecodedNames[T](regex: Regex)(extract: List[String] => T): Option[Seq[T]] =
val extracted = field.unexpandedDecodedNames
.flatMap(regex.unapplySeq)
.map(extract)
.distinct
if extracted.nonEmpty then Some(extracted) else None

extension (method: binary.Method)
private def extractFromDecodedNames[T](regex: Regex)(extract: List[String] => T): Option[Seq[T]] =
val extracted = method.unexpandedDecodedNames
Expand Down
17 changes: 17 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ extension (symbol: Symbol)
def isInline = symbol.isTerm && symbol.asTerm.isInline
def nameStr: String = symbol.name.toString

def outerClass: Option[ClassSymbol] =
symbol.owner match
case null => None
case owner: ClassSymbol => Some(owner)
case owner => owner.outerClass

def showBasic =
val span = symbol.tree.map(_.pos) match
case Some(pos) if pos.isFullyDefined =>
Expand Down Expand Up @@ -307,3 +313,14 @@ extension (method: DecodedMethod)
case _: DecodedMethod.SAMOrPartialFunctionConstructor => true
case method: DecodedMethod.InlinedMethod => method.underlying.isGenerated
case _ => false

extension (field: DecodedField)
def isGenerated: Boolean =
field match
case field: DecodedField.ValDef => false
case field: DecodedField.ModuleVal => true
case field: DecodedField.LazyValOffset => true
case field: DecodedField.Outer => true
case field: DecodedField.SerialVersionUID => true
case field: DecodedField.Capture => true
case field: DecodedField.LazyValBitmap => true
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,6 @@ class JavaReflectClass(cls: Class[?], extraInfo: ExtraClassInfo, override val cl
val methodInfo = extraInfo.getMethodInfo(sig)
JavaReflectConstructor(c, sig, methodInfo, classLoader)
}

override def declaredFields: Seq[binary.Field] =
cls.getDeclaredFields().map(f => JavaReflectField(f, classLoader))
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ch.epfl.scala.decoder.javareflect
import ch.epfl.scala.decoder.binary

import java.lang.reflect.Field
import java.lang.reflect.Modifier

class JavaReflectField(field: Field, loader: JavaReflectLoader) extends binary.Field:
override def name: String = field.getName
Expand All @@ -11,5 +12,9 @@ class JavaReflectField(field: Field, loader: JavaReflectLoader) extends binary.F

override def declaringClass: binary.ClassType = loader.loadClass(field.getDeclaringClass)

override def isStatic: Boolean = Modifier.isStatic(field.getModifiers)

override def `type`: binary.Type =
loader.loadClass(field.getType)

override def toString: String = field.toString
Loading

0 comments on commit df373c1

Please sign in to comment.