@@ -32,23 +32,34 @@ object Macros {
32
32
val companionModuleType = typeSymbolOfB.companionModule.tree.asInstanceOf [ValDef ].tpt.tpe.asType
33
33
val companionModuleExpr = Ident (companionModule).asExpr
34
34
val mainAnnotationInstance = typeSymbolOfB.getAnnotation(mainAnnotation).getOrElse {
35
- report.throwError(
36
- s " cannot find @main annotation on ${companionModule.name}" ,
37
- typeSymbolOfB.pos.get
35
+ ' {new mainargs.main()}.asTerm // construct a default if not found.
36
+ }
37
+ val ctor = typeSymbolOfB.primaryConstructor
38
+ val ctorParams = ctor.paramSymss.flatten
39
+ // try to match the apply method with the constructor parameters, this is a good heuristic
40
+ // for if the apply method is overloaded.
41
+ val annotatedMethod = typeSymbolOfB.companionModule.memberMethod(" apply" ).filter(p =>
42
+ p.paramSymss.flatten.corresponds(ctorParams) { (p1, p2) =>
43
+ p1.name == p2.name
44
+ }
45
+ ).headOption.getOrElse {
46
+ report.errorAndAbort(
47
+ s " Cannot find apply method in companion object of ${typeReprOfB.show}" ,
48
+ typeSymbolOfB.companionModule.pos.getOrElse(Position .ofMacroExpansion)
38
49
)
39
50
}
40
- val annotatedMethod = TypeRepr .of[B ].typeSymbol.companionModule.memberMethod(" apply" ).head
41
51
companionModuleType match
42
52
case ' [bCompanion] =>
43
- val mainData = createMainData[B , Any ](
53
+ val mainData = createMainData[B , bCompanion ](
44
54
annotatedMethod,
45
55
mainAnnotationInstance,
46
56
// Somehow the `apply` method parameter annotations don't end up on
47
57
// the `apply` method parameters, but end up in the `<init>` method
48
58
// parameters, so use those for getting the annotations instead
49
59
TypeRepr .of[B ].typeSymbol.primaryConstructor.paramSymss
50
60
)
51
- ' { new ParserForClass [B ]($ { mainData }, () => $ { Ident (companionModule).asExpr }) }
61
+ val erasedMainData = ' {$mainData.asInstanceOf [MainData [B , Any ]]}
62
+ ' { new ParserForClass [B ]($erasedMainData, () => $ { Ident (companionModule).asExpr }) }
52
63
}
53
64
54
65
def createMainData [T : Type , B : Type ](using Quotes )
@@ -57,41 +68,84 @@ object Macros {
57
68
createMainData[T , B ](method, mainAnnotation, method.paramSymss)
58
69
}
59
70
71
+ private object VarargTypeRepr {
72
+ def unapply (using Quotes )(tpe : quotes.reflect.TypeRepr ): Option [quotes.reflect.TypeRepr ] = {
73
+ import quotes .reflect .*
74
+ tpe match {
75
+ case AnnotatedType (AppliedType (_, Seq (arg)), x)
76
+ if x.tpe =:= defn.RepeatedAnnot .typeRef => Some (arg)
77
+ case _ => None
78
+ }
79
+ }
80
+ }
81
+
82
+ private object AsType {
83
+ def unapply (using Quotes )(tpe : quotes.reflect.TypeRepr ): Some [Type [? ]] = {
84
+ Some (tpe.asType)
85
+ }
86
+ }
87
+
60
88
def createMainData [T : Type , B : Type ](using Quotes )
61
89
(method : quotes.reflect.Symbol ,
62
90
mainAnnotation : quotes.reflect.Term ,
63
91
annotatedParamsLists : List [List [quotes.reflect.Symbol ]]): Expr [MainData [T , B ]] = {
64
92
65
93
import quotes .reflect .*
66
94
val params = method.paramSymss.headOption.getOrElse(report.throwError(" Multiple parameter lists not supported" ))
67
- val defaultParams = getDefaultParams(method)
95
+ val defaultParams = if (params.exists(_.flags.is( Flags . HasDefault ))) getDefaultParams(method) else Map .empty
68
96
val argSigsExprs = params.zip(annotatedParamsLists.flatten).map { paramAndAnnotParam =>
69
97
val param = paramAndAnnotParam._1
70
98
val annotParam = paramAndAnnotParam._2
71
99
val paramTree = param.tree.asInstanceOf [ValDef ]
72
100
val paramTpe = paramTree.tpt.tpe
101
+ val readerTpe = paramTpe match {
102
+ case VarargTypeRepr (AsType (' [t])) => TypeRepr .of[Leftover [t]]
103
+ case _ => paramTpe
104
+ }
73
105
val arg = annotParam.getAnnotation(argAnnotation).map(_.asExprOf[mainargs.arg]).getOrElse(' { new mainargs.arg() })
74
- val paramType = paramTpe.asType
75
- paramType match
106
+ readerTpe.asType match {
76
107
case ' [t] =>
108
+ def applyAndCast (f : Expr [Any ] => Expr [Any ], arg : Expr [B ]): Expr [t] = {
109
+ f(arg) match {
110
+ case ' { $v : `t` } => v
111
+ case expr => {
112
+ // this case will be activated when the found default parameter is not of type `t`
113
+ val recoveredType =
114
+ try
115
+ expr.asExprOf[t]
116
+ catch
117
+ case err : Exception =>
118
+ report.errorAndAbort(
119
+ s """ Failed to convert default value for parameter ${param.name},
120
+ |expected type: ${paramTpe.show},
121
+ |but default value ${expr.show} is of type: ${expr.asTerm.tpe.widen.show}
122
+ |while converting type caught an exception with message: ${err.getMessage}
123
+ |There might be a bug in mainargs. """ .stripMargin,
124
+ param.pos.getOrElse(Position .ofMacroExpansion)
125
+ )
126
+ recoveredType
127
+ }
128
+ }
129
+ }
77
130
val defaultParam : Expr [Option [B => t]] = defaultParams.get(param) match {
78
- case Some (' { $v : `t`} ) => ' { Some ((( _ : B ) => $v) ) }
131
+ case Some (f ) => ' { Some ((b : B ) => $ { applyAndCast(f, ' b ) } ) }
79
132
case None => ' { None }
80
133
}
81
134
val tokensReader = Expr .summon[mainargs.TokensReader [t]].getOrElse {
82
- report.throwError (
83
- s " No mainargs.ArgReader found for parameter ${param.name}" ,
84
- param .pos.get
135
+ report.errorAndAbort (
136
+ s " No mainargs.TokensReader[ ${ Type .show[t]} ] found for parameter ${param.name} of method ${method.name} in ${method.owner.fullName }" ,
137
+ method .pos.getOrElse( Position .ofMacroExpansion)
85
138
)
86
139
}
87
140
' { (ArgSig .create[t, B ]($ { Expr (param.name) }, $ { arg }, $ { defaultParam })(using $ { tokensReader })) }
141
+ }
88
142
}
89
143
val argSigs = Expr .ofList(argSigsExprs)
90
144
91
145
val invokeRaw : Expr [(B , Seq [Any ]) => T ] = {
92
146
93
147
def callOf (methodOwner : Expr [Any ], args : Expr [Seq [Any ]]) =
94
- call(methodOwner, method, ' { Seq ($ args) } ).asExprOf[T ]
148
+ call(methodOwner, method, args).asExprOf[T ]
95
149
96
150
' { (b : B , params : Seq [Any ]) => $ { callOf(' b , ' params ) } }
97
151
}
@@ -120,37 +174,50 @@ object Macros {
120
174
private def call (using Quotes )(
121
175
methodOwner : Expr [Any ],
122
176
method : quotes.reflect.Symbol ,
123
- argss : Expr [Seq [Seq [ Any ] ]]
177
+ args : Expr [Seq [Any ]]
124
178
): Expr [_] = {
125
179
// Copy pasted from Cask.
126
180
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L106
127
181
import quotes .reflect ._
128
182
val paramss = method.paramSymss
129
183
130
184
if (paramss.isEmpty) {
131
- report.throwError (" At least one parameter list must be declared." , method.pos.get)
185
+ report.errorAndAbort (" At least one parameter list must be declared." , method.pos.get)
132
186
}
133
-
134
- val accesses : List [List [Term ]] = for (i <- paramss.indices.toList) yield {
135
- for (j <- paramss(i).indices.toList) yield {
136
- val tpe = paramss(i)(j).tree.asInstanceOf [ValDef ].tpt.tpe
137
- tpe.asType match
138
- case ' [t] => ' { $argss($ {Expr (i)})($ {Expr (j)}).asInstanceOf [t] }.asTerm
139
- }
187
+ if (paramss.sizeIs > 1 ) {
188
+ report.errorAndAbort(" Multiple parameter lists are not supported." , method.pos.get)
140
189
}
190
+ val params = paramss.head
191
+
192
+ val methodType = methodOwner.asTerm.tpe.memberType(method)
193
+
194
+ def accesses (ref : Expr [Seq [Any ]]): List [Term ] =
195
+ for (i <- params.indices.toList) yield {
196
+ val param = params(i)
197
+ val tpe = methodType.memberType(param)
198
+ val untypedRef = ' { $ref($ {Expr (i)}) }
199
+ tpe match {
200
+ case VarargTypeRepr (AsType (' [t])) =>
201
+ Typed (
202
+ ' { $untypedRef.asInstanceOf [Leftover [t]].value }.asTerm,
203
+ Inferred (AppliedType (defn.RepeatedParamClass .typeRef, List (TypeRepr .of[t])))
204
+ )
205
+ case _ => tpe.asType match
206
+ case ' [t] => ' { $untypedRef.asInstanceOf [t] }.asTerm
207
+ }
208
+ }
141
209
142
- methodOwner.asTerm.select(method).appliedToArgss (accesses).asExpr
210
+ methodOwner.asTerm.select(method).appliedToArgs (accesses(args) ).asExpr
143
211
}
144
-
145
212
146
213
/** Lookup default values for a method's parameters. */
147
- private def getDefaultParams (using Quotes )(method : quotes.reflect.Symbol ): Map [quotes.reflect.Symbol , Expr [Any ]] = {
214
+ private def getDefaultParams (using Quotes )(method : quotes.reflect.Symbol ): Map [quotes.reflect.Symbol , Expr [Any ] => Expr [ Any ] ] = {
148
215
// Copy pasted from Cask.
149
216
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L38
150
217
import quotes .reflect ._
151
218
152
219
val params = method.paramSymss.flatten
153
- val defaults = collection.mutable.Map .empty[Symbol , Expr [Any ]]
220
+ val defaults = collection.mutable.Map .empty[Symbol , Expr [Any ] => Expr [ Any ] ]
154
221
155
222
val Name = (method.name + """ \$default\$(\d+)""" ).r
156
223
val InitName = """ \$lessinit\$greater\$default\$(\d+)""" .r
@@ -159,13 +226,13 @@ object Macros {
159
226
160
227
idents.foreach{
161
228
case deff @ DefDef (Name (idx), _, _, _) =>
162
- val expr = Ref ( deff.symbol).asExpr
229
+ val expr = ( owner : Expr [ Any ]) => Select (owner.asTerm, deff.symbol).asExpr
163
230
defaults += (params(idx.toInt - 1 ) -> expr)
164
231
165
232
// The `apply` method re-uses the default param factory methods from `<init>`,
166
233
// so make sure to check if those exist too
167
234
case deff @ DefDef (InitName (idx), _, _, _) if method.name == " apply" =>
168
- val expr = Ref ( deff.symbol).asExpr
235
+ val expr = ( owner : Expr [ Any ]) => Select (owner.asTerm, deff.symbol).asExpr
169
236
defaults += (params(idx.toInt - 1 ) -> expr)
170
237
171
238
case _ =>
0 commit comments