diff --git a/compiler/src/diagnostics/modules.re b/compiler/src/diagnostics/modules.re index 052f8a25e5..344764a7b0 100644 --- a/compiler/src/diagnostics/modules.re +++ b/compiler/src/diagnostics/modules.re @@ -1,16 +1,17 @@ open Grain_typed; -type export_kind = +type provide_kind = | Function | Value | Record | Enum | Abstract - | Exception; + | Exception + | Module; -type export = { +type provide = { name: string, - kind: export_kind, + kind: provide_kind, signature: string, }; diff --git a/compiler/src/language_server/completion.re b/compiler/src/language_server/completion.re new file mode 100644 index 0000000000..8d19486c8b --- /dev/null +++ b/compiler/src/language_server/completion.re @@ -0,0 +1,756 @@ +open Grain_utils; +open Grain_typed; +open Grain_diagnostics; +open Grain_parsing; +open Sourcetree; + +// This is the full enumeration of all CompletionItemKind as declared by the language server +// protocol (https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#completionItemKind), +// but not all will be used by Grain LSP +[@deriving (enum, yojson)] +type completion_item_kind = + // Since these are using ppx_deriving enum, order matters + | [@value 1] CompletionItemKindText + | CompletionItemKindMethod + | CompletionItemKindFunction + | CompletionItemKindConstructor + | CompletionItemKindField + | CompletionItemKindVariable + | CompletionItemKindClass + | CompletionItemKindInterface + | CompletionItemKindModule + | CompletionItemKindProperty + | CompletionItemKindUnit + | CompletionItemKindValue + | CompletionItemKindEnum + | CompletionItemKindKeyword + | CompletionItemKindSnippet + | CompletionItemKindColor + | CompletionItemKindFile + | CompletionItemKindReference + | CompletionItemKindFolder + | CompletionItemKindEnumMember + | CompletionItemKindConstant + | CompletionItemKindStruct + | CompletionItemKindEvent + | CompletionItemKindOperator + | CompletionItemKindTypeParameter; + +[@deriving (enum, yojson)] +type completion_trigger_kind = + // Since these are using ppx_deriving enum, order matters + | [@value 1] CompletionTriggerInvoke + | CompletionTriggerCharacter + | CompletionTriggerForIncompleteCompletions; + +[@deriving (enum, yojson)] +type insert_text_format = + // Since these are using ppx_deriving enum, order matters + | [@value 1] InsertTextFormatPlainText + | InsertTextFormatSnippet; + +let completion_item_kind_to_yojson = severity => + completion_item_kind_to_enum(severity) |> [%to_yojson: int]; +let completion_item_kind_of_yojson = json => + Result.bind(json |> [%of_yojson: int], value => { + switch (completion_item_kind_of_enum(value)) { + | Some(severity) => Ok(severity) + | None => Result.Error("Invalid enum value") + } + }); + +let completion_trigger_kind_to_yojson = kind => + completion_trigger_kind_to_enum(kind) |> [%to_yojson: int]; +let completion_trigger_kind_of_yojson = json => + Result.bind(json |> [%of_yojson: int], value => { + switch (completion_trigger_kind_of_enum(value)) { + | Some(kind) => Ok(kind) + | None => Result.Error("Invalid enum value") + } + }); + +let insert_text_format_to_yojson = value => + insert_text_format_to_enum(value) |> [%to_yojson: int]; +let insert_text_format_of_yojson = json => + Result.bind(json |> [%of_yojson: int], value => { + switch (insert_text_format_of_enum(value)) { + | Some(value) => Ok(value) + | None => Result.Error("Invalid enum value") + } + }); + +[@deriving yojson] +type completion_item = { + label: string, + kind: completion_item_kind, + detail: string, + [@key "insertText"] + insert_text: option(string), + [@key "insertTextFormat"] + insert_text_format, + documentation: string, +}; + +[@deriving yojson({strict: false})] +type completion_context = { + [@key "triggerKind"] + trigger_kind: completion_trigger_kind, + [@key "triggerCharacter"] [@default None] + trigger_character: option(string), +}; + +// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#completionParams +module RequestParams = { + [@deriving yojson({strict: false})] + type t = { + [@key "textDocument"] + text_document: Protocol.text_document_identifier, + position: Protocol.position, + [@default None] + context: option(completion_context), + }; +}; + +// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#completionList +module ResponseResult = { + [@deriving yojson] + type t = { + isIncomplete: bool, + items: list(completion_item), + }; +}; + +let send_completion = + (~id: Protocol.message_id, completions: list(completion_item)) => { + Protocol.response( + ~id, + ResponseResult.to_yojson({isIncomplete: false, items: completions}), + ); +}; + +// completions helpers +let build_completion = + ( + ~detail="", + ~documentation="", + ~insert_text_format=InsertTextFormatPlainText, + ~insert_text=?, + label: string, + kind: completion_item_kind, + ) => { + {label, kind, detail, insert_text, insert_text_format, documentation}; +}; + +// TODO: This is for debugging only +let debug_stringify_tkn = (token: Parser.token) => { + switch (token) { + | Parser.RATIONAL(c) => Printf.sprintf("Token.Rational(%s)", c) + | Parser.NUMBER_INT(c) => Printf.sprintf("Token.NUMBER_INT(%s)", c) + | Parser.NUMBER_FLOAT(c) => Printf.sprintf("Token.NUMBER_FLOAT(%s)", c) + | Parser.INT8(c) => Printf.sprintf("Token.INT8(%s)", c) + | Parser.INT16(c) => Printf.sprintf("Token.INT16(%s)", c) + | Parser.INT32(c) => Printf.sprintf("Token.INT32(%s)", c) + | Parser.INT64(c) => Printf.sprintf("Token.INT64(%s)", c) + | Parser.UINT8(c) => Printf.sprintf("Token.UINT8(%s)", c) + | Parser.UINT16(c) => Printf.sprintf("Token.UINT16(%s)", c) + | Parser.UINT32(c) => Printf.sprintf("Token.UINT32(%s)", c) + | Parser.UINT64(c) => Printf.sprintf("Token.UINT64(%s)", c) + | Parser.FLOAT32(c) => Printf.sprintf("Token.FLOAT32(%s)", c) + | Parser.FLOAT64(c) => Printf.sprintf("Token.FLOAT64(%s)", c) + | Parser.BIGINT(c) => Printf.sprintf("Token.BIGINT(%s)", c) + | Parser.WASMI32(c) => Printf.sprintf("Token.WASMI32(%s)", c) + | Parser.WASMI64(c) => Printf.sprintf("Token.WASMI64(%s)", c) + | Parser.WASMF32(c) => Printf.sprintf("Token.WASMF32(%s)", c) + | Parser.WASMF64(c) => Printf.sprintf("Token.WASMF64(%s)", c) + | Parser.LIDENT(c) => Printf.sprintf("Token.LIDENT(%s)", c) + | Parser.UIDENT(c) => Printf.sprintf("Token.UIDENT(%s)", c) + | Parser.STRING(c) => Printf.sprintf("Token.STRING(%s)", c) + | Parser.BYTES(c) => Printf.sprintf("Token.BYTES(%s)", c) + | Parser.CHAR(c) => Printf.sprintf("Token.CHAR(%s)", c) + | Parser.LBRACK => "Token.LBRACK" + | Parser.LBRACKRCARET => "Token.LBRACKRCARET" + | Parser.RBRACK => "Token.RBRACK" + | Parser.LPAREN => "Token.LPAREN" + | Parser.RPAREN => "Token.RPAREN" + | Parser.LBRACE => "Token.LBRACE" + | Parser.RBRACE => "Token.RBRACE" + | Parser.LCARET => "Token.LCARET" + | Parser.RCARET => "Token.RCARET" + | Parser.COMMA => "Token.COMMA" + | Parser.SEMI => "Token.SEMI" + | Parser.AS => "Token.AS" + | Parser.THICKARROW => "Token.THICKARROW" + | Parser.ARROW => "Token.ARROW" + | Parser.EQUAL => "Token.EQUAL" + | Parser.GETS => "Token.GETS" + | Parser.UNDERSCORE => "Token.UNDERSCORE" + | Parser.COLON => "Token.COLON" + | Parser.QUESTION => "Token.QUESTION" + | Parser.DOT => "Token.DOT" + | Parser.ELLIPSIS => "Token.ELLIPSIS" + | Parser.ASSERT => "Token.ASSERT" + | Parser.FAIL => "Token.FAIL" + | Parser.EXCEPTION => "Token.EXCEPTION" + | Parser.THROW => "Token.THROW" + | Parser.TRUE => "Token.TRUE" + | Parser.FALSE => "Token.FALSE" + | Parser.VOID => "Token.VOID" + | Parser.LET => "Token.LET" + | Parser.MUT => "Token.MUT" + | Parser.REC => "Token.REC" + | Parser.IF => "Token.IF" + | Parser.WHEN => "Token.WHEN" + | Parser.ELSE => "Token.ELSE" + | Parser.MATCH => "Token.MATCH" + | Parser.WHILE => "Token.WHILE" + | Parser.FOR => "Token.FOR" + | Parser.CONTINUE => "Token.CONTINUE" + | Parser.BREAK => "Token.BREAK" + | Parser.RETURN => "Token.RETURN" + | Parser.AT => "Token.AT" + | Parser.INFIX_10(c) => Printf.sprintf("Token.INFIX_10(%s)", c) + | Parser.INFIX_30(c) => Printf.sprintf("Token.INFIX_30(%s)", c) + | Parser.INFIX_40(c) => Printf.sprintf("Token.INFIX_40(%s)", c) + | Parser.INFIX_50(c) => Printf.sprintf("Token.INFIX_50(%s)", c) + | Parser.INFIX_60(c) => Printf.sprintf("Token.INFIX_60(%s)", c) + | Parser.INFIX_70(c) => Printf.sprintf("Token.INFIX_70(%s)", c) + | Parser.INFIX_80(c) => Printf.sprintf("Token.INFIX_80(%s)", c) + | Parser.INFIX_90(c) => Printf.sprintf("Token.INFIX_90(%s)", c) + | Parser.INFIX_100(c) => Printf.sprintf("Token.INFIX_100(%s)", c) + | Parser.INFIX_110(c) => Printf.sprintf("Token.INFIX_110(%s)", c) + | Parser.INFIX_120(c) => Printf.sprintf("Token.INFIX_120(%s)", c) + | Parser.PREFIX_150(c) => Printf.sprintf("Token.PREFIX_150(%s)", c) + | Parser.INFIX_ASSIGNMENT_10(c) => + Printf.sprintf("Token.INFIX_ASSIGNMENT_10(%s)", c) + | Parser.ENUM => "Token.ENUM" + | Parser.RECORD => "Token.RECORD" + | Parser.TYPE => "Token.TYPE" + | Parser.MODULE => "Token.MODULE" + | Parser.INCLUDE => "Token.INCLUDE" + | Parser.USE => "Token.USE" + | Parser.PROVIDE => "Token.PROVIDE" + | Parser.ABSTRACT => "Token.ABSTRACT" + | Parser.FOREIGN => "Token.FOREIGN" + | Parser.WASM => "Token.WASM" + | Parser.PRIMITIVE => "Token.PRIMITIVE" + | Parser.AND => "Token.AND" + | Parser.EXCEPT => "Token.EXCEPT" + | Parser.FROM => "Token.FROM" + | Parser.STAR => "Token.STAR" + | Parser.SLASH => "Token.SLASH" + | Parser.DASH => "Token.DASH" + | Parser.PIPE => "Token.PIPE" + | Parser.EOL => "Token.EOL" + | Parser.EOF => "Token.EOF" + | Parser.TRY => "Token.TRY" + | Parser.CATCH => "Token.CATCH" + | Parser.COLONCOLON => "Token.COLONCOLON" + | Parser.MACRO => "Token.MACRO" + | Parser.YIELD => "Token.YIELD" + | Parser.FUN => "Token.FUN" + }; +}; + +let debug_stringify_tkn_loc = + (token: Parser.token, start_loc: int, end_loc: int) => { + Printf.sprintf( + "Token: %s, Start: %d, End: %d", + debug_stringify_tkn(token), + start_loc, + end_loc, + ); +}; +// Completion Info + +type completion_value = + | PlainText(string) + | Snippet(string, string); + +let toplevel_keywords = [ + PlainText("exception"), + PlainText("enum"), + PlainText("record"), + PlainText("type"), + PlainText("module"), + PlainText("provide"), + PlainText("abstract"), + Snippet("from", "from \"$1\" include $0"), + Snippet("use", "use $1.{ $0 }"), +]; + +let expression_keywods = [ + PlainText("let"), + Snippet("if", "if ($1) $0"), + Snippet("match", "match ($1) {\n $0\n}"), + Snippet("while", "while ($1) {\n $0\n}"), + Snippet("for", "for ($1; $2; $3) {\n $0\n}"), +]; + +// context helpers +type lex_token = { + token: Parser.token, + start_loc: int, + end_loc: int, +}; + +type completable_context = + | CompletableInclude(string, bool) + | CompletableStatement(bool) + | CompletableExpressionWithReturn + | CompletableExpression + | CompletableExpressionPath(Path.t, bool) + | CompletableAs + | CompletableAfterLet + | CompletableUnknown; + +// TODO: This is only for debugging atm +let print_string_of_context = context => { + let ctx = + switch (context) { + | CompletableInclude(_, _) => "CompleteInclude" + | CompletableAs => "CompleteAs" + | CompletableAfterLet => "CompletableAfterLet" + | CompletableExpressionWithReturn => "CompletableExpressionWithReturn" + | CompletableExpression => "CompleteExpression" + | CompletableExpressionPath(_, _) => "CompleteExpressionPath" + | CompletableStatement(true) => "CompletableStatement(Module)" + | CompletableStatement(false) => "CompletableStatement(Statement | value)" + | CompletableUnknown => "CompleteUnknown" + }; + Trace.log(Printf.sprintf("Context: %s", ctx)); +}; + +let convert_position_to_offset = (source: string, position: Protocol.position) => { + let (_, _, offset) = + List.fold_left( + ((line_num, col_num, offset), c) => + if (line_num == position.line && col_num == position.character) { + (line_num, col_num, offset); + } else { + switch (c) { + | '\r' => (line_num, 0, offset + 1) + | '\n' => (line_num + 1, 0, offset + 1) + | _ => (line_num, col_num + 1, offset + 1) + }; + }, + (0, 0, 0), + List.of_seq(String.to_seq(source)), + ); + offset; +}; + +let in_range = (range_start: int, range_end: int, pos: int) => { + range_start < pos && pos < range_end; +}; +let after_range = (range_end: int, pos: int) => { + range_end < pos; +}; + +let last_token_eq = (token: Parser.token, token_list: list(Parser.token)) => { + switch (token_list) { + | [tkn, ..._] => tkn == token + | [] => false + }; +}; + +let rec token_non_breaking_lst = (token_list: list(Parser.token)) => { + switch (token_list) { + | [Parser.EOF, ...rest] + | [Parser.EOL, ...rest] + | [Parser.COMMA, ...rest] => token_non_breaking_lst(rest) + | [_, ..._] => false + | [] => true + }; +}; + +let rec collect_idents = + ( + acc: option(Path.t), + last_dot: bool, + token_list: list(Parser.token), + ) => { + switch (token_list) { + | _ when !last_dot => (acc, false) + | [Parser.UIDENT(str), ...rest] + | [Parser.LIDENT(str), ...rest] => + let ident = + switch (acc) { + | Some(acc) => Path.PExternal(acc, str) + | None => Path.PIdent(Ident.create(str)) + }; + collect_idents(Some(ident), false, rest); + | [Parser.DOT, ...rest] => collect_idents(acc, true, rest) + | [_, ..._] + | [] => (acc, last_dot) + }; +}; + +let get_completion_context = (documents, uri, position: Protocol.position) => { + // try and find the code we are completing in the original source + switch (Hashtbl.find_opt(documents, uri)) { + | None => CompletableUnknown + | Some(source_code) => + // Get Document + let offset = convert_position_to_offset(source_code, position); + Trace.log(Printf.sprintf("Offset: %d", offset)); + // Collect Tokens until offset + let lexbuf = Sedlexing.Utf8.from_string(source_code); + let lexer = Wrapped_lexer.init(lexbuf); + let token = _ => Wrapped_lexer.token(lexer); + Lexer.reset(); + let rec get_tokens = (tokens: list(lex_token)) => { + let (current_tok, start_loc, end_loc) = token(); + let current_token = { + token: current_tok, + start_loc: start_loc.pos_cnum, + end_loc: end_loc.pos_cnum, + }; + switch (current_tok) { + | _ when current_token.start_loc > offset => tokens + | Parser.EOF => [current_token, ...tokens] + | _ => get_tokens([current_token, ...tokens]) + }; + }; + let tokens = + try(get_tokens([])) { + | _ => [] + }; + List.iter( + current_token => { + Trace.log( + Printf.sprintf( + "Token(%s)", + debug_stringify_tkn_loc( + current_token.token, + current_token.start_loc, + current_token.end_loc, + ), + ), + ) + }, + tokens, + ); + // Determine Context + let rec determine_if_in_block = tokens => { + switch (tokens) { + | [{token: Parser.LBRACE, start_loc}, ..._] when start_loc < offset => + true + | [{token: Parser.RBRACE, start_loc}, ..._] when start_loc < offset => + false + | [_, ...rest] => determine_if_in_block(rest) + | [] => false + }; + }; + let in_block = determine_if_in_block(tokens); + let rec build_context = + ( + ~hit_eol: bool, + token_list: list(Parser.token), + tokens: list(lex_token), + ) => { + switch (tokens) { + // TODO: Add a state for when we are at match (XXXX) { | } <- This could be very useful also could not be + // TODO: Add a state for type XXXX = | + // TODO: Add state for use XXXX. + // Tokens that we care about + | [{token: Parser.LET}, ..._] + when !hit_eol && token_non_breaking_lst(token_list) => + CompletableAfterLet + // TODO: Reimplement the as completion + | [{token: Parser.MODULE, end_loc}, {token: Parser.INCLUDE}, ..._] + when + !hit_eol + && after_range(end_loc, offset) + && !last_token_eq(Parser.AS, token_list) => + CompletableAs + | [ + {token: Parser.STRING(str), start_loc, end_loc}, + {token: Parser.FROM}, + ..._, + ] + when in_range(start_loc, end_loc, offset) => + CompletableInclude(str, false) + // TODO: Just capture up to the . + | [{token: Parser.DOT}, {token: Parser.EOL}, ..._] + when !hit_eol && !last_token_eq(Parser.DOT, token_list) => + /* + * TODO: Support test().label on records somehow + * This is going to require using sourceTree, to get the return type of the function, (We may also be able to check the env but the problem is we don't have a complete env at this point) + * After we have a type signature it shouldn't be that hard to resolve the completions, until that point it is though. + */ + // TODO: Implement path collection + let (path, expr_start) = collect_idents(None, true, token_list); + switch (path) { + | Some(path) => CompletableExpressionPath(path, expr_start) + | None => CompletableUnknown + }; + // This is the case of XXXX.X| <- You are actively writing + // TODO: Support test().label on records somehow + | [ + {token: Parser.LIDENT(_) | Parser.UIDENT(_), start_loc}, + {token: Parser.EOL}, + ..._, + ] + when !hit_eol && start_loc < offset => + // TODO: Collect the path + if (!in_block) { + CompletableStatement(false); + } else { + CompletableExpression; + } + | [{token: Parser.UIDENT(_), start_loc}, {token: Parser.EOL}, ..._] + when !hit_eol && start_loc < offset => + if (!in_block) { + CompletableStatement(true); + } else { + CompletableExpression; + } + | [{token: Parser.THICKARROW}, ..._] => + // TODO: Determine if this is a type or expression + CompletableExpression + | [{token: Parser.LPAREN}, ..._] + when token_non_breaking_lst(token_list) => + CompletableExpressionWithReturn + | [{token: Parser.EQUAL}, ..._] + when token_non_breaking_lst(token_list) => + CompletableExpressionWithReturn + | [{token: Parser.EOL, start_loc}, ...rest] when start_loc < offset => + build_context(~hit_eol=true, [Parser.EOL, ...token_list], rest) + | [] => CompletableUnknown + // Most tokens we can skip + | [tok, ...rest] => + build_context(~hit_eol, [tok.token, ...token_list], rest) + }; + }; + build_context(~hit_eol=false, [], tokens); + }; +}; + +let rec resolve_type = (type_desc: Types.type_desc) => { + switch (type_desc) { + | TTySubst({desc}) + | TTyLink({desc}) => resolve_type(desc) + | _ => type_desc + }; +}; + +let build_keyword_completions = (values: list(completion_value)) => { + List.map( + keyword => + switch (keyword) { + | PlainText(label) => + build_completion(label, CompletionItemKindKeyword) + | Snippet(label, snippet) => + build_completion( + ~insert_text_format=InsertTextFormatSnippet, + ~insert_text=snippet, + label, + CompletionItemKindKeyword, + ) + }, + values, + ); +}; + +let get_expression_completions = + (desire_non_void: bool, program: option(Typedtree.typed_program)) => { + // TODO: Consider using source tree to better infer the env + // builtins + let builtins = [ + build_completion("Ok", CompletionItemKindEnumMember), + build_completion("Err", CompletionItemKindEnumMember), + build_completion("Some", CompletionItemKindEnumMember), + build_completion("None", CompletionItemKindEnumMember), + build_completion("true", CompletionItemKindValue), + build_completion("false", CompletionItemKindValue), + build_completion("void", CompletionItemKindValue), + ]; + // values + let value_completions = + switch (program) { + | Some({env}) => + Env.fold_values( + (tag, _, decl, acc) => { + let (kind, typ) = + switch (resolve_type(decl.val_type.desc)) { + | TTyArrow(_, typ, _) => (CompletionItemKindFunction, typ.desc) + | typ => (CompletionItemKindValue, typ) + }; + switch (List.of_seq(String.to_seq(tag))) { + | [ + '$' | '&' | '*' | '/' | '+' | '-' | '=' | '>' | '<' | '^' | '|' | + '!' | + '?' | + '%' | + ':' | + '.', + ..._, + ] => acc + | _ => + switch (resolve_type(typ)) { + | TTyConstr(id, _, _) + when Path.same(id, Builtin_types.path_void) && desire_non_void => acc + | _ => [ + build_completion( + ~detail=Document.print_type_raw(decl.val_type), + tag, + kind, + ), + ...acc, + ] + } + }; + }, + None, + env, + [], + ) + | None => [] + }; + // modules + let module_completions = + switch (program) { + | Some({env}) => + Env.fold_modules( + (tag, _, _, acc) => { + [ + build_completion( + ~detail=Printf.sprintf("module %s", tag), + tag, + CompletionItemKindModule, + ), + ...acc, + ] + }, + None, + env, + [], + ) + | None => [] + }; + // merge them all + List.concat([builtins, value_completions, module_completions]); +}; + +let get_completions_from_context = + (context: completable_context, program: option(Typedtree.typed_program)) => { + // TODO: Consider using the sourcetree to provide some extra env context, thinking type signatures + switch (context) { + | CompletableInclude(path, afterPath) => + if (afterPath) { + [ + // TODO: Implement completion for include Module name, Note: This is going to take some work as the module is not loaded into the env + // We are at from "path" | <- cursor is represented by | + build_completion("include", CompletionItemKindKeyword), + ]; + } else { + [ + // TODO: Add all paths in Includes + // TODO: Add all relative paths + // We are at from "|" <- cursor is represented by | + build_completion("number", CompletionItemKindFile), + ]; + } + | CompletableStatement(true) => + switch (program) { + | Some({env}) => + Env.fold_modules( + (tag, _, _, acc) => { + [ + build_completion( + ~detail=Printf.sprintf("module %s", tag), + tag, + CompletionItemKindModule, + ), + ...acc, + ] + }, + None, + env, + [], + ) + | None => [] + } + | CompletableStatement(false) => + let toplevel_completions = build_keyword_completions(toplevel_keywords); + let expression_completions = + build_keyword_completions(expression_keywods); + let value_completions = + switch (program) { + | Some({env}) => + Env.fold_values( + (tag, _, decl, acc) => { + switch (resolve_type(decl.val_type.desc)) { + | TTyArrow(_, _, _) => [ + build_completion( + ~detail=Document.print_type_raw(decl.val_type), + tag, + CompletionItemKindFunction, + ), + ...acc, + ] + | _ => acc + } + }, + None, + env, + [], + ) + | None => [] + }; + List.concat([ + toplevel_completions, + expression_completions, + value_completions, + ]); + | CompletableExpressionWithReturn => + get_expression_completions(true, program) + | CompletableExpression => + let keyword_completions = build_keyword_completions(expression_keywods); + let expression_completions = get_expression_completions(false, program); + List.concat([keyword_completions, expression_completions]); + | CompletableExpressionPath(path, expr_start) => + switch (expr_start) { + | false => + switch (program) { + | Some({env}) => + // TODO: Idk if this is the right function here + [] + | None => [] + } + // TODO: Handle this + | true => + Trace.log("Unknown Behaviour of expr.XXX"); + []; + } + | CompletableAs => [build_completion("as", CompletionItemKindKeyword)] + | CompletableAfterLet => [ + build_completion("mut", CompletionItemKindKeyword), + build_completion("rec", CompletionItemKindKeyword), + ] + | CompletableUnknown => [] + }; +}; + +let process = + ( + ~id: Protocol.message_id, + ~compiled_code: Hashtbl.t(Protocol.uri, Lsp_types.code), + ~documents: Hashtbl.t(Protocol.uri, string), + params: RequestParams.t, + ) => { + let program = + switch (Hashtbl.find_opt(compiled_code, params.text_document.uri)) { + | None => None + | Some({program}) => Some(program) + }; + let context = + get_completion_context( + documents, + params.text_document.uri, + params.position, + ); + print_string_of_context(context); + let completions = get_completions_from_context(context, program); + send_completion(~id, completions); +}; diff --git a/compiler/src/language_server/completion.rei b/compiler/src/language_server/completion.rei new file mode 100644 index 0000000000..a81d42db60 --- /dev/null +++ b/compiler/src/language_server/completion.rei @@ -0,0 +1,22 @@ +open Grain_typed; + +// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#completionParams +module RequestParams: { + [@deriving yojson({strict: false})] + type t; +}; + +// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#completionList +module ResponseResult: { + [@deriving yojson] + type t; +}; + +let process: + ( + ~id: Protocol.message_id, + ~compiled_code: Hashtbl.t(Protocol.uri, Lsp_types.code), + ~documents: Hashtbl.t(Protocol.uri, string), + RequestParams.t + ) => + unit; diff --git a/compiler/src/language_server/document.re b/compiler/src/language_server/document.re new file mode 100644 index 0000000000..51e1ce2906 --- /dev/null +++ b/compiler/src/language_server/document.re @@ -0,0 +1,70 @@ +open Grain; +open Compile; +open Grain_parsing; +open Grain_utils; +open Grain_typed; +open Grain_diagnostics; +open Sourcetree; + +// We need to use the "grain-type" markdown syntax to have correct coloring on hover items +let grain_type_code_block = Markdown.code_block(~syntax="grain-type"); +// Used for module hovers +let grain_code_block = Markdown.code_block(~syntax="grain"); + +let markdown_join = (a, b) => { + // Horizonal rules between code blocks render a little funky + // so we manually add linebreaks + Printf.sprintf( + "%s\n---\n

\n%s", + a, + b, + ); +}; + +let supressed_types = [Builtin_types.path_void, Builtin_types.path_bool]; + +let print_type = (env, ty) => { + let instance = grain_type_code_block(Printtyp.string_of_type_scheme(ty)); + try({ + let (path, _, decl) = Ctype.extract_concrete_typedecl(env, ty); + // Avoid showing the declaration for supressed types + if (List.exists( + supressed_type => Path.same(path, supressed_type), + supressed_types, + )) { + raise(Not_found); + }; + markdown_join( + grain_code_block( + Printtyp.string_of_type_declaration( + ~ident=Ident.create(Path.last(path)), + decl, + ), + ), + instance, + ); + }) { + | Not_found => instance + }; +}; + +let print_type_raw = ty => Printtyp.string_of_type_scheme(ty); + +let print_mod_type = (decl: Types.module_declaration) => { + let vals = Modules.get_provides(decl); + let signatures = + List.map( + (v: Modules.provide) => + switch (v.kind) { + | Function + | Value => Format.sprintf("let %s", v.signature) + | Record + | Enum + | Abstract + | Exception => v.signature + | Module => Format.sprintf("module %s", v.name) + }, + vals, + ); + String.concat("\n", signatures); +}; diff --git a/compiler/src/language_server/document.rei b/compiler/src/language_server/document.rei new file mode 100644 index 0000000000..de2930b4a6 --- /dev/null +++ b/compiler/src/language_server/document.rei @@ -0,0 +1,13 @@ +open Grain_typed; + +let grain_type_code_block: string => string; + +let grain_code_block: string => string; + +let markdown_join: (string, string) => string; + +let print_type: (Env.t, Types.type_expr) => string; + +let print_type_raw: Types.type_expr => string; + +let print_mod_type: Types.module_declaration => string; diff --git a/compiler/src/language_server/driver.re b/compiler/src/language_server/driver.re index 7a30f64a2c..8f0169a1c0 100644 --- a/compiler/src/language_server/driver.re +++ b/compiler/src/language_server/driver.re @@ -31,6 +31,9 @@ let process = msg => { | TextDocumentCodeLens(id, params) when is_initialized^ => Lenses.process(~id, ~compiled_code, ~documents, params); Reading; + | TextDocumentCompletion(id, params) when is_initialized^ => + Completion.process(~id, ~compiled_code, ~documents, params); + Reading; | Shutdown(id, params) when is_initialized^ => Shutdown.process(~id, ~compiled_code, ~documents, params); is_shutting_down := true; diff --git a/compiler/src/language_server/hover.re b/compiler/src/language_server/hover.re index a3009e6ac8..40f02837b5 100644 --- a/compiler/src/language_server/hover.re +++ b/compiler/src/language_server/hover.re @@ -28,11 +28,6 @@ module ResponseResult = { }; }; -// We need to use the "grain-type" markdown syntax to have correct coloring on hover items -let grain_type_code_block = Markdown.code_block(~syntax="grain-type"); -// Used for module hovers -let grain_code_block = Markdown.code_block(~syntax="grain"); - let send_hover = (~id: Protocol.message_id, ~range: Protocol.range, result) => { Protocol.response( ~id, @@ -46,83 +41,36 @@ let send_hover = (~id: Protocol.message_id, ~range: Protocol.range, result) => { ); }; -let markdown_join = (a, b) => { - // Horizonal rules between code blocks render a little funky - // so we manually add linebreaks - Printf.sprintf( - "%s\n---\n

\n%s", - a, - b, - ); -}; - let send_no_result = (~id: Protocol.message_id) => { Protocol.response(~id, `Null); }; -let supressed_types = [Builtin_types.path_void, Builtin_types.path_bool]; - -let print_type = (env, ty) => { - let instance = grain_type_code_block(Printtyp.string_of_type_scheme(ty)); - try({ - let (path, _, decl) = Ctype.extract_concrete_typedecl(env, ty); - // Avoid showing the declaration for supressed types - if (List.exists( - supressed_type => Path.same(path, supressed_type), - supressed_types, - )) { - raise(Not_found); - }; - markdown_join( - grain_code_block( - Printtyp.string_of_type_declaration( - ~ident=Ident.create(Path.last(path)), - decl, - ), - ), - instance, - ); - }) { - | Not_found => instance - }; -}; - let module_lens = (decl: Types.module_declaration) => { - let vals = Modules.get_provides(decl); - let signatures = - List.map( - (v: Modules.export) => - switch (v.kind) { - | Function - | Value => Format.sprintf("let %s", v.signature) - | Record - | Enum - | Abstract - | Exception => v.signature - }, - vals, - ); - grain_code_block(String.concat("\n", signatures)); + Document.grain_code_block(Document.print_mod_type(decl)); }; let value_lens = (env: Env.t, ty: Types.type_expr) => { - print_type(env, ty); + Document.print_type(env, ty); }; let pattern_lens = (p: Typedtree.pattern) => { - print_type(p.pat_env, p.pat_type); + Document.print_type(p.pat_env, p.pat_type); }; let type_lens = (ty: Typedtree.core_type) => { - grain_type_code_block(Printtyp.string_of_type_scheme(ty.ctyp_type)); + Document.grain_type_code_block( + Printtyp.string_of_type_scheme(ty.ctyp_type), + ); }; let declaration_lens = (ident: Ident.t, decl: Types.type_declaration) => { - grain_type_code_block(Printtyp.string_of_type_declaration(~ident, decl)); + Document.grain_type_code_block( + Printtyp.string_of_type_declaration(~ident, decl), + ); }; let include_lens = (env: Env.t, path: Path.t) => { - let header = grain_code_block("module " ++ Path.name(path)); + let header = Document.grain_code_block("module " ++ Path.name(path)); let decl = Env.find_module(path, None, env); let module_decl = switch (Modules.get_provides(decl)) { @@ -130,14 +78,14 @@ let include_lens = (env: Env.t, path: Path.t) => { | [] => None }; switch (module_decl) { - | Some(mod_sig) => markdown_join(header, mod_sig) + | Some(mod_sig) => Document.markdown_join(header, mod_sig) | None => header }; }; let exception_declaration_lens = (ident: Ident.t, ext: Types.extension_constructor) => { - grain_type_code_block( + Document.grain_type_code_block( Printtyp.string_of_extension_constructor(~ident, ext), ); }; diff --git a/compiler/src/language_server/initialize.re b/compiler/src/language_server/initialize.re index 57a6c851da..910ff2ba96 100644 --- a/compiler/src/language_server/initialize.re +++ b/compiler/src/language_server/initialize.re @@ -27,6 +27,14 @@ module RequestParams = { // https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#initializeResult module ResponseResult = { + [@deriving yojson] + type completion_values = { + [@key "resolveProvider"] + resolve_provider: bool, + [@key "triggerCharacters"] + trigger_characters: list(string), + }; + [@deriving yojson] type code_values = { [@key "resolveProvider"] @@ -41,6 +49,8 @@ module ResponseResult = { text_document_sync: Protocol.text_document_sync_kind, [@key "hoverProvider"] hover_provider: bool, + [@key "completionProvider"] + completion_provider: completion_values, [@key "definitionProvider"] definition_provider: Protocol.definition_client_capabilities, [@key "typeDefinitionProvider"] @@ -69,6 +79,10 @@ module ResponseResult = { document_formatting_provider: true, text_document_sync: Full, hover_provider: true, + completion_provider: { + resolve_provider: false, + trigger_characters: [".", ",", "(", ":", "[", "\"", " "], + }, definition_provider: { link_support: true, }, diff --git a/compiler/src/language_server/message.re b/compiler/src/language_server/message.re index cdb277f5ec..4bf39e3c1a 100644 --- a/compiler/src/language_server/message.re +++ b/compiler/src/language_server/message.re @@ -4,6 +4,7 @@ type t = | Initialize(Protocol.message_id, Initialize.RequestParams.t) | TextDocumentHover(Protocol.message_id, Hover.RequestParams.t) | TextDocumentCodeLens(Protocol.message_id, Lenses.RequestParams.t) + | TextDocumentCompletion(Protocol.message_id, Completion.RequestParams.t) | Shutdown(Protocol.message_id, Shutdown.RequestParams.t) | Exit(Protocol.message_id, Exit.RequestParams.t) | TextDocumentDidOpen(Protocol.uri, Code_file.DidOpen.RequestParams.t) @@ -37,6 +38,11 @@ let of_request = (msg: Protocol.request_message): t => { | Ok(params) => TextDocumentCodeLens(id, params) | Error(msg) => Error(msg) } + | {method: "textDocument/completion", id: Some(id), params: Some(params)} => + switch (Completion.RequestParams.of_yojson(params)) { + | Ok(params) => TextDocumentCompletion(id, params) + | Error(msg) => Error(msg) + } | {method: "shutdown", id: Some(id), params: None} => switch (Shutdown.RequestParams.of_yojson(`Null)) { | Ok(params) => Shutdown(id, params) diff --git a/compiler/src/language_server/message.rei b/compiler/src/language_server/message.rei index 1a72f48411..d143b05634 100644 --- a/compiler/src/language_server/message.rei +++ b/compiler/src/language_server/message.rei @@ -2,6 +2,7 @@ type t = | Initialize(Protocol.message_id, Initialize.RequestParams.t) | TextDocumentHover(Protocol.message_id, Hover.RequestParams.t) | TextDocumentCodeLens(Protocol.message_id, Lenses.RequestParams.t) + | TextDocumentCompletion(Protocol.message_id, Completion.RequestParams.t) | Shutdown(Protocol.message_id, Shutdown.RequestParams.t) | Exit(Protocol.message_id, Exit.RequestParams.t) | TextDocumentDidOpen(Protocol.uri, Code_file.DidOpen.RequestParams.t) diff --git a/compiler/src/parsing/wrapped_lexer.rei b/compiler/src/parsing/wrapped_lexer.rei new file mode 100644 index 0000000000..98402470a9 --- /dev/null +++ b/compiler/src/parsing/wrapped_lexer.rei @@ -0,0 +1,17 @@ +open Parser; + +type positioned('a) = ('a, Lexing.position, Lexing.position); + +type fn_ctx = + | DiscoverFunctions + | IgnoreFunctions; + +type t = { + lexbuf: Sedlexing.lexbuf, + mutable queued_tokens: list(positioned(token)), + mutable queued_exn: option(exn), + mutable fn_ctx_stack: list(fn_ctx), +}; + +let init: Sedlexing.lexbuf => t; +let token: t => (token, Lexing.position, Lexing.position);