Skip to content

Commit 91f2281

Browse files
committed
Hook mode variable solving into Btype.snapshot/backtrack
1 parent 54e4b09 commit 91f2281

File tree

2 files changed

+100
-24
lines changed

2 files changed

+100
-24
lines changed

testsuite/tests/typing-local/local.ml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,3 +1774,40 @@ Line 1, characters 10-11:
17741774
^
17751775
Error: This value escapes its region
17761776
|}]
1777+
1778+
(* Example of backtracking after mode error *)
1779+
let f g n =
1780+
let a = local_ [n+1] in
1781+
let () = g a in
1782+
()
1783+
let z : (int list -> unit) -> int -> unit = f
1784+
[%%expect{|
1785+
val f : (local_ int list -> unit) -> int -> unit = <fun>
1786+
Line 5, characters 44-45:
1787+
5 | let z : (int list -> unit) -> int -> unit = f
1788+
^
1789+
Error: This expression has type (local_ int list -> unit) -> int -> unit
1790+
but an expression was expected of type
1791+
(int list -> unit) -> int -> unit
1792+
Type local_ int list -> unit is not compatible with type
1793+
int list -> unit
1794+
|}]
1795+
1796+
module M = struct
1797+
let f g n =
1798+
let a = local_ [n+1] in
1799+
let () = g a in
1800+
()
1801+
let z : (int list -> unit) -> int -> unit = f
1802+
end
1803+
[%%expect{|
1804+
Line 6, characters 46-47:
1805+
6 | let z : (int list -> unit) -> int -> unit = f
1806+
^
1807+
Error: This expression has type
1808+
(local_ int list -> local_ unit) -> int -> unit
1809+
but an expression was expected of type
1810+
(int list -> unit) -> int -> unit
1811+
Type local_ int list -> local_ unit is not compatible with type
1812+
int list -> unit
1813+
|}]

typing/btype.ml

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ type change =
7878
| Ccommu of commutable ref * commutable
7979
| Cuniv of type_expr option ref * type_expr option
8080
| Ctypeset of TypeSet.t ref * TypeSet.t
81+
| Cmode_upper of alloc_mode_var * alloc_mode_const
82+
| Cmode_lower of alloc_mode_var * alloc_mode_const
83+
| Cmode_vlower of alloc_mode_var * alloc_mode_var list
8184

8285
type changes =
8386
Change of change * changes ref
@@ -93,6 +96,19 @@ let log_change ch =
9396
r := Change (ch, r');
9497
Weak.set !trail 0 (Some r')
9598

99+
let log_changes chead ctail =
100+
if chead = Unchanged then (assert (!ctail = Unchanged))
101+
else match Weak.get !trail 0 with None -> ()
102+
| Some r ->
103+
r := chead;
104+
Weak.set !trail 0 (Some ctail)
105+
106+
let append_change ctail ch =
107+
assert (!(!ctail) = Unchanged);
108+
let r' = ref Unchanged in
109+
(!ctail) := Change (ch, r');
110+
ctail := r'
111+
96112
(**** Representative of a type ****)
97113

98114
let rec field_kind_repr =
@@ -718,6 +734,9 @@ let undo_change = function
718734
| Ccommu (r, v) -> r := v
719735
| Cuniv (r, v) -> r := v
720736
| Ctypeset (r, v) -> r := v
737+
| Cmode_upper (v, u) -> v.upper <- u
738+
| Cmode_lower (v, l) -> v.lower <- l
739+
| Cmode_vlower (v, vs) -> v.vlower <- vs
721740

722741
type snapshot = changes ref * int
723742
let last_snapshot = s_ref 0
@@ -864,58 +883,78 @@ module Alloc_mode = struct
864883
else Printf.fprintf ppf "v%d" i);
865884
Printf.fprintf ppf "[%a%a]" pp_c v.lower pp_c v.upper
866885
*)
867-
let submode_cv m v =
886+
887+
let set_lower ~log v lower =
888+
append_change log (Cmode_lower (v, v.lower));
889+
v.lower <- lower
890+
891+
let set_upper ~log v upper =
892+
append_change log (Cmode_upper (v, v.upper));
893+
v.upper <- upper
894+
895+
let set_vlower ~log v vlower =
896+
append_change log (Cmode_vlower (v, v.vlower));
897+
v.vlower <- vlower
898+
899+
let submode_cv ~log m v =
868900
(* Printf.printf " %a <= %a\n" pp_c m pp_v v; *)
869901
if le_const m v.lower then ()
870902
else if not (le_const m v.upper) then raise NotSubmode
871903
else begin
872904
let m = join_const v.lower m in
873-
v.lower <- m;
874-
if m = v.upper then v.vlower <- []
905+
set_lower ~log v m;
906+
if m = v.upper then set_vlower ~log v []
875907
end
876908

877-
let rec submode_vc v m =
909+
let rec submode_vc ~log v m =
878910
(* Printf.printf " %a <= %a\n" pp_v v pp_c m; *)
879911
if le_const v.upper m then ()
880912
else if not (le_const v.lower m) then raise NotSubmode
881913
else begin
882914
let m = meet_const v.upper m in
883-
v.upper <- m;
915+
set_upper ~log v m;
884916
v.vlower |> List.iter (fun a ->
885917
(* a <= v <= m *)
886-
submode_vc a m;
887-
v.lower <- join_const v.lower a.lower;
918+
submode_vc ~log a m;
919+
set_lower ~log v (join_const v.lower a.lower);
888920
);
889-
if v.lower = m then v.vlower <- []
921+
if v.lower = m then set_vlower ~log v []
890922
end
891923

892-
let submode_vv a b =
924+
let submode_vv ~log a b =
893925
(* Printf.printf " %a <= %a\n" pp_v a pp_v b; *)
894926
if le_const a.upper b.lower then ()
895927
else if List.memq a b.vlower then ()
896928
else begin
897-
submode_vc a b.upper;
898-
b.vlower <- a :: b.vlower;
899-
submode_cv a.lower b;
929+
submode_vc ~log a b.upper;
930+
set_vlower ~log b (a :: b.vlower);
931+
submode_cv ~log a.lower b;
900932
end
901933

902934
let submode a b =
935+
let log_head = ref Unchanged in
936+
let log = ref log_head in
903937
match
904938
match a, b with
905939
| Amode a, Amode b ->
906940
if not (le_const a b) then raise NotSubmode
907941
| Amodevar v, Amode c ->
908942
(* Printf.printf "%a <= %a\n" pp_v v pp_c c; *)
909-
submode_vc v c
943+
submode_vc ~log v c
910944
| Amode c, Amodevar v ->
911945
(* Printf.printf "%a <= %a\n" pp_c c pp_v v; *)
912-
submode_cv c v
946+
submode_cv ~log c v
913947
| Amodevar a, Amodevar b ->
914948
(* Printf.printf "%a <= %a\n" pp_v a pp_v b; *)
915-
submode_vv a b
949+
submode_vv ~log a b
916950
with
917-
| () -> Ok ()
918-
| exception NotSubmode -> Error ()
951+
| () ->
952+
log_changes !log_head !log;
953+
Ok ()
954+
| exception NotSubmode ->
955+
let backlog = rev_log [] !log_head in
956+
List.iter undo_change backlog;
957+
Error ()
919958

920959
let submode_exn t1 t2 =
921960
match submode t1 t2 with
@@ -946,7 +985,7 @@ module Alloc_mode = struct
946985
if all_equal v rest then v
947986
else begin
948987
let v = fresh () in
949-
List.iter (fun v' -> submode_vv v' v) vars;
988+
List.iter (fun v' -> submode_exn (Amodevar v') (Amodevar v)) vars;
950989
v
951990
end
952991
in
@@ -963,7 +1002,7 @@ module Alloc_mode = struct
9631002
let constrain_upper = function
9641003
| Amode m -> m
9651004
| Amodevar v ->
966-
submode_cv v.upper v;
1005+
submode_exn (Amode v.upper) (Amodevar v);
9671006
v.upper
9681007

9691008
let compress_vlower v =
@@ -977,7 +1016,7 @@ module Alloc_mode = struct
9771016
trans_low v'
9781017
end
9791018
and trans_low v' =
980-
submode_cv v'.lower v;
1019+
submode_exn (Amode v'.lower) (Amodevar v);
9811020
List.iter trans v'.vlower
9821021
in
9831022
List.iter trans_low v.vlower
@@ -986,16 +1025,16 @@ module Alloc_mode = struct
9861025
| Amode m -> m
9871026
| Amodevar v ->
9881027
compress_vlower v;
989-
submode_vc v v.lower;
1028+
submode_exn (Amodevar v) (Amode v.lower);
9901029
v.lower
9911030

9921031
let newvar () = Amodevar (fresh ())
9931032

9941033
let check_const = function
9951034
| Amode m -> Some m
996-
| Amodevar v when v.lower = v.upper ->
997-
Some v.lower
998-
| Amodevar _ -> None
1035+
| Amodevar v ->
1036+
compress_vlower v;
1037+
if v.lower = v.upper then Some v.lower else None
9991038

10001039
let print_const ppf = function
10011040
| Global -> Format.fprintf ppf "Global"

0 commit comments

Comments
 (0)