Skip to content

Commit

Permalink
ProjectFiles: detect if source is test in Layout
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Dec 27, 2023
1 parent 4ed3b0f commit a24f82d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,36 +89,47 @@ object ProjectFiles {
matchesPath(file.path)
}

case class FileInfo(lang: String, isTest: Boolean)

sealed abstract class Layout {
def getLang(path: AbsoluteFile): Option[String]
def getInfo(path: AbsoluteFile): Option[FileInfo]
protected[config] def getDialectByLang(lang: String)(implicit
dialect: Dialect
): Option[NamedDialect]
final def getLang(path: AbsoluteFile): Option[String] =
getInfo(path).map(_.lang)
final def withLang(lang: String, style: ScalafmtConfig): ScalafmtConfig =
style.withDialect(getDialectByLang(lang)(style.dialect))
}

object Layout {

case object StandardConvention extends Layout {
private val phaseLabels = Seq("main", "test", "it")
private val mainLabels = Seq("main")
private val testLabels = Seq("test", "it")

override def getLang(af: AbsoluteFile): Option[String] = {
override def getInfo(af: AbsoluteFile): Option[FileInfo] = {
val parent = af.path.getParent
val depth = parent.getNameCount()
val dirs = new Array[String](depth)
for (i <- 0 until depth) dirs(i) = parent.getName(i).toString
getLang(dirs, depth)
getInfo(dirs, depth)
}

@tailrec
private def getLang(dirs: Array[String], len: Int): Option[String] = {
private def getInfo(
dirs: Array[String],
len: Int
): Option[FileInfo] = {
// src/phase/lang
val srcIdx = dirs.lastIndexOf("src", len - 3)
if (srcIdx < 0) None
else {
val langIdx = srcIdx + 2
val found = phaseLabels.contains(dirs(srcIdx + 1))
if (found) Some(dirs(langIdx)) else getLang(dirs, srcIdx)
val phase = dirs(srcIdx + 1)
def lang = dirs(srcIdx + 2)
if (mainLabels.contains(phase)) Some(FileInfo(lang, false))
else if (testLabels.contains(phase)) Some(FileInfo(lang, true))
else getInfo(dirs, srcIdx)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ case class ScalafmtConfig(
val absfile = AbsoluteFile(filename)
def onLang[A](f: (ProjectFiles.Layout, String) => A): Option[A] =
project.layout.flatMap { layout =>
layout.getLang(absfile).map { lang => f(layout, lang) }
layout.getInfo(absfile).map { x => f(layout, x.lang) }
}
expandedFileOverride.map { case (langStyles, pmStyles) =>
def langStyle = onLang { (layout, lang) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,28 @@ class StandardProjectLayoutTest extends munit.FunSuite {
}
}

Seq(
"/prj/src/main/scalaX" -> None,
"/prj/src/test/scalaX" -> None,
"/prj/src/none/scalaX" -> None,
"/prj/src/main/scalaX/foo" -> Some(false),
"/prj/src/test/scalaX/foo" -> Some(true),
"/prj/src/none/scalaX/foo" -> None,
"/prj/src/main/scalaX/src/test/scalaY" -> Some(false),
"/prj/src/test/scalaX/src/main/scalaY" -> Some(true),
"/prj/src/none/scalaX/src/main/scalaY" -> None,
"/prj/src/none/scalaX/src/test/scalaY" -> None,
"/prj/src/main/scalaX/src/test/scalaY/foo" -> Some(true),
"/prj/src/test/scalaX/src/main/scalaY/foo" -> Some(false),
"/prj/src/none/scalaX/src/main/scalaY/foo" -> Some(false),
"/prj/src/none/scalaX/src/test/scalaY/foo" -> Some(true)
).foreach { case (path, expectedTest) =>
test(s"StandardConvention.isTest($path) == $expectedTest") {
val actualTest = getInfo(AbsoluteFile(path)).map(_.isTest)
assertEquals(actualTest, expectedTest)
}
}

Seq(
(s210, "scala-2.10", None),
(s211, "scala-2.10", s210),
Expand Down

0 comments on commit a24f82d

Please sign in to comment.