diff --git a/crates/emmylua_code_analysis/resources/std/builtin.lua b/crates/emmylua_code_analysis/resources/std/builtin.lua index da7557002..44abac983 100644 --- a/crates/emmylua_code_analysis/resources/std/builtin.lua +++ b/crates/emmylua_code_analysis/resources/std/builtin.lua @@ -128,6 +128,7 @@ --- built-in type for Rawget --- @alias std.RawGet unknown +--- @deprecated use `const T` as a replacement, for example `---@generic const T`. --- --- built-in type for generic template, for match integer const and true/false --- @alias std.ConstTpl unknown diff --git a/crates/emmylua_code_analysis/resources/std/global.lua b/crates/emmylua_code_analysis/resources/std/global.lua index 220ca2de9..b68dd7a52 100644 --- a/crates/emmylua_code_analysis/resources/std/global.lua +++ b/crates/emmylua_code_analysis/resources/std/global.lua @@ -277,9 +277,9 @@ function rawequal(v1, v2) end --- --- Gets the real value of `table[index]`, the `__index` metamethod. `table` --- must be a table; `index` may be any value. ---- @generic T, K +--- @generic const T, const K --- @param table T ---- @param index std.ConstTpl +--- @param index K --- @return std.RawGet function rawget(table, index) end @@ -340,8 +340,8 @@ function require(modname) end --- `index`. a negative number indexes from the end (-1 is the last argument). --- Otherwise, `index` must be the string "#", and `select` returns --- the total number of extra arguments it received. ---- @generic T, Num: integer | '#' ---- @param index std.ConstTpl +--- @generic T, const Num: integer | '#' +--- @param index Num --- @param ... T... --- @return std.Select function select(index, ...) end @@ -460,9 +460,9 @@ function xpcall(f, msgh, ...) end --- @version 5.1, JIT --- ---- @generic T, Start: integer, End: integer ---- @param i? std.ConstTpl ---- @param j? std.ConstTpl +--- @generic const T, const Start: integer, const End: integer +--- @param i? Start +--- @param j? End --- @param list T --- @return std.Unpack function unpack(list, i, j) end diff --git a/crates/emmylua_code_analysis/resources/std/table.lua b/crates/emmylua_code_analysis/resources/std/table.lua index ee60a7343..8ab495e5e 100644 --- a/crates/emmylua_code_analysis/resources/std/table.lua +++ b/crates/emmylua_code_analysis/resources/std/table.lua @@ -106,9 +106,9 @@ function table.sort(list, comp) end --- Returns the elements from the given list. This function is equivalent to --- return `list[i]`, `list[i+1]`, `···`, `list[j]` --- By default, i is 1 and j is #list. ---- @generic T, Start: integer, End: integer ---- @param i? std.ConstTpl ---- @param j? std.ConstTpl +--- @generic const T, const Start: integer, const End: integer +--- @param i? Start +--- @param j? End --- @param list T --- @return std.Unpack function table.unpack(list, i, j) end diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs index e3ad351d2..de39d8a0e 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs @@ -4,24 +4,32 @@ use rowan::{TextRange, TextSize}; use smol_str::SmolStr; use std::sync::Arc; -use crate::{GenericParam, GenericTpl, GenericTplId, LuaType}; +use crate::{GenericParam, GenericTpl, GenericTplId}; pub trait GenericIndex: std::fmt::Debug { fn add_generic_scope(&mut self, ranges: Vec, is_func: bool) -> GenericScopeId; - fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam); + fn append_generic_param( + &mut self, + scope_id: GenericScopeId, + param: GenericParam, + ) -> Option; fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { for param in params { - self.append_generic_param(scope_id, param); + let _ = self.append_generic_param(scope_id, param); } } - fn find_generic( - &self, - position: TextSize, - name: &str, - ) -> Option<(GenericTplId, Option, Option)>; + fn find_generic(&self, position: TextSize, name: &str) -> Option<(GenericTplId, GenericParam)>; + + fn generic_param_mut(&mut self, tpl_id: GenericTplId) -> Option<&mut GenericParam>; + + fn mark_generic_const(&mut self, tpl_id: GenericTplId) -> Option { + let param = self.generic_param_mut(tpl_id)?; + param.is_const = true; + Some(param.clone()) + } } #[derive(Debug, Clone)] @@ -63,36 +71,38 @@ impl GenericIndex for FileGenericIndex { scope_id } - fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam) { + fn append_generic_param( + &mut self, + scope_id: GenericScopeId, + param: GenericParam, + ) -> Option { if let Some(scope) = self.scopes.get_mut(scope_id.id) { - scope.insert_param(param); - } - } - - fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { - for param in params { - self.append_generic_param(scope_id, param); + return Some(scope.insert_param(param)); } + None } /// Find generic parameter by position and name. - /// return (GenericTplId, constraint, default) - fn find_generic( - &self, - position: TextSize, - name: &str, - ) -> Option<(GenericTplId, Option, Option)> { + fn find_generic(&self, position: TextSize, name: &str) -> Option<(GenericTplId, GenericParam)> { for scope in self.scopes.iter().rev() { if !scope.contains(position) { continue; } if let Some((id, param)) = scope.params.get(name) { - return Some(( - *id, - param.type_constraint.clone(), - param.default_type.clone(), - )); + return Some((*id, param.clone())); + } + } + + None + } + + fn generic_param_mut(&mut self, tpl_id: GenericTplId) -> Option<&mut GenericParam> { + for scope in self.scopes.iter_mut().rev() { + for (id, param) in scope.params.values_mut() { + if *id == tpl_id { + return Some(param); + } } } @@ -131,10 +141,11 @@ impl FileGenericScope { self.next_tpl_id.is_func() } - fn insert_param(&mut self, param: GenericParam) { + fn insert_param(&mut self, param: GenericParam) -> GenericTplId { let tpl_id = self.next_tpl_id; self.next_tpl_id = self.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); self.params.insert(param.name.to_string(), (tpl_id, param)); + tpl_id } fn contains(&self, position: TextSize) -> bool { @@ -175,18 +186,19 @@ impl ConditionalInferIndex { let tpl_id = GenericTplId::ConditionalInfer(self.next_infer_id); self.next_infer_id += 1; + let param = GenericParam::new(SmolStr::new(name), None, None, false, None); let tpl = Arc::new(GenericTpl::new( tpl_id, - SmolStr::new(name).into(), - None, - None, + param.name.clone(), + param.constraint.clone(), + param.default.clone(), + param.is_const, + param.attributes.clone(), )); let scope = &mut self.scopes[scope_idx]; scope.bindings.insert(name.to_string(), tpl.clone()); - scope - .params - .push(GenericParam::new(SmolStr::new(name), None, None, None)); + scope.params.push(param); Some(tpl) } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs index fe2ff27d1..32414eede 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs @@ -1,11 +1,11 @@ use std::sync::Arc; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaComment, LuaDocAttributeType, LuaDocBinaryType, LuaDocConditionalType, - LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericDecl, LuaDocGenericDeclList, - LuaDocGenericType, LuaDocIndexAccessType, LuaDocMappedType, LuaDocMultiLineUnionType, - LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, LuaDocUnaryType, - LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, + LuaAst, LuaAstNode, LuaClosureExpr, LuaComment, LuaDocAttributeType, LuaDocBinaryType, + LuaDocConditionalType, LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericDecl, + LuaDocGenericDeclList, LuaDocGenericType, LuaDocIndexAccessType, LuaDocMappedType, + LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, + LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, LuaTypeUnaryOperator, LuaVarExpr, NumberResult, }; use rowan::TextRange; @@ -13,8 +13,8 @@ use smol_str::SmolStr; use crate::{ AsyncState, DiagnosticCode, FileId, GenericParam, GenericTpl, InFiled, LuaAliasCallKind, - LuaArrayLen, LuaArrayType, LuaAttributeType, LuaMultiLineUnion, LuaTupleStatus, LuaTypeDeclId, - TypeOps, VariadicType, complete_type_generic_args, + LuaArrayLen, LuaArrayType, LuaAttributeType, LuaMultiLineUnion, LuaSignatureId, LuaTupleStatus, + LuaTypeDeclId, TypeOps, VariadicType, complete_type_generic_args, db_index::{ AnalyzeError, DbIndex, LuaAliasCallType, LuaConditionalType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey, LuaIntersectionType, LuaMappedType, LuaObjectType, @@ -109,6 +109,70 @@ impl<'a> DocTypeAnalyzeContext<'a> { .add_type_reference(self.file_id, type_id, range); } } + + // TODO: 为`std.ConstTpl`实现的兼容性代码, 应在下一版本中移除 + fn mark_generic_const(&mut self, tpl: &GenericTpl) -> GenericTpl { + let tpl_id = tpl.get_tpl_id(); + let param = self + .generic_index + .mark_generic_const(tpl_id) + .unwrap_or_else(|| { + let mut param = tpl.get_param().clone(); + param.is_const = true; + param + }); + + if tpl_id.is_func() + && let Some(signature_id) = self.current_signature_id() + && let Some(signature) = self.db.get_signature_index_mut().get_mut(&signature_id) + { + if let Some(signature_param) = signature.generic_params.get_mut(tpl_id.get_idx()) { + signature_param.is_const = true; + } + + for overload in &mut signature.overloads { + let mut generic_params = overload.get_generic_params().to_vec(); + let mut changed = false; + for generic_param in &mut generic_params { + if generic_param.get_tpl_id() == tpl_id && !generic_param.is_const() { + *generic_param = generic_param.with_const(true); + changed = true; + } + } + + if changed { + *overload = Arc::new(LuaFunctionType::new( + overload.get_async_state(), + overload.is_colon_define(), + overload.is_variadic(), + overload.get_params().to_vec(), + overload.get_ret().clone(), + Some(generic_params), + )); + } + } + } + + GenericTpl::new( + tpl_id, + param.name, + param.constraint, + param.default, + true, + param.attributes, + ) + } + + fn current_signature_id(&self) -> Option { + let owner = self.comment.as_ref()?.get_owner()?; + let closure = match owner { + LuaAst::LuaFuncStat(func) => func.get_closure(), + LuaAst::LuaLocalFuncStat(local_func) => local_func.get_closure(), + owner => owner.descendants::().next(), + }?; + + Some(LuaSignatureId::from_closure(self.file_id, &closure)) + } } pub fn infer_type(analyzer: &mut DocTypeAnalyzeContext<'_>, node: LuaDocType) -> LuaType { @@ -256,14 +320,14 @@ fn infer_buildin_or_ref_type( return LuaType::TplRef(tpl); } - if let Some((tpl_id, constraint, default_type)) = - analyzer.generic_index.find_generic(position, name) - { + if let Some((tpl_id, param)) = analyzer.generic_index.find_generic(position, name) { return LuaType::TplRef(Arc::new(GenericTpl::new( tpl_id, - SmolStr::new(name).into(), - constraint, - default_type, + param.name, + param.constraint, + param.default, + param.is_const, + param.attributes, ))); } @@ -484,7 +548,8 @@ fn infer_special_generic_type( let first_doc_param_type = generic_type.get_generic_types()?.get_types().next()?; let first_param = infer_type(analyzer, first_doc_param_type); if let LuaType::TplRef(tpl) = first_param { - return Some(LuaType::ConstTplRef(tpl)); + let const_tpl = analyzer.mark_generic_const(&tpl); + return Some(LuaType::TplRef(Arc::new(const_tpl))); } } "Language" => { @@ -628,9 +693,11 @@ fn infer_unary_type( } fn infer_func_type(analyzer: &mut DocTypeAnalyzeContext<'_>, func: &LuaDocFuncType) -> LuaType { - if let Some(generic_list) = func.get_generic_decl_list() { - register_inline_func_generics(analyzer, func, generic_list); - } + let generic_params = if let Some(generic_list) = func.get_generic_decl_list() { + register_inline_func_generics(analyzer, func, generic_list) + } else { + Vec::new() + }; let mut params_result = Vec::new(); let mut is_variadic = false; @@ -711,6 +778,7 @@ fn infer_func_type(analyzer: &mut DocTypeAnalyzeContext<'_>, func: &LuaDocFuncTy is_variadic, params_result, return_type, + Some(generic_params), ) .into(), ) @@ -720,10 +788,11 @@ fn register_inline_func_generics( analyzer: &mut DocTypeAnalyzeContext<'_>, func: &LuaDocFuncType, generic_list: LuaDocGenericDeclList, -) { +) -> Vec { let scope_id = analyzer .generic_index .add_generic_scope(vec![func.get_range()], true); + let mut generic_params = Vec::new(); for param in generic_list.get_generic_decl() { let Some(name_token) = param.get_name_token() else { continue; @@ -733,16 +802,28 @@ fn register_inline_func_generics( .get_constraint_type() .map(|ty| infer_type(analyzer, ty)); let default_type = param.get_default_type().map(|ty| infer_type(analyzer, ty)); - analyzer.generic_index.append_generic_param( - scope_id, - GenericParam::new( - SmolStr::new(name_token.get_name_text()), - constraint, - default_type, - None, - ), + let generic_param = GenericParam::new( + SmolStr::new(name_token.get_name_text()), + constraint, + default_type, + param.has_const_modifier(), + None, ); + if let Some(tpl_id) = analyzer + .generic_index + .append_generic_param(scope_id, generic_param.clone()) + { + generic_params.push(GenericTpl::new( + tpl_id, + generic_param.name, + generic_param.constraint, + generic_param.default, + generic_param.is_const, + generic_param.attributes, + )); + } } + generic_params } fn get_colon_define(analyzer: &mut DocTypeAnalyzeContext<'_>) -> Option { @@ -963,7 +1044,13 @@ fn infer_mapped_type( let constraint = generic_decl .get_constraint_type() .map(|constraint| infer_type(analyzer, constraint)); - let param = GenericParam::new(SmolStr::new(name), constraint, None, None); + let param = GenericParam::new( + SmolStr::new(name), + constraint, + None, + generic_decl.has_const_modifier(), + None, + ); let scope_id = analyzer .generic_index @@ -972,7 +1059,7 @@ fn infer_mapped_type( .generic_index .append_generic_param(scope_id, param.clone()); let position = mapped_type.get_range().start(); - let (id, _, _) = analyzer.generic_index.find_generic(position, name)?; + let (id, _) = analyzer.generic_index.find_generic(position, name)?; let doc_type = mapped_type.get_value_type()?; let value_type = infer_type(analyzer, doc_type); diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs index 5e4e53a7a..8dfea9613 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs @@ -5,10 +5,10 @@ use crate::{ use super::{ DocAnalyzer, - tags::{find_owner_closure_or_report, get_owner_id_or_report}, + tags::{find_owner_closure_or_report, get_owner_id, get_owner_id_or_report}, }; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaDocDescriptionOwner, LuaDocTagAsync, LuaDocTagDeprecated, + LuaAst, LuaAstNode, LuaDocDescriptionOwner, LuaDocTag, LuaDocTagAsync, LuaDocTagDeprecated, LuaDocTagNodiscard, LuaDocTagReadonly, LuaDocTagSource, LuaDocTagVersion, LuaDocTagVisibility, LuaExpr, }; @@ -105,15 +105,81 @@ pub fn analyze_deprecated(analyzer: &mut DocAnalyzer, tag: LuaDocTagDeprecated) let message = tag .get_description() .map(|desc| desc.get_description_text().to_string()); + + let mut type_owner_id = None; + if let Some(current_type_id) = &analyzer.current_type_id { + type_owner_id = Some(LuaSemanticDeclId::TypeDecl(current_type_id.clone())); + } else { + let file_id = analyzer.file_id; + let workspace_id = analyzer.workspace_id; + let tags = analyzer.comment.get_doc_tags(); + for tag in tags { + match tag { + LuaDocTag::Class(class) => { + if let Some(name_token) = class.get_name_token() { + let name = name_token.get_name_text().to_string(); + if let Some(decl) = analyzer.get_db().get_type_index().find_type_decl( + file_id, + &name, + Some(workspace_id), + ) { + if decl.is_class() { + type_owner_id = Some(LuaSemanticDeclId::TypeDecl(decl.get_id())); + break; + } + } + } + } + LuaDocTag::Alias(alias) => { + if let Some(name_token) = alias.get_name_token() { + let name = name_token.get_name_text().to_string(); + if let Some(decl) = analyzer.get_db().get_type_index().find_type_decl( + file_id, + &name, + Some(workspace_id), + ) { + if decl.is_alias() { + type_owner_id = Some(LuaSemanticDeclId::TypeDecl(decl.get_id())); + break; + } + } + } + } + _ => {} + } + } + } + + if let Some(type_owner_id) = type_owner_id { + add_deprecated(analyzer, type_owner_id, message.clone())?; + let mut compat_owner_id = None; + if let Some(owner) = get_owner_id(analyzer, None, true) { + if let owner @ (LuaSemanticDeclId::LuaDecl(_) | LuaSemanticDeclId::Member(_)) = owner { + compat_owner_id = Some(owner); + } + } + if let Some(compat_owner_id) = compat_owner_id { + add_deprecated(analyzer, compat_owner_id, message)?; + } + return Some(()); + } + let owner_id = get_owner_id_or_report(analyzer, &tag)?; + add_deprecated(analyzer, owner_id, message)?; + + Some(()) +} +fn add_deprecated( + analyzer: &mut DocAnalyzer, + owner_id: LuaSemanticDeclId, + message: Option, +) -> Option<()> { analyzer .type_context .db .get_property_index_mut() - .add_deprecated(analyzer.file_id, owner_id, message); - - Some(()) + .add_deprecated(analyzer.file_id, owner_id, message) } pub fn analyze_version(analyzer: &mut DocAnalyzer, version: LuaDocTagVersion) -> Option<()> { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs index d63ed8290..afd82a7e5 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs @@ -10,15 +10,13 @@ use smol_str::SmolStr; use super::{ DocAnalyzer, infer_type::infer_type, preprocess_description, tags::find_owner_closure, }; -use crate::GenericParam; use crate::compilation::analyzer::doc::tags::report_orphan_tag; use crate::{ DbIndex, LuaTypeCache, LuaTypeDeclId, compilation::analyzer::common::bind_type, - db_index::{ - LuaDeclId, LuaGenericParamInfo, LuaMemberId, LuaSemanticDeclId, LuaSignatureId, LuaType, - }, + db_index::{LuaDeclId, LuaMemberId, LuaSemanticDeclId, LuaSignatureId, LuaType}, }; +use crate::{GenericParam, LuaFunctionType}; use std::{collections::HashSet, sync::Arc, vec}; pub fn analyze_class(analyzer: &mut DocAnalyzer, tag: LuaDocTagClass) -> Option<()> { @@ -384,8 +382,7 @@ pub fn analyze_func_generic(analyzer: &mut DocAnalyzer, tag: LuaDocTagGeneric) - let Some(name_token) = param.get_name_token() else { continue; }; - let name_text = name_token.get_name_text().to_string(); - let smol_name = SmolStr::new(name_text.as_str()); + let smol_name = SmolStr::new(name_token.get_name_text()); let type_ref = param .get_constraint_type() @@ -394,22 +391,18 @@ pub fn analyze_func_generic(analyzer: &mut DocAnalyzer, tag: LuaDocTagGeneric) - .get_default_type() .map(|type_ref| infer_type(&mut analyzer.type_context, type_ref)); - analyzer.type_context.generic_index.append_generic_param( - scope_id, - GenericParam::new( - smol_name.clone(), - type_ref.clone(), - default_type.clone(), - None, - ), - ); - - param_info.push(Arc::new(LuaGenericParamInfo::new( - name_text, + let generic_param = GenericParam::new( + smol_name, type_ref, default_type, + param.has_const_modifier(), None, - ))); + ); + analyzer + .type_context + .generic_index + .append_generic_param(scope_id, generic_param.clone()); + param_info.push(generic_param); } } @@ -420,6 +413,26 @@ pub fn analyze_func_generic(analyzer: &mut DocAnalyzer, tag: LuaDocTagGeneric) - .get_signature_index_mut() .get_or_create(signature_id); signature.generic_params = param_info; + let signature_generic_params = signature.get_function_generic_params(); + for overload in &mut signature.overloads { + let mut generic_params = signature_generic_params.clone(); + for generic_param in overload.get_generic_params() { + if !generic_params + .iter() + .any(|tpl| tpl.get_tpl_id() == generic_param.get_tpl_id()) + { + generic_params.push(generic_param.clone()); + } + } + *overload = Arc::new(LuaFunctionType::new( + overload.get_async_state(), + overload.is_colon_define(), + overload.is_variadic(), + overload.get_params().to_vec(), + overload.get_ret().clone(), + Some(generic_params), + )); + } Some(()) } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs index ed8601db4..8d9048e5c 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs @@ -59,13 +59,14 @@ fn normalize_generic_params(db: &DbIndex, params: &[GenericParam]) -> Vec Option { + let scope = self.scopes.get_mut(scope_id.id)?; let tpl_id = scope.next_tpl_id; scope.next_tpl_id = scope.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); scope.params.push((tpl_id, param)); + Some(tpl_id) } - fn find_generic( - &self, - position: TextSize, - name: &str, - ) -> Option<(GenericTplId, Option, Option)> { + fn find_generic(&self, position: TextSize, name: &str) -> Option<(GenericTplId, GenericParam)> { for scope in self.scopes.iter().rev() { if !scope.contains(position) { continue; @@ -181,11 +187,17 @@ impl GenericIndex for HeaderGenericIndex { .rev() .find(|(_, param)| param.name == name) { - return Some(( - *tpl_id, - param.type_constraint.clone(), - param.default_type.clone(), - )); + return Some((*tpl_id, param.clone())); + } + } + + None + } + + fn generic_param_mut(&mut self, tpl_id: GenericTplId) -> Option<&mut GenericParam> { + for scope in self.scopes.iter_mut().rev() { + if let Some((_, param)) = scope.params.iter_mut().find(|(id, _)| *id == tpl_id) { + return Some(param); } } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs index de0cc3320..3a1c8a821 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs @@ -4,6 +4,7 @@ use emmylua_parser::{ LuaDocTagReturnCast, LuaDocTagReturnOverload, LuaDocTagSchema, LuaDocTagSee, LuaDocTagType, LuaExpr, LuaLocalName, LuaTokenKind, LuaVarExpr, }; +use std::sync::Arc; use super::{ DocAnalyzer, @@ -12,8 +13,8 @@ use super::{ tags::{find_owner_closure, get_owner_id_or_report}, }; use crate::{ - InFiled, JsonSchemaFile, LuaOperatorMetaMethod, LuaTypeCache, LuaTypeOwner, OperatorFunction, - SignatureReturnStatus, TypeOps, + InFiled, JsonSchemaFile, LuaFunctionType, LuaOperatorMetaMethod, LuaTypeCache, LuaTypeOwner, + OperatorFunction, SignatureReturnStatus, TypeOps, compilation::analyzer::common::bind_type, db_index::{ LuaDeclId, LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaMemberId, @@ -375,6 +376,23 @@ pub fn analyze_overload(analyzer: &mut DocAnalyzer, tag: LuaDocTagOverload) -> O .db .get_signature_index_mut() .get_or_create(id); + let mut generic_params = signature.get_function_generic_params(); + for generic_param in func.get_generic_params() { + if !generic_params + .iter() + .any(|tpl| tpl.get_tpl_id() == generic_param.get_tpl_id()) + { + generic_params.push(generic_param.clone()); + } + } + let func = Arc::new(LuaFunctionType::new( + func.get_async_state(), + func.is_colon_define(), + func.is_variadic(), + func.get_params().to_vec(), + func.get_ret().clone(), + Some(generic_params), + )); signature.overloads.push(func); } } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs index d6605830b..8ece392f4 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs @@ -3,7 +3,7 @@ use emmylua_parser::{LuaAstToken, LuaExpr, LuaForRangeStat}; use crate::{ DbIndex, InferFailReason, LuaDeclId, LuaInferCache, LuaOperatorMetaMethod, LuaType, LuaTypeCache, TplContext, TypeOps, TypeSubstitutor, VariadicType, - compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_doc_function, + compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_type_generic, tpl_pattern_match_args, }; @@ -145,6 +145,12 @@ pub fn infer_for_range_iter_expr_func( return Ok(doc_function.get_variadic_ret()); }; let mut substitutor = TypeSubstitutor::new(); + let generic_tpls = doc_function + .get_generic_params() + .iter() + .map(|generic_tpl| generic_tpl.get_tpl_id()) + .collect(); + substitutor.add_need_infer_tpls(generic_tpls); let mut context = TplContext { db, cache, @@ -159,8 +165,9 @@ pub fn infer_for_range_iter_expr_func( tpl_pattern_match_args(&mut context, ¶ms, &[status_param])?; + let doc_function_type = LuaType::DocFunction(doc_function.clone()); let instantiate_func = if let LuaType::DocFunction(f) = - instantiate_doc_function(db, &doc_function, &substitutor) + instantiate_type_generic(db, &doc_function_type, &substitutor) { f } else { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs index 120588356..c37561eb1 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs @@ -409,6 +409,7 @@ fn resolve_closure_member_type( signature.is_vararg, final_params, final_ret, + Some(signature.get_function_generic_params()), ), self_type, ) diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 3a1b462b9..b85997659 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -764,10 +764,7 @@ mod test { .expect("Box generic params"); assert_eq!(box_params.len(), 1); assert_eq!(box_params[0].name.as_str(), "T"); - let box_default = box_params[0] - .default_type - .clone() - .expect("Box default type"); + let box_default = box_params[0].default.clone().expect("Box default type"); assert_eq!(ws.humanize_type(box_default), "string"); let optional_params = ws @@ -780,7 +777,7 @@ mod test { assert_eq!(optional_params.len(), 1); assert_eq!(optional_params[0].name.as_str(), "T"); let optional_default = optional_params[0] - .default_type + .default .clone() .expect("Optional default type"); assert_eq!(ws.humanize_type(optional_default), "number"); @@ -810,12 +807,105 @@ mod test { assert_eq!(signature.generic_params.len(), 1); assert_eq!(signature.generic_params[0].name, "T"); let default_type = signature.generic_params[0] - .default_type + .default .clone() .expect("signature default type"); assert_eq!(ws.humanize_type(default_type), "string"); } + #[test] + fn test_generic_const_metadata_storage() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@class Box + + ---@generic const R, S + ---@return R + local function id() + end + + ---@alias Mapper fun(value: A): B + "#, + ); + + let db = ws.analysis.compilation.get_db(); + let box_params = db + .get_type_index() + .get_generic_params(&LuaTypeDeclId::global("Box")) + .expect("Box generic params"); + assert_eq!(box_params.len(), 2); + assert_eq!(box_params[0].name.as_str(), "T"); + assert!(box_params[0].is_const); + assert_eq!(box_params[1].name.as_str(), "U"); + assert!(!box_params[1].is_const); + + let closure = ws.get_node::(file_id); + let signature_id = LuaSignatureId::from_closure(file_id, &closure); + let signature = db + .get_signature_index() + .get(&signature_id) + .expect("signature"); + assert_eq!(signature.generic_params.len(), 2); + assert_eq!(signature.generic_params[0].name.as_str(), "R"); + assert!(signature.generic_params[0].is_const); + assert_eq!(signature.generic_params[1].name.as_str(), "S"); + assert!(!signature.generic_params[1].is_const); + + let function_generic_params = signature.get_function_generic_params(); + assert!(function_generic_params[0].is_const()); + assert!(!function_generic_params[1].is_const()); + + let mapper_decl = db + .get_type_index() + .get_type_decl(&LuaTypeDeclId::global("Mapper")) + .expect("Mapper alias"); + let mapper_origin = mapper_decl.get_alias_ref().expect("Mapper alias origin"); + let LuaType::DocFunction(mapper_func) = mapper_origin else { + panic!("expected Mapper alias to be a function type"); + }; + let mapper_generic_params = mapper_func.get_generic_params(); + assert_eq!(mapper_generic_params.len(), 2); + assert_eq!(mapper_generic_params[0].get_name(), "A"); + assert!(mapper_generic_params[0].is_const()); + assert_eq!(mapper_generic_params[1].get_name(), "B"); + assert!(!mapper_generic_params[1].is_const()); + } + + #[test] + fn test_legacy_const_tpl_marks_generic_param_metadata() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@alias std.ConstTpl unknown + + ---@generic T + ---@param value std.ConstTpl + ---@return T + function id(value) + end + + result = id(1) + "#, + ); + + let closure = ws.get_node::(file_id); + let signature_id = LuaSignatureId::from_closure(file_id, &closure); + { + let signature = ws + .analysis + .compilation + .get_db() + .get_signature_index() + .get(&signature_id) + .expect("signature"); + assert_eq!(signature.generic_params.len(), 1); + assert!(signature.generic_params[0].is_const); + } + + assert_eq!(ws.expr_ty("result"), LuaType::IntegerConst(1)); + } + #[test] fn test_bare_generic_type_uses_default() { let mut ws = VirtualWorkspace::new(); @@ -954,7 +1044,7 @@ mod test { .get_type_index() .get_generic_params(&LuaTypeDeclId::global("B")) .expect("B generic params"); - let default_type = b_params[0].default_type.clone().expect("B default type"); + let default_type = b_params[0].default.clone().expect("B default type"); assert_eq!(ws.humanize_type(default_type), "A"); } @@ -982,7 +1072,7 @@ mod test { .get_type_index() .get_generic_params(&LuaTypeDeclId::global("B")) .expect("B generic params"); - let default_type = b_params[0].default_type.clone().expect("B default type"); + let default_type = b_params[0].default.clone().expect("B default type"); assert_eq!(ws.humanize_type(default_type), "A"); } @@ -1223,8 +1313,6 @@ mod test { r#" ---@alias std.RawGet unknown - ---@alias std.ConstTpl unknown - ---@generic T, K extends keyof T ---@param object T ---@param key K diff --git a/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs b/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs index de64a4739..26538fbe9 100644 --- a/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs +++ b/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs @@ -139,6 +139,7 @@ impl LuaOperator { ("arg0".to_string(), Some(param.clone())), ], ret.clone(), + None, ) .into(), OperatorFunction::UnOp { ret } => LuaFunctionType::new( @@ -147,6 +148,7 @@ impl LuaOperator { false, vec![("self".to_string(), Some(LuaType::SelfInfer))], ret.clone(), + None, ) .into(), OperatorFunction::Call { params, ret } => { @@ -165,8 +167,15 @@ impl LuaOperator { }) .collect(); - LuaFunctionType::new(AsyncState::None, false, is_variadic, params, ret.clone()) - .into() + LuaFunctionType::new( + AsyncState::None, + false, + is_variadic, + params, + ret.clone(), + None, + ) + .into() } OperatorFunction::Overload(func) => { LuaType::DocFunction(func.to_call_operator_func_type()) @@ -183,6 +192,7 @@ impl LuaOperator { signature.is_vararg, signature.get_type_params(), get_constructor_return_type(signature, return_mode), + Some(signature.get_function_generic_params()), ) .into(), None => LuaType::Signature(*id), diff --git a/crates/emmylua_code_analysis/src/db_index/signature/mod.rs b/crates/emmylua_code_analysis/src/db_index/signature/mod.rs index 46983b02c..25e03f603 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/mod.rs @@ -7,8 +7,8 @@ use hashbrown::{HashMap, HashSet}; pub use async_state::AsyncState; pub use signature::{ - LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaGenericParamInfo, LuaNoDiscard, - LuaSignature, LuaSignatureId, SignatureReturnStatus, + LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaNoDiscard, LuaSignature, + LuaSignatureId, SignatureReturnStatus, }; use crate::FileId; diff --git a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs index d2a6b818d..0271c3213 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs @@ -9,7 +9,7 @@ use rowan::TextSize; use super::return_rows; use crate::db_index::signature::async_state::AsyncState; use crate::{ - FileId, + FileId, GenericParam, GenericTpl, GenericTplId, db_index::{LuaFunctionType, LuaType}, }; use crate::{ @@ -19,7 +19,7 @@ use crate::{ #[derive(Debug)] pub struct LuaSignature { - pub generic_params: Vec>, + pub generic_params: Vec, pub overloads: Vec>, pub param_docs: HashMap, pub params: Vec, @@ -172,6 +172,7 @@ impl LuaSignature { is_vararg, params, return_type, + Some(self.get_function_generic_params()), ); Arc::new(func_type) } @@ -183,10 +184,33 @@ impl LuaSignature { } let return_type = self.get_return_type(); - let func_type = - LuaFunctionType::new(self.async_state, false, self.is_vararg, params, return_type); + let func_type = LuaFunctionType::new( + self.async_state, + false, + self.is_vararg, + params, + return_type, + Some(self.get_function_generic_params()), + ); Arc::new(func_type) } + + pub fn get_function_generic_params(&self) -> Vec { + self.generic_params + .iter() + .enumerate() + .map(|(idx, param)| { + GenericTpl::new( + GenericTplId::Func(idx as u32), + param.name.clone(), + param.constraint.clone(), + param.default.clone(), + param.is_const, + param.attributes.clone(), + ) + }) + .collect() + } } #[derive(Debug)] @@ -306,27 +330,3 @@ pub enum SignatureReturnStatus { DocResolve, InferResolve, } - -#[derive(Debug, Clone)] -pub struct LuaGenericParamInfo { - pub name: String, - pub constraint: Option, - pub default_type: Option, - pub attributes: Option>, -} - -impl LuaGenericParamInfo { - pub fn new( - name: String, - constraint: Option, - default_type: Option, - attributes: Option>, - ) -> Self { - Self { - name, - constraint, - default_type, - attributes, - } - } -} diff --git a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs index 1a66b2031..1a0895af8 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs @@ -5,22 +5,25 @@ use crate::{LuaAttributeUse, LuaType}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct GenericParam { pub name: SmolStr, - pub type_constraint: Option, - pub default_type: Option, + pub constraint: Option, + pub default: Option, + pub is_const: bool, pub attributes: Option>, } impl GenericParam { pub fn new( name: SmolStr, - type_constraint: Option, - default_type: Option, + constraint: Option, + default: Option, + is_const: bool, attributes: Option>, ) -> Self { Self { name, - type_constraint, - default_type, + constraint, + default, + is_const, attributes, } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs index f26732263..1103d689c 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs @@ -215,7 +215,6 @@ impl<'a> TypeHumanizer<'a> { self.level = saved; w.write_char('>') } - LuaType::ConstTplRef(const_tpl) => w.write_str(const_tpl.get_name()), LuaType::Language(s) => w.write_str(s), LuaType::Conditional(c) => self.write_conditional_type(c, w), LuaType::Never => w.write_str("never"), diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs b/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs index 1e7750c37..f3e1c0bc9 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs @@ -5,7 +5,9 @@ use smol_str::SmolStr; use std::{ops::Deref, sync::Arc}; use crate::db_index::LuaMemberKey; -use crate::{AsyncState, DbIndex, InFiled, SemanticModel, first_param_may_not_self}; +use crate::{ + AsyncState, DbIndex, InFiled, LuaAttributeUse, SemanticModel, first_param_may_not_self, +}; use super::super::basic_union::{BasicTypeKind, BasicTypeUnion}; use super::super::generic_param::GenericParam; @@ -107,6 +109,7 @@ pub struct LuaFunctionType { async_state: AsyncState, is_colon_define: bool, is_variadic: bool, + generic_params: Option>, params: Vec<(String, Option)>, ret: LuaType, } @@ -118,11 +121,14 @@ impl LuaFunctionType { is_variadic: bool, params: Vec<(String, Option)>, ret: LuaType, + generic_params: Option>, ) -> Self { + let generic_params = generic_params.filter(|params| !params.is_empty()); Self { async_state, is_colon_define, is_variadic, + generic_params, params, ret, } @@ -140,6 +146,10 @@ impl LuaFunctionType { &self.params } + pub fn get_generic_params(&self) -> &[GenericTpl] { + self.generic_params.as_deref().unwrap_or(&[]) + } + pub fn get_ret(&self) -> &LuaType { &self.ret } @@ -213,6 +223,7 @@ impl LuaFunctionType { self.is_variadic, params, self.ret.clone(), + self.generic_params.clone(), )) } } @@ -745,26 +756,24 @@ impl GenericTplId { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct GenericTpl { tpl_id: GenericTplId, - name: ArcIntern, - constraint: Option, - default_type: Option, + param: GenericParam, } impl GenericTpl { pub fn new( tpl_id: GenericTplId, - name: ArcIntern, + name: SmolStr, constraint: Option, default_type: Option, + is_const: bool, + attributes: Option>, ) -> Self { Self { tpl_id, - name, - constraint, - default_type, + param: GenericParam::new(name, constraint, default_type, is_const, attributes), } } @@ -772,16 +781,33 @@ impl GenericTpl { self.tpl_id } + pub fn get_param(&self) -> &GenericParam { + &self.param + } + pub fn get_name(&self) -> &str { - &self.name + self.param.name.as_str() + } + + pub fn is_const(&self) -> bool { + self.param.is_const + } + + pub fn with_const(&self, is_const: bool) -> Self { + let mut param = self.param.clone(); + param.is_const = is_const; + Self { + tpl_id: self.tpl_id, + param, + } } pub fn get_constraint(&self) -> Option<&LuaType> { - self.constraint.as_ref() + self.param.constraint.as_ref() } pub fn get_default_type(&self) -> Option<&LuaType> { - self.default_type.as_ref() + self.param.default.as_ref() } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs b/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs index 1ae5acb1d..d30e8c881 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs @@ -57,7 +57,6 @@ pub enum LuaType { Call(Arc), MultiLineUnion(Arc), TypeGuard(Arc), - ConstTplRef(Arc), Language(ArcIntern), ModuleRef(FileId), DocAttribute(Arc), @@ -110,7 +109,6 @@ impl PartialEq for LuaType { (LuaType::MultiLineUnion(a), LuaType::MultiLineUnion(b)) => a == b, (LuaType::TypeGuard(a), LuaType::TypeGuard(b)) => a == b, (LuaType::Never, LuaType::Never) => true, - (LuaType::ConstTplRef(a), LuaType::ConstTplRef(b)) => a == b, (LuaType::Language(a), LuaType::Language(b)) => a == b, (LuaType::ModuleRef(a), LuaType::ModuleRef(b)) => a == b, (LuaType::DocAttribute(a), LuaType::DocAttribute(b)) => a == b, @@ -168,12 +166,11 @@ impl Hash for LuaType { LuaType::MultiLineUnion(a) => (43, Arc::as_ptr(a)).hash(state), LuaType::TypeGuard(a) => (44, Arc::as_ptr(a)).hash(state), LuaType::Never => 45.hash(state), - LuaType::ConstTplRef(a) => (46, Arc::as_ptr(a)).hash(state), - LuaType::Language(a) => (47, a).hash(state), - LuaType::ModuleRef(a) => (48, a).hash(state), - LuaType::Conditional(a) => (49, Arc::as_ptr(a)).hash(state), - LuaType::Mapped(a) => (50, Arc::as_ptr(a)).hash(state), - LuaType::DocAttribute(a) => (51, a).hash(state), + LuaType::Language(a) => (46, a).hash(state), + LuaType::ModuleRef(a) => (47, a).hash(state), + LuaType::Conditional(a) => (48, Arc::as_ptr(a)).hash(state), + LuaType::Mapped(a) => (49, Arc::as_ptr(a)).hash(state), + LuaType::DocAttribute(a) => (50, a).hash(state), } } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs b/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs index 37cfc40c9..edcb11819 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs @@ -261,9 +261,10 @@ impl LuaType { match ty { LuaType::TplRef(_) | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) | LuaType::SelfInfer - | LuaType::Mapped(_) => return true, + | LuaType::Mapped(_) => { + return true; + } _ => ty.push_direct_children(&mut stack), } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/test.rs b/crates/emmylua_code_analysis/src/db_index/type/types/test.rs index ec63747de..f0ec046f6 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/test.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use internment::ArcIntern; + use smol_str::SmolStr; use std::mem::ManuallyDrop; @@ -25,8 +25,10 @@ mod tests { let mut ty = LuaType::TplRef( GenericTpl::new( GenericTplId::Type(0), - ArcIntern::new(SmolStr::new("T")), + SmolStr::new("T"), + None, None, + false, None, ) .into(), diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs b/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs index 9cb73bdcc..c40bd0cf0 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs @@ -49,7 +49,6 @@ pub trait LuaTypeNode { ty, LuaType::TplRef(_) | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) | LuaType::SelfInfer | LuaType::Mapped(_) ) @@ -62,7 +61,6 @@ pub trait LuaTypeNode { ty, LuaType::TplRef(_) | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) | LuaType::SelfInfer | LuaType::Mapped(_) ) @@ -244,10 +242,10 @@ impl LuaTypeNode for LuaConditionalType { impl LuaTypeNode for LuaMappedType { fn push_direct_children<'a>(&'a self, stack: &mut Vec<&'a LuaType>) { stack.push(&self.value); - if let Some(constraint) = self.param.1.type_constraint.as_ref() { + if let Some(constraint) = self.param.1.constraint.as_ref() { stack.push(constraint); } - if let Some(default_type) = self.param.1.default_type.as_ref() { + if let Some(default_type) = self.param.1.default.as_ref() { stack.push(default_type); } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/call_non_callable.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/call_non_callable.rs index 2da6fa2de..6aa335139 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/call_non_callable.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/call_non_callable.rs @@ -150,7 +150,7 @@ fn has_non_callable_member(db: &DbIndex, typ: &LuaType) -> bool { LuaType::Any | LuaType::Unknown | LuaType::SelfInfer | LuaType::Global | LuaType::Nil => { false } - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl + LuaType::TplRef(tpl) => tpl .get_constraint() .is_some_and(|constraint| has_non_callable_member(db, constraint)), LuaType::StrTplRef(str_tpl) => str_tpl diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs index 96af17edd..e4b9f9268 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs @@ -1,4 +1,4 @@ -use emmylua_parser::{LuaAst, LuaAstNode, LuaIndexExpr, LuaNameExpr}; +use emmylua_parser::{LuaAst, LuaAstNode, LuaDocNameType, LuaIndexExpr, LuaNameExpr}; use crate::{ DiagnosticCode, LuaDeclId, LuaDeprecated, LuaMemberId, LuaSemanticDeclId, SemanticDeclLevel, @@ -22,6 +22,9 @@ impl Checker for DeprecatedChecker { LuaAst::LuaIndexExpr(index_expr) => { check_index_expr(context, semantic_model, index_expr); } + LuaAst::LuaDocNameType(name_type) => { + check_doc_name_type(context, semantic_model, name_type); + } _ => {} } } @@ -74,6 +77,32 @@ fn check_index_expr( Some(()) } +fn check_doc_name_type( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + name_type: LuaDocNameType, +) -> Option<()> { + let semantic_decl = semantic_model.find_decl( + rowan::NodeOrToken::Node(name_type.syntax().clone()), + SemanticDeclLevel::default(), + )?; + + let LuaSemanticDeclId::TypeDecl(_) = &semantic_decl else { + return Some(()); + }; + + if let Some(deprecated_message) = get_deprecated_message(semantic_model, &semantic_decl) { + context.add_diagnostic( + DiagnosticCode::Deprecated, + name_type.get_range(), + deprecated_message, + None, + ); + } + + Some(()) +} + fn check_deprecated( context: &mut DiagnosticContext, semantic_model: &SemanticModel, @@ -87,14 +116,11 @@ fn check_deprecated( let Some(property) = property else { return; }; - if let Some(deprecated) = property.deprecated() { - let deprecated_message = match deprecated { - LuaDeprecated::Deprecated => "deprecated".to_string(), - LuaDeprecated::DeprecatedWithMessage(message) => message.to_string(), - }; + if let Some(deprecated_message) = get_deprecated_message(semantic_model, semantic_decl) { context.add_diagnostic(DiagnosticCode::Deprecated, range, deprecated_message, None); } + // 检查特性 if let Some(attribute_uses) = property.attribute_uses() { for attribute_use in attribute_uses.iter() { @@ -105,3 +131,23 @@ fn check_deprecated( } } } + +fn get_deprecated_message( + semantic_model: &SemanticModel, + semantic_decl: &LuaSemanticDeclId, +) -> Option { + let property = semantic_model + .get_db() + .get_property_index() + .get_property(semantic_decl); + let property = property?; + if let Some(deprecated) = property.deprecated() { + let deprecated_message = match deprecated { + LuaDeprecated::Deprecated => "deprecated".to_string(), + LuaDeprecated::DeprecatedWithMessage(message) => message.to_string(), + }; + return Some(deprecated_message); + } + + None +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs index cc5594eff..ce67c0cb6 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs @@ -103,7 +103,7 @@ fn check_doc_tag_class( .get_generic_params(&type_decl.get_id())?; let generic_param_types = generic_params .iter() - .map(|param| (param.type_constraint.clone(), param.default_type.clone())) + .map(|param| (param.constraint.clone(), param.default.clone())) .collect::>(); check_generic_decl_defaults( context, @@ -133,7 +133,7 @@ fn check_doc_tag_alias( .get_generic_params(&type_decl.get_id())?; let generic_param_types = generic_params .iter() - .map(|param| (param.type_constraint.clone(), param.default_type.clone())) + .map(|param| (param.constraint.clone(), param.default.clone())) .collect::>(); check_generic_decl_defaults( context, @@ -158,7 +158,7 @@ fn check_doc_tag_generic( let generic_param_types = signature .generic_params .iter() - .map(|param| (param.constraint.clone(), param.default_type.clone())) + .map(|param| (param.constraint.clone(), param.default.clone())) .collect::>(); check_generic_decl_defaults( context, @@ -467,7 +467,7 @@ fn check_variadic_default_satisfies_constraint( fn generic_tpl_id(ty: &LuaType) -> Option { match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => Some(tpl.get_tpl_id()), + LuaType::TplRef(tpl) => Some(tpl.get_tpl_id()), LuaType::StrTplRef(str_tpl) => Some(str_tpl.get_tpl_id()), _ => None, } @@ -475,7 +475,7 @@ fn generic_tpl_id(ty: &LuaType) -> Option { fn generic_upper_bound(ty: &LuaType) -> Option<&LuaType> { match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_constraint(), + LuaType::TplRef(tpl) => tpl.get_constraint(), LuaType::StrTplRef(str_tpl) => str_tpl.get_constraint(), _ => None, } @@ -491,7 +491,7 @@ fn instantiate_decl_default_for_check(ty: &LuaType) -> LuaType { fn instantiate_decl_type_for_check(ty: &LuaType, use_generic_upper_bound: bool) -> LuaType { match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + LuaType::TplRef(tpl) => { if use_generic_upper_bound && let Some(constraint) = tpl.get_constraint() { return instantiate_decl_default_for_check(constraint); } @@ -638,7 +638,7 @@ fn check_doc_tag_type( .take(explicit_args.len()) .enumerate() { - let extend_type = generic_params.get(i)?.type_constraint.clone()?; + let extend_type = generic_params.get(i)?.constraint.clone()?; let result = semantic_model.type_check_detail(&extend_type, param_type); if result.is_err() { add_type_check_diagnostic( @@ -702,7 +702,7 @@ fn check_param( extend_type, ); } - LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => { + LuaType::TplRef(tpl_ref) => { let extend_type = tpl_ref.get_constraint().cloned().map(|ty| { normalize_constraint_type( semantic_model.get_db(), diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/deprecated_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/deprecated_test.rs new file mode 100644 index 000000000..5dfb44715 --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/test/deprecated_test.rs @@ -0,0 +1,174 @@ +#[cfg(test)] +mod test { + use emmylua_parser::{LuaAstNode, LuaLocalName}; + + use crate::{DiagnosticCode, LuaDeclId, LuaSemanticDeclId, VirtualWorkspace}; + + fn assert_type_decl_deprecated(content: &str, name: &str) { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def(content); + let db = ws.analysis.compilation.get_db(); + let type_decl = db + .get_type_index() + .find_type_decl(file_id, name, db.resolve_workspace_id(file_id)) + .expect("type declaration must exist"); + let property = db + .get_property_index() + .get_property(&LuaSemanticDeclId::TypeDecl(type_decl.get_id())) + .expect("type declaration property must exist"); + + assert!(property.deprecated().is_some()); + } + + fn assert_lua_decl_deprecated(content: &str, name: &str) { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def(content); + let db = ws.analysis.compilation.get_db(); + let local_name = ws.get_node::(file_id); + assert_eq!(local_name.get_text(), name); + let decl = db + .get_decl_index() + .get_decl(&LuaDeclId::new(file_id, local_name.get_position())) + .expect("declaration must exist"); + let property = db + .get_property_index() + .get_property(&LuaSemanticDeclId::LuaDecl(decl.get_id())) + .expect("declaration property must exist"); + + assert!(property.deprecated().is_some()); + } + + #[test] + fn test_deprecated_alias_use() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::Deprecated, + r#" + ---@deprecated test + ---@alias std.ConstTpl unknown + "# + )); + } + + #[test] + fn test_deprecated_alias_no_usage_error() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::AnnotationUsageError, + r#" + ---@deprecated test + ---@alias std.ConstTpl unknown + "# + )); + } + + #[test] + fn test_deprecated_alias_attaches_to_type_decl() { + assert_type_decl_deprecated( + r#" + ---@deprecated test + ---@alias ConstTpl unknown + "#, + "ConstTpl", + ); + } + + #[test] + fn test_deprecated_alias_after_alias_attaches_to_type_decl() { + assert_type_decl_deprecated( + r#" + ---@alias ConstTpl unknown + ---@deprecated test + "#, + "ConstTpl", + ); + } + + #[test] + fn test_deprecated_class_no_usage_error() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::AnnotationUsageError, + r#" + ---@deprecated test + ---@class Foo + "# + )); + } + + #[test] + fn test_deprecated_class_attaches_to_type_decl() { + assert_type_decl_deprecated( + r#" + ---@deprecated test + ---@class Foo + local Foo = {} + "#, + "Foo", + ); + } + + #[test] + fn test_deprecated_class_usage_diagnostic() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::Deprecated, + r#" + ---@deprecated test + ---@class Foo + local Foo = {} + + local x = Foo + "# + )); + } + + #[test] + fn test_deprecated_class_type_annotation_diagnostic() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::Deprecated, + r#" + ---@deprecated + ---@class A + + ---@type A + local a + "# + )); + } + + #[test] + fn test_deprecated_class_param_annotation_diagnostic() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::Deprecated, + r#" + ---@deprecated + ---@class A + + ---@param a A + local function f(a) + end + "# + )); + } + + #[test] + fn test_deprecated_class_after_class_attaches_to_decl() { + assert_lua_decl_deprecated( + r#" + ---@class Foo + ---@deprecated test + local Foo = {} + "#, + "Foo", + ); + } +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs index 97bcd11c9..ad4e896c1 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs @@ -5,6 +5,7 @@ mod call_non_callable_test; mod cast_type_mismatch_test; mod check_return_count_test; mod code_style; +mod deprecated_test; mod disable_line_test; mod duplicate_field_test; mod duplicate_index_test; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs index 1ad1c18ef..9f1532f5f 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs @@ -5,11 +5,12 @@ use hashbrown::HashSet; use rowan::TextRange; use crate::{ - DbIndex, DocTypeInferContext, GenericTpl, GenericTplId, LuaFunctionType, LuaSemanticDeclId, - LuaType, LuaTypeNode, SemanticDeclLevel, SemanticModel, TypeOps, TypeSubstitutor, VariadicType, - infer_doc_type, + DbIndex, DocTypeInferContext, GenericTplId, LuaFunctionType, LuaSemanticDeclId, LuaType, + SemanticDeclLevel, SemanticModel, TypeOps, TypeSubstitutor, VariadicType, infer_doc_type, }; +use super::{TplContext, tpl_pattern_match_args}; + // 泛型约束上下文 pub struct CallConstraintContext { pub params: Vec<(String, Option)>, @@ -31,7 +32,12 @@ pub fn build_call_constraint_context( let mut params = doc_func.get_params().to_vec(); let mut args = get_arg_infos(semantic_model, call_expr)?; let mut substitutor = TypeSubstitutor::new(); - let generic_tpls = collect_func_tpl_ids(¶ms); + let generic_tpls = doc_func + .get_generic_params() + .iter() + .map(|generic_tpl| generic_tpl.get_tpl_id()) + .filter(GenericTplId::is_func) + .collect::>(); if !generic_tpls.is_empty() { substitutor.add_need_infer_tpls(generic_tpls); } @@ -65,7 +71,22 @@ pub fn build_call_constraint_context( } } - collect_generic_assignments(&mut substitutor, ¶ms, &args); + // 使用模式匹配推导泛型 + let mut cache = semantic_model.get_cache().borrow_mut(); + let mut context = TplContext { + db: semantic_model.get_db(), + cache: &mut cache, + substitutor: &mut substitutor, + call_expr: Some(call_expr.clone()), + }; + + let param_types: Vec = params + .iter() + .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) + .collect(); + let arg_types: Vec = args.iter().map(|arg| arg.check_type.clone()).collect(); + + let _ = tpl_pattern_match_args(&mut context, ¶m_types, &arg_types); Some(CallConstraintContext { params, @@ -82,205 +103,6 @@ pub fn normalize_constraint_type(db: &DbIndex, ty: LuaType) -> LuaType { } } -// 收集各个参数对应的泛型推导 -fn collect_generic_assignments( - substitutor: &mut TypeSubstitutor, - params: &[(String, Option)], - args: &[CallConstraintArg], -) { - for (idx, (_, param_type)) in params.iter().enumerate() { - let Some(param_type) = param_type else { - continue; - }; - let Some(arg) = args.get(idx) else { - continue; - }; - record_generic_assignment(param_type, &arg.check_type, substitutor); - } -} - -fn collect_func_tpl_ids(params: &[(String, Option)]) -> HashSet { - let mut generic_tpls = HashSet::new(); - for (_, param_type) in params { - let Some(param_type) = param_type else { - continue; - }; - collect_func_tpls_from_param_type(param_type, &mut generic_tpls); - } - - generic_tpls -} - -fn collect_func_tpls_from_param_type(ty: &LuaType, generic_tpls: &mut HashSet) { - collect_func_tpl_from_param_node(ty, generic_tpls); - ty.visit_nested_types(&mut |ty| { - collect_func_tpl_from_param_node(ty, generic_tpls); - }); -} - -fn collect_func_tpl_from_param_node(ty: &LuaType, generic_tpls: &mut HashSet) { - match ty { - LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { - collect_func_tpl_with_fallback_deps(generic_tpl, generic_tpls); - } - LuaType::StrTplRef(str_tpl) => { - let tpl_id = str_tpl.get_tpl_id(); - if tpl_id.is_func() { - generic_tpls.insert(tpl_id); - if let Some(constraint) = str_tpl.get_constraint() { - let mut constraint_deps = HashSet::new(); - if collect_func_tpl_deps_from_fallback_type( - constraint, - &mut constraint_deps, - &mut HashSet::new(), - ) { - generic_tpls.extend(constraint_deps); - } - } - } - } - _ => {} - } -} - -fn collect_func_tpl_with_fallback_deps( - generic_tpl: &GenericTpl, - generic_tpls: &mut HashSet, -) { - let tpl_id = generic_tpl.get_tpl_id(); - if !tpl_id.is_func() { - return; - } - - generic_tpls.insert(tpl_id); - - let Some(fallback_type) = generic_tpl - .get_default_type() - .or(generic_tpl.get_constraint()) - else { - return; - }; - - let mut fallback_deps = HashSet::new(); - let mut visiting_fallbacks = HashSet::new(); - visiting_fallbacks.insert(tpl_id); - if collect_func_tpl_deps_from_fallback_type( - fallback_type, - &mut fallback_deps, - &mut visiting_fallbacks, - ) { - generic_tpls.extend(fallback_deps); - } -} - -fn collect_func_tpl_deps_from_fallback_type( - ty: &LuaType, - generic_tpls: &mut HashSet, - visiting_fallbacks: &mut HashSet, -) -> bool { - let mut no_fallback_cycle = - collect_func_tpl_dep_from_fallback_type(ty, generic_tpls, visiting_fallbacks); - ty.visit_nested_types(&mut |ty| { - no_fallback_cycle &= - collect_func_tpl_dep_from_fallback_type(ty, generic_tpls, visiting_fallbacks); - }); - no_fallback_cycle -} - -fn collect_func_tpl_dep_from_fallback_type( - ty: &LuaType, - generic_tpls: &mut HashSet, - visiting_fallbacks: &mut HashSet, -) -> bool { - match ty { - LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { - collect_generic_tpl_from_fallback(generic_tpl, generic_tpls, visiting_fallbacks) - } - LuaType::StrTplRef(str_tpl) => { - let tpl_id = str_tpl.get_tpl_id(); - if !tpl_id.is_func() { - return true; - } - - if !visiting_fallbacks.insert(tpl_id) { - return false; - } - - generic_tpls.insert(tpl_id); - let no_fallback_cycle = match str_tpl.get_constraint() { - Some(constraint) => collect_func_tpl_deps_from_fallback_type( - constraint, - generic_tpls, - visiting_fallbacks, - ), - None => true, - }; - visiting_fallbacks.remove(&tpl_id); - no_fallback_cycle - } - _ => true, - } -} - -fn collect_generic_tpl_from_fallback( - generic_tpl: &GenericTpl, - generic_tpls: &mut HashSet, - visiting_fallbacks: &mut HashSet, -) -> bool { - let tpl_id = generic_tpl.get_tpl_id(); - if !tpl_id.is_func() { - return true; - } - - if !visiting_fallbacks.insert(tpl_id) { - return false; - } - - generic_tpls.insert(tpl_id); - let no_fallback_cycle = match generic_tpl - .get_default_type() - .or(generic_tpl.get_constraint()) - { - Some(fallback_type) => collect_func_tpl_deps_from_fallback_type( - fallback_type, - generic_tpls, - visiting_fallbacks, - ), - None => true, - }; - visiting_fallbacks.remove(&tpl_id); - no_fallback_cycle -} - -// 实际写入泛型替换表 -fn record_generic_assignment( - param_type: &LuaType, - arg_type: &LuaType, - substitutor: &mut TypeSubstitutor, -) { - match param_type { - LuaType::TplRef(tpl_ref) => { - if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), true); - } - } - LuaType::ConstTplRef(tpl_ref) => { - if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), false); - } - } - LuaType::StrTplRef(str_tpl_ref) => { - substitutor.insert_type(str_tpl_ref.get_tpl_id(), arg_type.clone(), true); - } - LuaType::Variadic(variadic) => { - if let Some(inner) = variadic.get_type(0) { - record_generic_assignment(inner, arg_type, substitutor); - } - } - _ => {} - } -} - // 解析冒号调用时调用者的具体类型 fn infer_call_source_type( semantic_model: &SemanticModel, @@ -350,7 +172,7 @@ fn infer_call_source_type( None } -// 推导每个实参类型 +// 推推导每个实参类型 fn get_arg_infos( semantic_model: &SemanticModel, call_expr: &LuaCallExpr, @@ -407,9 +229,7 @@ fn get_constraint_type( depth: usize, ) -> Option { match arg_type { - LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => { - tpl_ref.get_constraint().cloned() - } + LuaType::TplRef(tpl_ref) => tpl_ref.get_constraint().cloned(), LuaType::StrTplRef(str_tpl_ref) => str_tpl_ref.get_constraint().cloned(), LuaType::Union(union_type) => { if depth > 1 { @@ -453,40 +273,3 @@ fn infer_expr_list_types( } value_types } - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use hashbrown::HashSet; - use smol_str::SmolStr; - - use super::*; - - fn func_tpl(idx: u32, default_type: Option) -> Arc { - Arc::new(GenericTpl::new( - GenericTplId::Func(idx), - SmolStr::new(format!("T{}", idx)).into(), - None, - default_type, - )) - } - - #[test] - fn test_collect_func_tpl_with_fallback_deps_skips_cyclic_fallback_deps() { - let t0 = func_tpl(0, None); - let t1 = func_tpl(1, Some(LuaType::TplRef(t0.clone()))); - let t0 = GenericTpl::new( - GenericTplId::Func(0), - SmolStr::new("T0").into(), - None, - Some(LuaType::TplRef(t1)), - ); - - let mut generic_tpls = HashSet::new(); - collect_func_tpl_with_fallback_deps(&t0, &mut generic_tpls); - - assert!(generic_tpls.contains(&GenericTplId::Func(0))); - assert!(!generic_tpls.contains(&GenericTplId::Func(1))); - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/infer_call_generic.rs similarity index 75% rename from crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs rename to crates/emmylua_code_analysis/src/semantic/generic/infer_call_generic.rs index 4a8eaff7c..544c51818 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/infer_call_generic.rs @@ -12,7 +12,6 @@ use crate::{ semantic::{ LuaInferCache, generic::{ - instantiate_type::instantiate_doc_function, tpl_context::TplContext, tpl_pattern::{ multi_param_tpl_pattern_match_multi_return, return_type_pattern_match_target_type, @@ -25,26 +24,25 @@ use crate::{ }, }; use crate::{ - GenericTpl, LuaMemberOwner, LuaSemanticDeclId, LuaTypeOwner, SemanticDeclLevel, TypeVisitTrait, + LuaMemberOwner, LuaSemanticDeclId, LuaTypeOwner, SemanticDeclLevel, TypeVisitTrait, collect_callable_overload_groups, infer_node_semantic_decl, tpl_pattern_match_args_skip_unknown, }; -use super::{TypeSubstitutor, instantiate_type_generic}; +use crate::semantic::generic::{TypeSubstitutor, instantiate_type::instantiate_type_generic}; -pub fn instantiate_func_generic( +pub fn infer_call_generic( db: &DbIndex, cache: &mut LuaInferCache, func: &LuaFunctionType, call_expr: LuaCallExpr, ) -> Result { let file_id = cache.get_file_id().clone(); - let (generic_tpls, contain_self) = collect_func_tpl_ids(func); let origin_params = func.get_params(); - let mut func_params: Vec<_> = origin_params + let mut func_params: Vec = origin_params .iter() - .map(|(name, t)| (name.clone(), t.clone().unwrap_or(LuaType::Unknown))) + .map(|(_, t)| t.clone().unwrap_or(LuaType::Unknown)) .collect(); let arg_exprs = call_expr @@ -59,7 +57,18 @@ pub fn instantiate_func_generic( substitutor: &mut substitutor, call_expr: Some(call_expr.clone()), }; - if !generic_tpls.is_empty() { + + let has_func_generic = func + .get_generic_params() + .iter() + .any(|generic_tpl| generic_tpl.get_tpl_id().is_func()); + if has_func_generic { + let generic_tpls = func + .get_generic_params() + .iter() + .map(|generic_tpl| generic_tpl.get_tpl_id()) + .filter(GenericTplId::is_func) + .collect::>(); context.substitutor.add_need_infer_tpls(generic_tpls); if let Some(type_list) = call_expr.get_call_generic_type_list() { @@ -78,11 +87,13 @@ pub fn instantiate_func_generic( } } + let contain_self = func.any_nested_type(|ty| matches!(ty, LuaType::SelfInfer)); if contain_self && let Some(self_type) = infer_self_type(db, cache, &call_expr) { substitutor.add_self_type(self_type); } - if let LuaType::DocFunction(f) = instantiate_doc_function(db, func, &substitutor) { + let func_type = LuaType::DocFunction(func.clone().into()); + if let LuaType::DocFunction(f) = instantiate_type_generic(db, &func_type, &substitutor) { Ok(f.deref().clone()) } else { Ok(func.clone()) @@ -249,16 +260,21 @@ fn instantiate_callable_from_arg_types( return None; } - let mut callable_tpls = HashSet::new(); - callable.visit_nested_types(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { - callable_tpls.insert(generic_tpl.get_tpl_id()); - } - }); - if callable_tpls.is_empty() { + let has_callable_tpls = callable + .get_generic_params() + .iter() + .any(|generic_tpl| generic_tpl.get_tpl_id().is_func()); + if !has_callable_tpls { return Some(callable.clone()); } + let callable_tpls = callable + .get_generic_params() + .iter() + .map(|generic_tpl| generic_tpl.get_tpl_id()) + .filter(GenericTplId::is_func) + .collect::>(); + let callable_param_types = callable .get_params() .iter() @@ -282,14 +298,16 @@ fn instantiate_callable_from_arg_types( return None; } - let instantiated = match instantiate_doc_function(context.db, callable, &callable_substitutor) { - LuaType::DocFunction(func) => func, - _ => callable.clone(), - }; + let callable_type = LuaType::DocFunction(callable.clone()); + let instantiated = + match instantiate_type_generic(context.db, &callable_type, &callable_substitutor) { + LuaType::DocFunction(func) => func, + _ => callable.clone(), + }; let unresolved_return_tpls = { let mut tpl_ids = HashSet::new(); instantiated.get_ret().visit_type(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty + if let LuaType::TplRef(generic_tpl) = ty && callable_tpls.contains(&generic_tpl.get_tpl_id()) { tpl_ids.insert(generic_tpl.get_tpl_id()); @@ -314,7 +332,7 @@ fn instantiate_callable_from_arg_types( for tpl_id in callback_return_tpls { callable_substitutor.insert_type(tpl_id, LuaType::Unknown, true); } - match instantiate_doc_function(context.db, callable, &callable_substitutor) { + match instantiate_type_generic(context.db, &callable_type, &callable_substitutor) { LuaType::DocFunction(func) => Some(func), _ => None, } @@ -360,7 +378,7 @@ fn collect_callback_return_tpls( continue; }; param_func.get_ret().visit_type(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { + if let LuaType::TplRef(generic_tpl) = ty { let tpl_id = generic_tpl.get_tpl_id(); if unresolved_return_tpls.contains(&tpl_id) { callback_return_tpls.insert(tpl_id); @@ -372,120 +390,19 @@ fn collect_callback_return_tpls( callback_return_tpls } -fn collect_func_tpl_ids(func: &LuaFunctionType) -> (HashSet, bool) { - let mut generic_tpls = HashSet::new(); - let mut contain_self = false; - - func.visit_nested_types(&mut |ty| match ty { - LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { - collect_func_tpl_with_fallback_deps(generic_tpl, &mut generic_tpls); - } - LuaType::StrTplRef(str_tpl) => { - generic_tpls.insert(str_tpl.get_tpl_id()); - } - LuaType::SelfInfer => contain_self = true, - _ => {} - }); - - (generic_tpls, contain_self) -} - -fn collect_func_tpl_with_fallback_deps( - generic_tpl: &GenericTpl, - generic_tpls: &mut HashSet, -) { - let tpl_id = generic_tpl.get_tpl_id(); - if !tpl_id.is_func() { - return; - } - - generic_tpls.insert(tpl_id); - - let Some(fallback_type) = generic_tpl - .get_default_type() - .or(generic_tpl.get_constraint()) - else { - return; - }; - - // 只有提前加入的泛型才有 None 占位, fallback 展开时才能继续使用它自己的 default/constraint. - // 例如 `U = T[]` 或 `U: T[]` 中, 即使函数返回值只直接引用了 `U`, 也需要把 `T` 一并加入. - let mut fallback_deps = HashSet::new(); - let mut visiting_fallbacks = HashSet::new(); - visiting_fallbacks.insert(tpl_id); - if collect_func_tpl_deps_from_fallback_type( - fallback_type, - &mut fallback_deps, - &mut visiting_fallbacks, - ) { - generic_tpls.extend(fallback_deps); - } -} - -fn collect_func_tpl_deps_from_fallback_type( - ty: &LuaType, - generic_tpls: &mut HashSet, - visiting_fallbacks: &mut HashSet, -) -> bool { - // 返回 false 表示 fallback 依赖链里发现循环. - // visit_nested_types 只访问子节点, 所以这里先处理类型自身, 再处理嵌套类型. - let mut no_fallback_cycle = - collect_func_tpl_dep_from_fallback_type(ty, generic_tpls, visiting_fallbacks); - ty.visit_nested_types(&mut |ty| { - no_fallback_cycle &= - collect_func_tpl_dep_from_fallback_type(ty, generic_tpls, visiting_fallbacks); - }); - no_fallback_cycle -} - -fn collect_func_tpl_dep_from_fallback_type( - ty: &LuaType, - generic_tpls: &mut HashSet, - visiting_fallbacks: &mut HashSet, -) -> bool { - let (LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl)) = ty else { - return true; - }; - - if !generic_tpl.get_tpl_id().is_func() { - return true; - } - - let tpl_id = generic_tpl.get_tpl_id(); - if !visiting_fallbacks.insert(tpl_id) { - // 遇到 `T = U, U = T` 这类循环 fallback 时, 放弃合并本轮依赖避免递归展开. - return false; - } - - generic_tpls.insert(tpl_id); - let no_fallback_cycle = match generic_tpl - .get_default_type() - .or(generic_tpl.get_constraint()) - { - Some(fallback_type) => collect_func_tpl_deps_from_fallback_type( - fallback_type, - generic_tpls, - visiting_fallbacks, - ), - None => true, - }; - visiting_fallbacks.remove(&tpl_id); - no_fallback_cycle -} - fn infer_generic_types_from_call( db: &DbIndex, context: &mut TplContext, func: &LuaFunctionType, call_expr: &LuaCallExpr, - func_params: &mut Vec<(String, LuaType)>, + func_params: &mut Vec, arg_exprs: &[LuaExpr], ) -> Result<(), InferFailReason> { let colon_call = call_expr.is_colon_call(); let colon_define = func.is_colon_define(); match (colon_define, colon_call) { (true, false) => { - func_params.insert(0, ("self".to_string(), LuaType::Any)); + func_params.insert(0, LuaType::Any); } (false, true) => { if !func_params.is_empty() { @@ -498,7 +415,7 @@ fn infer_generic_types_from_call( let mut unresolve_tpls = vec![]; for i in 0..func_params.len() { if i >= arg_exprs.len() { - if let LuaType::Variadic(variadic) = &func_params[i].1 { + if let LuaType::Variadic(variadic) = &func_params[i] { variadic_tpl_pattern_match(context, variadic, &[])?; } break; @@ -508,14 +425,16 @@ fn infer_generic_types_from_call( break; } - let (_, func_param_type) = &func_params[i]; + let func_param_type = &func_params[i]; let call_arg_expr = &arg_exprs[i]; if !func_param_type.contains_tpl_node() { continue; } + let doc_param_func = as_doc_function_type(db, func_param_type)?; + if !func_param_type.is_variadic() - && check_expr_can_later_infer(context, func_param_type, call_arg_expr)? + && check_expr_can_later_infer_with_doc_func(doc_param_func.as_deref(), call_arg_expr) { // 如果参数不能被后续推断, 那么我们先不处理 unresolve_tpls.push((func_param_type.clone(), call_arg_expr.clone())); @@ -528,19 +447,18 @@ fn infer_generic_types_from_call( Err(e) => return Err(e), }; - if let Some(return_pattern) = - as_doc_function_type(context.db, func_param_type)?.map(|func| func.get_ret().clone()) - { + if let Some(doc_func) = &doc_param_func { + let return_pattern = doc_func.get_ret(); if let Some(inferred_return_type) = infer_callable_return_from_remaining_args(context, &arg_type, &arg_exprs[i + 1..])? { return_type_pattern_match_target_type( context, - &return_pattern, + return_pattern, &inferred_return_type, )?; } else if arg_type.is_any() || arg_type.is_unknown() { - return_type_pattern_match_target_type(context, &return_pattern, &LuaType::Unknown)?; + return_type_pattern_match_target_type(context, return_pattern, &LuaType::Unknown)?; } } @@ -555,11 +473,7 @@ fn infer_generic_types_from_call( break; } (_, LuaType::Variadic(variadic)) => { - let func_param_types = func_params[i..] - .iter() - .map(|(_, t)| t) - .cloned() - .collect::>(); + let func_param_types = func_params[i..].to_vec(); multi_param_tpl_pattern_match_multi_return(context, &func_param_types, variadic)?; break; } @@ -607,9 +521,9 @@ fn build_self_generic_arg( substitutor: &TypeSubstitutor, ) -> LuaType { let Some(arg) = generic_param - .default_type + .default .as_ref() - .or(generic_param.type_constraint.as_ref()) + .or(generic_param.constraint.as_ref()) else { return LuaType::Unknown; }; @@ -666,30 +580,23 @@ pub fn infer_self_type( None } -fn check_expr_can_later_infer( - context: &mut TplContext, - func_param_type: &LuaType, +fn check_expr_can_later_infer_with_doc_func( + doc_function: Option<&LuaFunctionType>, call_arg_expr: &LuaExpr, -) -> Result { - let Some(doc_function) = as_doc_function_type(context.db, func_param_type)? else { - return Ok(false); +) -> bool { + let Some(doc_function) = doc_function else { + return false; }; if let LuaExpr::ClosureExpr(_) = call_arg_expr { - return Ok(true); + return true; } let doc_params = doc_function.get_params(); let variadic_count = doc_params .iter() - .filter_map(|(_, t)| { - if let Some(LuaType::Variadic(_)) = t { - Some(()) - } else { - None - } - }) + .filter(|(_, t)| matches!(t, Some(LuaType::Variadic(_)))) .count(); - Ok(variadic_count > 1) + variadic_count > 1 } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs index 8ac1a2644..576612cf3 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs @@ -1,8 +1,8 @@ use hashbrown::HashSet; use crate::{ - DbIndex, GenericParam, GenericTplId, LuaAliasCallType, LuaArrayType, LuaAttributeType, - LuaConditionalType, LuaMappedType, LuaMultiLineUnion, LuaTypeDeclId, + DbIndex, GenericParam, GenericTpl, GenericTplId, LuaAliasCallType, LuaArrayType, + LuaAttributeType, LuaConditionalType, LuaMappedType, LuaMultiLineUnion, LuaTypeDeclId, db_index::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaTupleType, LuaType, LuaUnionType, VariadicType, @@ -101,7 +101,7 @@ fn complete_type_generic_args_inner( continue; } - if let Some(default_type) = &generic_param.default_type { + if let Some(default_type) = &generic_param.default { if missing_required_count != 0 { continue; } @@ -295,6 +295,7 @@ fn complete_doc_function( visiting: &mut HashSet, ) -> CompletedType { let mut cycled = false; + let generic_params = complete_function_generic_params(db, func, visiting, &mut cycled); let params = func .get_params() .iter() @@ -315,6 +316,7 @@ fn complete_doc_function( func.is_variadic(), params, ret.ty, + Some(generic_params), ) .into(), ), @@ -322,6 +324,31 @@ fn complete_doc_function( ) } +fn complete_function_generic_params( + db: &DbIndex, + func: &LuaFunctionType, + visiting: &mut HashSet, + cycled: &mut bool, +) -> Vec { + func.get_generic_params() + .iter() + .map(|generic_tpl| { + let tpl_id = generic_tpl.get_tpl_id(); + let param = generic_tpl.get_param(); + let completed = complete_generic_param(db, param, visiting); + *cycled |= completed.cycled; + GenericTpl::new( + tpl_id, + completed.param.name, + completed.param.constraint, + completed.param.default, + completed.param.is_const, + completed.param.attributes, + ) + }) + .collect() +} + fn complete_object_type( db: &DbIndex, object: &LuaObjectType, @@ -529,11 +556,11 @@ fn complete_generic_param( visiting: &mut HashSet, ) -> CompletedGenericParam { let constraint = param - .type_constraint + .constraint .as_ref() .map(|ty| complete_type_generic_args_in_type_inner(db, ty, visiting)); let default_type = param - .default_type + .default .as_ref() .map(|ty| complete_type_generic_args_in_type_inner(db, ty, visiting)); let cycled = constraint.as_ref().is_some_and(|ty| ty.cycled) @@ -543,6 +570,7 @@ fn complete_generic_param( param.name.clone(), constraint.map(|ty| ty.ty), default_type.map(|ty| ty.ty), + param.is_const, param.attributes.clone(), ), cycled, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index c0c64bbdf..5ea42ce91 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -2,13 +2,13 @@ use hashbrown::{HashMap, HashSet}; use std::ops::Deref; use crate::{ - DbIndex, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, + DbIndex, GenericTpl, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, check_type_compact, db_index::{LuaObjectType, LuaTupleType, LuaType}, semantic::{member::find_members_with_key, type_check::check_type_compact_with_level}, }; -use super::{get_default_constructor, instantiate_type_generic_with_context}; +use super::{get_default_constructor, instantiate_type_generic_inner}; use crate::semantic::generic::type_substitutor::GenericInstantiateContext; #[derive(Debug, Clone, Copy)] @@ -80,19 +80,18 @@ fn instantiate_conditional_once( finalize_infer_assignments(infer_assignments), ) } else { - instantiate_type_generic_with_context(context, conditional.get_false_type()) + instantiate_type_generic_inner(context, conditional.get_false_type()) }; } match check_conditional_extends(context.db, &left_type, &right_type) { ConditionalCheck::True => instantiate_true_branch(context, conditional, HashMap::new()), ConditionalCheck::False => { - instantiate_type_generic_with_context(context, conditional.get_false_type()) + instantiate_type_generic_inner(context, conditional.get_false_type()) } ConditionalCheck::Both => { let true_type = instantiate_true_branch(context, conditional, HashMap::new()); - let false_type = - instantiate_type_generic_with_context(context, conditional.get_false_type()); + let false_type = instantiate_type_generic_inner(context, conditional.get_false_type()); TypeOps::Union.apply(context.db, &true_type, &false_type) } } @@ -125,9 +124,7 @@ fn instantiate_distributed_conditional( fn naked_checked_type_tpl_id(checked_type: &LuaType) -> Option { match checked_type { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) if tpl.get_tpl_id().is_type() => { - Some(tpl.get_tpl_id()) - } + LuaType::TplRef(tpl) if tpl.get_tpl_id().is_type() => Some(tpl.get_tpl_id()), _ => None, } } @@ -152,7 +149,7 @@ fn instantiate_true_branch( infer_assignments: HashMap, ) -> LuaType { if infer_assignments.is_empty() { - return instantiate_type_generic_with_context(context, conditional.get_true_type()); + return instantiate_type_generic_inner(context, conditional.get_true_type()); } let mut true_substitutor = context.substitutor.clone(); @@ -160,7 +157,7 @@ fn instantiate_true_branch( true_substitutor.insert_conditional_infer_type(tpl_id, ty); } let true_context = context.with_substitutor(&true_substitutor); - instantiate_type_generic_with_context(&true_context, conditional.get_true_type()) + instantiate_type_generic_inner(&true_context, conditional.get_true_type()) } fn contains_conditional_infer(ty: &LuaType) -> bool { @@ -170,7 +167,7 @@ fn contains_conditional_infer(ty: &LuaType) -> bool { fn conditional_infer_tpl_id(ty: &LuaType) -> bool { matches!( ty, - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + LuaType::TplRef(tpl) if tpl.get_tpl_id().is_conditional_infer() ) } @@ -257,9 +254,7 @@ fn collect_infer_assignments( variance: InferVariance, ) -> bool { match pattern { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) - if tpl.get_tpl_id().is_conditional_infer() => - { + LuaType::TplRef(tpl) if tpl.get_tpl_id().is_conditional_infer() => { insert_infer_assignment(db, assignments, tpl.get_tpl_id(), source, variance) } LuaType::Generic(pattern_generic) => { @@ -650,8 +645,8 @@ fn instantiate_conditional_operand( checked: bool, has_new: bool, ) -> LuaType { - let mut result = instantiate_type_generic_with_context(context, operand); - if let LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) = operand { + let mut result = instantiate_type_generic_inner(context, operand); + if let LuaType::TplRef(tpl_ref) = operand { let tpl_id = tpl_ref.get_tpl_id(); if let Some(raw) = context.substitutor.get_raw_type(tpl_id) { result = raw.clone(); @@ -678,7 +673,7 @@ fn instantiate_conditional_operand( // `infer` pattern 也以模板引用表示, 必须保留下来供后续结构匹配绑定. fn actualize_unresolved_templates(ty: LuaType) -> LuaType { match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + LuaType::TplRef(tpl) => { if tpl.get_tpl_id().is_conditional_infer() { // Conditional infer 是右侧 pattern 的占位孔, 不能像普通未解模板一样抹成 unknown. LuaType::TplRef(tpl) @@ -718,6 +713,7 @@ fn actualize_unresolved_templates(ty: LuaType) -> LuaType { }) .collect(), actualize_unresolved_templates(func.get_ret().clone()), + Some(actualize_function_generic_params(&func)), ) .into(), ), @@ -818,3 +814,21 @@ fn actualize_unresolved_templates(ty: LuaType) -> LuaType { ty => ty, } } + +fn actualize_function_generic_params(func: &crate::LuaFunctionType) -> Vec { + func.get_generic_params() + .iter() + .map(|generic_tpl| { + let tpl_id = generic_tpl.get_tpl_id(); + let param = generic_tpl.get_param(); + GenericTpl::new( + tpl_id, + param.name.clone(), + param.constraint.clone().map(actualize_unresolved_templates), + param.default.clone().map(actualize_unresolved_templates), + param.is_const, + param.attributes.clone(), + ) + }) + .collect() +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index 57fae183d..02d367f55 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -10,7 +10,7 @@ use crate::{ use hashbrown::HashMap; use std::{ops::Deref, vec}; -use super::{GenericInstantiateContext, TypeSubstitutor, instantiate_type_generic_with_context}; +use super::{GenericInstantiateContext, TypeSubstitutor, instantiate_type_generic_inner}; pub(super) fn instantiate_alias_call( context: &GenericInstantiateContext, @@ -19,7 +19,7 @@ pub(super) fn instantiate_alias_call( let operand_exprs = alias_call.get_operands(); let operands = operand_exprs .iter() - .map(|it| instantiate_type_generic_with_context(context, it)) + .map(|it| instantiate_type_generic_inner(context, it)) .collect::>(); match alias_call.get_call_kind() { @@ -135,9 +135,7 @@ fn resolve_literal_operand( substitutor: &TypeSubstitutor, ) -> Option { match operand { - Some(LuaType::TplRef(tpl_ref)) | Some(LuaType::ConstTplRef(tpl_ref)) => { - substitutor.get_raw_type(tpl_ref.get_tpl_id()).cloned() - } + Some(LuaType::TplRef(tpl_ref)) => substitutor.get_raw_type(tpl_ref.get_tpl_id()).cloned(), _ => None, } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index 420c0842e..e71843e81 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -1,10 +1,11 @@ mod complete_generic_args; mod instantiate_conditional_generic; -mod instantiate_func_generic; mod instantiate_special_generic; use hashbrown::{HashMap, HashSet}; -use std::{ops::Deref, sync::Arc}; +use std::ops::Deref; + +use smol_str::SmolStr; use crate::{ DbIndex, GenericTpl, GenericTplId, LuaArrayType, LuaMappedType, LuaMemberKey, @@ -14,123 +15,42 @@ use crate::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaType, LuaUnionType, VariadicType, }, - semantic::infer::InferFailReason, }; use super::type_substitutor::{ - GenericInstantiateContext, SubstitutorValue, TypeSubstitutor, UninferredTplPolicy, + GenericInstantiateContext, SubstitutorTypeValue, SubstitutorValue, TypeSubstitutor, }; pub use complete_generic_args::{ GenericArgumentCompletion, complete_type_generic_args, complete_type_generic_args_in_type, }; -pub use instantiate_func_generic::{build_self_type, infer_self_type, instantiate_func_generic}; pub use instantiate_special_generic::get_keyof_members; -pub(crate) fn collect_callable_overload_groups( - db: &DbIndex, - callable_type: &LuaType, - groups: &mut Vec>>, -) -> Result<(), InferFailReason> { - let mut visiting_aliases = HashSet::new(); - collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) -} - -fn collect_callable_overload_groups_inner( - db: &DbIndex, - callable_type: &LuaType, - groups: &mut Vec>>, - visiting_aliases: &mut HashSet, -) -> Result<(), InferFailReason> { - match callable_type { - LuaType::Ref(type_id) | LuaType::Def(type_id) => { - let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { - return Ok(()); - }; - if !visiting_aliases.insert(type_id.clone()) { - return Ok(()); - } - - let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { - collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) - } else { - Ok(()) - }; - visiting_aliases.remove(type_id); - result?; - } - LuaType::Generic(generic) => { - let type_id = generic.get_base_type_id(); - if !visiting_aliases.insert(type_id.clone()) { - return Ok(()); - } - let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); - let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { - visiting_aliases.remove(&type_id); - return Ok(()); - }; - - let result = if let Some(origin_type) = - type_decl.get_alias_origin(db, Some(&substitutor)) - { - collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) - } else { - Ok(()) - }; - visiting_aliases.remove(&type_id); - result?; - } - LuaType::Union(union) => { - for member in union.into_vec() { - collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; - } - } - LuaType::Intersection(intersection) => { - for member in intersection.get_types() { - collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; - } - } - LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), - LuaType::Signature(sig_id) => { - let Some(signature) = db.get_signature_index().get(sig_id) else { - return Ok(()); - }; - let mut overloads = signature.overloads.to_vec(); - overloads.push(signature.to_doc_func_type()); - groups.push(overloads); - } - _ => {} - } - - Ok(()) -} - pub fn instantiate_type_generic( db: &DbIndex, ty: &LuaType, substitutor: &TypeSubstitutor, ) -> LuaType { let context = GenericInstantiateContext::new(db, substitutor); - instantiate_type_generic_with_context(&context, ty) + match ty { + LuaType::DocFunction(doc_func) => instantiate_doc_function_with_context(&context, doc_func), + _ => instantiate_type_generic_inner(&context, ty), + } } -pub(super) fn instantiate_type_generic_with_context( +pub(super) fn instantiate_type_generic_inner( context: &GenericInstantiateContext, ty: &LuaType, ) -> LuaType { match ty { LuaType::Array(array_type) => instantiate_array(context, array_type.get_base()), LuaType::Tuple(tuple) => instantiate_tuple(context, tuple), - LuaType::DocFunction(doc_func) => instantiate_doc_function_with_context( - &context.with_policy(UninferredTplPolicy::PreserveTplRef), - doc_func, - ), + LuaType::DocFunction(doc_func) => instantiate_nested_doc_function(context, doc_func), LuaType::Object(object) => instantiate_object(context, object), LuaType::Union(union) => instantiate_union(context, union), LuaType::Intersection(intersection) => instantiate_intersection(context, intersection), - LuaType::Generic(generic) => instantiate_generic_with_context(context, generic), + LuaType::Generic(generic) => instantiate_generic_type(context, generic), LuaType::TableGeneric(table_params) => instantiate_table_generic(context, table_params), LuaType::TplRef(tpl) => instantiate_tpl_ref(tpl, context), - LuaType::ConstTplRef(tpl) => instantiate_const_tpl_ref(tpl, context), LuaType::Signature(sig_id) => instantiate_signature(context, sig_id), LuaType::Call(alias_call) => { instantiate_special_generic::instantiate_alias_call(context, alias_call) @@ -144,7 +64,7 @@ pub(super) fn instantiate_type_generic_with_context( } } LuaType::TypeGuard(guard) => { - let inner = instantiate_type_generic_with_context(context, guard.deref()); + let inner = instantiate_type_generic_inner(context, guard.deref()); LuaType::TypeGuard(inner.into()) } LuaType::Conditional(conditional) => { @@ -161,7 +81,7 @@ where { types .into_iter() - .map(|ty| instantiate_type_generic_with_context(context, ty)) + .map(|ty| instantiate_type_generic_inner(context, ty)) .collect() } @@ -176,15 +96,15 @@ where .into_iter() .map(|(key, value)| { ( - instantiate_type_generic_with_context(context, key), - instantiate_type_generic_with_context(context, value), + instantiate_type_generic_inner(context, key), + instantiate_type_generic_inner(context, value), ) }) .collect() } fn instantiate_array(context: &GenericInstantiateContext, base: &LuaType) -> LuaType { - let base = instantiate_type_generic_with_context(context, base); + let base = instantiate_type_generic_inner(context, base); LuaType::Array(LuaArrayType::from_base_type(base).into()) } @@ -209,7 +129,9 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) new_types.push(ty.clone().unwrap_or(LuaType::Unknown)); } } - SubstitutorValue::Type(ty) => new_types.push(ty.default().clone()), + SubstitutorValue::Type(ty) => { + new_types.push(substitutor_type_for_tpl(tpl, ty).clone()) + } SubstitutorValue::MultiBase(base) => new_types.push(base.clone()), } } else { @@ -223,21 +145,12 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) break; } - let t = instantiate_type_generic_with_context(context, t); + let t = instantiate_type_generic_inner(context, t); new_types.push(t); } LuaType::Tuple(LuaTupleType::new(new_types, tuple.status).into()) } -pub fn instantiate_doc_function( - db: &DbIndex, - doc_func: &LuaFunctionType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_doc_function_with_context(&context, doc_func) -} - fn instantiate_doc_function_with_context( context: &GenericInstantiateContext, doc_func: &LuaFunctionType, @@ -246,6 +159,7 @@ fn instantiate_doc_function_with_context( let tpl_ret = doc_func.get_ret(); let async_state = doc_func.get_async_state(); let colon_define = doc_func.is_colon_define(); + let generic_params = instantiate_function_generic_params(context, doc_func); let mut new_params = Vec::new(); for origin_param in tpl_func_params.iter() { @@ -266,7 +180,7 @@ fn instantiate_doc_function_with_context( new_params.push((origin_param.0.clone(), Some(ty))); } SubstitutorValue::Type(ty) => { - let resolved_type = ty.default(); + let resolved_type = substitutor_type_for_tpl(tpl, ty); // 如果参数是 `...: T...` if origin_param.0 == "..." { // 类型是 tuple, 那么我们将展开 tuple @@ -318,7 +232,7 @@ fn instantiate_doc_function_with_context( } } LuaType::Generic(generic) => { - let new_type = instantiate_generic_with_context(context, generic); + let new_type = instantiate_generic_type(context, generic); // 如果是 rest 参数且实例化后的类型是 tuple, 那么我们将展开 tuple if let LuaType::Tuple(tuple_type) = &new_type { let base_index = new_params.len(); @@ -336,13 +250,13 @@ fn instantiate_doc_function_with_context( VariadicType::Multi(_) => (), }, _ => { - let new_type = instantiate_type_generic_with_context(context, origin_param_type); + let new_type = instantiate_type_generic_inner(context, origin_param_type); new_params.push((origin_param.0.clone(), Some(new_type))); } } } - let mut inst_ret_type = instantiate_type_generic_with_context(context, tpl_ret); + let mut inst_ret_type = instantiate_type_generic_inner(context, tpl_ret); // 对于可变返回值, 如果实例化是 tuple, 那么我们将展开 tuple if let LuaType::Variadic(_) = &&tpl_ret && let LuaType::Tuple(tuple) = &inst_ret_type @@ -375,21 +289,140 @@ fn instantiate_doc_function_with_context( is_variadic, new_params, inst_ret_type, + Some(generic_params), ) .into(), ) } +fn instantiate_nested_doc_function( + context: &GenericInstantiateContext, + doc_func: &LuaFunctionType, +) -> LuaType { + let mut transferred_params = Vec::new(); + let mut transferred_tpls = HashSet::new(); + collect_pending_function_generic_params( + context, + doc_func, + &mut transferred_params, + &mut transferred_tpls, + ); + + if transferred_tpls.is_empty() { + return instantiate_doc_function_with_context(context, doc_func); + } + + let mut generic_params = doc_func.get_generic_params().to_vec(); + for generic_param in transferred_params { + if generic_params + .iter() + .any(|tpl| tpl.get_tpl_id() == generic_param.get_tpl_id()) + { + continue; + } + + generic_params.push(generic_param); + } + + let nested_substitutor = context.substitutor.without_pending_tpls(&transferred_tpls); + let nested_context = context.with_substitutor(&nested_substitutor); + let doc_func = LuaFunctionType::new( + doc_func.get_async_state(), + doc_func.is_colon_define(), + doc_func.is_variadic(), + doc_func.get_params().to_vec(), + doc_func.get_ret().clone(), + Some(generic_params), + ); + instantiate_doc_function_with_context(&nested_context, &doc_func) +} + +fn collect_pending_function_generic_params( + context: &GenericInstantiateContext, + doc_func: &LuaFunctionType, + generic_params: &mut Vec, + generic_tpls: &mut HashSet, +) { + for generic_tpl in doc_func.get_generic_params() { + let tpl_id = generic_tpl.get_tpl_id(); + if is_pending_tpl(context, tpl_id) && generic_tpls.insert(tpl_id) { + generic_params.push(generic_tpl.clone()); + } + } + + doc_func.visit_nested_types(&mut |ty| match ty { + LuaType::TplRef(tpl) => { + let tpl_id = tpl.get_tpl_id(); + if is_pending_tpl(context, tpl_id) && generic_tpls.insert(tpl_id) { + generic_params.push(tpl.as_ref().clone()); + } + } + LuaType::StrTplRef(str_tpl) => { + let tpl_id = str_tpl.get_tpl_id(); + if is_pending_tpl(context, tpl_id) && generic_tpls.insert(tpl_id) { + generic_params.push(GenericTpl::new( + tpl_id, + SmolStr::new(str_tpl.get_name()), + str_tpl.get_constraint().cloned(), + None, + false, + None, + )); + } + } + _ => {} + }); +} + +fn is_pending_tpl(context: &GenericInstantiateContext, tpl_id: GenericTplId) -> bool { + matches!( + context.substitutor.get(tpl_id), + Some(SubstitutorValue::None) + ) +} + +fn instantiate_function_generic_params( + context: &GenericInstantiateContext, + doc_func: &LuaFunctionType, +) -> Vec { + doc_func + .get_generic_params() + .iter() + .filter_map(|generic_tpl| { + let tpl_id = generic_tpl.get_tpl_id(); + let param = generic_tpl.get_param(); + // A pending entry means this generic belongs to the current instantiation boundary + // and has been finalized into the function params/return. Foreign nested generics + // are absent from the substitutor and remain owned by the nested function. + if context.substitutor.get(tpl_id).is_some() { + return None; + } + + let constraint = param + .constraint + .as_ref() + .map(|ty| instantiate_type_generic_inner(context, ty)); + let default_type = param + .default + .as_ref() + .map(|ty| instantiate_type_generic_inner(context, ty)); + Some(GenericTpl::new( + tpl_id, + param.name.clone(), + constraint, + default_type, + param.is_const, + param.attributes.clone(), + )) + }) + .collect() +} + fn instantiate_object(context: &GenericInstantiateContext, object: &LuaObjectType) -> LuaType { let new_fields = object .get_fields() .iter() - .map(|(key, field)| { - ( - key.clone(), - instantiate_type_generic_with_context(context, field), - ) - }) + .map(|(key, field)| (key.clone(), instantiate_type_generic_inner(context, field))) .collect::>(); let new_index_access = instantiate_type_pairs(context, object.get_index_access().iter()); @@ -411,16 +444,7 @@ fn instantiate_intersection( ) } -pub fn instantiate_generic( - db: &DbIndex, - generic: &LuaGenericType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_generic_with_context(&context, generic) -} - -fn instantiate_generic_with_context( +fn instantiate_generic_type( context: &GenericInstantiateContext, generic: &LuaGenericType, ) -> LuaType { @@ -458,18 +482,13 @@ fn instantiate_uninferred_tpl_fallback( tpl: &GenericTpl, context: &GenericInstantiateContext, ) -> LuaType { - // 一些情况下需要保留 TplRef, 例如高阶函数调用 - if context.should_preserve_tpl_ref() && tpl.get_default_type().is_none() { - return LuaType::TplRef(tpl.clone().into()); - } - // 显式默认值优先, 然后是 extends 约束, 最后才是 unknown. if let Some(default_type) = tpl.get_default_type() { - return instantiate_type_generic_with_context(context, default_type); + return instantiate_type_generic_inner(context, default_type); } if let Some(constraint) = tpl.get_constraint() { - return instantiate_type_generic_with_context(context, constraint); + return instantiate_type_generic_inner(context, constraint); } LuaType::Unknown @@ -481,7 +500,7 @@ fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> SubstitutorValue::None => { return instantiate_uninferred_tpl_fallback(tpl, context); } - SubstitutorValue::Type(ty) => return ty.default().clone(), + SubstitutorValue::Type(ty) => return substitutor_type_for_tpl(tpl, ty).clone(), SubstitutorValue::MultiTypes(types) => { return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); } @@ -500,29 +519,12 @@ fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType::TplRef(tpl.clone().into()) } -fn instantiate_const_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { - match value { - SubstitutorValue::None => { - return instantiate_uninferred_tpl_fallback(tpl, context); - } - SubstitutorValue::Type(ty) => return ty.raw().clone(), - SubstitutorValue::MultiTypes(types) => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); - } - SubstitutorValue::Params(params) => { - return params - .first() - .unwrap_or(&(String::new(), None)) - .1 - .clone() - .unwrap_or(LuaType::Unknown); - } - SubstitutorValue::MultiBase(base) => return base.clone(), - } +fn substitutor_type_for_tpl<'a>(tpl: &GenericTpl, value: &'a SubstitutorTypeValue) -> &'a LuaType { + if tpl.is_const() { + value.raw() + } else { + value.default() } - - LuaType::ConstTplRef(tpl.clone().into()) } fn instantiate_signature( @@ -577,7 +579,7 @@ fn instantiate_variadic_type( }; } SubstitutorValue::Type(ty) => { - let resolved_type = ty.default(); + let resolved_type = substitutor_type_for_tpl(tpl, ty); if matches!( resolved_type, LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never @@ -607,7 +609,7 @@ fn instantiate_variadic_type( } } LuaType::Generic(generic) => { - return instantiate_generic_with_context(context, generic); + return instantiate_generic_type(context, generic); } _ => {} }, @@ -615,7 +617,7 @@ fn instantiate_variadic_type( if types.iter().any(LuaTypeNode::contains_tpl_node) { let mut new_types = Vec::new(); for t in types { - let t = instantiate_type_generic_with_context(context, t); + let t = instantiate_type_generic_inner(context, t); match t { LuaType::Never => {} LuaType::Variadic(variadic) => match variadic.deref() { @@ -641,9 +643,9 @@ fn instantiate_mapped_type(context: &GenericInstantiateContext, mapped: &LuaMapp let constraint = mapped .param .1 - .type_constraint + .constraint .as_ref() - .map(|ty| instantiate_type_generic_with_context(context, ty)); + .map(|ty| instantiate_type_generic_inner(context, ty)); if let Some(constraint) = constraint { let mut key_types = Vec::new(); @@ -701,7 +703,7 @@ fn instantiate_mapped_type(context: &GenericInstantiateContext, mapped: &LuaMapp } } - instantiate_type_generic_with_context(context, &mapped.value) + instantiate_type_generic_inner(context, &mapped.value) } fn instantiate_mapped_value( @@ -713,7 +715,7 @@ fn instantiate_mapped_value( let mut local_substitutor = context.substitutor.clone(); local_substitutor.insert_type(tpl_id, replacement.clone(), true); let local_context = context.with_substitutor(&local_substitutor); - let mut result = instantiate_type_generic_with_context(&local_context, &mapped.value); + let mut result = instantiate_type_generic_inner(&local_context, &mapped.value); // 根据 readonly 和 optional 属性进行处理 if mapped.is_optional { result = TypeOps::Union.apply(context.db, &result, &LuaType::Nil); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index 90e34baa3..a322f181e 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -1,4 +1,5 @@ mod call_constraint; +mod infer_call_generic; mod instantiate_type; mod test; mod tpl_context; @@ -9,118 +10,10 @@ pub use call_constraint::{ CallConstraintArg, CallConstraintContext, build_call_constraint_context, normalize_constraint_type, }; -use emmylua_parser::LuaAstNode; -use emmylua_parser::LuaExpr; -pub(crate) use instantiate_type::collect_callable_overload_groups; +pub use infer_call_generic::{build_self_type, infer_call_generic, infer_self_type}; +pub use instantiate_type::get_keyof_members; pub use instantiate_type::*; -use rowan::NodeOrToken; pub use tpl_context::TplContext; pub use tpl_pattern::tpl_pattern_match_args; pub use tpl_pattern::tpl_pattern_match_args_skip_unknown; pub use type_substitutor::TypeSubstitutor; - -use crate::DbIndex; -use crate::GenericTplId; -use crate::LuaDeclExtra; -use crate::LuaInferCache; -use crate::LuaMemberOwner; -use crate::LuaSemanticDeclId; -use crate::LuaType; -use crate::SemanticDeclLevel; -use crate::TypeOps; -use crate::infer_node_semantic_decl; -use crate::semantic::semantic_info::infer_token_semantic_decl; -pub use instantiate_type::get_keyof_members; - -pub fn get_tpl_ref_extend_type( - db: &DbIndex, - cache: &mut LuaInferCache, - arg_type: &LuaType, - arg_expr: LuaExpr, - depth: usize, -) -> Option { - match arg_type { - LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => { - if let Some(extend) = tpl_ref.get_constraint().cloned() { - return Some(extend); - } - let node_or_token = arg_expr.syntax().clone().into(); - let semantic_decl = match node_or_token { - NodeOrToken::Node(node) => { - infer_node_semantic_decl(db, cache, node, SemanticDeclLevel::default()) - } - NodeOrToken::Token(token) => { - infer_token_semantic_decl(db, cache, token, SemanticDeclLevel::default()) - } - }?; - - match tpl_ref.get_tpl_id() { - GenericTplId::Func(tpl_id) => { - if let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl { - let decl = db.get_decl_index().get_decl(&decl_id)?; - match decl.extra { - LuaDeclExtra::Param { signature_id, .. } => { - let signature = db.get_signature_index().get(&signature_id)?; - if let Some(generic_param) = - signature.generic_params.get(tpl_id as usize) - { - return generic_param.constraint.clone(); - } - } - _ => return None, - } - } - None - } - GenericTplId::Type(tpl_id) => { - if let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl { - let decl = db.get_decl_index().get_decl(&decl_id)?; - match decl.extra { - LuaDeclExtra::Param { - owner_member_id, .. - } => { - let owner_member_id = owner_member_id?; - let parent_owner = - db.get_member_index().get_current_owner(&owner_member_id)?; - match parent_owner { - LuaMemberOwner::Type(type_id) => { - let generic_params = - db.get_type_index().get_generic_params(type_id)?; - return generic_params - .get(tpl_id as usize)? - .type_constraint - .clone(); - } - _ => return None, - } - } - _ => return None, - } - } - None - } - GenericTplId::ConditionalInfer(_) => None, - } - } - LuaType::StrTplRef(str_tpl) => str_tpl.get_constraint().cloned(), - LuaType::Union(union_type) => { - if depth > 1 { - return None; - } - let mut result = LuaType::Never; - for union_member_type in union_type.into_vec().iter() { - let extend_type = get_tpl_ref_extend_type( - db, - cache, - union_member_type, - arg_expr.clone(), - depth + 1, - ) - .unwrap_or(union_member_type.clone()); - result = TypeOps::Union.apply(db, &result, &extend_type); - } - Some(result) - } - _ => None, - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs index aa556ca88..02d43c695 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs @@ -1,6 +1,6 @@ use crate::{ InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaType, LuaTypeNode, TplContext, - TypeSubstitutor, instantiate_generic, instantiate_type_generic, + TypeSubstitutor, instantiate_type_generic, semantic::generic::tpl_pattern::{ TplPatternMatchResult, tpl_pattern_match, variadic_tpl_pattern_match, }, @@ -125,8 +125,9 @@ fn generic_tpl_pattern_match_inner( _ => { // 对于 @alias 类型, 我们能拿到的 target 实际上很有可能是实例化后的类型, 因此我们需要实例化后再进行匹配 let substitutor = TypeSubstitutor::new(); - let typ = instantiate_generic(context.db, source_generic, &substitutor); - if LuaType::from(source_generic.clone()) != typ { + let source_type = LuaType::from(source_generic.clone()); + let typ = instantiate_type_generic(context.db, &source_type, &substitutor); + if source_type != typ { tpl_pattern_match(context, &typ, target)?; } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 467da24e6..c21bff656 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -161,14 +161,7 @@ pub fn tpl_pattern_match( if tpl.get_tpl_id().is_func() { context .substitutor - .insert_type(tpl.get_tpl_id(), target.clone(), true); - } - } - LuaType::ConstTplRef(tpl) => { - if tpl.get_tpl_id().is_func() { - context - .substitutor - .insert_type(tpl.get_tpl_id(), target, false); + .infer_type(tpl.get_tpl_id(), target.clone(), !tpl.is_const()); } } LuaType::StrTplRef(str_tpl) => { @@ -176,7 +169,7 @@ pub fn tpl_pattern_match( let prefix = str_tpl.get_prefix(); let suffix = str_tpl.get_suffix(); let type_name = SmolStr::new(format!("{}{}{}", prefix, s, suffix)); - context.substitutor.insert_type( + context.substitutor.infer_type( str_tpl.get_tpl_id(), get_str_tpl_infer_type(&type_name), true, @@ -220,6 +213,14 @@ pub fn constant_decay(typ: LuaType) -> LuaType { } } +fn maybe_decay_type(typ: &LuaType, decay: bool) -> LuaType { + if decay { + constant_decay(typ.clone()) + } else { + typ.clone() + } +} + fn object_tpl_pattern_match( context: &mut TplContext, origin_obj: &LuaObjectType, @@ -716,7 +717,7 @@ pub(crate) fn return_type_pattern_match_target_type( let tpl_id = type_ref.get_tpl_id(); context .substitutor - .insert_type(tpl_id, target_base.clone(), true); + .infer_type(tpl_id, target_base.clone(), true); } } VariadicType::Multi(source_multi) => { @@ -727,7 +728,7 @@ pub(crate) fn return_type_pattern_match_target_type( && let LuaType::TplRef(type_ref) = base { let tpl_id = type_ref.get_tpl_id(); - context.substitutor.insert_type( + context.substitutor.infer_type( tpl_id, target_base.clone(), true, @@ -738,7 +739,7 @@ pub(crate) fn return_type_pattern_match_target_type( } LuaType::TplRef(tpl_ref) => { let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.insert_type( + context.substitutor.infer_type( tpl_id, target_base.clone(), true, @@ -781,7 +782,7 @@ fn func_varargs_tpl_pattern_match( VariadicType::Base(base) => { if let LuaType::TplRef(tpl_ref) = base { let tpl_id = tpl_ref.get_tpl_id(); - substitutor.insert_params( + substitutor.infer_params( tpl_id, target_rest_params .iter() @@ -802,13 +803,14 @@ pub fn variadic_tpl_pattern_match( target_rest_types: &[LuaType], ) -> TplPatternMatchResult { match tpl { - VariadicType::Base(base) => match base { - LuaType::TplRef(tpl_ref) => { + VariadicType::Base(base) => { + if let LuaType::TplRef(tpl_ref) = base { let tpl_id = tpl_ref.get_tpl_id(); + let decay = !tpl_ref.is_const(); match target_rest_types.len() { 0 => { // Zero varargs are an empty sequence, not one nil return slot. - context.substitutor.insert_multi_types(tpl_id, Vec::new()); + context.substitutor.infer_multi_types(tpl_id, Vec::new()); } 1 => { // If the single argument is itself a multi-return (e.g. a function call @@ -818,67 +820,46 @@ pub fn variadic_tpl_pattern_match( LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Multi(types) => match types.len() { 0 => { - context.substitutor.insert_multi_types(tpl_id, Vec::new()); + context.substitutor.infer_multi_types(tpl_id, Vec::new()); } 1 => { - context.substitutor.insert_type( + context.substitutor.infer_type( tpl_id, types[0].clone(), - true, + decay, ); } _ => { - context.substitutor.insert_multi_types( + context.substitutor.infer_multi_types( tpl_id, types .iter() - .map(|t| constant_decay(t.clone())) + .map(|t| maybe_decay_type(t, decay)) .collect(), ); } }, VariadicType::Base(base) => { - context.substitutor.insert_multi_base(tpl_id, base.clone()); + context.substitutor.infer_multi_base(tpl_id, base.clone()); } }, arg => { - context.substitutor.insert_type(tpl_id, arg.clone(), true); + context.substitutor.infer_type(tpl_id, arg.clone(), decay); } } } _ => { - context.substitutor.insert_multi_types( + context.substitutor.infer_multi_types( tpl_id, target_rest_types .iter() - .map(|t| constant_decay(t.clone())) + .map(|t| maybe_decay_type(t, decay)) .collect(), ); } } } - LuaType::ConstTplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - match target_rest_types.len() { - 0 => { - context.substitutor.insert_multi_types(tpl_id, Vec::new()); - } - 1 => { - context.substitutor.insert_type( - tpl_id, - target_rest_types[0].clone(), - false, - ); - } - _ => { - context - .substitutor - .insert_multi_types(tpl_id, target_rest_types.to_vec()); - } - } - } - _ => {} - }, + } VariadicType::Multi(multi) => { for (i, ret_type) in multi.iter().enumerate() { match ret_type { @@ -893,7 +874,7 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.get(i) { Some(t) => { - context.substitutor.insert_type(tpl_id, t.clone(), true); + context.substitutor.infer_type(tpl_id, t.clone(), true); } None => { break; @@ -946,7 +927,7 @@ fn tuple_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); context .substitutor - .insert_multi_base(tpl_id, target_array_base.get_base().clone()); + .infer_multi_base(tpl_id, target_array_base.get_base().clone()); } } VariadicType::Multi(_) => {} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs index 10a1733ba..1891820d5 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs @@ -4,19 +4,10 @@ use std::{cell::RefCell, rc::Rc}; use super::tpl_pattern::constant_decay; use crate::{DbIndex, GenericTplId, LuaSignatureId, LuaType, LuaTypeDeclId}; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(super) enum UninferredTplPolicy { - /// 未推断模板按 `default -> constraint -> unknown` 推断成实际类型. - Fallback, - /// 没有默认值的未推断模板仍保留为 `TplRef`, 让后续调用点继续参与参数推导. - PreserveTplRef, -} - #[derive(Debug)] pub struct GenericInstantiateContext<'a> { pub db: &'a DbIndex, pub substitutor: &'a TypeSubstitutor, - policy: UninferredTplPolicy, instantiating_signatures: Rc>>, } @@ -25,20 +16,10 @@ impl<'a> GenericInstantiateContext<'a> { Self { db, substitutor, - policy: UninferredTplPolicy::Fallback, instantiating_signatures: Rc::new(RefCell::new(HashSet::new())), } } - pub(super) fn with_policy(&self, policy: UninferredTplPolicy) -> GenericInstantiateContext<'a> { - GenericInstantiateContext { - db: self.db, - substitutor: self.substitutor, - policy, - instantiating_signatures: self.instantiating_signatures.clone(), - } - } - pub fn with_substitutor<'b>( &'b self, substitutor: &'b TypeSubstitutor, @@ -46,15 +27,10 @@ impl<'a> GenericInstantiateContext<'a> { GenericInstantiateContext { db: self.db, substitutor, - policy: self.policy, instantiating_signatures: self.instantiating_signatures.clone(), } } - pub fn should_preserve_tpl_ref(&self) -> bool { - self.policy == UninferredTplPolicy::PreserveTplRef - } - pub(super) fn enter_signature( &self, signature_id: LuaSignatureId, @@ -169,6 +145,17 @@ impl TypeSubstitutor { self.insert_type_value(tpl_id, SubstitutorTypeValue::new(replace_type, decay)); } + pub fn infer_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType, decay: bool) { + if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { + return; + } + + self.tpl_replace_map.insert( + tpl_id, + SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, decay)), + ); + } + pub(super) fn replace_type( &mut self, tpl_id: GenericTplId, @@ -214,6 +201,12 @@ impl TypeSubstitutor { true } + fn can_infer_type(&self, tpl_id: GenericTplId) -> bool { + self.tpl_replace_map + .get(&tpl_id) + .is_some_and(SubstitutorValue::is_none) + } + pub fn insert_params(&mut self, tpl_id: GenericTplId, params: Vec<(String, Option)>) { if tpl_id.is_conditional_infer() { return; @@ -232,6 +225,20 @@ impl TypeSubstitutor { .insert(tpl_id, SubstitutorValue::Params(params)); } + pub fn infer_params(&mut self, tpl_id: GenericTplId, params: Vec<(String, Option)>) { + if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { + return; + } + + let params = params + .into_iter() + .map(|(name, ty)| (name, ty.map(into_ref_type))) + .collect(); + + self.tpl_replace_map + .insert(tpl_id, SubstitutorValue::Params(params)); + } + pub fn insert_multi_types(&mut self, tpl_id: GenericTplId, types: Vec) { if tpl_id.is_conditional_infer() { return; @@ -245,6 +252,15 @@ impl TypeSubstitutor { .insert(tpl_id, SubstitutorValue::MultiTypes(types)); } + pub fn infer_multi_types(&mut self, tpl_id: GenericTplId, types: Vec) { + if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { + return; + } + + self.tpl_replace_map + .insert(tpl_id, SubstitutorValue::MultiTypes(types)); + } + pub fn insert_multi_base(&mut self, tpl_id: GenericTplId, type_base: LuaType) { if tpl_id.is_conditional_infer() { return; @@ -258,10 +274,34 @@ impl TypeSubstitutor { .insert(tpl_id, SubstitutorValue::MultiBase(type_base)); } + pub fn infer_multi_base(&mut self, tpl_id: GenericTplId, type_base: LuaType) { + if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { + return; + } + + self.tpl_replace_map + .insert(tpl_id, SubstitutorValue::MultiBase(type_base)); + } + pub fn get(&self, tpl_id: GenericTplId) -> Option<&SubstitutorValue> { self.tpl_replace_map.get(&tpl_id) } + pub(super) fn without_pending_tpls(&self, tpl_ids: &HashSet) -> Self { + let mut substitutor = self.clone(); + for tpl_id in tpl_ids { + if substitutor + .tpl_replace_map + .get(tpl_id) + .is_some_and(SubstitutorValue::is_none) + { + substitutor.tpl_replace_map.remove(tpl_id); + } + } + + substitutor + } + pub fn get_raw_type(&self, tpl_id: GenericTplId) -> Option<&LuaType> { match self.tpl_replace_map.get(&tpl_id) { Some(SubstitutorValue::Type(ty)) => Some(ty.raw()), diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index cd1360cf2..22419c5c5 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -17,14 +17,11 @@ use crate::{ use crate::{ InferGuardRef, semantic::{ - generic::{ - TypeSubstitutor, collect_callable_overload_groups, get_tpl_ref_extend_type, - instantiate_doc_function, - }, - infer::narrow::get_type_at_call_expr_inline_cast, + generic::TypeSubstitutor, infer::narrow::get_type_at_call_expr_inline_cast, + overload_resolve::collect_callable_overload_groups, }, }; -use crate::{build_self_type, infer_self_type, instantiate_func_generic, semantic::infer_expr}; +use crate::{build_self_type, infer_call_generic, infer_self_type, semantic::infer_expr}; use infer_require::infer_require_call; use infer_setmetatable::infer_setmetatable_call; @@ -99,7 +96,7 @@ pub fn infer_call_expr_func( ), LuaType::Instance(inst) => infer_instance_type_doc_function(db, inst), LuaType::TableConst(meta_table) => infer_table_type_doc_function(db, meta_table.clone()), - LuaType::TplRef(_) | LuaType::ConstTplRef(_) | LuaType::StrTplRef(_) => infer_tpl_ref_call( + LuaType::TplRef(_) | LuaType::StrTplRef(_) => infer_tpl_ref_call( db, cache, call_expr.clone(), @@ -113,6 +110,7 @@ pub fn infer_call_expr_func( true, vec![("...".to_string(), Some(LuaType::Unknown))], LuaType::Variadic(VariadicType::Base(LuaType::Unknown).into()), + None, ))), LuaType::Intersection(intersection) => infer_intersection( db, @@ -128,6 +126,7 @@ pub fn infer_call_expr_func( true, vec![], LuaType::Any, + None, ))), LuaType::Union(union) => infer_union(db, cache, union, call_expr.clone(), args_count), _ => Err(InferFailReason::None), @@ -136,7 +135,7 @@ pub fn infer_call_expr_func( let result = if let Ok(func_ty) = result { let func_ty = match func_ty.get_ret() { LuaType::Call(_) => { - match instantiate_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { + match infer_call_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { Ok(func_ty) => Arc::new(func_ty), Err(_) => func_ty, } @@ -154,6 +153,7 @@ pub fn infer_call_expr_func( func_ty.is_variadic(), func_ty.get_params().to_vec(), new_ret, + Some(func_ty.get_generic_params().to_vec()), ) .into() }), @@ -207,9 +207,12 @@ fn infer_tpl_ref_call( infer_guard: &InferGuardRef, args_count: Option, ) -> InferCallFuncResult { - let prefix_expr = call_expr.get_prefix_expr().ok_or(InferFailReason::None)?; - let extend_type = get_tpl_ref_extend_type(db, cache, call_expr_type, prefix_expr, 0) - .ok_or(InferFailReason::None)?; + let extend_type = match call_expr_type { + LuaType::TplRef(tpl) => tpl.get_constraint().cloned(), + LuaType::StrTplRef(str_tpl) => str_tpl.get_constraint().cloned(), + _ => None, + } + .ok_or(InferFailReason::None)?; if &extend_type == call_expr_type { return Err(InferFailReason::None); } @@ -223,7 +226,7 @@ fn infer_doc_function( call_expr: LuaCallExpr, ) -> InferCallFuncResult { if func.contain_tpl() { - let result = instantiate_func_generic(db, cache, func, call_expr)?; + let result = infer_call_generic(db, cache, func, call_expr)?; return Ok(Arc::new(result)); } @@ -254,16 +257,11 @@ fn filter_callable_overloads_by_call_args( Ok(overloads .into_iter() .filter(|func| { - let mut callable_tpls = HashSet::new(); - func.visit_type(&mut |ty| match ty { - LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { - callable_tpls.insert(generic_tpl.get_tpl_id()); - } - LuaType::StrTplRef(str_tpl) => { - callable_tpls.insert(str_tpl.get_tpl_id()); - } - _ => {} - }); + let callable_tpls = func + .get_generic_params() + .iter() + .map(|generic_tpl| generic_tpl.get_tpl_id()) + .collect::>(); if callable_tpls.is_empty() && !strict_arg_filter { return true; @@ -273,7 +271,8 @@ fn filter_callable_overloads_by_call_args( let mut substitutor = TypeSubstitutor::new(); substitutor.add_need_infer_tpls(callable_tpls); let match_func = if has_tpls { - match instantiate_doc_function(db, func, &substitutor) { + let func_type = LuaType::DocFunction(func.clone()); + match instantiate_type_generic(db, &func_type, &substitutor) { LuaType::DocFunction(doc_func) => doc_func, _ => func.clone(), } @@ -353,22 +352,21 @@ fn infer_type_doc_function( let has_generic_tpl = { let mut has_generic_tpl = false; f.visit_type(&mut |t| { - has_generic_tpl |= matches!( - t, - LuaType::TplRef(_) | LuaType::ConstTplRef(_) | LuaType::StrTplRef(_) - ); + has_generic_tpl |= matches!(t, LuaType::TplRef(_) | LuaType::StrTplRef(_)); }); has_generic_tpl }; if has_generic_tpl { - let result = instantiate_func_generic(db, cache, &f, call_expr.clone())?; + let result = infer_call_generic(db, cache, &f, call_expr.clone())?; overloads.push(Arc::new(result)); } else if f.contain_self() { let mut substitutor = TypeSubstitutor::new(); let self_type = build_self_type(db, call_expr_type); substitutor.add_self_type(self_type); - if let LuaType::DocFunction(f) = instantiate_doc_function(db, &f, &substitutor) + let func_type = LuaType::DocFunction(f.clone()); + if let LuaType::DocFunction(f) = + instantiate_type_generic(db, &func_type, &substitutor) { overloads.push(f); } @@ -612,6 +610,7 @@ fn infer_union( first_func.is_variadic(), first_func.get_params().to_vec(), LuaType::from_vec(returns), + Some(first_func.get_generic_params().to_vec()), ))) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs index 288617c82..aee32d9f5 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs @@ -497,6 +497,7 @@ fn infer_func_type(ctx: DocTypeInferContext<'_>, func: &LuaDocFuncType) -> LuaTy is_variadic, params_result, return_type, + None, ) .into(), ) diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs index 85fd20a77..aadd064da 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs @@ -68,11 +68,7 @@ pub(super) fn infer_array_member_by_key( fn key_type_matches(db: &DbIndex, expected: &LuaType, actual: &LuaType) -> bool { !matches!( actual, - LuaType::Any - | LuaType::Unknown - | LuaType::TplRef(_) - | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) + LuaType::Any | LuaType::Unknown | LuaType::TplRef(_) | LuaType::StrTplRef(_) ) && check_type_compact(db, expected, actual).is_ok() } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs index 1f3ad09ce..b972d2f7c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs @@ -15,7 +15,7 @@ use crate::{ DbIndex, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaObjectType, LuaOperatorMetaMethod, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, }, - enum_variable_is_param, get_keyof_members, get_tpl_ref_extend_type, + enum_variable_is_param, get_keyof_members, semantic::{ InferGuard, generic::{TypeSubstitutor, instantiate_type_generic}, @@ -1183,18 +1183,9 @@ fn infer_tpl_ref_member( lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> InferResult { - let extend_type = get_tpl_ref_extend_type( - db, - cache, - &LuaType::TplRef(generic.clone().into()), - lookup - .index_expr - .get_index_expr() - .ok_or(InferFailReason::None)? - .get_prefix_expr() - .ok_or(InferFailReason::None)?, - 0, - ) - .ok_or(InferFailReason::None)?; + let extend_type = generic + .get_constraint() + .cloned() + .ok_or(InferFailReason::None)?; infer_member_by_lookup(db, cache, &extend_type, lookup, infer_guard) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index e915915b1..47dd87510 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -16,7 +16,7 @@ use crate::{ var_ref_id::get_var_expr_var_ref_id, }, }, - semantic::instantiate_func_generic, + semantic::infer_call_generic, }; pub fn get_type_at_call_expr( @@ -225,9 +225,9 @@ fn get_type_guard_call_info( let mut return_type = func_type.get_ret().clone(); if return_type.contain_tpl() { - let Ok(inst_func) = cache.with_no_flow(|cache| { - instantiate_func_generic(db, cache, func_type.as_ref(), call_expr) - }) else { + let Ok(inst_func) = cache + .with_no_flow(|cache| infer_call_generic(db, cache, func_type.as_ref(), call_expr)) + else { return Ok(None); }; return_type = inst_func.get_ret().clone(); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index a2d06dd5f..324498e47 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -7,7 +7,7 @@ use crate::{ LuaSignature, LuaType, TypeOps, semantic::{ infer::{InferResult, VarRefId, narrow::narrow_down_type, try_infer_expr_no_flow}, - instantiate_func_generic, + infer_call_generic, }, }; @@ -575,10 +575,9 @@ fn instantiate_return_rows( signature.is_vararg, signature.get_type_params(), return_type.clone(), + Some(signature.get_function_generic_params()), ); - match cache - .with_no_flow(|cache| instantiate_func_generic(db, cache, &func, call_expr.clone())) - { + match cache.with_no_flow(|cache| infer_call_generic(db, cache, &func, call_expr.clone())) { Ok(instantiated) => instantiated.get_ret().clone(), Err(_) => return_type, } diff --git a/crates/emmylua_code_analysis/src/semantic/mod.rs b/crates/emmylua_code_analysis/src/semantic/mod.rs index b9cd692f9..e98f48c25 100644 --- a/crates/emmylua_code_analysis/src/semantic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/mod.rs @@ -58,6 +58,7 @@ pub use infer::infer_call_expr_func; pub use infer::infer_param; pub(crate) use infer::try_infer_expr_for_index; pub(crate) use infer::{infer_expr, try_infer_expr_no_flow}; +pub(crate) use overload_resolve::collect_callable_overload_groups; use overload_resolve::resolve_signature; pub use semantic_info::SemanticDeclLevel; pub use type_check::{TypeCheckFailReason, TypeCheckResult}; diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs new file mode 100644 index 000000000..f493e8e8b --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs @@ -0,0 +1,86 @@ +use hashbrown::HashSet; +use std::sync::Arc; + +use crate::{ + DbIndex, LuaTypeDeclId, + db_index::{LuaFunctionType, LuaType}, + semantic::{generic::TypeSubstitutor, infer::InferFailReason}, +}; + +pub(crate) fn collect_callable_overload_groups( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, +) -> Result<(), InferFailReason> { + let mut visiting_aliases = HashSet::new(); + collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) +} + +fn collect_callable_overload_groups_inner( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, + visiting_aliases: &mut HashSet, +) -> Result<(), InferFailReason> { + match callable_type { + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { + return Ok(()); + }; + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + + let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(type_id); + result?; + } + LuaType::Generic(generic) => { + let type_id = generic.get_base_type_id(); + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); + let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { + visiting_aliases.remove(&type_id); + return Ok(()); + }; + + let result = if let Some(origin_type) = + type_decl.get_alias_origin(db, Some(&substitutor)) + { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(&type_id); + result?; + } + LuaType::Union(union) => { + for member in union.into_vec() { + collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; + } + } + LuaType::Intersection(intersection) => { + for member in intersection.get_types() { + collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; + } + } + LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), + LuaType::Signature(sig_id) => { + let Some(signature) = db.get_signature_index().get(sig_id) else { + return Ok(()); + }; + let mut overloads = signature.overloads.to_vec(); + overloads.push(signature.to_doc_func_type()); + groups.push(overloads); + } + _ => {} + } + + Ok(()) +} diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index a6447a91c..5486cd571 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -1,3 +1,4 @@ +mod collect_overloads; mod resolve_signature_by_args; use std::sync::Arc; @@ -8,10 +9,11 @@ use crate::db_index::{DbIndex, LuaFunctionType, LuaType}; use super::{ LuaInferCache, - generic::instantiate_func_generic, + generic::infer_call_generic, infer::{InferCallFuncResult, InferFailReason, infer_expr_list_types, try_infer_expr_no_flow}, }; +pub(crate) use collect_overloads::collect_callable_overload_groups; pub(crate) use resolve_signature_by_args::{callable_accepts_args, resolve_signature_by_args}; pub fn resolve_signature( @@ -78,7 +80,7 @@ fn resolve_signature_by_generic( ) -> InferCallFuncResult { let mut instantiate_funcs = Vec::new(); for func in overloads { - let instantiate_func = instantiate_func_generic(db, cache, &func, call_expr.clone())?; + let instantiate_func = infer_call_generic(db, cache, &func, call_expr.clone())?; instantiate_funcs.push(Arc::new(instantiate_func)); } resolve_signature_by_args( diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs index 215d71842..35e737587 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs @@ -124,7 +124,6 @@ fn check_general_type_compact( | LuaType::DocBooleanConst(_) | LuaType::TplRef(_) | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) | LuaType::Namespace(_) | LuaType::Variadic(_) | LuaType::Language(_) => { @@ -195,11 +194,7 @@ fn check_general_type_compact( fn is_like_any(ty: &LuaType) -> bool { matches!( ty, - LuaType::Any - | LuaType::Unknown - | LuaType::TplRef(_) - | LuaType::StrTplRef(_) - | LuaType::ConstTplRef(_) + LuaType::Any | LuaType::Unknown | LuaType::TplRef(_) | LuaType::StrTplRef(_) ) } diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs index 6f994cb0c..123279d46 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs @@ -263,7 +263,7 @@ pub fn check_simple_type_compact( return Ok(()); } } - LuaType::TplRef(_) | LuaType::ConstTplRef(_) => return Ok(()), + LuaType::TplRef(_) => return Ok(()), LuaType::Namespace(source_namespace) => { if let LuaType::Namespace(compact_namespace) = compact_type && source_namespace == compact_namespace diff --git a/crates/emmylua_doc_cli/src/json_generator/export.rs b/crates/emmylua_doc_cli/src/json_generator/export.rs index 8e4284d5e..5ac500acc 100644 --- a/crates/emmylua_doc_cli/src/json_generator/export.rs +++ b/crates/emmylua_doc_cli/src/json_generator/export.rs @@ -194,7 +194,7 @@ fn export_generics(db: &DbIndex, type_decl_id: &LuaTypeDeclId) -> Vec { .map(|it| TypeVar { name: it.name.to_string(), base: it - .type_constraint + .constraint .as_ref() .map(|typ| render_typ(db, typ, RenderLevel::Simple)), }) diff --git a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs index 0e9050f4b..8aec73cba 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs @@ -1,5 +1,5 @@ use emmylua_code_analysis::{ - DbIndex, GenericTplId, InferGuard, InferGuardRef, LuaAliasCallKind, LuaAliasCallType, + DbIndex, GenericTpl, InferGuard, InferGuardRef, LuaAliasCallKind, LuaAliasCallType, LuaDeclLocation, LuaFunctionType, LuaMember, LuaMemberKey, LuaMemberOwner, LuaMultiLineUnion, LuaSemanticDeclId, LuaStringTplType, LuaType, LuaTypeCache, LuaTypeDeclId, LuaUnionType, RenderLevel, SemanticDeclLevel, TypeSubstitutor, build_call_constraint_context, get_real_type, @@ -151,11 +151,8 @@ pub fn dispatch_type( LuaType::StrTplRef(key) => { add_str_tpl_ref_completion(builder, &key); } - LuaType::ConstTplRef(tpl) => { - return add_const_tpl_ref_completion(builder, &tpl.get_tpl_id(), infer_guard); - } LuaType::TplRef(tpl) => { - return add_tpl_ref_completion(builder, &tpl.get_tpl_id(), infer_guard); + return add_tpl_ref_completion(builder, &tpl, infer_guard); } LuaType::Call(special_call) => { add_special_call_completion(builder, &special_call); @@ -401,7 +398,7 @@ fn rebuild_keyof_alias_call( substitutor: &TypeSubstitutor, ) -> Option { let tpl = match original_type { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl, + LuaType::TplRef(tpl) => tpl, _ => return None, }; let constraint = tpl.get_constraint()?; @@ -809,9 +806,8 @@ fn add_str_tpl_ref_completion( let db = builder.semantic_model.get_db(); let module_index = db.get_module_index(); let types = db.get_type_index().get_all_types(); - let tpl_id = str_tpl.get_tpl_id(); // 泛型约束 - let extend_type = get_tpl_ref_extend_type(builder, &tpl_id).unwrap_or(LuaType::Any); + let extend_type = str_tpl.get_constraint().cloned().unwrap_or(LuaType::Any); let mut completion_items: Vec<_> = types .into_iter() @@ -863,16 +859,6 @@ fn add_str_tpl_ref_completion( Some(()) } -fn add_const_tpl_ref_completion( - builder: &mut CompletionBuilder, - tpl_id: &GenericTplId, - infer_guard: &InferGuardRef, -) -> Option { - // 泛型约束 - let extend_type = get_tpl_ref_extend_type(builder, tpl_id)?; - dispatch_type(builder, extend_type, infer_guard) -} - fn add_special_call_completion( builder: &mut CompletionBuilder, alias_call: &LuaAliasCallType, @@ -896,36 +882,6 @@ fn add_special_call_completion( Some(()) } -fn get_tpl_ref_extend_type(builder: &CompletionBuilder, tpl_id: &GenericTplId) -> Option { - let token = builder.trigger_token.clone(); - let mut parent_node = token.parent()?; - if LuaLiteralExpr::can_cast(parent_node.kind().into()) { - parent_node = parent_node.parent()?; - } - match parent_node.kind().into() { - LuaSyntaxKind::CallArgList => { - let call_expr = LuaCallArgList::cast(parent_node)?.get_parent::()?; - let function = builder - .semantic_model - .infer_expr(call_expr.get_prefix_expr()?.clone()) - .ok()?; - if let LuaType::Signature(signature_id) = function { - let signature = builder - .semantic_model - .get_db() - .get_signature_index() - .get(&signature_id)?; - let generic_param = signature.generic_params.get(tpl_id.get_idx()); - if let Some(generic_param) = generic_param { - return Some(generic_param.constraint.clone().unwrap_or(LuaType::Any)); - } - } - None - } - _ => None, - } -} - /// 确保所有成员均为 function 或者 nil, 然后返回 function 的联合类型, 如果非 function 则返回 None pub fn get_function_remove_nil(db: &DbIndex, typ: &LuaType) -> Option { match typ { @@ -964,9 +920,9 @@ pub fn get_function_remove_nil(db: &DbIndex, typ: &LuaType) -> Option { fn add_tpl_ref_completion( builder: &mut CompletionBuilder, - tpl_id: &GenericTplId, + tpl: &GenericTpl, infer_guard: &InferGuardRef, ) -> Option { - let extend_type = get_tpl_ref_extend_type(builder, tpl_id)?; + let extend_type = tpl.get_constraint().cloned()?; dispatch_type(builder, extend_type, infer_guard) } diff --git a/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs index 373134ddb..85a3dd299 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs @@ -1,6 +1,6 @@ use emmylua_code_analysis::{ DbIndex, LuaMemberInfo, LuaMemberKey, LuaSemanticDeclId, LuaType, LuaTypeDeclId, SemanticModel, - enum_variable_is_param, get_tpl_ref_extend_type, + enum_variable_is_param, }; use emmylua_parser::{LuaAstNode, LuaAstToken, LuaIndexExpr, LuaStringToken, LuaSyntaxToken}; use std::collections::HashMap; @@ -59,13 +59,7 @@ fn complete_provider(builder: &mut CompletionBuilder) -> Option<()> { .infer_expr(prefix_expr.clone()) .ok()? { - LuaType::TplRef(tpl) => get_tpl_ref_extend_type( - builder.semantic_model.get_db(), - &mut builder.semantic_model.get_cache().borrow_mut(), - &LuaType::TplRef(tpl.clone()), - prefix_expr.clone(), - 0, - )?, + LuaType::TplRef(tpl) => tpl.get_constraint().cloned()?, prefix_type => prefix_type, }; // 如果是枚举类型且为函数参数, 则不进行补全 diff --git a/crates/emmylua_ls/src/handlers/definition/goto_function.rs b/crates/emmylua_ls/src/handlers/definition/goto_function.rs index f60d9b395..adf505c09 100644 --- a/crates/emmylua_ls/src/handlers/definition/goto_function.rs +++ b/crates/emmylua_ls/src/handlers/definition/goto_function.rs @@ -1,6 +1,6 @@ use emmylua_code_analysis::{ LuaCompilation, LuaDeclId, LuaFunctionType, LuaSemanticDeclId, LuaSignature, LuaSignatureId, - LuaType, SemanticDeclLevel, SemanticModel, instantiate_func_generic, + LuaType, SemanticDeclLevel, SemanticModel, infer_call_generic, }; use emmylua_parser::{ LuaAstNode, LuaCallExpr, LuaExpr, LuaLiteralToken, LuaSyntaxToken, LuaTokenKind, @@ -291,7 +291,7 @@ pub fn compare_function_types( call_expr: &LuaCallExpr, ) -> Option { if func.contain_tpl() { - let instantiated_func = instantiate_func_generic( + let instantiated_func = infer_call_generic( semantic_model.get_db(), &mut semantic_model.get_cache().borrow_mut(), func, diff --git a/crates/emmylua_ls/src/handlers/hover/function/mod.rs b/crates/emmylua_ls/src/handlers/hover/function/mod.rs index 2402e078e..d4be62913 100644 --- a/crates/emmylua_ls/src/handlers/hover/function/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/function/mod.rs @@ -3,8 +3,8 @@ use std::{collections::HashSet, sync::Arc, vec}; use emmylua_code_analysis::{ AsyncState, DbIndex, InferGuard, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaFunctionType, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, - TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, instantiate_doc_function, - instantiate_func_generic, try_extract_signature_id_from_field, + TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, infer_call_generic, + instantiate_type_generic, try_extract_signature_id_from_field, }; use crate::handlers::hover::{ @@ -88,7 +88,7 @@ fn build_function_call_hover( let function_member = match match_semantic_decl { LuaSemanticDeclId::Member(id) => { - let member = db.get_member_index().get_member(&id)?; + let member = db.get_member_index().get_member(id)?; Some(member) } _ => None, @@ -103,8 +103,9 @@ fn build_function_call_hover( signature.is_vararg, signature.get_type_params(), signature.get_return_type(), + Some(signature.get_function_generic_params()), ); - let instantiated_signature = instantiate_func_generic( + let instantiated_signature = infer_call_generic( db, &mut builder.semantic_model.get_cache().borrow_mut(), &base_function, @@ -182,7 +183,7 @@ fn build_function_define_hover( let mut typ = typ.clone(); let function_member = match semantic_decl_id { LuaSemanticDeclId::Member(id) => { - let member = db.get_member_index().get_member(&id)?; + let member = db.get_member_index().get_member(id)?; Some(member) } _ => None, @@ -284,6 +285,7 @@ fn process_function_type( signature.is_vararg, signature.get_type_params(), signature.get_return_type(), + Some(signature.get_function_generic_params()), )); new_overloads.insert(0, fake_doc_function.clone()); let mut contents = Vec::with_capacity(new_overloads.len()); @@ -484,9 +486,10 @@ fn instantiate_call_return_overloads( signature.is_vararg, signature.get_type_params(), row_return_type, + Some(signature.get_function_generic_params()), ); let instantiated_row = - instantiate_func_generic(db, &mut cache, &row_function, call_expr.clone()) + infer_call_generic(db, &mut cache, &row_function, call_expr.clone()) .ok() .map(|func| match func.get_ret() { LuaType::Variadic(variadic) => match variadic.as_ref() { @@ -504,7 +507,6 @@ fn instantiate_call_return_overloads( }) .collect() } - fn convert_function_return_to_docs(func: &LuaFunctionType) -> Vec { match func.get_ret() { LuaType::Variadic(variadic) => match variadic.as_ref() { @@ -702,8 +704,8 @@ fn hover_instantiate_function_type( return None; } match typ { - LuaType::DocFunction(f) => { - if let LuaType::DocFunction(f) = instantiate_doc_function(db, f, substitutor) { + LuaType::DocFunction(_) => { + if let LuaType::DocFunction(f) = instantiate_type_generic(db, typ, substitutor) { Some(f) } else { None diff --git a/crates/emmylua_ls/src/handlers/test/completion_test.rs b/crates/emmylua_ls/src/handlers/test/completion_test.rs index 5ee0ad588..f90464470 100644 --- a/crates/emmylua_ls/src/handlers/test/completion_test.rs +++ b/crates/emmylua_ls/src/handlers/test/completion_test.rs @@ -2322,8 +2322,6 @@ mod tests { r#" ---@alias std.RawGet unknown - ---@alias std.ConstTpl unknown - ---@generic T, K extends keyof T ---@param object T ---@param key K diff --git a/crates/emmylua_parser/locales/app.yml b/crates/emmylua_parser/locales/app.yml index 4be4e230b..dc4086e89 100644 --- a/crates/emmylua_parser/locales/app.yml +++ b/crates/emmylua_parser/locales/app.yml @@ -555,3 +555,7 @@ unfinished long comment: en: unfinished long comment zh_CN: 未完成的长注释 zh_HK: 未完成的長註釋 +Identifier expected. '%{reserved}' is a reserved word that cannot be used here.: + en: Identifier expected. '%{reserved}' is a reserved word that cannot be used here. + zh_CN: 应为标识符。'%{reserved}' 是保留字,不能在此处使用。 + zh_HK: 應為標識符。'%{reserved}' 是保留字,不能在此處使用。 diff --git a/crates/emmylua_parser/src/grammar/doc/tag.rs b/crates/emmylua_parser/src/grammar/doc/tag.rs index 563f5acf2..1fab60933 100644 --- a/crates/emmylua_parser/src/grammar/doc/tag.rs +++ b/crates/emmylua_parser/src/grammar/doc/tag.rs @@ -151,6 +151,7 @@ pub(super) fn parse_generic_decl_list( // A = type fn parse_generic_param(p: &mut LuaDocParser) -> DocParseResult { let m = p.mark(LuaSyntaxKind::DocGenericParameter); + parse_generic_modifier(p)?; expect_token(p, LuaTokenKind::TkName)?; if p.current_token() == LuaTokenKind::TkDots { p.bump(); @@ -169,6 +170,25 @@ fn parse_generic_param(p: &mut LuaDocParser) -> DocParseResult { Ok(m.complete(p)) } +fn parse_generic_modifier(p: &mut LuaDocParser) -> Result<(), LuaParseError> { + if p.current_token() == LuaTokenKind::TkName && p.current_token_text() == "const" { + let range = p.current_token_range(); + p.set_current_token_kind(LuaTokenKind::TkDocConst); + p.bump(); + if p.current_token() != LuaTokenKind::TkName { + return Err(LuaParseError::doc_error_from( + &t!( + "Identifier expected. '%{reserved}' is a reserved word that cannot be used here.", + reserved = "const" + ), + range, + )); + } + } + + Ok(()) +} + // ---@enum A // ---@enum A : number fn parse_tag_enum(p: &mut LuaDocParser) -> DocParseResult { diff --git a/crates/emmylua_parser/src/grammar/doc/test.rs b/crates/emmylua_parser/src/grammar/doc/test.rs index 12ab4db84..d282b1ae5 100644 --- a/crates/emmylua_parser/src/grammar/doc/test.rs +++ b/crates/emmylua_parser/src/grammar/doc/test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use crate::{LuaParser, parser::ParserConfig}; + use crate::{LuaParseErrorKind, LuaParser, parser::ParserConfig}; macro_rules! assert_ast_eq { ($lua_code:expr, $expected:expr) => { @@ -1055,6 +1055,68 @@ Syntax(Chunk)@0..92 assert_ast_eq!(code, result); } + #[test] + fn test_generic_const_modifier_doc() { + let code = "---@class A\n---@generic const R\n---@alias B const\n"; + + let result = r#" +Syntax(Chunk)@0..62 + Syntax(Block)@0..62 + Syntax(Comment)@0..61 + Token(TkDocStart)@0..4 "---@" + Syntax(DocTagClass)@4..20 + Token(TkTagClass)@4..9 "class" + Token(TkWhitespace)@9..10 " " + Token(TkName)@10..11 "A" + Syntax(DocGenericDeclareList)@11..20 + Token(TkLt)@11..12 "<" + Syntax(DocGenericParameter)@12..19 + Token(TkDocConst)@12..17 "const" + Token(TkWhitespace)@17..18 " " + Token(TkName)@18..19 "T" + Token(TkGt)@19..20 ">" + Token(TkEndOfLine)@20..21 "\n" + Token(TkDocStart)@21..25 "---@" + Syntax(DocTagGeneric)@25..40 + Token(TkTagGeneric)@25..32 "generic" + Token(TkWhitespace)@32..33 " " + Syntax(DocGenericDeclareList)@33..40 + Syntax(DocGenericParameter)@33..40 + Token(TkDocConst)@33..38 "const" + Token(TkWhitespace)@38..39 " " + Token(TkName)@39..40 "R" + Token(TkEndOfLine)@40..41 "\n" + Token(TkDocStart)@41..45 "---@" + Syntax(DocTagAlias)@45..61 + Token(TkTagAlias)@45..50 "alias" + Token(TkWhitespace)@50..51 " " + Token(TkName)@51..52 "B" + Syntax(DocGenericDeclareList)@52..55 + Token(TkLt)@52..53 "<" + Syntax(DocGenericParameter)@53..54 + Token(TkName)@53..54 "T" + Token(TkGt)@54..55 ">" + Token(TkWhitespace)@55..56 " " + Syntax(TypeName)@56..61 + Token(TkName)@56..61 "const" + Token(TkEndOfLine)@61..62 "\n" + "#; + + assert_ast_eq!(code, result); + } + + #[test] + fn test_generic_const_modifier_requires_identifier() { + let tree = LuaParser::parse("---@class A\n", ParserConfig::default()); + let errors = tree.get_errors(); + + assert!(errors.iter().any(|error| { + error.kind == LuaParseErrorKind::DocError + && error.message + == "Identifier expected. 'const' is a reserved word that cannot be used here." + })); + } + #[test] fn test_diagnostic_doc() { let code = r#" diff --git a/crates/emmylua_parser/src/kind/lua_token_kind.rs b/crates/emmylua_parser/src/kind/lua_token_kind.rs index e0a18a35c..8c96b63b7 100644 --- a/crates/emmylua_parser/src/kind/lua_token_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_token_kind.rs @@ -147,6 +147,7 @@ pub enum LuaTokenKind { TkDocAs, // as TkDocIn, // in TkDocInfer, // infer + TkDocConst, // const TkDocElse, // else (for return_cast) TkDocContinue, // --- TkDocContinueOr, // ---| or ---|+ or ---|> diff --git a/crates/emmylua_parser/src/syntax/node/doc/mod.rs b/crates/emmylua_parser/src/syntax/node/doc/mod.rs index 04a5b43ad..a159b1bfe 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/mod.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/mod.rs @@ -365,6 +365,10 @@ impl LuaDocGenericDecl { pub fn is_variadic(&self) -> bool { self.token_by_kind(LuaTokenKind::TkDots).is_some() } + + pub fn has_const_modifier(&self) -> bool { + self.token_by_kind(LuaTokenKind::TkDocConst).is_some() + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/crates/emmylua_parser/src/syntax/node/doc/test.rs b/crates/emmylua_parser/src/syntax/node/doc/test.rs index 6854d7103..454e33526 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/test.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/test.rs @@ -206,4 +206,21 @@ mod test { "string" ); } + + #[test] + fn test_doc_generic_const_modifier_accessor() { + let tree = LuaParser::parse("---@class A\n", ParserConfig::default()); + let root = tree.get_chunk_node(); + let class = root.descendants::().next().unwrap(); + let generic_decl = class.get_generic_decl().unwrap(); + let mut params = generic_decl.get_generic_decl(); + + let const_param = params.next().unwrap(); + assert!(const_param.has_const_modifier()); + assert_eq!(const_param.get_name_token().unwrap().get_name_text(), "T"); + + let regular_param = params.next().unwrap(); + assert!(!regular_param.has_const_modifier()); + assert_eq!(regular_param.get_name_token().unwrap().get_name_text(), "U"); + } }