Skip to content

Commit db75e9a

Browse files
committed
very wip: inference: allow semi-concrete interpret to perform recursive inference
fix #48679
1 parent 7eb9615 commit db75e9a

File tree

12 files changed

+254
-169
lines changed

12 files changed

+254
-169
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 99 additions & 91 deletions
Large diffs are not rendered by default.

base/compiler/compiler.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Core.Intrinsics, Core.IR
66

77
import Core: print, println, show, write, unsafe_write, stdout, stderr,
88
_apply_iterate, svec, apply_type, Builtin, IntrinsicFunction,
9-
MethodInstance, CodeInstance, MethodMatch, PartialOpaque,
9+
MethodInstance, CodeInstance, MethodTable, MethodMatch, PartialOpaque,
1010
TypeofVararg
1111

1212
const getproperty = Core.getfield
@@ -154,6 +154,25 @@ include("compiler/ssair/ir.jl")
154154
include("compiler/abstractlattice.jl")
155155

156156
include("compiler/inferenceresult.jl")
157+
158+
# TODO define the interface for this abstract type
159+
abstract type AbsIntState end
160+
function frame_instance end
161+
function frame_module(sv::AbsIntState)
162+
mi = frame_instance(sv)
163+
def = mi.def
164+
isa(def, Module) && return def
165+
return def.module
166+
end
167+
function frame_parent end
168+
function frame_cached end
169+
function frame_src end
170+
function callers_in_cycle end
171+
# function recur_state end
172+
# pclimitations(sv::AbsIntState) = recur_state(sv).pclimitations
173+
# limitations(sv::AbsIntState) = recur_state(sv).limitations
174+
# callers_in_cycle(sv::AbsIntState) = recur_state(sv).callers_in_cycle
175+
157176
include("compiler/inferencestate.jl")
158177

159178
include("compiler/typeutils.jl")

base/compiler/inferencestate.jl

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,18 @@ function append!(bsbmp::BitSetBoundedMinPrioritySet, itr)
7878
end
7979
end
8080

81-
mutable struct InferenceState
81+
struct AbsIntRecursionState
82+
pclimitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
83+
limitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on return
84+
callers_in_cycle::Vector{AbsIntState}
85+
end
86+
function AbsIntRecursionState()
87+
return AbsIntRecursionState(IdSet{AbsIntState}(),
88+
IdSet{AbsIntState}(),
89+
Vector{AbsIntState}())
90+
end
91+
92+
mutable struct InferenceState <: AbsIntState
8293
#= information about this method instance =#
8394
linfo::MethodInstance
8495
world::UInt
@@ -197,25 +208,26 @@ mutable struct InferenceState
197208
end
198209
end
199210

211+
frame_instance(sv::InferenceState) = sv.linfo
212+
frame_parent(sv::InferenceState) = sv.parent
213+
frame_cached(sv::InferenceState) = sv.cached
214+
frame_src(sv::InferenceState) = sv.src
215+
callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle
200216
Effects(state::InferenceState) = state.ipo_effects
201217

202218
function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
203219
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
204220
end
205221

206-
merge_effects!(interp::AbstractInterpreter, caller::InferenceState, callee::InferenceState) =
207-
merge_effects!(interp, caller, Effects(callee))
208-
merge_effects!(interp::AbstractInterpreter, caller::IRCode, effects::Effects) = nothing
209-
210-
is_effect_overridden(sv::InferenceState, effect::Symbol) = is_effect_overridden(sv.linfo, effect)
222+
is_effect_overridden(sv::AbsIntState, effect::Symbol) = is_effect_overridden(frame_instance(sv), effect)
211223
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
212224
def = linfo.def
213225
return isa(def, Method) && is_effect_overridden(def, effect)
214226
end
215227
is_effect_overridden(method::Method, effect::Symbol) = is_effect_overridden(decode_effects_override(method.purity), effect)
216228
is_effect_overridden(override::EffectsOverride, effect::Symbol) = getfield(override, effect)
217229

218-
add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return
230+
add_remark!(::AbstractInterpreter, ::AbsIntState, remark) = return
219231

220232
struct InferenceLoopState
221233
sig
@@ -226,13 +238,13 @@ struct InferenceLoopState
226238
end
227239
end
228240

229-
function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
230-
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
241+
function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::InferenceState)
242+
return sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
231243
end
232-
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
244+
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
233245
return state.rt === Any && !is_foldable(state.effects)
234246
end
235-
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
247+
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
236248
return state.rt === Any
237249
end
238250

@@ -351,21 +363,21 @@ end
351363
children before their parents (i.e. ascending the tree from the given
352364
InferenceState). Note that cycles may be visited in any order.
353365
"""
354-
struct InfStackUnwind
355-
inf::InferenceState
366+
struct InfStackUnwind{SV<:AbsIntState}
367+
inf::SV
356368
end
357369
iterate(unw::InfStackUnwind) = (unw.inf, (unw.inf, 0))
358-
function iterate(unw::InfStackUnwind, (infstate, cyclei)::Tuple{InferenceState, Int})
370+
function iterate(unw::InfStackUnwind{SV}, (infstate, cyclei)::Tuple{SV, Int}) where SV<:AbsIntState
359371
# iterate through the cycle before walking to the parent
360-
if cyclei < length(infstate.callers_in_cycle)
372+
if cyclei < length(callers_in_cycle(infstate))
361373
cyclei += 1
362-
infstate = infstate.callers_in_cycle[cyclei]
374+
infstate = callers_in_cycle(infstate)[cyclei]
363375
else
364376
cyclei = 0
365-
infstate = infstate.parent
377+
infstate = frame_parent(infstate)
366378
end
367379
infstate === nothing && return nothing
368-
(infstate::InferenceState, (infstate, cyclei))
380+
(infstate, (infstate, cyclei))
369381
end
370382

371383
function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
@@ -504,7 +516,7 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
504516
return sptypes
505517
end
506518

507-
_topmod(sv::InferenceState) = _topmod(sv.mod)
519+
_topmod(sv::InferenceState) = _topmod(frame_module(sv))
508520

509521
# work towards converging the valid age range for sv
510522
function update_valid_age!(sv::InferenceState, worlds::WorldRange)
@@ -548,10 +560,10 @@ function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, curr
548560
end
549561

550562
# temporarily accumulate our edges to later add as backedges in the callee
551-
function add_backedge!(caller::InferenceState, li::MethodInstance)
563+
function add_backedge!(caller::InferenceState, mi::MethodInstance)
552564
edges = get_stmt_edges!(caller)
553565
if edges !== nothing
554-
push!(edges, li)
566+
push!(edges, mi)
555567
end
556568
return nothing
557569
end
@@ -565,7 +577,7 @@ function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::T
565577
end
566578

567579
# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
568-
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
580+
function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ))
569581
edges = get_stmt_edges!(caller)
570582
if edges !== nothing
571583
push!(edges, mt, typ)

base/compiler/methodtable.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ external table, e.g., to override existing method.
3939
"""
4040
struct OverlayMethodTable <: MethodTableView
4141
world::UInt
42-
mt::Core.MethodTable
42+
mt::MethodTable
4343
end
4444

4545
struct MethodMatchKey
@@ -98,7 +98,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
9898
!isempty(result))
9999
end
100100

101-
function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
101+
function _findall(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt, limit::Int)
102102
_min_val = RefValue{UInt}(typemin(UInt))
103103
_max_val = RefValue{UInt}(typemax(UInt))
104104
_ambig = RefValue{Int32}(0)
@@ -155,7 +155,7 @@ function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
155155
false)
156156
end
157157

158-
function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)
158+
function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt)
159159
min_valid = RefValue{UInt}(typemin(UInt))
160160
max_valid = RefValue{UInt}(typemax(UInt))
161161
match = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),

base/compiler/optimize.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ struct InliningState{Interp<:AbstractInterpreter}
126126
world::UInt
127127
interp::Interp
128128
end
129-
function InliningState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
130-
et = EdgeTracker(frame.stmt_edges[1]::Vector{Any}, frame.valid_worlds)
131-
return InliningState(params, et, frame.world, interp)
129+
function InliningState(sv::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
130+
et = EdgeTracker(sv.stmt_edges[1]::Vector{Any}, sv.valid_worlds)
131+
return InliningState(params, et, sv.world, interp)
132132
end
133133
function InliningState(params::OptimizationParams, interp::AbstractInterpreter)
134134
return InliningState(params, nothing, get_world_counter(interp), interp)
@@ -151,12 +151,12 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
151151
cfg::Union{Nothing,CFG}
152152
insert_coverage::Bool
153153
end
154-
function OptimizationState(frame::InferenceState, params::OptimizationParams,
154+
function OptimizationState(sv::InferenceState, params::OptimizationParams,
155155
interp::AbstractInterpreter, recompute_cfg::Bool=true)
156-
inlining = InliningState(frame, params, interp)
157-
cfg = recompute_cfg ? nothing : frame.cfg
158-
return OptimizationState(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod,
159-
frame.sptypes, frame.slottypes, inlining, cfg, frame.insert_coverage)
156+
inlining = InliningState(sv, params, interp)
157+
cfg = recompute_cfg ? nothing : sv.cfg
158+
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, frame_module(sv),
159+
sv.sptypes, sv.slottypes, inlining, cfg, sv.insert_coverage)
160160
end
161161
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams,
162162
interp::AbstractInterpreter)
@@ -387,9 +387,9 @@ function argextype(
387387
return Const(x)
388388
end
389389
end
390+
abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any})
390391
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]
391392

392-
393393
"""
394394
finish(interp::AbstractInterpreter, opt::OptimizationState,
395395
params::OptimizationParams, ir::IRCode, caller::InferenceResult)

base/compiler/ssair/irinterp.jl

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,17 @@ function getindex(tpdum::TwoPhaseDefUseMap, idx::Int)
9494
return TwoPhaseVectorView(tpdum.data, nelems, range)
9595
end
9696

97-
struct IRInterpretationState
97+
# TODO add `result::InferenceResult` & `parent::InferenceState` for this
98+
struct IRInterpretationState <: AbsIntState
9899
ir::IRCode
99100
mi::MethodInstance
100101
world::UInt
101102
argtypes_refined::Vector{Bool}
103+
sptypes::Vector{VarState}
102104
tpdum::TwoPhaseDefUseMap
103105
ssa_refined::BitSet
104106
lazydomtree::LazyDomtree
107+
callers_in_cycle::Vector{InferenceState}
105108
function IRInterpretationState(interp::AbstractInterpreter,
106109
ir::IRCode, mi::MethodInstance, world::UInt, argtypes::Vector{Any})
107110
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, mi)
@@ -114,10 +117,40 @@ struct IRInterpretationState
114117
tpdum = TwoPhaseDefUseMap(length(ir.stmts))
115118
ssa_refined = BitSet()
116119
lazydomtree = LazyDomtree(ir)
117-
return new(ir, mi, world, argtypes_refined, tpdum, ssa_refined, lazydomtree)
120+
callers_in_cycle = Vector{InferenceState}()
121+
return new(ir, mi, world, argtypes_refined, ir.sptypes, tpdum, ssa_refined, lazydomtree, callers_in_cycle)
118122
end
119123
end
120124

125+
frame_instance(sv::IRInterpretationState) = sv.mi
126+
frame_parent(sv::IRInterpretationState) = nothing
127+
frame_cached(sv::IRInterpretationState) = false
128+
frame_src(sv::IRInterpretationState) = retrieve_code_info(sv.mi) # TODO optimize
129+
callers_in_cycle(sv::IRInterpretationState) = sv.callers_in_cycle
130+
# TODO
131+
merge_effects!(::AbstractInterpreter, ::IRInterpretationState, ::Effects) = return
132+
get_max_methods(::IRInterpretationState, ::AbstractInterpreter) = 3
133+
get_max_methods(@nospecialize(f), ::IRInterpretationState, ::AbstractInterpreter) = 3
134+
ssa_def_slot(@nospecialize(arg), ::IRInterpretationState) = nothing
135+
function bail_out_toplevel_call(::AbstractInterpreter, ::InferenceLoopState, ::IRInterpretationState)
136+
return false
137+
end
138+
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::IRInterpretationState)
139+
return state.rt === Any && !is_foldable(state.effects)
140+
end
141+
function bail_out_apply(::AbstractInterpreter, @nospecialize(rt), ::IRInterpretationState)
142+
return rt === Any
143+
end
144+
should_infer_this_call(::AbstractInterpreter, ::IRInterpretationState) = true
145+
const_prop_enabled(::AbstractInterpreter, ::IRInterpretationState, match::MethodMatch) = false
146+
147+
# TODO
148+
update_valid_age!(::IRInterpretationState, ::WorldRange) = return
149+
update_valid_age!(::InferenceState, ::IRInterpretationState) = return
150+
add_backedge!(::IRInterpretationState, ::MethodInstance) = return
151+
add_invoke_backedge!(::IRInterpretationState, @nospecialize(invokesig::Type), ::MethodInstance) = return
152+
add_mt_backedge!(::IRInterpretationState, ::MethodTable, @nospecialize(typ)) = return
153+
121154
function codeinst_to_ir(interp::AbstractInterpreter, code::CodeInstance)
122155
src = @atomic :monotonic code.inferred
123156
mi = code.def
@@ -129,13 +162,13 @@ function codeinst_to_ir(interp::AbstractInterpreter, code::CodeInstance)
129162
return inflate_ir(src, mi)
130163
end
131164

132-
function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
133-
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
134-
sv::IRCode, max_methods::Int)
135-
return CallMeta(Any, Effects(), NoCallInfo())
165+
function from_interconditional(::AbstractLattice,
166+
typ, ::IRInterpretationState, ::ArgInfo, maybecondinfo)
167+
@nospecialize typ maybecondinfo
168+
return widenconditional(typ)
136169
end
137170

138-
function collect_limitations!(@nospecialize(typ), ::IRCode)
171+
function collect_limitations!(@nospecialize(typ), ::IRInterpretationState)
139172
@assert !isa(typ, LimitedAccuracy) "semi-concrete eval on recursive call graph"
140173
return typ
141174
end
@@ -147,7 +180,7 @@ function concrete_eval_invoke(interp::AbstractInterpreter,
147180
if code === nothing
148181
return Pair{Any, Bool}(nothing, false)
149182
end
150-
argtypes = collect_argtypes(interp, inst.args[2:end], nothing, irsv.ir)
183+
argtypes = collect_argtypes(interp, inst.args[2:end], nothing, irsv)
151184
argtypes === nothing && return Pair{Any, Bool}(Union{}, false)
152185
effects = decode_effects(code.ipo_purity_bits)
153186
if is_foldable(effects) && is_all_const_arg(argtypes, #=start=#1)
@@ -169,8 +202,10 @@ function concrete_eval_invoke(interp::AbstractInterpreter,
169202
return Pair{Any, Bool}(nothing, is_nothrow(effects))
170203
end
171204

205+
abstract_eval_ssavalue(s::SSAValue, sv::IRInterpretationState) = abstract_eval_ssavalue(s, sv.ir)
206+
172207
function abstract_eval_phi_stmt(interp::AbstractInterpreter, phi::PhiNode, ::Int, irsv::IRInterpretationState)
173-
return abstract_eval_phi(interp, phi, nothing, irsv.ir)
208+
return abstract_eval_phi(interp, phi, nothing, irsv)
174209
end
175210

176211
function propagate_control_effects!(interp::AbstractInterpreter, idx::Int, stmt::GotoIfNot,
@@ -237,7 +272,7 @@ function reprocess_instruction!(interp::AbstractInterpreter,
237272
if isa(inst, Expr)
238273
head = inst.head
239274
if head === :call || head === :foreigncall || head === :new || head === :splatnew
240-
(; rt, effects) = abstract_eval_statement_expr(interp, inst, nothing, ir, irsv.mi)
275+
(; rt, effects) = abstract_eval_statement_expr(interp, inst, nothing, irsv)
241276
# All other effects already guaranteed effect free by construction
242277
if is_nothrow(effects)
243278
ir.stmts[idx][:flag] |= IR_FLAG_NOTHROW
@@ -261,7 +296,6 @@ function reprocess_instruction!(interp::AbstractInterpreter,
261296
head === :gc_preserve_end
262297
return false
263298
else
264-
ccall(:jl_, Cvoid, (Any,), inst)
265299
error("reprocess_instruction!: unhandled expression found")
266300
end
267301
elseif isa(inst, PhiNode)
@@ -276,8 +310,7 @@ function reprocess_instruction!(interp::AbstractInterpreter,
276310
elseif isa(inst, GlobalRef)
277311
# GlobalRef is not refinable
278312
else
279-
ccall(:jl_, Cvoid, (Any,), inst)
280-
error()
313+
error("reprocess_instruction!: unhandled instruction found")
281314
end
282315
if rt !== nothing && !(optimizer_lattice(interp), typ, rt)
283316
ir.stmts[idx][:type] = rt

base/compiler/tfuncs.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,7 +1945,7 @@ function array_elmtype(@nospecialize ary)
19451945
return Any
19461946
end
19471947

1948-
@nospecs function _opaque_closure_tfunc(𝕃::AbstractLattice, arg, lb, ub, source, env::Vector{Any}, linfo::MethodInstance)
1948+
@nospecs function opaque_closure_tfunc(𝕃::AbstractLattice, arg, lb, ub, source, env::Vector{Any}, linfo::MethodInstance)
19491949
argt, argt_exact = instanceof_tfunc(arg)
19501950
lbt, lb_exact = instanceof_tfunc(lb)
19511951
if !lb_exact
@@ -2307,7 +2307,7 @@ function builtin_nothrow(𝕃::AbstractLattice, @nospecialize(f), argtypes::Vect
23072307
end
23082308

23092309
function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any},
2310-
sv::Union{InferenceState,IRCode,Nothing})
2310+
sv::Union{AbsIntState, Nothing})
23112311
𝕃ᵢ = typeinf_lattice(interp)
23122312
if f === tuple
23132313
return tuple_tfunc(𝕃ᵢ, argtypes)
@@ -2478,7 +2478,7 @@ end
24782478
# TODO: this function is a very buggy and poor model of the return_type function
24792479
# since abstract_call_gf_by_type is a very inaccurate model of _method and of typeinf_type,
24802480
# while this assumes that it is an absolutely precise and accurate and exact model of both
2481-
function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::Union{InferenceState, IRCode})
2481+
function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::AbsIntState)
24822482
if length(argtypes) == 3
24832483
tt = widenslotwrapper(argtypes[3])
24842484
if isa(tt, Const) || (isType(tt) && !has_free_typevars(tt))
@@ -2605,7 +2605,7 @@ function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv
26052605
types = rewrap_unionall(Tuple{ft, unwrapped.parameters...}, types)::Type
26062606
end
26072607
mt = ccall(:jl_method_table_for, Any, (Any,), types)
2608-
if !isa(mt, Core.MethodTable)
2608+
if !isa(mt, MethodTable)
26092609
return CallMeta(Bool, EFFECTS_THROWS, NoCallInfo())
26102610
end
26112611
match, valid_worlds, overlayed = findsup(types, method_table(interp))

0 commit comments

Comments
 (0)