Skip to content

Commit 37d64b8

Browse files
authored
Various improvements for Scala 3 macro to match Scala 2 implementation (#148)
These issues were found while porting Mill to Scala 3.
1 parent 4100113 commit 37d64b8

File tree

4 files changed

+167
-31
lines changed

4 files changed

+167
-31
lines changed

example/optseq/src/Main.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
package example.optseq
2-
import mainargs.{main, arg, ParserForMethods, ArgReader}
2+
import mainargs.{main, arg, ParserForMethods, TokensReader}
33

44
object Main {
55
@main

mainargs/src-3/Macros.scala

Lines changed: 96 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,34 @@ object Macros {
3232
val companionModuleType = typeSymbolOfB.companionModule.tree.asInstanceOf[ValDef].tpt.tpe.asType
3333
val companionModuleExpr = Ident(companionModule).asExpr
3434
val mainAnnotationInstance = typeSymbolOfB.getAnnotation(mainAnnotation).getOrElse {
35-
report.throwError(
36-
s"cannot find @main annotation on ${companionModule.name}",
37-
typeSymbolOfB.pos.get
35+
'{new mainargs.main()}.asTerm // construct a default if not found.
36+
}
37+
val ctor = typeSymbolOfB.primaryConstructor
38+
val ctorParams = ctor.paramSymss.flatten
39+
// try to match the apply method with the constructor parameters, this is a good heuristic
40+
// for if the apply method is overloaded.
41+
val annotatedMethod = typeSymbolOfB.companionModule.memberMethod("apply").filter(p =>
42+
p.paramSymss.flatten.corresponds(ctorParams) { (p1, p2) =>
43+
p1.name == p2.name
44+
}
45+
).headOption.getOrElse {
46+
report.errorAndAbort(
47+
s"Cannot find apply method in companion object of ${typeReprOfB.show}",
48+
typeSymbolOfB.companionModule.pos.getOrElse(Position.ofMacroExpansion)
3849
)
3950
}
40-
val annotatedMethod = TypeRepr.of[B].typeSymbol.companionModule.memberMethod("apply").head
4151
companionModuleType match
4252
case '[bCompanion] =>
43-
val mainData = createMainData[B, Any](
53+
val mainData = createMainData[B, bCompanion](
4454
annotatedMethod,
4555
mainAnnotationInstance,
4656
// Somehow the `apply` method parameter annotations don't end up on
4757
// the `apply` method parameters, but end up in the `<init>` method
4858
// parameters, so use those for getting the annotations instead
4959
TypeRepr.of[B].typeSymbol.primaryConstructor.paramSymss
5060
)
51-
'{ new ParserForClass[B](${ mainData }, () => ${ Ident(companionModule).asExpr }) }
61+
val erasedMainData = '{$mainData.asInstanceOf[MainData[B, Any]]}
62+
'{ new ParserForClass[B]($erasedMainData, () => ${ Ident(companionModule).asExpr }) }
5263
}
5364

5465
def createMainData[T: Type, B: Type](using Quotes)
@@ -57,41 +68,84 @@ object Macros {
5768
createMainData[T, B](method, mainAnnotation, method.paramSymss)
5869
}
5970

71+
private object VarargTypeRepr {
72+
def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Option[quotes.reflect.TypeRepr] = {
73+
import quotes.reflect.*
74+
tpe match {
75+
case AnnotatedType(AppliedType(_, Seq(arg)), x)
76+
if x.tpe =:= defn.RepeatedAnnot.typeRef => Some(arg)
77+
case _ => None
78+
}
79+
}
80+
}
81+
82+
private object AsType {
83+
def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Some[Type[?]] = {
84+
Some(tpe.asType)
85+
}
86+
}
87+
6088
def createMainData[T: Type, B: Type](using Quotes)
6189
(method: quotes.reflect.Symbol,
6290
mainAnnotation: quotes.reflect.Term,
6391
annotatedParamsLists: List[List[quotes.reflect.Symbol]]): Expr[MainData[T, B]] = {
6492

6593
import quotes.reflect.*
6694
val params = method.paramSymss.headOption.getOrElse(report.throwError("Multiple parameter lists not supported"))
67-
val defaultParams = getDefaultParams(method)
95+
val defaultParams = if (params.exists(_.flags.is(Flags.HasDefault))) getDefaultParams(method) else Map.empty
6896
val argSigsExprs = params.zip(annotatedParamsLists.flatten).map { paramAndAnnotParam =>
6997
val param = paramAndAnnotParam._1
7098
val annotParam = paramAndAnnotParam._2
7199
val paramTree = param.tree.asInstanceOf[ValDef]
72100
val paramTpe = paramTree.tpt.tpe
101+
val readerTpe = paramTpe match {
102+
case VarargTypeRepr(AsType('[t])) => TypeRepr.of[Leftover[t]]
103+
case _ => paramTpe
104+
}
73105
val arg = annotParam.getAnnotation(argAnnotation).map(_.asExprOf[mainargs.arg]).getOrElse('{ new mainargs.arg() })
74-
val paramType = paramTpe.asType
75-
paramType match
106+
readerTpe.asType match {
76107
case '[t] =>
108+
def applyAndCast(f: Expr[Any] => Expr[Any], arg: Expr[B]): Expr[t] = {
109+
f(arg) match {
110+
case '{ $v: `t` } => v
111+
case expr => {
112+
// this case will be activated when the found default parameter is not of type `t`
113+
val recoveredType =
114+
try
115+
expr.asExprOf[t]
116+
catch
117+
case err: Exception =>
118+
report.errorAndAbort(
119+
s"""Failed to convert default value for parameter ${param.name},
120+
|expected type: ${paramTpe.show},
121+
|but default value ${expr.show} is of type: ${expr.asTerm.tpe.widen.show}
122+
|while converting type caught an exception with message: ${err.getMessage}
123+
|There might be a bug in mainargs.""".stripMargin,
124+
param.pos.getOrElse(Position.ofMacroExpansion)
125+
)
126+
recoveredType
127+
}
128+
}
129+
}
77130
val defaultParam: Expr[Option[B => t]] = defaultParams.get(param) match {
78-
case Some('{ $v: `t`}) => '{ Some(((_: B) => $v)) }
131+
case Some(f) => '{ Some((b: B) => ${ applyAndCast(f, 'b) }) }
79132
case None => '{ None }
80133
}
81134
val tokensReader = Expr.summon[mainargs.TokensReader[t]].getOrElse {
82-
report.throwError(
83-
s"No mainargs.ArgReader found for parameter ${param.name}",
84-
param.pos.get
135+
report.errorAndAbort(
136+
s"No mainargs.TokensReader[${Type.show[t]}] found for parameter ${param.name} of method ${method.name} in ${method.owner.fullName}",
137+
method.pos.getOrElse(Position.ofMacroExpansion)
85138
)
86139
}
87140
'{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ tokensReader })) }
141+
}
88142
}
89143
val argSigs = Expr.ofList(argSigsExprs)
90144

91145
val invokeRaw: Expr[(B, Seq[Any]) => T] = {
92146

93147
def callOf(methodOwner: Expr[Any], args: Expr[Seq[Any]]) =
94-
call(methodOwner, method, '{ Seq($args) }).asExprOf[T]
148+
call(methodOwner, method, args).asExprOf[T]
95149

96150
'{ (b: B, params: Seq[Any]) => ${ callOf('b, 'params) } }
97151
}
@@ -120,37 +174,50 @@ object Macros {
120174
private def call(using Quotes)(
121175
methodOwner: Expr[Any],
122176
method: quotes.reflect.Symbol,
123-
argss: Expr[Seq[Seq[Any]]]
177+
args: Expr[Seq[Any]]
124178
): Expr[_] = {
125179
// Copy pasted from Cask.
126180
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L106
127181
import quotes.reflect._
128182
val paramss = method.paramSymss
129183

130184
if (paramss.isEmpty) {
131-
report.throwError("At least one parameter list must be declared.", method.pos.get)
185+
report.errorAndAbort("At least one parameter list must be declared.", method.pos.get)
132186
}
133-
134-
val accesses: List[List[Term]] = for (i <- paramss.indices.toList) yield {
135-
for (j <- paramss(i).indices.toList) yield {
136-
val tpe = paramss(i)(j).tree.asInstanceOf[ValDef].tpt.tpe
137-
tpe.asType match
138-
case '[t] => '{ $argss(${Expr(i)})(${Expr(j)}).asInstanceOf[t] }.asTerm
139-
}
187+
if (paramss.sizeIs > 1) {
188+
report.errorAndAbort("Multiple parameter lists are not supported.", method.pos.get)
140189
}
190+
val params = paramss.head
191+
192+
val methodType = methodOwner.asTerm.tpe.memberType(method)
193+
194+
def accesses(ref: Expr[Seq[Any]]): List[Term] =
195+
for (i <- params.indices.toList) yield {
196+
val param = params(i)
197+
val tpe = methodType.memberType(param)
198+
val untypedRef = '{ $ref(${Expr(i)}) }
199+
tpe match {
200+
case VarargTypeRepr(AsType('[t])) =>
201+
Typed(
202+
'{ $untypedRef.asInstanceOf[Leftover[t]].value }.asTerm,
203+
Inferred(AppliedType(defn.RepeatedParamClass.typeRef, List(TypeRepr.of[t])))
204+
)
205+
case _ => tpe.asType match
206+
case '[t] => '{ $untypedRef.asInstanceOf[t] }.asTerm
207+
}
208+
}
141209

142-
methodOwner.asTerm.select(method).appliedToArgss(accesses).asExpr
210+
methodOwner.asTerm.select(method).appliedToArgs(accesses(args)).asExpr
143211
}
144-
145212

146213
/** Lookup default values for a method's parameters. */
147-
private def getDefaultParams(using Quotes)(method: quotes.reflect.Symbol): Map[quotes.reflect.Symbol, Expr[Any]] = {
214+
private def getDefaultParams(using Quotes)(method: quotes.reflect.Symbol): Map[quotes.reflect.Symbol, Expr[Any] => Expr[Any]] = {
148215
// Copy pasted from Cask.
149216
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L38
150217
import quotes.reflect._
151218

152219
val params = method.paramSymss.flatten
153-
val defaults = collection.mutable.Map.empty[Symbol, Expr[Any]]
220+
val defaults = collection.mutable.Map.empty[Symbol, Expr[Any] => Expr[Any]]
154221

155222
val Name = (method.name + """\$default\$(\d+)""").r
156223
val InitName = """\$lessinit\$greater\$default\$(\d+)""".r
@@ -159,13 +226,13 @@ object Macros {
159226

160227
idents.foreach{
161228
case deff @ DefDef(Name(idx), _, _, _) =>
162-
val expr = Ref(deff.symbol).asExpr
229+
val expr = (owner: Expr[Any]) => Select(owner.asTerm, deff.symbol).asExpr
163230
defaults += (params(idx.toInt - 1) -> expr)
164231

165232
// The `apply` method re-uses the default param factory methods from `<init>`,
166233
// so make sure to check if those exist too
167234
case deff @ DefDef(InitName(idx), _, _, _) if method.name == "apply" =>
168-
val expr = Ref(deff.symbol).asExpr
235+
val expr = (owner: Expr[Any]) => Select(owner.asTerm, deff.symbol).asExpr
169236
defaults += (params(idx.toInt - 1) -> expr)
170237

171238
case _ =>

mainargs/test/src/ClassTests.scala

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,56 @@ object ClassTests extends TestSuite {
1212
@main
1313
case class Qux(moo: String, b: Bar)
1414

15+
case class Cli(@arg(short = 'd') debug: Flag)
16+
17+
@main
18+
class Compat(
19+
@arg(short = 'h') val home: String,
20+
@arg(short = 's') val silent: Flag,
21+
val leftoverArgs: Leftover[String]
22+
) {
23+
override def equals(obj: Any): Boolean =
24+
obj match {
25+
case c: Compat =>
26+
home == c.home && silent == c.silent && leftoverArgs == c.leftoverArgs
27+
case _ => false
28+
}
29+
}
30+
object Compat {
31+
def apply(
32+
home: String = "/home",
33+
silent: Flag = Flag(),
34+
leftoverArgs: Leftover[String] = Leftover()
35+
) = new Compat(home, silent, leftoverArgs)
36+
37+
@deprecated("bin-compat shim", "0.1.0")
38+
private[mainargs] def apply(
39+
home: String,
40+
silent: Flag,
41+
noDefaultPredef: Flag,
42+
leftoverArgs: Leftover[String]
43+
) = new Compat(home, silent, leftoverArgs)
44+
}
45+
1546
implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo]
1647
implicit val barParser: ParserForClass[Bar] = ParserForClass[Bar]
1748
implicit val quxParser: ParserForClass[Qux] = ParserForClass[Qux]
49+
implicit val cliParser: ParserForClass[Cli] = ParserForClass[Cli]
50+
implicit val compatParser: ParserForClass[Compat] = ParserForClass[Compat]
51+
52+
class PathWrap {
53+
@main
54+
case class Foo(x: Int = 23, y: Int = 47)
55+
56+
object Main {
57+
@main
58+
def run(bar: Bar, bool: Boolean = false) = {
59+
s"${bar.w.value} ${bar.f.x} ${bar.f.y} ${bar.zzzz} $bool"
60+
}
61+
}
62+
63+
implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo]
64+
}
1865

1966
object Main {
2067
@main
@@ -161,5 +208,27 @@ object ClassTests extends TestSuite {
161208
Seq("-x", "1", "-y", "2", "-z", "hello")
162209
) ==> "false 1 2 hello false"
163210
}
211+
test("mill-compat") {
212+
test("apply-overload-class") {
213+
compatParser.constructOrThrow(Seq("foo")) ==> Compat(
214+
home = "/home",
215+
silent = Flag(false),
216+
leftoverArgs = Leftover("foo")
217+
)
218+
}
219+
test("no-main-on-class") {
220+
cliParser.constructOrThrow(Seq("-d")) ==> Cli(Flag(true))
221+
}
222+
test("path-dependent-default") {
223+
val p = new PathWrap
224+
p.fooParser.constructOrThrow(Seq()) ==> p.Foo(23, 47)
225+
}
226+
test("path-dependent-default-method") {
227+
val p = new PathWrap
228+
ParserForMethods(p.Main).runOrThrow(
229+
Seq("-x", "1", "-y", "2", "-z", "hello")
230+
) ==> "false 1 2 hello false"
231+
}
232+
}
164233
}
165234
}

mainargs/test/src-2/VarargsOldTests.scala renamed to mainargs/test/src/VarargsOldTests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ object VarargsOldTests extends VarargsBaseTests {
99

1010
@main
1111
def mixedVariadic(@arg(short = 'f') first: Int, args: String*) =
12-
first + args.mkString
12+
first.toString + args.mkString
1313
}
1414

1515
val check = new Checker(ParserForMethods(Base), allowPositional = true)

0 commit comments

Comments
 (0)