Skip to content

Commit 1b5ec83

Browse files
committed
Track modes in Lambda.lfunction and onwards (ocaml-flambda#33)
Extends Lambda.function_kind to track modes of curried functions. The supported modes have a heap-returning prefix, followed by a local-returning suffix, both of which may be empty. Functions which do not fit this pattern cannot be expressed as a single Lfunction, and must be explicitly curried in Lambda and Clambda. (See the new Lambda.check_lfunction for the exact invariants) Most of the changes are straightforward plumbing of the new info. Some delicate changes are needed to preserve the new invariants, in: - simplif.ml: optimisations which combine Lfunctions - translcore.ml: transl_curried_function and friends - closure.ml: optimisations for partial and over-application - cmm_helpers.ml: mode-aware variants of caml_curry Caveat: build_apply in translcore is wrong
1 parent f1e2e97 commit 1b5ec83

31 files changed

+523
-235
lines changed

asmcomp/cmm_helpers.ml

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ let black_block_header tag sz = Nativeint.logor (block_header tag sz) caml_black
5555
let local_block_header tag sz = Nativeint.logor (block_header tag sz) caml_local
5656
let white_closure_header sz = block_header Obj.closure_tag sz
5757
let black_closure_header sz = black_block_header Obj.closure_tag sz
58+
let local_closure_header sz = local_block_header Obj.closure_tag sz
5859
let infix_header ofs = block_header Obj.infix_tag ofs
5960
let float_header = block_header Obj.double_tag (size_float / size_addr)
6061
let floatarray_header len =
@@ -76,6 +77,11 @@ let pos_arity_in_closinfo = 8 * size_addr - 8
7677
(* arity = the top 8 bits of the closinfo word *)
7778

7879
let closure_info ~arity ~startenv =
80+
let arity =
81+
match arity with
82+
| Lambda.Tupled, n -> -n
83+
| Lambda.Curried _, n -> n
84+
in
7985
assert (-128 <= arity && arity <= 127);
8086
assert (0 <= startenv && startenv < 1 lsl (pos_arity_in_closinfo - 1));
8187
Nativeint.(add (shift_left (of_int arity) pos_arity_in_closinfo)
@@ -84,7 +90,10 @@ let closure_info ~arity ~startenv =
8490

8591
let alloc_float_header dbg = Cconst_natint (float_header, dbg)
8692
let alloc_floatarray_header len dbg = Cconst_natint (floatarray_header len, dbg)
87-
let alloc_closure_header sz dbg = Cconst_natint (white_closure_header sz, dbg)
93+
let alloc_closure_header ~mode sz dbg =
94+
match (mode : Lambda.alloc_mode) with
95+
| Alloc_heap -> Cconst_natint (white_closure_header sz, dbg)
96+
| Alloc_local -> Cconst_natint (local_closure_header sz, dbg)
8897
let alloc_infix_header ofs dbg = Cconst_natint (infix_header ofs, dbg)
8998
let alloc_closure_info ~arity ~startenv dbg =
9099
Cconst_natint (closure_info ~arity ~startenv, dbg)
@@ -838,12 +847,16 @@ let make_checkbound dbg = function
838847
(* Record application and currying functions *)
839848

840849
let apply_function_sym n =
850+
assert (n > 0);
841851
Compilenv.need_apply_fun n; "caml_apply" ^ Int.to_string n
842-
let curry_function_sym n =
843-
Compilenv.need_curry_fun n;
844-
if n >= 0
845-
then "caml_curry" ^ Int.to_string n
846-
else "caml_tuplify" ^ Int.to_string (-n)
852+
let curry_function_sym ar =
853+
Compilenv.need_curry_fun ar;
854+
match ar with
855+
| Lambda.Curried {nlocal}, n ->
856+
"caml_curry" ^ Int.to_string n ^
857+
(if nlocal > 0 then "L" ^ Int.to_string nlocal else "")
858+
| Lambda.Tupled, n ->
859+
"caml_tuplify" ^ Int.to_string n
847860

848861
(* Big arrays *)
849862

@@ -1969,7 +1982,7 @@ let tuplify_function arity =
19691982
*)
19701983

19711984
let max_arity_optimized = 15
1972-
let final_curry_function arity =
1985+
let final_curry_function ~nlocal ~arity =
19731986
let dbg = placeholder_dbg in
19741987
let last_arg = V.create_local "arg" in
19751988
let last_clos = V.create_local "clos" in
@@ -1998,7 +2011,9 @@ let final_curry_function arity =
19982011
newclos (n-1))
19992012
end in
20002013
let fun_name =
2001-
"caml_curry" ^ Int.to_string arity ^ "_" ^ Int.to_string (arity-1)
2014+
"caml_curry" ^ Int.to_string arity
2015+
^ (if nlocal > 0 then "L" ^ Int.to_string nlocal else "")
2016+
^ "_" ^ Int.to_string (arity-1)
20022017
in
20032018
let fun_dbg = placeholder_fun_dbg ~human_name:fun_name in
20042019
Cfunction
@@ -2009,34 +2024,38 @@ let final_curry_function arity =
20092024
fun_dbg;
20102025
}
20112026

2012-
let rec intermediate_curry_functions arity num =
2027+
let rec intermediate_curry_functions ~nlocal ~arity num =
20132028
let dbg = placeholder_dbg in
20142029
if num = arity - 1 then
2015-
[final_curry_function arity]
2030+
[final_curry_function ~nlocal ~arity]
20162031
else begin
2017-
let name1 = "caml_curry" ^ Int.to_string arity in
2032+
let name1 = "caml_curry" ^ Int.to_string arity
2033+
^ (if nlocal > 0 then "L" ^ Int.to_string nlocal else "") in
20182034
let name2 = if num = 0 then name1 else name1 ^ "_" ^ Int.to_string num in
20192035
let arg = V.create_local "arg" and clos = V.create_local "clos" in
20202036
let fun_dbg = placeholder_fun_dbg ~human_name:name2 in
2037+
let mode : Lambda.alloc_mode =
2038+
if num >= arity - nlocal then Alloc_local else Alloc_heap in
2039+
let curried n : Clambda.arity = (Curried {nlocal=min nlocal n}, n) in
20212040
Cfunction
20222041
{fun_name = name2;
20232042
fun_args = [VP.create arg, typ_val; VP.create clos, typ_val];
20242043
fun_body =
20252044
if arity - num > 2 && arity <= max_arity_optimized then
2026-
Cop(Calloc Alloc_heap,
2027-
[alloc_closure_header 5 (dbg ());
2045+
Cop(Calloc mode,
2046+
[alloc_closure_header ~mode 5 (dbg ());
20282047
Cconst_symbol(name1 ^ "_" ^ Int.to_string (num+1), dbg ());
2029-
alloc_closure_info ~arity:(arity - num - 1)
2048+
alloc_closure_info ~arity:(curried (arity - num - 1))
20302049
~startenv:3 (dbg ());
20312050
Cconst_symbol(name1 ^ "_" ^ Int.to_string (num+1) ^ "_app",
20322051
dbg ());
20332052
Cvar arg; Cvar clos],
20342053
dbg ())
20352054
else
2036-
Cop(Calloc Alloc_heap,
2037-
[alloc_closure_header 4 (dbg ());
2055+
Cop(Calloc mode,
2056+
[alloc_closure_header ~mode 4 (dbg ());
20382057
Cconst_symbol(name1 ^ "_" ^ Int.to_string (num+1), dbg ());
2039-
alloc_closure_info ~arity:1 ~startenv:2 (dbg ());
2058+
alloc_closure_info ~arity:(curried 1) ~startenv:2 (dbg ());
20402059
Cvar arg; Cvar clos],
20412060
dbg ());
20422061
fun_codegen_options = [];
@@ -2082,19 +2101,21 @@ let rec intermediate_curry_functions arity num =
20822101
fun_dbg;
20832102
}
20842103
in
2085-
cf :: intermediate_curry_functions arity (num+1)
2104+
cf :: intermediate_curry_functions ~nlocal ~arity (num+1)
20862105
else
2087-
intermediate_curry_functions arity (num+1))
2106+
intermediate_curry_functions ~nlocal ~arity (num+1))
20882107
end
20892108

2090-
let curry_function arity =
2091-
assert(arity <> 0);
2092-
(* Functions with arity = 0 does not have a curry_function *)
2093-
if arity > 0
2094-
then intermediate_curry_functions arity 0
2095-
else [tuplify_function (-arity)]
2109+
let curry_function = function
2110+
| Lambda.Tupled, n ->
2111+
assert (n > 0); [tuplify_function n]
2112+
| Lambda.Curried {nlocal}, n ->
2113+
assert (n > 0);
2114+
intermediate_curry_functions ~nlocal ~arity:n 0
20962115

20972116
module Int = Numbers.Int
2117+
module AritySet =
2118+
Set.Make (struct type t = Clambda.arity let compare = compare end)
20982119

20992120
let default_apply = Int.Set.add 2 (Int.Set.add 3 Int.Set.empty)
21002121
(* These apply funs are always present in the main program because
@@ -2106,13 +2127,13 @@ let generic_functions shared units =
21062127
(fun (apply,send,curry) (ui : Cmx_format.unit_infos) ->
21072128
List.fold_right Int.Set.add ui.ui_apply_fun apply,
21082129
List.fold_right Int.Set.add ui.ui_send_fun send,
2109-
List.fold_right Int.Set.add ui.ui_curry_fun curry)
2110-
(Int.Set.empty,Int.Set.empty,Int.Set.empty)
2130+
List.fold_right AritySet.add ui.ui_curry_fun curry)
2131+
(Int.Set.empty,Int.Set.empty,AritySet.empty)
21112132
units in
21122133
let apply = if shared then apply else Int.Set.union apply default_apply in
21132134
let accu = Int.Set.fold (fun n accu -> apply_function n :: accu) apply [] in
21142135
let accu = Int.Set.fold (fun n accu -> send_function n :: accu) send accu in
2115-
Int.Set.fold (fun n accu -> curry_function n @ accu) curry accu
2136+
AritySet.fold (fun arity accu -> curry_function arity @ accu) curry accu
21162137

21172138
(* Primitives *)
21182139

@@ -2713,7 +2734,7 @@ let fundecls_size fundecls =
27132734
(fun (f : Clambda.ufunction) ->
27142735
let indirect_call_code_pointer_size =
27152736
match f.arity with
2716-
| 0 | 1 -> 0
2737+
| Curried _, (0 | 1) -> 0
27172738
(* arity 1 does not need an indirect call handler.
27182739
arity 0 cannot be indirect called *)
27192740
| _ -> 1
@@ -2746,30 +2767,32 @@ let emit_constant_closure ((_, global_symb) as symb) fundecls clos_vars cont =
27462767
let rec emit_others pos = function
27472768
[] -> clos_vars @ cont
27482769
| (f2 : Clambda.ufunction) :: rem ->
2749-
if f2.arity = 1 || f2.arity = 0 then
2770+
match f2.arity with
2771+
| Curried _, (0|1) as arity ->
27502772
Cint(infix_header pos) ::
27512773
(closure_symbol f2) @
27522774
Csymbol_address f2.label ::
2753-
Cint(closure_info ~arity:f2.arity ~startenv:(startenv - pos)) ::
2775+
Cint(closure_info ~arity ~startenv:(startenv - pos)) ::
27542776
emit_others (pos + 3) rem
2755-
else
2777+
| arity ->
27562778
Cint(infix_header pos) ::
27572779
(closure_symbol f2) @
27582780
Csymbol_address(curry_function_sym f2.arity) ::
2759-
Cint(closure_info ~arity:f2.arity ~startenv:(startenv - pos)) ::
2781+
Cint(closure_info ~arity ~startenv:(startenv - pos)) ::
27602782
Csymbol_address f2.label ::
27612783
emit_others (pos + 4) rem in
27622784
Cint(black_closure_header (fundecls_size fundecls
27632785
+ List.length clos_vars)) ::
27642786
cdefine_symbol symb @
27652787
(closure_symbol f1) @
2766-
if f1.arity = 1 || f1.arity = 0 then
2788+
match f1.arity with
2789+
| Curried _, (0|1) as arity ->
27672790
Csymbol_address f1.label ::
2768-
Cint(closure_info ~arity:f1.arity ~startenv) ::
2791+
Cint(closure_info ~arity ~startenv) ::
27692792
emit_others 3 remainder
2770-
else
2793+
| arity ->
27712794
Csymbol_address(curry_function_sym f1.arity) ::
2772-
Cint(closure_info ~arity:f1.arity ~startenv) ::
2795+
Cint(closure_info ~arity ~startenv) ::
27732796
Csymbol_address f1.label ::
27742797
emit_others 4 remainder
27752798

asmcomp/cmm_helpers.mli

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,18 @@ val boxedint64_header : nativeint
6565
val boxedintnat_header : nativeint
6666

6767
(** Closure info for a closure of given arity and distance to environment *)
68-
val closure_info : arity:int -> startenv:int -> nativeint
68+
val closure_info : arity:Clambda.arity -> startenv:int -> nativeint
6969

7070
(** Wrappers *)
71+
(* FIXME: these all need mode params *)
7172
val alloc_float_header : Debuginfo.t -> expression
7273
val alloc_floatarray_header : int -> Debuginfo.t -> expression
73-
val alloc_closure_header : int -> Debuginfo.t -> expression
74+
val alloc_closure_header :
75+
mode:Lambda.alloc_mode -> int -> Debuginfo.t -> expression
7476
val alloc_infix_header : int -> Debuginfo.t -> expression
7577
val alloc_closure_info :
76-
arity:int -> startenv:int -> Debuginfo.t -> expression
78+
arity:(Lambda.function_kind * int) -> startenv:int ->
79+
Debuginfo.t -> expression
7780
val alloc_boxedint32_header : Debuginfo.t -> expression
7881
val alloc_boxedint64_header : Debuginfo.t -> expression
7982
val alloc_boxedintnat_header : Debuginfo.t -> expression
@@ -326,10 +329,9 @@ val check_bound :
326329
ensure its presence in the set of defined symbols *)
327330
val apply_function_sym : int -> string
328331

329-
(** If [n] is positive, get the symbol for the generic currying wrapper with
330-
[n] arguments, and ensure its presence in the set of defined symbols.
331-
Otherwise, do the same for the generic tuple wrapper with [-n] arguments. *)
332-
val curry_function_sym : int -> string
332+
(** Get the symbol for the generic currying or tuplifying wrapper with
333+
[n] arguments, and ensure its presence in the set of defined symbols. *)
334+
val curry_function_sym : Clambda.arity -> string
333335

334336
(** Bigarrays *)
335337

asmcomp/cmmgen.ml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,14 +394,15 @@ let rec transl env e =
394394
Cmmgen_state.add_function f;
395395
let dbg = f.dbg in
396396
let without_header =
397-
if f.arity = 1 || f.arity = 0 then
397+
match f.arity with
398+
| Curried _, (1|0) as arity ->
398399
Cconst_symbol (f.label, dbg) ::
399-
alloc_closure_info ~arity:f.arity
400+
alloc_closure_info ~arity
400401
~startenv:(startenv - pos) dbg ::
401402
transl_fundecls (pos + 3) rem
402-
else
403+
| arity ->
403404
Cconst_symbol (curry_function_sym f.arity, dbg) ::
404-
alloc_closure_info ~arity:f.arity
405+
alloc_closure_info ~arity
405406
~startenv:(startenv - pos) dbg ::
406407
Cconst_symbol (f.label, dbg) ::
407408
transl_fundecls (pos + 4) rem

file_formats/cmx_format.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type unit_infos =
4141
mutable ui_defines: string list; (* Unit and sub-units implemented *)
4242
mutable ui_imports_cmi: crcs; (* Interfaces imported *)
4343
mutable ui_imports_cmx: crcs; (* Infos imported *)
44-
mutable ui_curry_fun: int list; (* Currying functions needed *)
44+
mutable ui_curry_fun: Clambda.arity list; (* Currying functions needed *)
4545
mutable ui_apply_fun: int list; (* Apply functions needed *)
4646
mutable ui_send_fun: int list; (* Send functions needed *)
4747
mutable ui_export_info: export_info;

lambda/lambda.ml

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ type local_attribute =
261261
| Never_local (* [@local never] *)
262262
| Default_local (* [@local maybe] or no [@local] attribute *)
263263

264-
type function_kind = Curried | Tupled
264+
type function_kind = Curried of {nlocal: int} | Tupled
265265

266266
type let_kind = Strict | Alias | StrictOpt | Variable
267267

@@ -317,7 +317,8 @@ and lfunction =
317317
body: lambda;
318318
attr: function_attribute; (* specified with [@inline] attribute *)
319319
loc: scoped_location;
320-
mode: alloc_mode }
320+
mode: alloc_mode;
321+
ret_mode: alloc_mode; }
321322

322323
and lambda_apply =
323324
{ ap_func : lambda;
@@ -359,6 +360,29 @@ let const_unit = const_int 0
359360

360361
let lambda_unit = Lconst const_unit
361362

363+
let check_lfunction fn =
364+
(* A curried function type with n parameters has n arrows. Of these,
365+
the first [n-nlocal] have return mode Heap, while the remainder
366+
have return mode Local, except possibly the final one.
367+
368+
That is, after supplying the first [n-nlocal] arguments, further
369+
partial applications must be locally allocated.
370+
371+
A curried function with no local parameters or returns has kind
372+
[Curried {nlocal=0}]. *)
373+
let nparams = List.length fn.params in
374+
begin match fn.mode, fn.ret_mode, fn.kind with
375+
| Alloc_heap, _, Tupled -> ()
376+
| Alloc_local, _, Tupled ->
377+
(* Tupled optimisation does not apply to local functions *)
378+
assert false
379+
| mode, ret_mode, Curried {nlocal} ->
380+
assert (0 <= nlocal);
381+
assert (nlocal <= nparams);
382+
if ret_mode = Alloc_local then assert (nlocal >= 1);
383+
if mode = Alloc_local then assert (nlocal = nparams)
384+
end
385+
362386
let default_function_attribute = {
363387
inline = Default_inline;
364388
specialise = Default_specialise;
@@ -842,8 +866,9 @@ let shallow_map f = function
842866
ap_inlined;
843867
ap_specialised;
844868
}
845-
| Lfunction { kind; params; return; body; attr; loc; mode } ->
846-
Lfunction { kind; params; return; body = f body; attr; loc; mode }
869+
| Lfunction { kind; params; return; body; attr; loc; mode; ret_mode } ->
870+
Lfunction { kind; params; return; body = f body; attr; loc;
871+
mode; ret_mode }
847872
| Llet (str, k, v, e1, e2) ->
848873
Llet (str, k, v, f e1, f e2)
849874
| Lletrec (idel, e2) ->
@@ -957,11 +982,6 @@ let merge_inline_attributes attr1 attr2 =
957982
if attr1 = attr2 then Some attr1
958983
else None
959984

960-
let function_is_curried func =
961-
match func.kind with
962-
| Curried -> true
963-
| Tupled -> false
964-
965985
let max_arity () =
966986
if !Clflags.native_code then 126 else max_int
967987
(* 126 = 127 (the maximal number of parameters supported in C--)

lambda/lambda.mli

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,10 @@ type local_attribute =
236236
| Never_local (* [@local never] *)
237237
| Default_local (* [@local maybe] or no [@local] attribute *)
238238

239-
type function_kind = Curried | Tupled
239+
type function_kind = Curried of {nlocal: int} | Tupled
240+
(* [nlocal] determines how many arguments may be partially applied
241+
before the resulting closure must be locally allocated.
242+
See [check_lfunction] for details *)
240243

241244
type let_kind = Strict | Alias | StrictOpt | Variable
242245
(* Meaning of kinds for let x = e in e':
@@ -301,7 +304,8 @@ and lfunction =
301304
body: lambda;
302305
attr: function_attribute; (* specified with [@inline] attribute *)
303306
loc : scoped_location;
304-
mode : alloc_mode }
307+
mode : alloc_mode;
308+
ret_mode : alloc_mode }
305309

306310
and lambda_apply =
307311
{ ap_func : lambda;
@@ -354,6 +358,7 @@ val make_key: lambda -> lambda option
354358
val const_unit: structured_constant
355359
val const_int : int -> structured_constant
356360
val lambda_unit: lambda
361+
val check_lfunction : lfunction -> unit
357362
val name_lambda: let_kind -> lambda -> (Ident.t -> lambda) -> lambda
358363
val name_lambda_list: lambda list -> (lambda list -> lambda) -> lambda
359364

@@ -431,8 +436,6 @@ val swap_float_comparison : float_comparison -> float_comparison
431436
val default_function_attribute : function_attribute
432437
val default_stub_attribute : function_attribute
433438

434-
val function_is_curried : lfunction -> bool
435-
436439
val max_arity : unit -> int
437440
(** Maximal number of parameters for a function, or in other words,
438441
maximal length of the [params] list of a [lfunction] record.

0 commit comments

Comments
 (0)