1
+ import * as ts from "typescript" ;
2
+ import { NodeBuilderFlags } from "typescript" ;
3
+ import { map } from "../compiler/lang-utils" ;
4
+ import { SymbolTracker } from "../compiler/types" ;
5
+
6
+ const declarationEmitNodeBuilderFlags =
7
+ NodeBuilderFlags . MultilineObjectLiterals |
8
+ NodeBuilderFlags . WriteClassExpressionAsTypeLiteral |
9
+ NodeBuilderFlags . UseTypeOfFunction |
10
+ NodeBuilderFlags . UseStructuralFallback |
11
+ NodeBuilderFlags . AllowEmptyTuple |
12
+ NodeBuilderFlags . GenerateNamesForShadowedTypeParams |
13
+ NodeBuilderFlags . NoTruncation ;
14
+
15
+
16
+ // Define a transformer function
17
+ export function addTypeAnnotationTransformer ( program : ts . Program , moduleResolutionHost ?: ts . ModuleResolutionHost ) {
18
+ function tryGetReturnType (
19
+ typeChecker : ts . TypeChecker ,
20
+ node : ts . SignatureDeclaration
21
+ ) : ts . Type | undefined {
22
+ const signature = typeChecker . getSignatureFromDeclaration ( node ) ;
23
+ if ( signature ) {
24
+ return typeChecker . getReturnTypeOfSignature ( signature ) ;
25
+ }
26
+ }
27
+
28
+ function isVarConst ( node : ts . VariableDeclaration | ts . VariableDeclarationList ) : boolean {
29
+ return ! ! ( ts . getCombinedNodeFlags ( node ) & ts . NodeFlags . Const ) ;
30
+ }
31
+
32
+ function isDeclarationReadonly ( declaration : ts . Declaration ) : boolean {
33
+ return ! ! ( ts . getCombinedModifierFlags ( declaration ) & ts . ModifierFlags . Readonly && ! ts . isParameterPropertyDeclaration ( declaration , declaration . parent ) ) ;
34
+ }
35
+
36
+ function isLiteralConstDeclaration ( node : ts . VariableDeclaration | ts . PropertyDeclaration | ts . PropertySignature | ts . ParameterDeclaration ) : boolean {
37
+ if ( isDeclarationReadonly ( node ) || ts . isVariableDeclaration ( node ) && isVarConst ( node ) ) {
38
+ // TODO: Make sure this is a valid approximation for literal types
39
+ return ! node . type && 'initializer' in node && ! ! node . initializer && ts . isLiteralExpression ( node . initializer ) ;
40
+ // Original TS version
41
+ // return isFreshLiteralType(getTypeOfSymbol(getSymbolOfNode(node)));
42
+ }
43
+ return false ;
44
+ }
45
+
46
+ const typeChecker = program . getTypeChecker ( ) ;
47
+
48
+ return ( context : ts . TransformationContext ) => {
49
+ let hasError = false ;
50
+ let reportError = ( ) => {
51
+ hasError = true ;
52
+ }
53
+ const symbolTracker : SymbolTracker | undefined = ! moduleResolutionHost ? undefined : {
54
+ trackSymbol ( ) { return false ; } ,
55
+ reportInaccessibleThisError : reportError ,
56
+ reportInaccessibleUniqueSymbolError : reportError ,
57
+ reportCyclicStructureError : reportError ,
58
+ reportPrivateInBaseOfClassExpression : reportError ,
59
+ reportLikelyUnsafeImportRequiredError : reportError ,
60
+ reportTruncationError : reportError ,
61
+ moduleResolverHost : moduleResolutionHost as any ,
62
+ trackReferencedAmbientModule ( ) { } ,
63
+ trackExternalModuleSymbolOfImportTypeNode ( ) { } ,
64
+ reportNonlocalAugmentation ( ) { } ,
65
+ reportNonSerializableProperty ( ) { } ,
66
+ reportImportTypeNodeResolutionModeOverride ( ) { } ,
67
+ } ;
68
+
69
+ function typeToTypeNode ( type : ts . Type , enclosingDeclaration : ts . Node ) {
70
+ const typeNode = typeChecker . typeToTypeNode (
71
+ type ,
72
+ enclosingDeclaration ,
73
+ declarationEmitNodeBuilderFlags ,
74
+ // @ts -expect-error Use undocumented parameters
75
+ symbolTracker ,
76
+ )
77
+ if ( hasError ) {
78
+ hasError = false ;
79
+ return undefined ;
80
+ }
81
+
82
+ return typeNode ;
83
+ }
84
+ // Return a visitor function
85
+ return ( rootNode : ts . Node ) => {
86
+ function updateTypesInNodeArray < T extends ts . Node > ( nodeArray : ts . NodeArray < T > ) : ts . NodeArray < T >
87
+ function updateTypesInNodeArray < T extends ts . Node > ( nodeArray : ts . NodeArray < T > | undefined ) : ts . NodeArray < T > | undefined
88
+ function updateTypesInNodeArray < T extends ts . Node > ( nodeArray : ts . NodeArray < T > | undefined ) {
89
+ if ( nodeArray === undefined ) return undefined ;
90
+ return ts . factory . createNodeArray (
91
+ nodeArray . map ( param => {
92
+ return visit ( param ) as ts . ParameterDeclaration ;
93
+ } )
94
+ )
95
+ }
96
+
97
+ // Define a visitor function
98
+ function visit ( node : ts . Node ) : ts . Node | ts . Node [ ] {
99
+ if ( ts . isParameter ( node ) && ! node . type ) {
100
+ const type = typeChecker . getTypeAtLocation ( node ) ;
101
+ if ( type ) {
102
+ const typeNode = typeToTypeNode ( type , node ) ;
103
+ return ts . factory . updateParameterDeclaration (
104
+ node ,
105
+ node . modifiers ,
106
+ node . dotDotDotToken ,
107
+ node . name ,
108
+ node . questionToken ,
109
+ typeNode ,
110
+ node . initializer
111
+ )
112
+ }
113
+ }
114
+ // Check if node is a variable declaration
115
+ if ( ts . isVariableDeclaration ( node ) && ! node . type && ! isLiteralConstDeclaration ( node ) ) {
116
+ const type = typeChecker . getTypeAtLocation ( node ) ;
117
+ const typeNode = typeToTypeNode ( type , node )
118
+ return ts . factory . updateVariableDeclaration (
119
+ node ,
120
+ node . name ,
121
+ undefined ,
122
+ typeNode ,
123
+ node . initializer
124
+ ) ;
125
+ }
126
+
127
+ if ( ts . isFunctionDeclaration ( node ) && ! node . type ) {
128
+ const type = tryGetReturnType ( typeChecker , node ) ;
129
+ if ( type ) {
130
+
131
+ const typeNode = typeToTypeNode ( type , node ) ;
132
+ return ts . factory . updateFunctionDeclaration (
133
+ node ,
134
+ node . modifiers ,
135
+ node . asteriskToken ,
136
+ node . name ,
137
+ updateTypesInNodeArray ( node . typeParameters ) ,
138
+ updateTypesInNodeArray ( node . parameters ) ,
139
+ typeNode ,
140
+ node . body
141
+ )
142
+ }
143
+ }
144
+ if ( ts . isPropertySignature ( node ) && ! node . type && ! isLiteralConstDeclaration ( node ) ) {
145
+ const type = typeChecker . getTypeAtLocation ( node ) ;
146
+ const typeNode = typeToTypeNode ( type , node ) ;
147
+ return ts . factory . updatePropertySignature (
148
+ node ,
149
+ node . modifiers ,
150
+ node . name ,
151
+ node . questionToken ,
152
+ typeNode ,
153
+ ) ;
154
+ }
155
+ if ( ts . isPropertyDeclaration ( node ) && ! node . type && ! isLiteralConstDeclaration ( node ) ) {
156
+ const type = typeChecker . getTypeAtLocation ( node ) ;
157
+ const typeNode = typeToTypeNode ( type , node ) ;
158
+ return ts . factory . updatePropertyDeclaration (
159
+ node ,
160
+ node . modifiers ,
161
+ node . name ,
162
+ node . questionToken ?? node . exclamationToken ,
163
+ typeNode ,
164
+ node . initializer
165
+ ) ;
166
+ }
167
+ if ( ts . isMethodSignature ( node ) && ! node . type ) {
168
+ const type = tryGetReturnType ( typeChecker , node ) ;
169
+ if ( type ) {
170
+
171
+ const typeNode = typeToTypeNode ( type , node ) ;
172
+ return ts . factory . updateMethodSignature (
173
+ node ,
174
+ node . modifiers ,
175
+ node . name ,
176
+ node . questionToken ,
177
+ updateTypesInNodeArray ( node . typeParameters ) ,
178
+ updateTypesInNodeArray ( node . parameters ) ,
179
+ typeNode ,
180
+ ) ;
181
+ }
182
+ }
183
+ if ( ts . isCallSignatureDeclaration ( node ) ) {
184
+ const type = tryGetReturnType ( typeChecker , node ) ;
185
+ if ( type ) {
186
+ const typeNode = typeToTypeNode ( type , node ) ;
187
+ return ts . factory . updateCallSignature (
188
+ node ,
189
+ updateTypesInNodeArray ( node . typeParameters ) ,
190
+ updateTypesInNodeArray ( node . parameters ) ,
191
+ typeNode ,
192
+ )
193
+ }
194
+ }
195
+ if ( ts . isMethodDeclaration ( node ) && ! node . type ) {
196
+ const type = tryGetReturnType ( typeChecker , node ) ;
197
+ if ( type ) {
198
+
199
+ const typeNode = typeToTypeNode ( type , node ) ;
200
+ return ts . factory . updateMethodDeclaration (
201
+ node ,
202
+ node . modifiers ,
203
+ node . asteriskToken ,
204
+ node . name ,
205
+ node . questionToken ,
206
+ updateTypesInNodeArray ( node . typeParameters ) ,
207
+ updateTypesInNodeArray ( node . parameters ) ,
208
+ typeNode ,
209
+ node . body ,
210
+ ) ;
211
+ }
212
+ }
213
+ if ( ts . isGetAccessorDeclaration ( node ) && ! node . type ) {
214
+ const type = tryGetReturnType ( typeChecker , node ) ;
215
+ if ( type ) {
216
+ const typeNode = typeToTypeNode ( type , node ) ;
217
+ return ts . factory . updateGetAccessorDeclaration (
218
+ node ,
219
+ node . modifiers ,
220
+ node . name ,
221
+ updateTypesInNodeArray ( node . parameters ) ,
222
+ typeNode ,
223
+ node . body ,
224
+ ) ;
225
+ }
226
+ }
227
+ if ( ts . isSetAccessorDeclaration ( node ) && ! node . parameters [ 0 ] ?. type ) {
228
+ return ts . factory . updateSetAccessorDeclaration (
229
+ node ,
230
+ node . modifiers ,
231
+ node . name ,
232
+ updateTypesInNodeArray ( node . parameters ) ,
233
+ node . body ,
234
+ ) ;
235
+ }
236
+ if ( ts . isConstructorDeclaration ( node ) ) {
237
+ return ts . factory . updateConstructorDeclaration (
238
+ node ,
239
+ node . modifiers ,
240
+ updateTypesInNodeArray ( node . parameters ) ,
241
+ node . body ,
242
+ )
243
+ }
244
+ if ( ts . isConstructSignatureDeclaration ( node ) ) {
245
+ const type = tryGetReturnType ( typeChecker , node ) ;
246
+ if ( type ) {
247
+ const typeNode = typeToTypeNode ( type , node ) ;
248
+ return ts . factory . updateConstructSignature (
249
+ node ,
250
+ updateTypesInNodeArray ( node . typeParameters ) ,
251
+ updateTypesInNodeArray ( node . parameters ) ,
252
+ typeNode ,
253
+ )
254
+ }
255
+ }
256
+ if ( ts . isExportAssignment ( node ) && node . expression . kind !== ts . SyntaxKind . Identifier ) {
257
+ const type = typeChecker . getTypeAtLocation ( node . expression ) ;
258
+ if ( type ) {
259
+ const typeNode = typeToTypeNode ( type , node ) ;
260
+ const newId = ts . factory . createIdentifier ( "_default" ) ;
261
+ const varDecl = ts . factory . createVariableDeclaration ( newId , /*exclamationToken*/ undefined , typeNode , /*initializer*/ undefined ) ;
262
+ const statement = ts . factory . createVariableStatement (
263
+ [ ] ,
264
+ ts . factory . createVariableDeclarationList ( [ varDecl ] , ts . NodeFlags . Const )
265
+ ) ;
266
+ return [ statement , ts . factory . updateExportAssignment ( node , node . modifiers , newId ) ] ;
267
+ }
268
+ }
269
+ // Otherwise, visit each child node recursively
270
+ return ts . visitEachChild ( node , visit , context ) ;
271
+ }
272
+ // Start visiting from root node
273
+ return ts . visitNode ( rootNode , visit ) ! ;
274
+ } ;
275
+ } ;
276
+ }
0 commit comments