From a8876a4cdaaf0a1f051fefad436e08d983b402e2 Mon Sep 17 00:00:00 2001 From: William Carroll Date: Wed, 12 Oct 2022 10:22:58 -0700 Subject: feat(wpcarro/scratch): Implement "Algorithm W" I've been wanting to grok Haskell-style type inference for awhile, so instead of just watching conference talks and reading papers about it, I've decided to attempt to implement it to more readily test my understanding of it. Change-Id: I69261202a3d74d55c6e38763d7ddfec73c392465 Reviewed-on: https://cl.tvl.fyi/c/depot/+/6988 Tested-by: BuildkiteCI Reviewed-by: wpcarro Autosubmit: wpcarro --- users/wpcarro/scratch/compiler/.gitignore | 5 +- users/wpcarro/scratch/compiler/expr_parser.ml | 195 ++++++++++++++++++++++++++ users/wpcarro/scratch/compiler/inference.ml | 177 +++++++++++++++++++++++ users/wpcarro/scratch/compiler/parser.ml | 48 +++++++ users/wpcarro/scratch/compiler/tests.ml | 43 ++++++ users/wpcarro/scratch/compiler/type_parser.ml | 115 +++++++++++++++ users/wpcarro/scratch/compiler/types.ml | 76 ++++++++++ 7 files changed, 658 insertions(+), 1 deletion(-) create mode 100644 users/wpcarro/scratch/compiler/expr_parser.ml create mode 100644 users/wpcarro/scratch/compiler/inference.ml create mode 100644 users/wpcarro/scratch/compiler/parser.ml create mode 100644 users/wpcarro/scratch/compiler/tests.ml create mode 100644 users/wpcarro/scratch/compiler/type_parser.ml create mode 100644 users/wpcarro/scratch/compiler/types.ml (limited to 'users/wpcarro') diff --git a/users/wpcarro/scratch/compiler/.gitignore b/users/wpcarro/scratch/compiler/.gitignore index 25414be2a981..96261d3fc7e9 100644 --- a/users/wpcarro/scratch/compiler/.gitignore +++ b/users/wpcarro/scratch/compiler/.gitignore @@ -1,2 +1,5 @@ a.out -*.cmi \ No newline at end of file +*.cmi +*.cmo +*.cmx +*.o \ No newline at end of file diff --git a/users/wpcarro/scratch/compiler/expr_parser.ml b/users/wpcarro/scratch/compiler/expr_parser.ml new file mode 100644 index 000000000000..7d24efc80de3 --- /dev/null +++ b/users/wpcarro/scratch/compiler/expr_parser.ml @@ -0,0 +1,195 @@ +(******************************************************************************* + * CLI REPL for an s-expression Lambda Calculus. + * + * Lambda Calculus Expression Language: + * + * Helpers: + * symbol -> [-a-z]+ + * boolean -> 'true' | 'false' + * integer -> [1-9][0-9]* + * + * Core: + * expression -> funcdef + * binding -> '(' 'let' symbol expr expr ')' + * funcdef -> '(' 'fn' symbol expr ')' + * funccall -> '(' ( symbol | funcdef) expr ')' + * literal -> boolean | integer + * variable -> symbol + * + * Example Usage: + * $ ocamlopt types.ml str.cmxa inference.ml parser.ml expr_parser.ml && ./a.out + * repl> true + * tokens: [ "true" ] + * ast: ValueLiteral (LiteralBool true) + * Boolean + * repl> + * + ******************************************************************************) + +open Parser +open Inference + +let to_array (q : 'a Queue.t) : 'a array = + let result = Array.make (Queue.length q) "" in + let i = ref 0 in + Queue.iter + (fun x -> + result.(!i) <- x; + i := !i + 1) + q; + result + +type literal = LiteralBool of bool | LiteralInt of int + +let ( let* ) = Option.bind +let map = Option.map + +let tokenize (x : string) : token array = + let q = Queue.create () in + let i = ref 0 in + while !i < String.length x do + match x.[!i] with + | ' ' -> i := !i + 1 + | '(' -> + Queue.push "(" q; + i := !i + 1 + | ')' -> + Queue.push ")" q; + i := !i + 1 + | _ -> + let token = ref "" in + while !i < String.length x && not (String.contains "() " x.[!i]) do + token := !token ^ String.make 1 x.[!i]; + i := !i + 1 + done; + Queue.push !token q + done; + to_array q + +let parse_symbol (p : parser) : string option = + let* x = p#curr in + if Str.string_match (Str.regexp "[-a-z][0-9]*") x 0 then + begin + p#advance; + Some x + end + else + None + +let parse_variable (p : parser) : Types.value option = + let* x = parse_symbol p in + Some (Types.ValueVariable x) + +let parse_literal (p : parser) : Types.value option = + match p#curr with + | Some "true" -> + p#advance; + Some (ValueLiteral (LiteralBool true)) + | Some "false" -> + p#advance; + Some (ValueLiteral (LiteralBool false)) + | Some x -> + (match int_of_string_opt x with + | Some n -> + p#advance; + Some (ValueLiteral (LiteralInt n)) + | _ -> parse_variable p) + | _ -> None + +let rec parse_expression (p : parser) : Types.value option = + parse_binding p + +and parse_funccall (p : parser) : Types.value option = + match (p#curr, p#next) with + | (Some "(", Some "(") -> + p#advance; + let* f = parse_funcdef p in + let* x = parse_expression p in + p#expect ")"; + Some (Types.ValueApplication (f, x)) + | (Some "(", _) -> + p#advance; + let* f = parse_symbol p in + let* x = parse_expression p in + p#expect ")"; + Some (Types.ValueVarApplication (f, x)) + | _ -> parse_literal p + +and parse_funcdef (p : parser) : Types.value option = + match (p#curr, p#next) with + | (Some "(", Some "fn") -> + p#advance; + p#advance; + let* name = parse_symbol p in + let* body = parse_expression p in + p#expect ")"; + Some (Types.ValueFunction (name, body)) + | _ -> parse_funccall p + +and parse_binding (p : parser) : Types.value option = + match (p#curr, p#next) with + | (Some "(", Some "let") -> + p#advance; + p#advance; + let* name = parse_symbol p in + let* value = parse_expression p in + let* body = parse_expression p in + Some (Types.ValueBinder (name, value, body)) + | _ -> parse_funcdef p + +let print_tokens (xs : string array) = + xs |> Array.to_list + |> List.map (Printf.sprintf "\"%s\"") + |> String.concat ", " + |> Printf.sprintf "tokens: [ %s ]" + |> print_string |> print_newline + +let parse_language (x : string) : Types.value option = + let tokens = tokenize x in + print_tokens tokens; + parse_expression (new parser tokens) + +let rec debug (ast : Types.value) : string = + match ast with + | ValueLiteral (LiteralBool x) -> + Printf.sprintf "ValueLiteral (LiteralBool %s)" (string_of_bool x) + | ValueLiteral (LiteralInt x) -> + Printf.sprintf "ValueLiteral (LiteralInt %s)" (string_of_int x) + | ValueVariable x -> + Printf.sprintf "ValueVariable %s" x + | ValueFunction (x, body) -> + Printf.sprintf "ValueFunction (%s, %s)" x (debug body) + | ValueApplication (f, x) -> + Printf.sprintf "ValueApplication (%s, %s)" (debug f) (debug x) + | ValueVarApplication (f, x) -> + Printf.sprintf "ValueVarApplication (%s, %s)" f (debug x) + | ValueBinder (k, v, x) -> + Printf.sprintf "ValueBinder (%s, %s, %s)" k (debug v) (debug x) + +let debug_ast (ast : Types.value) : Types.value = + ast |> debug |> Printf.sprintf "ast: %s" |> print_string |> print_newline; + ast + +let main = + while true do + begin + print_string "repl> "; + let x = read_line () in + match parse_language x with + | Some ast -> + (match ast |> debug_ast |> do_infer with + | None -> + "Type-check failed" + |> print_string + |> print_newline + | Some x -> + x + |> Types.pretty + |> print_string + |> print_newline) + | None -> + "Could not parse" + |> print_string + |> print_newline + end + done diff --git a/users/wpcarro/scratch/compiler/inference.ml b/users/wpcarro/scratch/compiler/inference.ml new file mode 100644 index 000000000000..52838f85e649 --- /dev/null +++ b/users/wpcarro/scratch/compiler/inference.ml @@ -0,0 +1,177 @@ +(******************************************************************************* + * WIP implementation of the Hindley-Milner type system primarily for learning + * purposes. + * + * Wish List: + * - TODO Debug this inference (let f (fn x x) f) + ******************************************************************************) + +open Types + +(******************************************************************************* + * Library + ******************************************************************************) + +let ( let* ) = Option.bind + +let set_from_list (xs : string list) : set = + xs |> List.fold_left (fun acc x -> FromString.add x true acc) FromString.empty + +(* Map union that favors the rhs values (i.e. "last writer wins"). *) +let lww (xs : 'a FromString.t) (ys : 'a FromString.t) : 'a FromString.t = + FromString.union (fun k x y -> Some y) xs ys + +let emptyEnv : env = FromString.empty + +let rec free_type_vars (t : _type) : set = + match t with + | TypeVariable k -> FromString.singleton k true + | TypeInt -> FromString.empty + | TypeBool -> FromString.empty + | TypeArrow (a, b) -> lww (free_type_vars a) (free_type_vars b) + +let i : int ref = ref 0 + +let make_type_var () : _type = + let res = Printf.sprintf "a%d" !i in + i := !i + 1; + TypeVariable res + +exception OccursCheck + +let bind_var (k : string) (t : _type) : substitution = + if t == TypeVariable k then FromString.empty + else if FromString.exists (fun name _ -> name == k) (free_type_vars t) then + raise OccursCheck + else FromString.singleton k t + +let rec instantiate (q : quantified_type) : _type = + let (QuantifiedType (names, t)) = q in + match t with + | TypeInt -> TypeInt + | TypeBool -> TypeBool + | TypeVariable k -> + if List.exists (( == ) k) names then make_type_var () else TypeVariable k + | TypeArrow (a, b) -> + TypeArrow + (instantiate (QuantifiedType (names, a)), instantiate (QuantifiedType (names, b))) + +let quantified_type_ftvs (q : quantified_type) : set = + let (QuantifiedType (names, t)) = q in + lww (free_type_vars t) (names |> set_from_list) + +let generalize (env : env) (t : _type) : quantified_type = + let envftv = + env |> FromString.bindings + |> List.map (fun (_, v) -> quantified_type_ftvs v) + |> List.fold_left lww FromString.empty + in + let names = + lww (free_type_vars t) envftv + |> FromString.bindings + |> List.map (fun (k, _) -> k) + in + QuantifiedType (names, t) + +let rec substitute_type (s : substitution) (t : _type) : _type = + match t with + | TypeVariable k as tvar -> + (match FromString.find_opt k s with + | Some v -> substitute_type s v + | None -> tvar) + | TypeArrow (a, b) -> TypeArrow (substitute_type s a, substitute_type s b) + | TypeInt -> TypeInt + | TypeBool -> TypeBool + +let substitute_quantified_type (s : substitution) (q : quantified_type) : quantified_type = + let (QuantifiedType (names, t)) = q in + let s1 = + FromString.filter (fun k v -> List.exists (fun x -> k != x) names) s + in + QuantifiedType (names, substitute_type s1 t) + +let substitute_env (s : substitution) (env : env) : env = + FromString.map (fun q -> substitute_quantified_type s q) env + +let compose_substitutions (xs : substitution list) : substitution = + let do_compose_substitutions s1 s2 = lww s2 (FromString.map (substitute_type s2) s1) in + List.fold_left do_compose_substitutions FromString.empty xs + +let rec unify (a : _type) (b : _type) : substitution option = + match (a, b) with + | TypeInt, TypeInt -> Some FromString.empty + | TypeBool, TypeBool -> Some FromString.empty + | TypeVariable k, _ -> Some (bind_var k b) + | _, TypeVariable k -> Some (bind_var k a) + | TypeArrow (a, b), TypeArrow (c, d) -> + let* s1 = unify a c in + let* s2 = unify (substitute_type s1 b) (substitute_type s1 d) in + let s3 = compose_substitutions [s1; s2] in + s1 |> Types.debug_substitution |> Printf.sprintf "s1: %s\n" |> print_string; + s2 |> Types.debug_substitution |> Printf.sprintf "s2: %s\n" |> print_string; + s3 |> Types.debug_substitution |> Printf.sprintf "s3: %s\n" |> print_string; + Some s3 + | _ -> None + +let print_env (env : env) = + Printf.sprintf "env: %s\n" (Types.debug_env env) + |> print_string + +let print_val (x : value) = + Printf.sprintf "val: %s\n" (Types.debug_value x) + |> print_string + +let print_inference (x : inference option) = + match x with + | None -> "no inference\n" |> print_string + | Some x -> + Printf.sprintf "inf: %s\n" (Types.debug_inference x) + |> print_string + +let rec infer (env : env) (x : value) : inference option = + print_env env; + print_val x; + let res = match x with + | ValueLiteral lit -> ( + match lit with + | LiteralInt _ -> Some (Inference (FromString.empty, TypeInt)) + | LiteralBool _ -> Some (Inference (FromString.empty, TypeBool))) + | ValueVariable k -> + let* v = FromString.find_opt k env in + Some (Inference (FromString.empty, instantiate v)) + | ValueFunction (param, body) -> + let typevar = make_type_var () in + let env1 = FromString.remove param env in + let env2 = lww (FromString.singleton param (QuantifiedType ([], typevar))) env1 in + let* (Inference (s1, t1)) = infer env2 body in + Some (Inference (s1, TypeArrow (substitute_type s1 typevar, t1))) + | ValueApplication (f, x) -> + let result = make_type_var () in + let* (Inference (s1, t1)) = infer env f in + let* (Inference (s2, t2)) = infer (substitute_env s1 env) x in + let* s3 = unify (substitute_type s2 t1) (TypeArrow (t2, result)) in + Some (Inference + ( compose_substitutions [s3; s2; s1], + substitute_type s3 result )) + | ValueVarApplication (name, x) -> + let* v = FromString.find_opt name env in + let t1 = instantiate v in + let typevar = make_type_var () in + let* (Inference (s2, t2)) = infer env x in + let* s3 = unify (substitute_type s2 t1) (TypeArrow (t2, typevar)) in + Some (Inference + ( compose_substitutions [s2; s3], + substitute_type s3 typevar )) + | ValueBinder (k, v, body) -> + let* (Inference (s1, t1)) = infer env v in + let env1 = FromString.remove k env in + let tg = generalize (substitute_env s1 env) t1 in + let env2 = FromString.add k tg env1 in + let* (Inference (s2, t2)) = infer (substitute_env s1 env2) body in + Some (Inference (compose_substitutions [s1; s2], t2)) in + print_inference res; + res + +let do_infer (x : value) : _type option = + let* Inference (_, t) = infer FromString.empty x in + Some t diff --git a/users/wpcarro/scratch/compiler/parser.ml b/users/wpcarro/scratch/compiler/parser.ml new file mode 100644 index 000000000000..75cbe04a3f72 --- /dev/null +++ b/users/wpcarro/scratch/compiler/parser.ml @@ -0,0 +1,48 @@ +(******************************************************************************* + * Defines a generic parser class. + ******************************************************************************) + +exception ParseError of string + +type token = string +type state = { i : int; tokens : token array } + +let get (i : int) (xs : 'a array) : 'a option = + if i >= Array.length xs then None else Some xs.(i) + +class parser (tokens : token array) = + object (self) + val mutable tokens : token array = tokens + val mutable i = ref 0 + method print_state = Printf.sprintf "{ i = %d; }" !i + method advance = i := !i + 1 + method prev : token option = get (!i - 1) tokens + method curr : token option = get !i tokens + method next : token option = get (!i + 1) tokens + + method consume : token option = + match self#curr with + | None -> None + | Some x as res -> + self#advance; + res + + method expect (x : token) = + match self#curr with + | Some y when x = y -> self#advance + | _ -> raise (ParseError (Printf.sprintf "Expected %s" x)) + + method matches (x : token) : bool = + match self#curr with + | None -> false + | Some y -> + if x = y then + begin + self#advance; + true + end + else false + + method exhausted : bool = !i >= Array.length tokens + method state : state = { i = !i; tokens } + end diff --git a/users/wpcarro/scratch/compiler/tests.ml b/users/wpcarro/scratch/compiler/tests.ml new file mode 100644 index 000000000000..828cbd16f090 --- /dev/null +++ b/users/wpcarro/scratch/compiler/tests.ml @@ -0,0 +1,43 @@ +open Expr_parser +open Type_parser +open Inference + +type test = { input : string; expect : string; } +(* type sub_test = { s1 : string; s2 : string; s3 : string } *) + +let ( let* ) = Option.bind + +let tests = [ + { input = "((fn x x) 10)"; expect = "Integer"; }; + { input = "(let f (fn x x) f)"; expect = "a -> a"; }; +] + +(* let sub_tests = [ *) +(* { *) +(* s1 = "{b |-> b -> Int}"; *) +(* s2 = "{a: Bool, b: Int, c: Bool}"; *) +(* s3 = "{a: Bool, b: Int -> Int, c: Bool}"; *) +(* } *) +(* ] *) + +exception FailedAssertion +exception TestError + +let main = + tests + |> List.iter (fun { input; expect } -> + Printf.sprintf ":t %s == %s\n" input expect |> print_string; + match (parse_language input, parse_input expect) with + | Some ast, Some expected -> + (match do_infer ast with + | Some actual -> + if actual != expected then + begin + print_type actual; + raise FailedAssertion + end + else + print_string "Test passed.\n" + | _ -> raise TestError) + | _ -> raise TestError); + print_string "All tests pass!" diff --git a/users/wpcarro/scratch/compiler/type_parser.ml b/users/wpcarro/scratch/compiler/type_parser.ml new file mode 100644 index 000000000000..9b6c0a309309 --- /dev/null +++ b/users/wpcarro/scratch/compiler/type_parser.ml @@ -0,0 +1,115 @@ +(****************************************************************************** + * Type Expression Language: + * + * Helpers: + * symbol -> [a-z] + * + * Core: + * type -> function + * function -> ( variable | literal ) '->' type + * literal -> 'Integer' | 'Boolean' + * variable -> symbol + ******************************************************************************) + +open Types +open Parser +open Inference + +type side = LHS | RHS + +let ( let* ) = Option.bind + +let printsub (s : substitution) = + FromString.fold (fun k v acc -> Printf.sprintf "%s\"%s\" |-> %s;" acc k (pretty v)) s "" + |> Printf.sprintf "Sub { %s }" + |> print_string + |> print_newline + +let to_array (q : 'a Queue.t) : 'a array = + let result = Array.make (Queue.length q) "" in + let i = ref 0 in + Queue.iter + (fun x -> + result.(!i) <- x; + i := !i + 1) + q; + result + +let tokenize (x : string) : token array = + let q = Queue.create () in + let i = ref 0 in + while !i < String.length x do + match x.[!i] with + | ' ' -> i := !i + 1 + | _ -> + let beg = !i in + while (!i < String.length x) && (x.[!i] != ' ') do + i := !i + 1 + done; + Queue.push (String.sub x beg (!i - beg)) q + done; + to_array q + +let rec parse_type (p : parser) : _type option = + parse_function p +and parse_function (p : parser) : _type option = + match p#next with + | Some "->" -> + let* a = parse_literal p in + p#advance; + let* b = parse_type p in + Some (TypeArrow (a, b)) + | _ -> parse_literal p +and parse_literal (p : parser) : _type option = + match p#curr with + | Some "Integer" | Some "Int" -> p#advance; Some TypeInt + | Some "Boolean" | Some "Bool" -> p#advance; Some TypeBool + | Some _ -> parse_variable p + | None -> None +and parse_variable (p : parser) : _type option = + match p#curr with + | Some x when String.length x = 1 -> p#advance; Some (TypeVariable x) + | _ -> None + +let print_tokens (xs : string array) = + xs + |> Array.to_list + |> List.map (Printf.sprintf "\"%s\"") + |> String.concat ", " + |> Printf.sprintf "tokens: [ %s ]" + |> print_string |> print_newline + +let print_type (t : _type) = + t |> pretty |> Printf.sprintf "type: %s" |> print_string |> print_newline + +let parse_input (x : string) : _type option = + let tokens = tokenize x in + print_tokens tokens; + parse_type (new parser tokens) + +(* Continually prompt until user provides a parseable type expression *) +let rec read_type (arg : side) : _type = + let prompt = match arg with + | LHS -> "lhs> " + | RHS -> "rhs> " in + print_string prompt; + let x = read_line () in + match parse_input x with + | None -> + print_string "Failed to parse input.\n"; + read_type arg + | Some ast -> + print_type ast; + ast + +let main = + while true do + begin + let lhs = read_type LHS in + let rhs = read_type RHS in + match unify lhs rhs with + | None -> + Printf.printf "Cannot unify \"%s\" with \"%s\"\n" (pretty lhs) (pretty rhs) + | Some x -> printsub x + end + done diff --git a/users/wpcarro/scratch/compiler/types.ml b/users/wpcarro/scratch/compiler/types.ml new file mode 100644 index 000000000000..99f35cd5a57d --- /dev/null +++ b/users/wpcarro/scratch/compiler/types.ml @@ -0,0 +1,76 @@ +type literal = LiteralInt of int | LiteralBool of bool + +(* Lambda Calculus definition *) +type value = + | ValueLiteral of literal + | ValueVariable of string + | ValueFunction of string * value + | ValueApplication of value * value + | ValueVarApplication of string * value + | ValueBinder of string * value * value + +let rec debug_value (x : value) : string = + match x with + | ValueLiteral (LiteralInt x) -> + Printf.sprintf "Int %d" x + | ValueLiteral (LiteralBool x) -> + Printf.sprintf "Bool %b" x + | ValueVariable x -> + Printf.sprintf "Var %s" x + | ValueFunction (name, x) -> + Printf.sprintf "Fn %s %s" name (debug_value x) + | ValueApplication (f, x) -> + Printf.sprintf "App %s %s" (debug_value f) (debug_value x) + | ValueVarApplication (name, x) -> + Printf.sprintf "App %s %s" name (debug_value x) + | ValueBinder (name, x, body) -> + Printf.sprintf "Bind %s %s %s" name (debug_value x) (debug_value body) + +module FromString = Map.Make (String) + +type _type = + | TypeInt + | TypeBool + | TypeVariable of string + | TypeArrow of _type * _type + +let rec debug_type (t : _type) : string = + match t with + | TypeInt -> "Integer" + | TypeBool -> "Boolean" + | TypeVariable k -> Printf.sprintf "%s" k + | TypeArrow (a, b) -> Printf.sprintf "%s -> %s" (debug_type a) (debug_type b) + +type quantified_type = QuantifiedType of string list * _type + +let debug_quantified_type (q : quantified_type) : string = + let QuantifiedType (vars, t) = q in + if List.length vars == 0 then + Printf.sprintf "%s" (debug_type t) + else + Printf.sprintf "forall %s. %s" (String.concat "," vars) (debug_type t) + +type set = bool FromString.t +type substitution = _type FromString.t + +let debug_substitution (s : substitution) : string = + FromString.fold (fun k v acc -> Printf.sprintf "%s\"%s\" |-> %s;" acc k (debug_type v)) s "" + |> Printf.sprintf "{ %s }" + +type env = quantified_type FromString.t + +let debug_env (s : env) : string = + FromString.fold (fun k v acc -> Printf.sprintf "%s\"%s\" |-> %s;" acc k (debug_quantified_type v)) s "" + |> Printf.sprintf "{ %s }" + +type inference = Inference of substitution * _type + +let debug_inference (Inference (s, t)) = + Printf.sprintf "type: %s; sub: %s" (debug_type t) (debug_substitution s) + +let rec pretty (t : _type) : string = + match t with + | TypeInt -> "Integer" + | TypeBool -> "Boolean" + | TypeVariable k -> Printf.sprintf "%s" k + | TypeArrow (a, b) -> Printf.sprintf "%s -> %s" (pretty a) (pretty b) -- cgit 1.4.1