diff --git a/crates/hir/src/display.rs b/crates/hir/src/display.rs index df0d0c43..34b5a52a 100644 --- a/crates/hir/src/display.rs +++ b/crates/hir/src/display.rs @@ -106,9 +106,6 @@ impl HirDisplay for InContainer { match self.value { DataTy::Builtin(ty_id) => match ty_id.lookup(f.db) { BuiltinDataTy::Int { kind, signing } => { - if signing { - f.write_str("signed ")?; - } match kind { IntKind::Byte => f.write_str("byte"), IntKind::ShortInt => f.write_str("shortint"), @@ -116,27 +113,45 @@ impl HirDisplay for InContainer { IntKind::LongInt => f.write_str("longint"), IntKind::Integer => f.write_str("integer"), IntKind::Time => f.write_str("time"), + }?; + if signing { + f.write_str(" signed")?; } + Ok(()) } BuiltinDataTy::Vector { kind, signing, dimensions } => { - if signing { - f.write_str("signed ")?; - } + let mut wrote_head = false; match kind { VecKind::Bit => { if !f.simplified_ty { - f.write_str("bit")? + f.write_str("bit")?; + wrote_head = true; } } VecKind::Logic => { if !f.simplified_ty { - f.write_str("logic")? + f.write_str("logic")?; + wrote_head = true; } } - VecKind::Reg => f.write_str("reg")?, + VecKind::Reg => { + f.write_str("reg")?; + wrote_head = true; + } + } + if signing { + if wrote_head { + f.write_str(" ")?; + } + f.write_str("signed")?; + wrote_head = true; } for dim in dimensions.iter().flatten() { + if wrote_head { + f.write_str(" ")?; + } self.with_value(*dim).hir_fmt(f)?; + wrote_head = true; } Ok(()) } diff --git a/crates/hir/src/hir_def/aggregate.rs b/crates/hir/src/hir_def/aggregate.rs index 8b4b9250..f12acfec 100644 --- a/crates/hir/src/hir_def/aggregate.rs +++ b/crates/hir/src/hir_def/aggregate.rs @@ -75,6 +75,7 @@ pub(crate) fn lower_struct_def( #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct StructSrc { pub node: SyntaxNodePtr, + pub name: Option, } impl IsSrc for StructSrc { @@ -89,6 +90,18 @@ impl IsSrc for StructSrc { } } +impl IsNamedSrc for StructSrc { + #[inline] + fn name_kind(&self) -> Option { + self.name.map(|name| name.kind()) + } + + #[inline] + fn name_range(&self) -> Option { + self.name.map(|name| name.range()) + } +} + impl<'a> ToAstNode<'a, ast::StructUnionType<'a>> for StructSrc { fn to_node(&self, tree: &'a syntax::SyntaxTree) -> Option> { let mut node = self.node.to_node(tree)?; @@ -101,16 +114,28 @@ impl<'a> ToAstNode<'a, ast::StructUnionType<'a>> for StructSrc { impl From> for StructSrc { fn from(node: ast::StructUnionType<'_>) -> Self { - StructSrc { node: AstNodeExt::to_ptr(&node) } + let syntax = node.syntax(); + let name = struct_name_token(node).map(|name| SyntaxTokenPtr::from_token_in(syntax, name)); + StructSrc { node: AstNodeExt::to_ptr(&node), name } } } impl<'a> FromSourceAst<'a, ast::StructUnionType<'a>> for StructSrc { fn from_source_ast(node: SourceAst>) -> Self { - StructSrc { node: AstNodeExt::to_ptr(&node.into_inner()) } + let node = node.into_inner(); + let syntax = node.syntax(); + let name = struct_name_token(node) + .and_then(|name| root_token_in(syntax, name).map(SyntaxTokenPtr::from_token)); + StructSrc { node: AstNodeExt::to_ptr(&node), name } } } +fn struct_name_token(node: ast::StructUnionType<'_>) -> Option> { + let data_type = ast::DataType::StructUnionType(node); + let typedef = data_type.syntax().parent().and_then(ast::TypedefDeclaration::cast)?; + typedef.name() +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum ClassMemberKind { Property, diff --git a/crates/hir/src/hir_def/expr/data_ty.rs b/crates/hir/src/hir_def/expr/data_ty.rs index e4ce7250..9d420b67 100644 --- a/crates/hir/src/hir_def/expr/data_ty.rs +++ b/crates/hir/src/hir_def/expr/data_ty.rs @@ -140,8 +140,7 @@ impl LowerExprCtx<'_> { LogicType(_) => Either::Right(VecKind::Logic), }; - let signing = Self::lower_signing(ty.signing()) - .unwrap_or(matches!(kind, Either::Left(IntKind::Time) | Either::Right(_))); + let signing = Self::lower_signing(ty.signing()).unwrap_or(matches!(kind, Either::Left(_))); let dimensions = ty.dimensions().children().map(|dim| self.lower_dimension(dim)).collect(); match kind { diff --git a/crates/hir/src/hir_def/subroutine.rs b/crates/hir/src/hir_def/subroutine.rs index e833b25c..6c1c2e7a 100644 --- a/crates/hir/src/hir_def/subroutine.rs +++ b/crates/hir/src/hir_def/subroutine.rs @@ -24,7 +24,7 @@ use super::{ impl_lower_expr, timing_control::{EventExpr, EventExprSrc}, }, - lower_ident, lower_ident_opt, + lower_ident_opt, module::{ModuleId, generate::GenerateBlockId}, stmt::{LowerStmt, Stmt, StmtId, StmtSrc, impl_lower_stmt}, typedef::{Typedef, TypedefId, TypedefSrc, lower_typedef_data_ty}, @@ -236,10 +236,10 @@ where fn lower_name(name: ast::Name) -> Option { if let Some(id) = name.as_identifier_name().and_then(|n| n.identifier()) { - return lower_ident(Some(id)); + return lower_ident_opt(Some(id)); } if let Some(select) = name.as_identifier_select_name() { - return select.identifier().and_then(|tok| lower_ident(Some(tok))); + return select.identifier().and_then(|tok| lower_ident_opt(Some(tok))); } if let Some(scoped) = name.as_scoped_name() { return lower_name(scoped.right()); diff --git a/crates/ide/src/code_action/context.rs b/crates/ide/src/code_action/context.rs index 24ac41e6..e72ebe37 100644 --- a/crates/ide/src/code_action/context.rs +++ b/crates/ide/src/code_action/context.rs @@ -26,6 +26,7 @@ impl<'a> CodeActionCtx<'a> { ) -> Option { let parsed_file = sema.parse_file(file_id); parsed_file.compilation_unit()?; + Some(Self { sema, file_id, range, diagnostics, parsed_file }) } diff --git a/crates/ide/src/code_action/handlers.rs b/crates/ide/src/code_action/handlers.rs index 3289147b..33390b07 100644 --- a/crates/ide/src/code_action/handlers.rs +++ b/crates/ide/src/code_action/handlers.rs @@ -8,13 +8,21 @@ mod add_instance_parens; mod add_missing_connections; mod add_missing_parameters; mod apply_de_morgan; +mod convert_always_block; mod convert_literal_base; +mod convert_named_port_connections; mod convert_ordered_connections; +mod convert_port_declarations; mod expand_compound_assignment; mod expand_postfix_inc_dec; +mod extract_variable; mod insert_expected_token; mod invert_if_else; +mod merge_nested_if; +mod pull_assignment_up; +mod reformat_number_literal; mod remove_empty_port_connections; +mod remove_parentheses; mod sort_named_instantiation_items; mod split_declaration_declarators; mod wrap_statement_in_begin_end; @@ -22,22 +30,31 @@ mod wrap_statement_in_begin_end; pub(crate) fn all() -> &'static [Handler] { &[ convert_literal_base::convert_literal_base, + reformat_number_literal::reformat_number_literal, add_missing_connections::add_missing_connections, add_missing_parameters::add_missing_parameters, convert_ordered_connections::convert_ordered_ports, convert_ordered_connections::convert_ordered_params, + convert_named_port_connections::convert_named_port_connection_shorthand, remove_empty_port_connections::remove_empty_port_connections, add_implicit_named_port_parens::add_implicit_named_port_parens, add_instance_parens::add_instance_parens, + convert_always_block::convert_always_block, + convert_port_declarations::convert_port_declarations, split_declaration_declarators::split_declaration_declarators, sort_named_instantiation_items::sort_named_parameter_assignments, sort_named_instantiation_items::sort_named_port_connections, add_default_case_item::add_default_case_item, invert_if_else::invert_if_else, + merge_nested_if::merge_nested_if, wrap_statement_in_begin_end::unwrap_single_statement_block, wrap_statement_in_begin_end::wrap_statement_in_begin_end, + remove_parentheses::remove_parentheses, expand_postfix_inc_dec::expand_postfix_inc_dec, expand_compound_assignment::expand_compound_assignment, + extract_variable::extract_variable, + pull_assignment_up::pull_assignment_up, + pull_assignment_up::pull_assignment_down, apply_de_morgan::apply_de_morgan, insert_expected_token::insert_expected_token, ] diff --git a/crates/ide/src/code_action/handlers/add_default_case_item.rs b/crates/ide/src/code_action/handlers/add_default_case_item.rs index 3e635346..928063ef 100644 --- a/crates/ide/src/code_action/handlers/add_default_case_item.rs +++ b/crates/ide/src/code_action/handlers/add_default_case_item.rs @@ -12,6 +12,18 @@ const ID: CodeActionId = CodeActionId { name: "add_default_case_item", kind: CodeActionKind::Generate, repair: None }; const LABEL: &str = "Add default case item"; +// Assist: add_default_case_item +// +// This adds a `default` item to a case statement that does not already have +// one. +// +// ``` +// module top; always_comb case$0 (sel) 1'b0: y = a; endcase endmodule +// ``` +// -> +// ``` +// module top; always_comb case (sel) 1'b0: y = a; default: ; endcase endmodule +// ``` pub(super) fn add_default_case_item( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/add_implicit_named_port_parens.rs b/crates/ide/src/code_action/handlers/add_implicit_named_port_parens.rs index 31ec47aa..660ccae2 100644 --- a/crates/ide/src/code_action/handlers/add_implicit_named_port_parens.rs +++ b/crates/ide/src/code_action/handlers/add_implicit_named_port_parens.rs @@ -14,6 +14,18 @@ const ID: CodeActionId = CodeActionId { }; const LABEL: &str = "Add explicit empty port connection"; +// Assist: add_implicit_named_port_parens +// +// This makes an implicit named port connection explicit by adding empty +// parentheses. +// +// ``` +// child u(.ready$0); +// ``` +// -> +// ``` +// child u(.ready()); +// ``` pub(super) fn add_implicit_named_port_parens( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/add_instance_parens.rs b/crates/ide/src/code_action/handlers/add_instance_parens.rs index a25aeee8..56de01d1 100644 --- a/crates/ide/src/code_action/handlers/add_instance_parens.rs +++ b/crates/ide/src/code_action/handlers/add_instance_parens.rs @@ -11,6 +11,17 @@ const ID: CodeActionId = CodeActionId { }; const LABEL: &str = "Add empty instance port list"; +// Assist: add_instance_parens +// +// This adds an empty port list to an instance that is missing one. +// +// ``` +// child u$0; +// ``` +// -> +// ``` +// child u(); +// ``` pub(super) fn add_instance_parens( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/add_missing_connections.rs b/crates/ide/src/code_action/handlers/add_missing_connections.rs index 4bc703f4..607db3f6 100644 --- a/crates/ide/src/code_action/handlers/add_missing_connections.rs +++ b/crates/ide/src/code_action/handlers/add_missing_connections.rs @@ -1,13 +1,13 @@ use hir::{ base_db::source_db::SourceDb, container::InModule, db::HirDb, - hir_def::module::instantiation::PortConn, + hir_def::module::instantiation::PortConn, source_map::IsSrc, }; use rustc_hash::FxHashSet; use syntax::{ ast::{self, AstNode}, - has_text_range::{HasTextRange, HasTextRangeIn}, + has_text_range::HasTextRangeIn, }; -use utils::get::GetRef; +use utils::get::{Get, GetRef}; use crate::{ code_action::{ @@ -15,7 +15,7 @@ use crate::{ apply_missing_list_edit, missing_member_entry_text, port_names, remaining_ordered_port_names, }, - module_resolution::resolve_instantiation_target, + module_resolution::resolve_hir_instantiation_target, }; const ID: CodeActionId = CodeActionId { @@ -25,6 +25,18 @@ const ID: CodeActionId = CodeActionId { }; const LABEL: &str = "Fill connections"; +// Assist: add_missing_connections +// +// This fills the missing port connections for an instance from the target +// module definition. +// +// ``` +// child u($0.a(a)); +// ``` +// -> +// ``` +// child u(.a(a), .b('0)); +// ``` pub(super) fn add_missing_connections( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, @@ -36,14 +48,13 @@ pub(super) fn add_missing_connections( let ast_instance = ctx.find_node_at_offset::()?; let InModule { value: instance_id, module_id } = sema.resolve_instance(file_id, ast_instance)?; - let module = db.module(module_id); + let (module, module_src_map) = db.module_with_source_map(module_id); let instance = module.get(instance_id); let open_paren = ast_instance.open_paren()?.text_range_in(ast_instance.syntax())?; let close_paren = ast_instance.close_paren()?.text_range_in(ast_instance.syntax())?; - let instantiation = ast::HierarchyInstantiation::cast(ast_instance.syntax().parent()?)?; - let target_module_id = - resolve_instantiation_target(db, ctx.file_id(), instantiation).unique()?; + let instantiation = module.get(instance.parent); + let target_module_id = resolve_hir_instantiation_target(db, ctx.file_id(), instantiation)?; let target_module = db.module(target_module_id); let is_ordered = instance @@ -83,8 +94,8 @@ pub(super) fn add_missing_connections( .collect(); let text = sema.db.file_text(ctx.file_id()); - let item_ranges = ast_instance.connections().children().filter_map(|conn| { - let range = conn.syntax().text_range()?; + let item_ranges = instance.connections.iter().filter_map(|conn_id| { + let range = module_src_map.get(*conn_id)?.range(); (!range.is_empty()).then_some(range) }); apply_missing_list_edit(builder, &text, open_paren, close_paren, item_ranges, entries); diff --git a/crates/ide/src/code_action/handlers/add_missing_parameters.rs b/crates/ide/src/code_action/handlers/add_missing_parameters.rs index f0dee54f..b9f466d7 100644 --- a/crates/ide/src/code_action/handlers/add_missing_parameters.rs +++ b/crates/ide/src/code_action/handlers/add_missing_parameters.rs @@ -1,13 +1,13 @@ use hir::{ base_db::source_db::SourceDb, container::InModule, db::HirDb, - hir_def::module::instantiation::ParamAssign, + hir_def::module::instantiation::ParamAssign, source_map::IsSrc, }; use rustc_hash::FxHashSet; use syntax::{ ast::{self, AstNode}, - has_text_range::{HasTextRange, HasTextRangeIn}, + has_text_range::HasTextRangeIn, }; -use utils::get::GetRef; +use utils::get::{Get, GetRef}; use crate::{ code_action::{ @@ -15,7 +15,7 @@ use crate::{ all_parameter_names, apply_missing_list_edit, leading_parameter_names, missing_member_entry_text, }, - module_resolution::resolve_instantiation_target, + module_resolution::resolve_hir_instantiation_target, }; const ID: CodeActionId = CodeActionId { @@ -25,6 +25,18 @@ const ID: CodeActionId = CodeActionId { }; const LABEL: &str = "Fill parameters"; +// Assist: add_missing_parameters +// +// This fills the missing parameter assignments for an instantiation from the +// target module definition. +// +// ``` +// child #($0.A(1)) u(); +// ``` +// -> +// ``` +// child #(.A(1), .B(0)) u(); +// ``` pub(super) fn add_missing_parameters( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, @@ -36,15 +48,14 @@ pub(super) fn add_missing_parameters( let ast_instantiation = ctx.find_node_at_offset::()?; let InModule { value: instantiation_id, module_id } = sema.resolve_instantiation(file_id, ast_instantiation)?; - let module = db.module(module_id); + let (module, module_src_map) = db.module_with_source_map(module_id); let instantiation = module.get(instantiation_id); let params_node = ast_instantiation.parameters()?; let open_paren = params_node.open_paren()?.text_range_in(params_node.syntax())?; let close_paren = params_node.close_paren()?.text_range_in(params_node.syntax())?; - let target_module_id = - resolve_instantiation_target(db, ctx.file_id(), ast_instantiation).unique()?; + let target_module_id = resolve_hir_instantiation_target(db, ctx.file_id(), instantiation)?; let target_module = db.module(target_module_id); let is_ordered = instantiation @@ -87,8 +98,8 @@ pub(super) fn add_missing_parameters( .collect(); let text = sema.db.file_text(ctx.file_id()); - let item_ranges = params_node.parameters().children().filter_map(|assign| { - let range = assign.syntax().text_range()?; + let item_ranges = instantiation.param_assigns.iter().filter_map(|assign_id| { + let range = module_src_map.get(*assign_id)?.range(); (!range.is_empty()).then_some(range) }); apply_missing_list_edit(builder, &text, open_paren, close_paren, item_ranges, entries); diff --git a/crates/ide/src/code_action/handlers/apply_de_morgan.rs b/crates/ide/src/code_action/handlers/apply_de_morgan.rs index 465edb75..5f5d8ff3 100644 --- a/crates/ide/src/code_action/handlers/apply_de_morgan.rs +++ b/crates/ide/src/code_action/handlers/apply_de_morgan.rs @@ -16,6 +16,18 @@ const FACTOR_ID: CodeActionId = CodeActionId { name: "factor_de_morgan", kind: CodeActionKind::RefactorRewrite, repair: None }; const FACTOR_LABEL: &str = "Factor De Morgan's law"; +// Assist: apply_de_morgan +// +// This applies or factors De Morgan's law for boolean expressions and if +// conditions. +// +// ``` +// assign y = $0!(a && b); +// ``` +// -> +// ``` +// assign y = !a || !b; +// ``` pub(super) fn apply_de_morgan( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/convert_always_block.rs b/crates/ide/src/code_action/handlers/convert_always_block.rs new file mode 100644 index 00000000..36a713c4 --- /dev/null +++ b/crates/ide/src/code_action/handlers/convert_always_block.rs @@ -0,0 +1,153 @@ +use std::ops::Range; + +use hir::base_db::source_db::SourceDb; +use syntax::{ + ast::{self, AstNode}, + has_text_range::{HasTextRange, HasTextRangeIn}, +}; + +use crate::code_action::{CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind}; + +const ALWAYS_TO_COMB_ID: CodeActionId = CodeActionId { + name: "convert_always_to_always_comb", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; +const ALWAYS_TO_COMB_LABEL: &str = "Convert to always_comb"; + +const ALWAYS_TO_FF_ID: CodeActionId = CodeActionId { + name: "convert_always_to_always_ff", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; +const ALWAYS_TO_FF_LABEL: &str = "Convert to always_ff"; + +const ALWAYS_COMB_TO_ALWAYS_ID: CodeActionId = CodeActionId { + name: "convert_always_comb_to_always", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; +const ALWAYS_COMB_TO_ALWAYS_LABEL: &str = "Convert to always @(*)"; + +const ALWAYS_FF_TO_ALWAYS_ID: CodeActionId = CodeActionId { + name: "convert_always_ff_to_always", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; +const ALWAYS_FF_TO_ALWAYS_LABEL: &str = "Convert to always @(...)"; + +// Assist: convert_always_block +// +// This converts compatible procedural blocks between `always`, `always_comb`, +// and `always_ff`. +// +// ``` +// always$0 @(*) begin y = a; end +// ``` +// -> +// ``` +// always_comb begin y = a; end +// ``` +pub(super) fn convert_always_block( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + let proc = ctx.find_node_at_offset::()?; + let keyword = proc.keyword()?.text_range_in(proc.syntax())?; + let target = proc.syntax().text_range()?; + let mut allowed_ranges = vec![keyword]; + + match proc { + ast::ProceduralBlock::AlwaysBlock(_) => { + let timing_stmt = proc.statement().as_timing_control_statement()?; + let timing = timing_stmt.timing_control(); + allowed_ranges.push(timing.syntax().text_range()?); + if !range_intersects_any(ctx.range(), &allowed_ranges) { + return None; + } + + if timing.as_implicit_event_control().is_some() { + let stmt_range = timing_stmt.statement().syntax().text_range()?; + let text = ctx.sema().db.file_text(ctx.file_id()); + let stmt_text = text.get(Range::from(stmt_range))?; + collector.add(ALWAYS_TO_COMB_ID, ALWAYS_TO_COMB_LABEL, target, |builder| { + builder.replace(keyword, "always_comb"); + builder + .replace(timing_stmt.syntax().text_range().unwrap(), stmt_text.to_owned()); + }); + } + + if edge_sensitive_timing_control(timing) { + collector.add(ALWAYS_TO_FF_ID, ALWAYS_TO_FF_LABEL, target, |builder| { + builder.replace(keyword, "always_ff"); + }); + } + + Some(()) + } + ast::ProceduralBlock::AlwaysCombBlock(_) => { + if !range_intersects_any(ctx.range(), &allowed_ranges) { + return None; + } + + collector.add( + ALWAYS_COMB_TO_ALWAYS_ID, + ALWAYS_COMB_TO_ALWAYS_LABEL, + target, + |builder| { + builder.replace(keyword, "always"); + builder.insert(keyword.end(), " @(*)"); + }, + ) + } + ast::ProceduralBlock::AlwaysFFBlock(_) => { + let timing_stmt = proc.statement().as_timing_control_statement()?; + allowed_ranges.push(timing_stmt.timing_control().syntax().text_range()?); + if !range_intersects_any(ctx.range(), &allowed_ranges) { + return None; + } + + if !edge_sensitive_timing_control(timing_stmt.timing_control()) { + return None; + } + + collector.add(ALWAYS_FF_TO_ALWAYS_ID, ALWAYS_FF_TO_ALWAYS_LABEL, target, |builder| { + builder.replace(keyword, "always"); + }) + } + _ => None, + } +} + +fn range_intersects_any( + range: utils::text_edit::TextRange, + allowed_ranges: &[utils::text_edit::TextRange], +) -> bool { + allowed_ranges.iter().any(|allowed| range_intersects(range, *allowed)) +} + +fn range_intersects(lhs: utils::text_edit::TextRange, rhs: utils::text_edit::TextRange) -> bool { + if lhs.is_empty() { + rhs.contains(lhs.start()) + } else { + lhs.start() < rhs.end() && rhs.start() < lhs.end() + } +} + +fn edge_sensitive_timing_control(timing: ast::TimingControl<'_>) -> bool { + timing + .as_event_control_with_expression() + .is_some_and(|control| edge_sensitive_event_expr(control.expr())) +} + +fn edge_sensitive_event_expr(expr: ast::EventExpression<'_>) -> bool { + match expr { + ast::EventExpression::ParenthesizedEventExpression(expr) => { + edge_sensitive_event_expr(expr.expr()) + } + ast::EventExpression::BinaryEventExpression(expr) => { + edge_sensitive_event_expr(expr.left()) && edge_sensitive_event_expr(expr.right()) + } + ast::EventExpression::SignalEventExpression(expr) => expr.edge().is_some(), + } +} diff --git a/crates/ide/src/code_action/handlers/convert_literal_base.rs b/crates/ide/src/code_action/handlers/convert_literal_base.rs index 55cc9420..51b442bf 100644 --- a/crates/ide/src/code_action/handlers/convert_literal_base.rs +++ b/crates/ide/src/code_action/handlers/convert_literal_base.rs @@ -13,6 +13,18 @@ const ACTION_ID: CodeActionId = CodeActionId { repair: None, }; +// Assist: convert_literal_base +// +// This converts an integer literal between binary, octal, decimal, and +// hexadecimal notation. +// +// ``` +// localparam int value = 8'h0f$0; +// ``` +// -> +// ``` +// localparam int value = 8'b1111; +// ``` pub(super) fn convert_literal_base( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/convert_named_port_connections.rs b/crates/ide/src/code_action/handlers/convert_named_port_connections.rs new file mode 100644 index 00000000..18435ae3 --- /dev/null +++ b/crates/ide/src/code_action/handlers/convert_named_port_connections.rs @@ -0,0 +1,131 @@ +use syntax::{ + ast::{self, AstNode}, + has_text_range::{HasTextRange, HasTextRangeIn}, +}; +use utils::text_edit::TextRange; + +use crate::code_action::{CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind}; + +const EXPAND_ID: CodeActionId = CodeActionId { + name: "expand_named_port_connection_shorthand", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; +const EXPAND_LABEL: &str = "Expand named port shorthand"; + +const COLLAPSE_ID: CodeActionId = CodeActionId { + name: "collapse_named_port_connection_shorthand", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; +const COLLAPSE_LABEL: &str = "Collapse named port to shorthand"; + +// Assist: convert_named_port_connection_shorthand +// +// This expands named port connection shorthand, or collapses same-name +// connections to shorthand. +// +// ``` +// child u(.ready$0); +// ``` +// -> +// ``` +// child u(.ready(ready)); +// ``` +pub(super) fn convert_named_port_connection_shorthand( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + expand_named_port_connection_shorthand(collector, ctx) + .or(collapse_named_port_connection_shorthand(collector, ctx)) +} + +fn expand_named_port_connection_shorthand( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + ctx.find_node_at_offset::()?; + let instance = ctx.find_node_at_offset::()?; + let conns = named_port_connections(instance)?; + let edits = conns + .iter() + .filter(|conn| conn.open_paren().is_none()) + .map(|conn| { + let name = conn.name()?; + Some((name.text_range_in(conn.syntax())?.end(), name.value_text().to_string())) + }) + .collect::>>()?; + if edits.is_empty() { + return None; + } + + let target = instance.syntax().text_range()?; + + collector.add(EXPAND_ID, EXPAND_LABEL, target, |builder| { + for (insert_offset, name) in edits { + builder.insert(insert_offset, format!("({name})")); + } + }) +} + +fn collapse_named_port_connection_shorthand( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + ctx.find_node_at_offset::()?; + let instance = ctx.find_node_at_offset::()?; + let conns = named_port_connections(instance)?; + let edits = conns + .iter() + .filter_map(|conn| collapsible_named_port_connection_range(*conn)) + .collect::>(); + if edits.is_empty() { + return None; + } + + let target = instance.syntax().text_range()?; + collector.add(COLLAPSE_ID, COLLAPSE_LABEL, target, |builder| { + for remove_range in edits { + builder.delete(remove_range); + } + }) +} + +fn named_port_connections( + instance: ast::HierarchicalInstance<'_>, +) -> Option>> { + let conns = instance + .connections() + .children() + .map(|conn| conn.as_named_port_connection()) + .collect::>>()?; + (!conns.is_empty()).then_some(conns) +} + +fn collapsible_named_port_connection_range( + conn: ast::NamedPortConnection<'_>, +) -> Option { + let conn_name = conn.name()?; + let port_name = conn_name.value_text().to_string(); + + let expr = conn.expr()?.as_simple_property_expr()?.expr().as_simple_sequence_expr()?.expr(); + + use ast::{Expression, Name}; + let actual = match expr { + Expression::Name(Name::IdentifierName(ident)) => ident.identifier()?, + Expression::Name(Name::IdentifierSelectName(ident)) + if ident.selectors().children().next().is_none() => + { + ident.identifier()? + } + _ => return None, + }; + if actual.value_text().to_string() != port_name { + return None; + } + + Some(TextRange::new( + conn_name.text_range_in(conn.syntax())?.end(), + conn.close_paren()?.text_range_in(conn.syntax())?.end(), + )) +} diff --git a/crates/ide/src/code_action/handlers/convert_ordered_connections.rs b/crates/ide/src/code_action/handlers/convert_ordered_connections.rs index df2fbb29..52eb0ae3 100644 --- a/crates/ide/src/code_action/handlers/convert_ordered_connections.rs +++ b/crates/ide/src/code_action/handlers/convert_ordered_connections.rs @@ -1,18 +1,22 @@ use std::ops::Range; -use hir::{base_db::source_db::SourceDb, db::HirDb}; -use itertools::Itertools; -use syntax::{ - ast::{self, AstNode}, - has_text_range::HasTextRange, +use hir::{ + base_db::source_db::SourceDb, + container::InModule, + db::HirDb, + hir_def::module::instantiation::{ParamAssign, PortConn}, + source_map::IsSrc, }; +use itertools::Itertools; +use syntax::ast; +use utils::get::{Get, GetRef}; use crate::{ code_action::{ CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind, RepairKind, leading_parameter_names, port_names, }, - module_resolution::resolve_instantiation_target, + module_resolution::resolve_hir_instantiation_target, }; const PORTS_ID: CodeActionId = CodeActionId { @@ -29,6 +33,18 @@ const PARAMS_ID: CodeActionId = CodeActionId { }; const PARAMS_LABEL: &str = "Convert ordered parameter assignments to named assignments"; +// Assist: convert_ordered_ports +// +// This converts ordered port connections to named port connections using the +// target module's port order. +// +// ``` +// child u($0a, b); +// ``` +// -> +// ``` +// child u(.a(a), .b(b)); +// ``` pub(super) fn convert_ordered_ports( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, @@ -37,21 +53,25 @@ pub(super) fn convert_ordered_ports( let db = sema.db; let text = db.file_text(ctx.file_id()); let ast_instance = ctx.find_node_at_offset::()?; - let instantiation = ast::HierarchyInstantiation::cast(ast_instance.syntax().parent()?)?; - let target_module_id = - resolve_instantiation_target(db, ctx.file_id(), instantiation).unique()?; - let target_module = db.module(target_module_id); - let port_names = port_names(&target_module); + let InModule { value: instance_id, module_id } = + sema.resolve_instance(ctx.file_id().into(), ast_instance)?; + let (module, module_src_map) = db.module_with_source_map(module_id); + let instantiation = module.get(module.get(instance_id).parent); + let target_module_id = resolve_hir_instantiation_target(db, ctx.file_id(), instantiation)?; + let port_names = port_names(&db.module(target_module_id)); - let replacements = ast_instance - .connections() - .children() + let replacements = module + .get(instance_id) + .connections + .iter() .enumerate() - .filter_map(|(idx, conn)| { - let ordered = conn.as_ordered_port_connection()?; + .filter_map(|(idx, conn_id)| { + let PortConn::Ordered(expr_id) = module.get(*conn_id) else { + return None; + }; let name = port_names.get(idx)?; - let expr = ordered.expr().syntax().text_range()?; - let range = ordered.syntax().text_range()?; + let expr = module_src_map.get(*expr_id)?.range(); + let range = module_src_map.get(*conn_id)?.range(); Some((range, format!(".{name}({})", text.get(Range::from(expr))?))) }) .collect_vec(); @@ -69,6 +89,18 @@ pub(super) fn convert_ordered_ports( Some(()) } +// Assist: convert_ordered_params +// +// This converts ordered parameter assignments to named parameter assignments +// using the target module's parameter order. +// +// ``` +// child #($01, 2) u(); +// ``` +// -> +// ``` +// child #(.A(1), .B(2)) u(); +// ``` pub(super) fn convert_ordered_params( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, @@ -77,21 +109,25 @@ pub(super) fn convert_ordered_params( let db = sema.db; let text = db.file_text(ctx.file_id()); let ast_instantiation = ctx.find_node_at_offset::()?; - let target_module_id = - resolve_instantiation_target(db, ctx.file_id(), ast_instantiation).unique()?; + let InModule { value: instantiation_id, module_id } = + sema.resolve_instantiation(ctx.file_id().into(), ast_instantiation)?; + let (module, module_src_map) = db.module_with_source_map(module_id); + let instantiation = module.get(instantiation_id); + let target_module_id = resolve_hir_instantiation_target(db, ctx.file_id(), instantiation)?; let target_module = db.module(target_module_id); let param_names = leading_parameter_names(&target_module); - let replacements = ast_instantiation - .parameters()? - .parameters() - .children() + let replacements = instantiation + .param_assigns + .iter() .enumerate() - .filter_map(|(idx, assign)| { - let ordered = assign.as_ordered_param_assignment()?; + .filter_map(|(idx, assign_id)| { + let ParamAssign::Ordered(expr_id) = module.get(*assign_id) else { + return None; + }; let name = param_names.get(idx)?; - let expr = ordered.expr().syntax().text_range()?; - let range = ordered.syntax().text_range()?; + let expr = module_src_map.get(*expr_id)?.range(); + let range = module_src_map.get(*assign_id)?.range(); Some((range, format!(".{name}({})", text.get(Range::from(expr))?))) }) .collect_vec(); diff --git a/crates/ide/src/code_action/handlers/convert_port_declarations.rs b/crates/ide/src/code_action/handlers/convert_port_declarations.rs new file mode 100644 index 00000000..a159330d --- /dev/null +++ b/crates/ide/src/code_action/handlers/convert_port_declarations.rs @@ -0,0 +1,409 @@ +use std::ops::Range; + +use hir::{ + base_db::source_db::SourceDb, + container::{InContainer, InModule}, + db::HirDb, + display::HirDisplay, + hir_def::{ + Ident, + declaration::DeclarationSrc, + expr::declarator::{DeclId, DeclaratorParent}, + module::{ + Module, ModuleId, ModuleSourceMap, + port::{PortDecl, PortDeclSrc, Ports}, + }, + }, + scope::{ModuleEntry, ModuleScope, NonAnsiPortEntry}, + source_map::IsSrc, +}; +use itertools::Itertools; +use syntax::{ + ast::{self, AstNode}, + has_text_range::{HasTextRange, HasTextRangeIn}, +}; +use utils::{ + get::{Get, GetRef}, + text_edit::TextRange, +}; + +use crate::code_action::{ + CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind, line_indent, +}; + +const ANSI_TO_NON_ANSI_ID: CodeActionId = CodeActionId { + name: "convert_ansi_ports_to_non_ansi", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; +const ANSI_TO_NON_ANSI_LABEL: &str = "Convert ANSI port declarations to non-ANSI"; + +const NON_ANSI_TO_ANSI_ID: CodeActionId = CodeActionId { + name: "convert_non_ansi_ports_to_ansi", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; +const NON_ANSI_TO_ANSI_LABEL: &str = "Convert non-ANSI port declarations to ANSI"; + +// Assist: convert_port_declarations +// +// This converts module ports between ANSI declarations and non-ANSI +// declarations. +// +// ``` +// module top($0input a, output logic b); endmodule +// ``` +// -> +// ``` +// module top(a, b); input a; output logic b; endmodule +// ``` +pub(super) fn convert_port_declarations( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + convert_ansi_ports_to_non_ansi(collector, ctx) + .or(convert_non_ansi_ports_to_ansi(collector, ctx)) +} + +fn convert_ansi_ports_to_non_ansi( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + let ast_module = ctx.find_node_at_offset::()?; + let port_list = ast_module.header().ports()?.as_ansi_port_list()?; + + let module_id = ctx.sema().module_to_def(ctx.file_id().into(), ast_module)?; + let (module, module_src_map) = ctx.sema().db.module_with_source_map(module_id); + let Ports::Ansi(port_decls) = &module.ports else { + return None; + }; + + let mut port_names = Vec::with_capacity(port_decls.len()); + let mut port_items = Vec::with_capacity(port_decls.len()); + for (port_id, port_decl) in port_decls.iter() { + let src = module_src_map.port_srcs.get(port_id)?; + let PortDeclSrc::ImplicitAnsiPort(_) = src else { + return None; + }; + + let name = port_decl_declared_name(&module, port_decl)?; + port_names.push(name); + port_items.push((port_decl, src)); + } + + if port_names.is_empty() { + return None; + } + + let open_paren = port_list.open_paren()?.text_range_in(port_list.syntax())?; + let close_paren = port_list.close_paren()?.text_range_in(port_list.syntax())?; + if !port_list_trigger_range(open_paren, close_paren)?.contains_range(ctx.range()) { + return None; + } + + let body_range = module_body_range(ast_module)?; + let text = ctx.sema().db.file_text(ctx.file_id()); + let generated_members = port_items + .iter() + .map(|(port_decl, src)| { + render_ansi_port_declaration(ctx, module_id, port_decl, *src, &text) + }) + .collect::>>()?; + let port_list_replacement = render_port_list(&text, open_paren, close_paren, &port_names)?; + let body_replacement = + render_module_body(&text, ast_module, body_range, &generated_members, &[])?; + let target = port_list.syntax().text_range()?; + + collector.add(ANSI_TO_NON_ANSI_ID, ANSI_TO_NON_ANSI_LABEL, target, |builder| { + builder.replace(target, port_list_replacement); + builder.replace(body_range, body_replacement); + }) +} + +fn convert_non_ansi_ports_to_ansi( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + let ast_module = ctx.find_node_at_offset::()?; + let port_list = ast_module.header().ports()?.as_non_ansi_port_list()?; + + let module_id = ctx.sema().module_to_def(ctx.file_id().into(), ast_module)?; + let (module, module_src_map) = ctx.sema().db.module_with_source_map(module_id); + let Ports::NonAnsi { ports, refs, .. } = &module.ports else { + return None; + }; + + let mut port_names = Vec::new(); + for (_, port) in ports.iter() { + let mut ref_ids = port.refs.clone()?; + let ref_id = ref_ids.next()?; + if ref_ids.next().is_some() { + return None; + } + + let port_ref = &refs[ref_id]; + if port_ref.select.is_some() { + return None; + } + + let ident = port_ref.ident.as_ref()?; + if port.label.as_ref() != Some(ident) { + return None; + } + port_names.push(ident.clone()); + } + if port_names.is_empty() { + return None; + } + + let open_paren = port_list.open_paren()?.text_range_in(port_list.syntax())?; + let close_paren = port_list.close_paren()?.text_range_in(port_list.syntax())?; + if !port_list_trigger_range(open_paren, close_paren)?.contains_range(ctx.range()) { + return None; + } + + let body_range = module_body_range(ast_module)?; + let text = ctx.sema().db.file_text(ctx.file_id()); + let module_scope = ctx.sema().db.module_scope(module_id); + let port_replacements = port_names + .iter() + .map(|name| { + non_ansi_port_replacement(ctx, &module, &module_src_map, &module_scope, name, &text) + }) + .collect::>>()?; + let ansi_items = port_replacements + .iter() + .map(|replacement| replacement.ansi_item.clone()) + .collect::>(); + let removed_ranges = port_replacements + .into_iter() + .flat_map(|replacement| replacement.remove_ranges) + .collect::>(); + let port_list_replacement = render_port_list(&text, open_paren, close_paren, &ansi_items)?; + let body_replacement = render_module_body(&text, ast_module, body_range, &[], &removed_ranges)?; + let target = port_list.syntax().text_range()?; + + collector.add(NON_ANSI_TO_ANSI_ID, NON_ANSI_TO_ANSI_LABEL, target, |builder| { + builder.replace(target, port_list_replacement); + builder.replace(body_range, body_replacement); + }) +} + +fn port_list_trigger_range(open: TextRange, close: TextRange) -> Option { + (open.end() <= close.start()).then(|| TextRange::new(open.end(), close.start())) +} + +fn port_decl_declared_name(module: &Module, port_decl: &PortDecl) -> Option { + let decl_id = single_port_decl_id(port_decl)?; + Some(module.get(decl_id).name.as_ref()?.to_string()) +} + +fn single_port_decl_id(port_decl: &PortDecl) -> Option { + let mut decls = port_decl.decls.clone(); + let decl_id = decls.next()?; + if decls.next().is_some() { + return None; + } + Some(decl_id) +} + +struct NonAnsiPortReplacement { + ansi_item: String, + remove_ranges: Vec, +} + +fn non_ansi_port_replacement( + ctx: &CodeActionCtx, + module: &Module, + module_src_map: &ModuleSourceMap, + module_scope: &ModuleScope, + name: &Ident, + text: &str, +) -> Option { + let ModuleEntry::NonAnsiPortEntry(NonAnsiPortEntry { + port_decl: Some(port_decl), + data_decl, + .. + }) = module_scope.get(name)? + else { + return None; + }; + let DeclaratorParent::PortDeclId(port_decl_id) = module.get(port_decl).parent else { + return None; + }; + let port_decl = module.get(port_decl_id); + if port_decl_declared_name(module, port_decl).as_deref() != Some(name.as_str()) { + return None; + } + + let port_src = module_src_map.port_srcs.get(port_decl_id)?; + let PortDeclSrc::PortDeclaration(_) = port_src else { + return None; + }; + let port_range = port_src.range(); + + if let Some(data_decl) = data_decl { + let data_range = data_decl_range_for_name(module, module_src_map, data_decl, name)?; + let direction = port_decl.header.dir().display_source(ctx.sema().db).ok()?; + let data_decl = declaration_text_without_semicolon(text, data_range)?; + return Some(NonAnsiPortReplacement { + ansi_item: format!("{direction} {data_decl}"), + remove_ranges: vec![port_range, data_range], + }); + } + + Some(NonAnsiPortReplacement { + ansi_item: declaration_text_without_semicolon(text, port_range)?, + remove_ranges: vec![port_range], + }) +} + +fn data_decl_range_for_name( + module: &Module, + module_src_map: &ModuleSourceMap, + decl_id: DeclId, + name: &Ident, +) -> Option { + let decl = module.get(decl_id); + if decl.name.as_ref() != Some(name) { + return None; + } + + let DeclaratorParent::DeclarationId(declaration_id) = decl.parent else { + return None; + }; + let declaration = module.get(declaration_id); + let mut decls = declaration.decls(); + let single_decl_id = decls.next()?; + if single_decl_id != decl_id || decls.next().is_some() { + return None; + } + + let src = module_src_map.declaration_srcs.get(declaration_id)?; + match src { + DeclarationSrc::DataDeclaration(_) | DeclarationSrc::NetDeclaration(_) => Some(src.range()), + _ => None, + } +} + +fn render_ansi_port_declaration( + ctx: &CodeActionCtx, + module_id: ModuleId, + port_decl: &PortDecl, + src: PortDeclSrc, + text: &str, +) -> Option { + let source = text.get(Range::from(src.range()))?; + if source + .split_ascii_whitespace() + .next() + .is_some_and(|word| matches!(word, "input" | "output" | "inout" | "ref")) + { + return Some(format!("{source};")); + } + + let decl_id = single_port_decl_id(port_decl)?; + let header = InModule::new(module_id, port_decl.header).display_source(ctx.sema().db).ok()?; + let decl = InContainer::new(module_id.into(), decl_id).display_signature(ctx.sema().db).ok()?; + + if header.is_empty() { Some(format!("{decl};")) } else { Some(format!("{header} {decl};")) } +} + +fn declaration_text_without_semicolon(text: &str, range: TextRange) -> Option { + Some(text.get(Range::from(range))?.strip_suffix(';')?.to_owned()) +} + +fn module_body_range(module: ast::ModuleDeclaration<'_>) -> Option { + let header = module.header(); + Some(TextRange::new( + header.semi()?.text_range_in(header.syntax())?.end(), + module.endmodule()?.text_range_in(module.syntax())?.start(), + )) +} + +fn render_port_list( + text: &str, + open: TextRange, + close: TextRange, + items: &[String], +) -> Option { + let content = text.get(usize::from(open.end())..usize::from(close.start()))?; + if content.contains('\n') { + let close_indent = line_indent(text, close.start()); + let item_indent = format!("{close_indent} "); + let rendered = items + .iter() + .enumerate() + .map(|(idx, item)| { + let suffix = if idx + 1 == items.len() { "" } else { "," }; + format!("{item_indent}{item}{suffix}") + }) + .collect::>() + .join("\n"); + Some(format!("(\n{rendered}\n{close_indent})")) + } else { + Some(format!("({})", items.join(", "))) + } +} + +fn render_module_body( + text: &str, + module: ast::ModuleDeclaration<'_>, + body_range: TextRange, + prefix_items: &[String], + remove_ranges: &[TextRange], +) -> Option { + let mut items = prefix_items.to_vec(); + let mut body = text.get(Range::from(body_range))?.to_owned(); + remove_ranges_from_body(&mut body, body_range, remove_ranges)?; + let body = body.trim(); + if !body.is_empty() { + items.push(body.to_owned()); + } + + let endmodule = module.endmodule()?.text_range_in(module.syntax())?; + let module_indent = line_indent(text, endmodule.start()); + if items.is_empty() { + return Some(format!("\n{module_indent}")); + } + + let item_indent = format!("{module_indent} "); + let rendered = items + .into_iter() + .map(|item| indent_block(&item, &item_indent)) + .collect::>() + .join("\n"); + Some(format!("\n{rendered}\n{module_indent}")) +} + +fn remove_ranges_from_body( + body: &mut String, + body_range: TextRange, + remove_ranges: &[TextRange], +) -> Option<()> { + let body_start = usize::from(body_range.start()); + let body_end = usize::from(body_range.end()); + let mut ranges = remove_ranges + .iter() + .filter(|range| body_range.contains_range(**range)) + .map(|range| { + Some(( + usize::from(range.start()).checked_sub(body_start)?, + usize::from(range.end()).checked_sub(body_start)?, + )) + }) + .collect::>>()?; + + ranges.sort_by_key(|(start, _)| *start); + for (start, end) in ranges.into_iter().rev() { + if start > end || body_start + end > body_end { + return None; + } + body.replace_range(start..end, ""); + } + Some(()) +} + +fn indent_block(text: &str, indent: &str) -> String { + text.lines().map(|line| format!("{indent}{line}")).join("\n") +} diff --git a/crates/ide/src/code_action/handlers/expand_compound_assignment.rs b/crates/ide/src/code_action/handlers/expand_compound_assignment.rs index e2275ee8..61e2feae 100644 --- a/crates/ide/src/code_action/handlers/expand_compound_assignment.rs +++ b/crates/ide/src/code_action/handlers/expand_compound_assignment.rs @@ -21,6 +21,18 @@ const COLLAPSE_ID: CodeActionId = CodeActionId { }; const COLLAPSE_LABEL: &str = "Collapse compound assignment"; +// Assist: expand_compound_assignment +// +// This expands compound assignments, or collapses simple self-assignments into +// compound assignments. +// +// ``` +// always_comb a $0+= b; +// ``` +// -> +// ``` +// always_comb a = a + b; +// ``` pub(super) fn expand_compound_assignment( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/expand_postfix_inc_dec.rs b/crates/ide/src/code_action/handlers/expand_postfix_inc_dec.rs index 46619c77..307ee17a 100644 --- a/crates/ide/src/code_action/handlers/expand_postfix_inc_dec.rs +++ b/crates/ide/src/code_action/handlers/expand_postfix_inc_dec.rs @@ -78,6 +78,18 @@ const ASSIGNMENT_TO_PREFIX_ID: CodeActionId = CodeActionId { }; const ASSIGNMENT_TO_PREFIX_LABEL: &str = "Convert assignment to prefix expression"; +// Assist: expand_postfix_inc_dec +// +// This converts between postfix, prefix, compound assignment, and expanded +// assignment forms of increment/decrement expressions. +// +// ``` +// always_ff @(posedge clk) count$0++; +// ``` +// -> +// ``` +// always_ff @(posedge clk) count = count + 1; +// ``` pub(super) fn expand_postfix_inc_dec( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/extract_variable.rs b/crates/ide/src/code_action/handlers/extract_variable.rs new file mode 100644 index 00000000..1b260627 --- /dev/null +++ b/crates/ide/src/code_action/handlers/extract_variable.rs @@ -0,0 +1,232 @@ +use std::ops::Range; + +use hir::{ + base_db::source_db::SourceDb, + container::InContainer, + display::HirDisplay, + type_infer::{BuiltinTy, Ty, type_of_expr, type_of_path_resolution}, +}; +use syntax::{ + SyntaxAncestors, SyntaxKind, TokenKind, WalkEvent, + ast::{self, AstNode}, + has_text_range::HasTextRange, +}; +use utils::{ + get::GetRef, + text_edit::{TextRange, TextSize}, +}; + +use crate::code_action::{ + CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind, line_indent, +}; + +const ID: CodeActionId = + CodeActionId { name: "extract_variable", kind: CodeActionKind::RefactorExtract, repair: None }; + +// Assist: extract_variable +// +// This extracts a selected expression into a new local variable or continuous +// net declaration. +// +// ``` +// always_comb begin y = $0a + b$0; end +// ``` +// -> +// ``` +// always_comb begin logic value = a + b; +// y = value; end +// ``` +pub(super) fn extract_variable( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + let text = ctx.sema().db.file_text(ctx.file_id()); + let expr = selected_expression(ctx, &text)?; + let expr_range = expr.syntax().text_range()?; + let target = extract_target(&text, expr)?; + let expr_text = text.get(Range::from(expr_range))?.trim().to_owned(); + let name = fresh_variable_name(&text, "value"); + + collector.add(ID, "Extract into variable", expr_range, |builder| { + let ty_text = extracted_variable_type(ctx, expr).unwrap_or_else(|| "logic".to_owned()); + let declaration = target.declaration(&ty_text, &name, &expr_text); + builder.insert(target.insert_offset, declaration); + builder.replace(expr_range, name); + }) +} + +struct ExtractTarget { + insert_offset: TextSize, + indent: String, + declaration_style: DeclarationStyle, +} + +impl ExtractTarget { + fn declaration(&self, ty_text: &str, name: &str, expr_text: &str) -> String { + match self.declaration_style { + DeclarationStyle::Local => { + format!("{}{ty_text} {name} = {expr_text};\n", self.indent) + } + DeclarationStyle::ContinuousNet => { + format!("{}wire {ty_text} {name} = {expr_text};\n", self.indent) + } + } + } +} + +enum DeclarationStyle { + Local, + ContinuousNet, +} + +fn extract_target(text: &str, expr: ast::Expression<'_>) -> Option { + if let Some(stmt) = + SyntaxAncestors::start_from(expr.syntax()).find_map(ast::ExpressionStatement::cast) + && stmt.syntax().parent().and_then(ast::BlockStatement::cast).is_some() + { + let stmt_range = stmt.syntax().text_range()?; + return Some(ExtractTarget { + insert_offset: stmt_range.start(), + indent: line_indent(text, stmt_range.start()), + declaration_style: DeclarationStyle::Local, + }); + } + + let assign = + SyntaxAncestors::start_from(expr.syntax()).find_map(ast::ContinuousAssign::cast)?; + expression_is_assignment_rhs(expr)?; + let assign_range = assign.syntax().text_range()?; + Some(ExtractTarget { + insert_offset: assign_range.start(), + indent: line_indent(text, assign_range.start()), + declaration_style: DeclarationStyle::ContinuousNet, + }) +} + +fn expression_is_assignment_rhs(expr: ast::Expression<'_>) -> Option<()> { + assignment_expression_containing_rhs(expr) + .filter(|binary| { + binary.operator_token().is_some_and(|token| token.kind() == TokenKind::EQUALS) + }) + .map(|_| ()) +} + +fn assignment_expression_containing_rhs( + expr: ast::Expression<'_>, +) -> Option> { + let expr_range = expr.syntax().text_range()?; + SyntaxAncestors::start_from(expr.syntax()).filter_map(ast::BinaryExpression::cast).find( + |binary| { + is_assignment_expression(binary.syntax().kind()) + && binary + .right() + .syntax() + .text_range() + .is_some_and(|range| range.contains_range(expr_range)) + }, + ) +} + +fn is_assignment_expression(kind: SyntaxKind) -> bool { + matches!( + kind, + SyntaxKind::ASSIGNMENT_EXPRESSION + | SyntaxKind::NONBLOCKING_ASSIGNMENT_EXPRESSION + | SyntaxKind::ADD_ASSIGNMENT_EXPRESSION + | SyntaxKind::SUBTRACT_ASSIGNMENT_EXPRESSION + | SyntaxKind::MULTIPLY_ASSIGNMENT_EXPRESSION + | SyntaxKind::DIVIDE_ASSIGNMENT_EXPRESSION + | SyntaxKind::MOD_ASSIGNMENT_EXPRESSION + | SyntaxKind::AND_ASSIGNMENT_EXPRESSION + | SyntaxKind::OR_ASSIGNMENT_EXPRESSION + | SyntaxKind::XOR_ASSIGNMENT_EXPRESSION + | SyntaxKind::LOGICAL_LEFT_SHIFT_ASSIGNMENT_EXPRESSION + | SyntaxKind::LOGICAL_RIGHT_SHIFT_ASSIGNMENT_EXPRESSION + | SyntaxKind::ARITHMETIC_LEFT_SHIFT_ASSIGNMENT_EXPRESSION + | SyntaxKind::ARITHMETIC_RIGHT_SHIFT_ASSIGNMENT_EXPRESSION + ) +} + +fn selected_expression<'a>(ctx: &'a CodeActionCtx<'_>, text: &str) -> Option> { + let range = trim_range(text, ctx.range())?; + if range.is_empty() { + return None; + } + + ctx.syntax().node_preorder().find_map(|event| match event { + WalkEvent::Enter(node) => { + let expr = ast::Expression::cast(node)?; + (expr.syntax().text_range()? == range).then_some(expr) + } + WalkEvent::Leave(_) => None, + }) +} + +fn trim_range(text: &str, range: TextRange) -> Option { + let selected = text.get(Range::::from(range))?; + let trimmed_start = selected.trim_start(); + let trimmed = trimmed_start.trim_end(); + + let leading = selected.len() - trimmed_start.len(); + let trailing = trimmed_start.len() - trimmed.len(); + Some(TextRange::new( + range.start() + TextSize::from(leading as u32), + range.end() - TextSize::from(trailing as u32), + )) +} + +fn extracted_variable_type(ctx: &CodeActionCtx<'_>, expr: ast::Expression<'_>) -> Option { + let ty = type_of_expr(ctx.sema().db, ctx.sema().resolve_expr(ctx.file_id().into(), expr)?).ty; + render_ty(ctx, &ty) + .or_else(|| expected_type_for_assignment_rhs(ctx, expr).and_then(|ty| render_ty(ctx, &ty))) +} + +fn expected_type_for_assignment_rhs( + ctx: &CodeActionCtx<'_>, + expr: ast::Expression<'_>, +) -> Option { + let assignment = assignment_expression_containing_rhs(expr)?; + let res = ctx + .sema() + .expr_to_def(ctx.sema().resolve_expr(ctx.file_id().into(), assignment.left())?)?; + Some(type_of_path_resolution(ctx.sema().db, res).ty) +} + +fn render_ty(ctx: &CodeActionCtx<'_>, ty: &Ty) -> Option { + match ty { + Ty::Builtin(BuiltinTy::Data { id, container }) => { + InContainer::new(*container, hir::hir_def::expr::data_ty::DataTy::Builtin(*id)) + .display_source(ctx.sema().db) + .ok() + } + Ty::Alias { typedef, .. } => { + let container = typedef.cont_id.to_container(ctx.sema().db); + container.get(typedef.value).name.as_ref().map(ToString::to_string) + } + Ty::Struct(struct_ref) => { + let container = struct_ref.cont_id.to_container(ctx.sema().db); + container.get(struct_ref.value).name.as_ref().map(ToString::to_string) + } + Ty::Unknown + | Ty::Error + | Ty::Void + | Ty::Module(_) + | Ty::GenerateBlock(_) + | Ty::Block(_) => None, + } +} + +fn fresh_variable_name(text: &str, base: &str) -> String { + if !text.contains(base) { + return base.to_owned(); + } + + let mut idx = 1usize; + loop { + let candidate = format!("{base}_{idx}"); + if !text.contains(&candidate) { + return candidate; + } + idx += 1; + } +} diff --git a/crates/ide/src/code_action/handlers/insert_expected_token.rs b/crates/ide/src/code_action/handlers/insert_expected_token.rs index 3d213b55..64be8952 100644 --- a/crates/ide/src/code_action/handlers/insert_expected_token.rs +++ b/crates/ide/src/code_action/handlers/insert_expected_token.rs @@ -11,6 +11,17 @@ const ID: CodeActionId = CodeActionId { repair: Some(RepairKind::InsertExpectedToken), }; +// Assist: insert_expected_token +// +// This inserts a token that the parser expected at the diagnostic location. +// +// ``` +// module top$0 endmodule +// ``` +// -> +// ``` +// module top; endmodule +// ``` pub(super) fn insert_expected_token( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/invert_if_else.rs b/crates/ide/src/code_action/handlers/invert_if_else.rs index 0dfb2873..de3876ea 100644 --- a/crates/ide/src/code_action/handlers/invert_if_else.rs +++ b/crates/ide/src/code_action/handlers/invert_if_else.rs @@ -12,6 +12,18 @@ const ID: CodeActionId = CodeActionId { name: "invert_if_else", kind: CodeActionKind::RefactorRewrite, repair: None }; const LABEL: &str = "Invert if/else"; +// Assist: invert_if_else +// +// This swaps the then and else branches of an if statement and negates the +// condition. +// +// ``` +// always_comb if$0 (ready) y = a; else y = b; +// ``` +// -> +// ``` +// always_comb if (!(ready)) y = b; else y = a; +// ``` pub(super) fn invert_if_else( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/merge_nested_if.rs b/crates/ide/src/code_action/handlers/merge_nested_if.rs new file mode 100644 index 00000000..061b8787 --- /dev/null +++ b/crates/ide/src/code_action/handlers/merge_nested_if.rs @@ -0,0 +1,138 @@ +use std::{borrow::Cow, ops::Range}; + +use hir::base_db::source_db::SourceDb; +use syntax::{ + ast::{self, AstNode}, + has_text_range::HasTextRange, +}; +use utils::text_edit::TextRange; + +use crate::code_action::{CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind}; + +const ID: CodeActionId = + CodeActionId { name: "merge_nested_if", kind: CodeActionKind::RefactorRewrite, repair: None }; + +// Assist: merge_nested_if +// +// This merges nested if statements without else branches into one if statement +// with a combined condition. +// +// ``` +// always_comb if$0 (a) begin if (b) y = 1; end +// ``` +// -> +// ``` +// always_comb if (a && b) y = 1; +// ``` +pub(super) fn merge_nested_if( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + let current_if = ctx.find_node_at_offset::()?; + if !in_if_head(current_if, ctx.range()) || current_if.else_clause().is_some() { + return None; + } + + let outer_if = outermost_mergeable_if(current_if); + let chain = nested_if_chain(outer_if); + if chain.len() < 2 { + return None; + } + + let innermost_if = *chain.last()?; + let innermost_body_stmt = single_statement_body(innermost_if.statement())?; + + let text = ctx.sema().db.file_text(ctx.file_id()); + let predicates = chain + .iter() + .map(|if_stmt| { + let range = if_stmt.predicate().syntax().text_range()?; + let predicate = text.get(Range::from(range))?.trim(); + if predicate.contains("||") || predicate.contains('?') { + Some(Cow::Owned(format!("({predicate})"))) + } else { + Some(Cow::Borrowed(predicate)) + } + }) + .collect::>>()?; + + let outer_pred_range = outer_if.predicate().syntax().text_range()?; + let outer_body_range = outer_if.statement().syntax().text_range()?; + + let innermost_body_range = innermost_body_stmt.syntax().text_range()?; + let innermost_body = text.get(Range::from(innermost_body_range))?.trim().to_owned(); + + collector.add(ID, "Merge nested if", outer_if.syntax().text_range()?, |builder| { + let merged_predicate = predicates.join(" && "); + builder.replace(outer_pred_range, merged_predicate); + builder.replace(outer_body_range, innermost_body); + }) +} + +fn in_if_head(if_stmt: ast::ConditionalStatement<'_>, range: TextRange) -> bool { + let Some(if_range) = if_stmt.syntax().text_range() else { return false }; + let Some(pred_range) = if_stmt.predicate().syntax().text_range() else { return false }; + TextRange::new(if_range.start(), pred_range.end()).contains_range(range) +} + +fn outermost_mergeable_if<'a>( + mut if_stmt: ast::ConditionalStatement<'a>, +) -> ast::ConditionalStatement<'a> { + while let Some(parent_if) = parent_conditional_statement(if_stmt) { + if parent_if.else_clause().is_some() { + break; + } + let Some(body) = single_statement_body(parent_if.statement()) else { break }; + let Some(body_stmt) = body.as_conditional_statement() else { break }; + if body_stmt.syntax() != if_stmt.syntax() { + break; + } + if_stmt = parent_if; + } + + if_stmt +} + +fn parent_conditional_statement<'a>( + if_stmt: ast::ConditionalStatement<'a>, +) -> Option> { + let mut parent = if_stmt.syntax().parent(); + while let Some(node) = parent { + if let Some(parent_if) = ast::ConditionalStatement::cast(node) { + return Some(parent_if); + } + parent = node.parent(); + } + None +} + +fn nested_if_chain<'a>( + outer_if: ast::ConditionalStatement<'a>, +) -> Vec> { + let mut chain = vec![outer_if]; + let mut current_if = outer_if; + while let Some(body) = single_statement_body(current_if.statement()) { + let Some(nested_if) = body.as_conditional_statement() else { + break; + }; + if nested_if.else_clause().is_some() { + break; + } + chain.push(nested_if); + current_if = nested_if; + } + chain +} + +fn single_statement_body(stmt: ast::Statement<'_>) -> Option> { + let Some(block) = stmt.as_block_statement() else { + return Some(stmt); + }; + + let mut items = block.items().children(); + let item = items.next()?; + if items.next().is_some() { + return None; + } + ast::Statement::cast(item.syntax()) +} diff --git a/crates/ide/src/code_action/handlers/pull_assignment_up.rs b/crates/ide/src/code_action/handlers/pull_assignment_up.rs new file mode 100644 index 00000000..feb203a5 --- /dev/null +++ b/crates/ide/src/code_action/handlers/pull_assignment_up.rs @@ -0,0 +1,152 @@ +use std::{borrow::Cow, ops::Range}; + +use hir::base_db::source_db::SourceDb; +use syntax::{ + TokenKind, + ast::{self, AstNode}, + has_text_range::HasTextRange, +}; + +use crate::code_action::{CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind}; + +const ID: CodeActionId = CodeActionId { + name: "pull_assignment_up", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; +const DOWN_ID: CodeActionId = CodeActionId { + name: "pull_assignment_down", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; + +// Assist: pull_assignment_up +// +// This pulls matching assignments out of an if/else chain into a single ternary +// assignment. +// +// ``` +// always_comb if$0 (a) y = 1; else y = 0; +// ``` +// -> +// ``` +// always_comb y = a ? 1 : 0; +// ``` +pub(super) fn pull_assignment_up( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + let mut conditional = ctx.find_node_at_offset::()?; + while let Some(parent_if) = conditional + .syntax() + .parent() + .and_then(|node| ast::ElseClause::cast(node)?.syntax().parent()) + .and_then(ast::ConditionalStatement::cast) + { + conditional = parent_if; + } + + let text = ctx.sema().db.file_text(ctx.file_id()); + let (lhs, expr) = conditional_assignment_expression(conditional, &text)?; + + collector.add(ID, "Pull assignment up", conditional.syntax().text_range()?, |builder| { + let replacement = format!("{} = {};", lhs.trim(), expr); + builder.replace(conditional.syntax().text_range().unwrap(), replacement); + }) +} + +// Assist: pull_assignment_down +// +// This expands a ternary assignment into an if/else assignment chain. +// +// ``` +// always_comb $0y = a ? 1 : 0; +// ``` +// -> +// ``` +// always_comb if (a) y = 1; else y = 0; +// ``` +pub(super) fn pull_assignment_down( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + let assignment = ctx.find_node_at_offset::()?; + if assignment.operator_token()?.kind() != TokenKind::EQUALS { + return None; + } + + let conditional = assignment.right().as_conditional_expression()?; + let stmt = syntax::SyntaxAncestors::start_from(assignment.syntax()) + .find_map(ast::ExpressionStatement::cast)?; + let text = ctx.sema().db.file_text(ctx.file_id()); + let lhs = text.get(Range::from(assignment.left().syntax().text_range()?))?.trim(); + let replacement = conditional_assignment_statement(conditional, lhs, &text)?; + + collector.add(DOWN_ID, "Pull assignment down", stmt.syntax().text_range()?, |builder| { + builder.replace(stmt.syntax().text_range().unwrap(), replacement); + }) +} + +fn conditional_assignment_expression<'a>( + conditional: ast::ConditionalStatement<'_>, + text: &'a str, +) -> Option<(&'a str, String)> { + let (lhs, then_rhs) = assignment_rhs_text(conditional.statement(), text)?; + + let else_syntax = conditional.else_clause()?.clause().syntax(); + let (else_lhs, else_expr) = if let Some(nested) = ast::ConditionalStatement::cast(else_syntax) { + conditional_assignment_expression(nested, text)? + } else { + let else_stmt = ast::Statement::cast(else_syntax)?; + let (lhs, expr) = assignment_rhs_text(else_stmt, text)?; + (lhs, expr.to_owned()) + }; + + if else_lhs != lhs { + return None; + } + + let predicate: Cow<'a, str> = { + let predicate = + text.get(Range::from(conditional.predicate().syntax().text_range()?))?.trim(); + + if predicate.contains('?') { format!("({predicate})").into() } else { predicate.into() } + }; + Some((lhs, format!("{predicate} ? {then_rhs} : {else_expr}"))) +} + +fn assignment_rhs_text<'a>(stmt: ast::Statement<'_>, text: &'a str) -> Option<(&'a str, &'a str)> { + if let Some(block) = stmt.as_block_statement() { + let item = block.items().only_children()?; + let stmt = ast::Statement::cast(item.syntax())?; + return assignment_rhs_text(stmt, text); + } + + let assignment = stmt.as_expression_statement()?.expr().as_binary_expression()?; + if assignment.operator_token()?.kind() != TokenKind::EQUALS { + return None; + } + + let lhs = text.get(Range::from(assignment.left().syntax().text_range()?))?.trim(); + let rhs = text.get(Range::from(assignment.right().syntax().text_range()?))?.trim(); + Some((lhs, rhs)) +} + +fn conditional_assignment_statement( + conditional: ast::ConditionalExpression<'_>, + lhs: &str, + text: &str, +) -> Option { + let predicate = text.get(Range::from(conditional.predicate().syntax().text_range()?))?.trim(); + let then_expr = expr_text(conditional.left(), text)?; + let else_expr = if let Some(nested) = conditional.right().as_conditional_expression() { + conditional_assignment_statement(nested, lhs, text)? + } else { + format!("{lhs} = {};", expr_text(conditional.right(), text)?) + }; + Some(format!("if ({predicate}) {lhs} = {then_expr}; else {else_expr}")) +} + +fn expr_text<'a>(expr: ast::Expression<'_>, text: &'a str) -> Option<&'a str> { + text.get(Range::from(expr.syntax().text_range()?)).map(str::trim) +} diff --git a/crates/ide/src/code_action/handlers/reformat_number_literal.rs b/crates/ide/src/code_action/handlers/reformat_number_literal.rs new file mode 100644 index 00000000..167477db --- /dev/null +++ b/crates/ide/src/code_action/handlers/reformat_number_literal.rs @@ -0,0 +1,108 @@ +use std::ops::Range; + +use hir::base_db::source_db::SourceDb; +use syntax::{ + ast::{self, AstNode}, + has_text_range::HasTextRange, +}; +use utils::text_edit::TextRange; + +use crate::code_action::{CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind}; + +const ID: CodeActionId = CodeActionId { + name: "reformat_number_literal", + kind: CodeActionKind::RefactorInline, + repair: None, +}; +const MIN_NUMBER_OF_DIGITS_TO_FORMAT: usize = 5; + +// Assist: reformat_number_literal +// +// This adds digit separators to long integer literals or removes existing digit +// separators. +// +// ``` +// localparam int value = 10000$0; +// ``` +// -> +// ``` +// localparam int value = 10_000; +// ``` +pub(super) fn reformat_number_literal( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + let text = ctx.sema().db.file_text(ctx.file_id()); + let (raw, prefix, digits, group_size, range) = selected_integer_literal(ctx, &text)?; + + if digits.contains('_') { + let replacement = raw.replace('_', ""); + return collector.add(ID, "Remove digit separators", range, |builder| { + builder.replace(range, replacement); + }); + } + + if digits.chars().count() < MIN_NUMBER_OF_DIGITS_TO_FORMAT { + return None; + } + + let replacement = format!("{}{}", prefix, add_group_separators(digits, group_size)); + let label = format!("Convert {raw} to {replacement}"); + collector.add(ID, label, range, |builder| { + builder.replace(range, replacement); + }) +} + +fn selected_integer_literal<'a>( + ctx: &CodeActionCtx<'_>, + text: &'a str, +) -> Option<(&'a str, &'a str, &'a str, usize, TextRange)> { + if let Some(expr) = ctx.find_node_at_offset::() { + let range = expr.syntax().text_range()?; + let raw = text.get(Range::from(range))?; + return parse_based_literal(raw, range); + } + + let literal = ctx.find_node_at_offset::()?; + let ast::LiteralExpression::IntegerLiteralExpression(integer) = literal else { + return None; + }; + let range = integer.text_range()?; + let raw = text.get(Range::from(range))?; + Some((raw, "", raw, 3, range)) +} + +fn parse_based_literal( + raw: &str, + range: TextRange, +) -> Option<(&str, &str, &str, usize, TextRange)> { + let apostrophe = raw.find('\'')?; + let after_quote = raw.get(apostrophe + 1..)?; + let (sign_len, rest) = match after_quote.as_bytes().first().copied() { + Some(b's' | b'S') => (1usize, after_quote.get(1..)?), + _ => (0usize, after_quote), + }; + let base = rest.as_bytes().first().copied()?; + let group_size = match base.to_ascii_lowercase() { + b'b' => 4, + b'o' => 3, + b'd' => 3, + b'h' => 4, + _ => return None, + }; + let digits_start = apostrophe + 1 + sign_len + 1; + let digits = raw.get(digits_start..)?; + Some((raw, raw.get(..digits_start)?, digits, group_size, range)) +} + +fn add_group_separators(digits: &str, group_size: usize) -> String { + let clean: Vec = digits.chars().filter(|ch| *ch != '_').collect(); + let mut buf = String::with_capacity(clean.len() + clean.len() / group_size); + for (idx, ch) in clean.iter().rev().enumerate() { + if idx != 0 && idx % group_size == 0 { + buf.push('_'); + } + buf.push(*ch); + } + buf.chars().rev().collect() +} diff --git a/crates/ide/src/code_action/handlers/remove_empty_port_connections.rs b/crates/ide/src/code_action/handlers/remove_empty_port_connections.rs index 29a31812..82a2c5f0 100644 --- a/crates/ide/src/code_action/handlers/remove_empty_port_connections.rs +++ b/crates/ide/src/code_action/handlers/remove_empty_port_connections.rs @@ -17,6 +17,17 @@ const ID: CodeActionId = CodeActionId { }; const LABEL: &str = "Remove empty port connections"; +// Assist: remove_empty_port_connections +// +// This removes empty ordered port connections from an instance port list. +// +// ``` +// child u(a, $0, b); +// ``` +// -> +// ``` +// child u(a, b); +// ``` pub(super) fn remove_empty_port_connections( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/remove_parentheses.rs b/crates/ide/src/code_action/handlers/remove_parentheses.rs new file mode 100644 index 00000000..e6e431da --- /dev/null +++ b/crates/ide/src/code_action/handlers/remove_parentheses.rs @@ -0,0 +1,160 @@ +use std::{cmp::Ordering, ops::Range}; + +use hir::base_db::source_db::SourceDb; +use syntax::{ + SyntaxKind, TokenKind, + ast::{self, AstNode}, + has_text_range::{HasTextRange, HasTextRangeIn}, +}; + +use crate::code_action::{CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind}; + +const ID: CodeActionId = CodeActionId { + name: "remove_parentheses", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; + +// Assist: remove_parentheses +// +// This removes parentheses when they are redundant for the surrounding +// expression. +// +// ``` +// assign y = $0(a + b) + c; +// ``` +// -> +// ``` +// assign y = a + b + c; +// ``` +pub(super) fn remove_parentheses( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + let parens = ctx.find_node_at_offset::()?; + let range = parens.syntax().text_range()?; + let left = parens.open_paren()?.text_range_in(parens.syntax())?; + let right = parens.close_paren()?.text_range_in(parens.syntax())?; + if !left.contains_range(ctx.range()) && !right.contains_range(ctx.range()) { + return None; + } + + let expr = parens.expression(); + let parent = parens.syntax().parent()?; + if parentheses_are_required(parens, expr, parent) { + return None; + } + + let expr_range = expr.syntax().text_range()?; + let text = ctx.sema().db.file_text(ctx.file_id()); + let inner = text.get(Range::from(expr_range))?.to_owned(); + collector.add(ID, "Remove redundant parentheses", range, |builder| { + builder.replace(range, inner); + }) +} + +fn parentheses_are_required( + parens: ast::ParenthesizedExpression<'_>, + expr: ast::Expression<'_>, + parent: syntax::SyntaxNode<'_>, +) -> bool { + if ast::ParenthesizedExpression::cast(parent).is_some() { + return false; + } + + if matches!(parent.kind(), SyntaxKind::MEMBER_ACCESS_EXPRESSION | SyntaxKind::SCOPED_NAME) { + return true; + } + + let Some(parent_binary) = ast::BinaryExpression::cast(parent) else { + return ast::Expression::cast(parent).is_some_and(|_| { + expr.as_binary_expression().is_some() || expr.as_conditional_expression().is_some() + }); + }; + let Some(child_binary) = expr.as_binary_expression() else { + return false; + }; + + let (Some(parent_prec), Some(child_prec)) = + (binary_precedence(parent_binary), binary_precedence(child_binary)) + else { + return true; + }; + + match child_prec.cmp(&parent_prec) { + Ordering::Greater => false, + Ordering::Less => true, + Ordering::Equal => { + let same_associative_op = parent_binary + .operator_token() + .zip(child_binary.operator_token()) + .is_some_and(|(parent_op, child_op)| { + parent_op.kind() == child_op.kind() + && associative_binary_operator(parent_op.kind()) + }); + !(parent_binary.left().syntax() == parens.syntax() && same_associative_op) + } + } +} + +fn associative_binary_operator(kind: TokenKind) -> bool { + matches!( + kind, + TokenKind::PLUS + | TokenKind::STAR + | TokenKind::DOUBLE_AND + | TokenKind::DOUBLE_OR + | TokenKind::AND + | TokenKind::OR + | TokenKind::XOR + | TokenKind::TILDE_XOR + | TokenKind::XOR_TILDE + ) +} + +fn binary_precedence(expr: ast::BinaryExpression<'_>) -> Option { + let kind = expr.operator_token()?.kind(); + Some(match kind { + TokenKind::DOUBLE_STAR => 12, + TokenKind::STAR | TokenKind::SLASH | TokenKind::PERCENT => 11, + TokenKind::PLUS | TokenKind::MINUS => 10, + TokenKind::LEFT_SHIFT + | TokenKind::RIGHT_SHIFT + | TokenKind::TRIPLE_LEFT_SHIFT + | TokenKind::TRIPLE_RIGHT_SHIFT => 9, + TokenKind::LESS_THAN_EQUALS + if expr.syntax().kind() == SyntaxKind::NONBLOCKING_ASSIGNMENT_EXPRESSION => + { + 1 + } + TokenKind::GREATER_THAN + | TokenKind::GREATER_THAN_EQUALS + | TokenKind::LESS_THAN + | TokenKind::LESS_THAN_EQUALS => 8, + TokenKind::DOUBLE_EQUALS + | TokenKind::EXCLAMATION_EQUALS + | TokenKind::TRIPLE_EQUALS + | TokenKind::EXCLAMATION_DOUBLE_EQUALS + | TokenKind::DOUBLE_EQUALS_QUESTION + | TokenKind::EXCLAMATION_EQUALS_QUESTION => 7, + TokenKind::AND => 6, + TokenKind::XOR | TokenKind::TILDE_XOR | TokenKind::XOR_TILDE => 5, + TokenKind::OR => 4, + TokenKind::DOUBLE_AND => 3, + TokenKind::DOUBLE_OR => 2, + TokenKind::EQUALS + | TokenKind::PLUS_EQUAL + | TokenKind::MINUS_EQUAL + | TokenKind::STAR_EQUAL + | TokenKind::SLASH_EQUAL + | TokenKind::PERCENT_EQUAL + | TokenKind::AND_EQUAL + | TokenKind::OR_EQUAL + | TokenKind::XOR_EQUAL + | TokenKind::LEFT_SHIFT_EQUAL + | TokenKind::RIGHT_SHIFT_EQUAL + | TokenKind::TRIPLE_LEFT_SHIFT_EQUAL + | TokenKind::TRIPLE_RIGHT_SHIFT_EQUAL => 1, + _ => return None, + }) +} diff --git a/crates/ide/src/code_action/handlers/sort_named_instantiation_items.rs b/crates/ide/src/code_action/handlers/sort_named_instantiation_items.rs index 917f7699..9ee0c222 100644 --- a/crates/ide/src/code_action/handlers/sort_named_instantiation_items.rs +++ b/crates/ide/src/code_action/handlers/sort_named_instantiation_items.rs @@ -1,21 +1,29 @@ use std::ops::Range; -use hir::{base_db::source_db::SourceDb, db::HirDb}; +use hir::{ + base_db::source_db::SourceDb, + container::InModule, + db::HirDb, + hir_def::module::instantiation::{ParamAssign, PortConn}, + source_map::IsSrc, +}; use itertools::Itertools; use rustc_hash::FxHashMap; -use smol_str::ToSmolStr; use syntax::{ ast::{self, AstNode}, - has_text_range::{HasTextRange, HasTextRangeIn}, + has_text_range::HasTextRangeIn, +}; +use utils::{ + get::{Get, GetRef}, + text_edit::TextRange, }; -use utils::text_edit::TextRange; use crate::{ code_action::{ CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind, all_parameter_names, line_indent, port_names, }, - module_resolution::resolve_instantiation_target, + module_resolution::resolve_hir_instantiation_target, }; const SORT_NAMED_PARAMETER_ASSIGNMENTS_ID: CodeActionId = CodeActionId { @@ -25,29 +33,46 @@ const SORT_NAMED_PARAMETER_ASSIGNMENTS_ID: CodeActionId = CodeActionId { }; const SORT_NAMED_PARAMETER_ASSIGNMENTS_LABEL: &str = "Sort named parameter assignments"; +// Assist: sort_named_parameter_assignments +// +// This sorts named parameter assignments to match the target module's parameter +// declaration order. +// +// ``` +// child #(.B(2), $0.A(1)) u(); +// ``` +// -> +// ``` +// child #(.A(1), .B(2)) u(); +// ``` pub(super) fn sort_named_parameter_assignments( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, ) -> Option<()> { - let instantiation = ctx.find_node_at_offset::()?; - let params = instantiation.parameters()?; + let ast_instantiation = ctx.find_node_at_offset::()?; + let params = ast_instantiation.parameters()?; let open = params.open_paren()?.text_range_in(params.syntax())?; let close = params.close_paren()?.text_range_in(params.syntax())?; - let db = ctx.sema().db; - let target_module_id = - resolve_instantiation_target(db, ctx.file_id(), instantiation).unique()?; + let sema = ctx.sema(); + let db = sema.db; + let InModule { value: instantiation_id, module_id } = + sema.resolve_instantiation(ctx.file_id().into(), ast_instantiation)?; + let (module, module_src_map) = db.module_with_source_map(module_id); + let instantiation = module.get(instantiation_id); + let target_module_id = resolve_hir_instantiation_target(db, ctx.file_id(), instantiation)?; let parameter_order = all_parameter_names(&db.module(target_module_id)); let parameter_order_map: FxHashMap<_, _> = parameter_order.iter().enumerate().map(|(index, name)| (name.as_ref(), index)).collect(); - let text = ctx.sema().db.file_text(ctx.file_id()); + let text = sema.db.file_text(ctx.file_id()); let mut items = Vec::new(); - for assign in params.parameters().children() { - let named = assign.as_named_param_assignment()?; - let name = named.name()?.value_text().to_smolstr(); + for assign_id in instantiation.param_assigns.iter() { + let ParamAssign::Named(Some(name), _) = module.get(*assign_id) else { + return None; + }; let order = *parameter_order_map.get(name.as_str())?; - let range = assign.syntax().text_range()?; + let range = module_src_map.get(*assign_id)?.range(); items.push((order, text.get(Range::from(range))?, range)); } @@ -69,30 +94,46 @@ const SORT_NAMED_PORT_CONNECTIONS_ID: CodeActionId = CodeActionId { }; const SORT_NAMED_PORT_CONNECTIONS_LABEL: &str = "Sort named port connections"; +// Assist: sort_named_port_connections +// +// This sorts named port connections to match the target module's port +// declaration order. +// +// ``` +// child u(.b(b), $0.a(a)); +// ``` +// -> +// ``` +// child u(.a(a), .b(b)); +// ``` pub(super) fn sort_named_port_connections( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, ) -> Option<()> { - let instance = ctx.find_node_at_offset::()?; - let instantiation = ast::HierarchyInstantiation::cast(instance.syntax().parent()?)?; - let open = instance.open_paren()?.text_range_in(instance.syntax())?; - let close = instance.close_paren()?.text_range_in(instance.syntax())?; - - let db = ctx.sema().db; - let target_module_id = - resolve_instantiation_target(db, ctx.file_id(), instantiation).unique()?; - let target_module = db.module(target_module_id); - let port_order = port_names(&target_module); + let ast_instance = ctx.find_node_at_offset::()?; + let open = ast_instance.open_paren()?.text_range_in(ast_instance.syntax())?; + let close = ast_instance.close_paren()?.text_range_in(ast_instance.syntax())?; + + let sema = ctx.sema(); + let db = sema.db; + let InModule { value: instance_id, module_id } = + sema.resolve_instance(ctx.file_id().into(), ast_instance)?; + let (module, module_src_map) = db.module_with_source_map(module_id); + let instance = module.get(instance_id); + let instantiation = module.get(instance.parent); + let target_module_id = resolve_hir_instantiation_target(db, ctx.file_id(), instantiation)?; + let port_order = port_names(&db.module(target_module_id)); let port_order_map: FxHashMap<_, _> = port_order.iter().enumerate().map(|(index, name)| (name.as_ref(), index)).collect(); - let text = ctx.sema().db.file_text(ctx.file_id()); + let text = sema.db.file_text(ctx.file_id()); let mut items = Vec::new(); - for conn in instance.connections().children() { - let named = conn.as_named_port_connection()?; - let name = named.name()?.value_text().to_smolstr(); + for conn_id in instance.connections.iter() { + let PortConn::Named(Some(name), _) = module.get(*conn_id) else { + return None; + }; let order = *port_order_map.get(name.as_str())?; - let range = conn.syntax().text_range()?; + let range = module_src_map.get(*conn_id)?.range(); items.push((order, text.get(Range::from(range))?, range)); } diff --git a/crates/ide/src/code_action/handlers/split_declaration_declarators.rs b/crates/ide/src/code_action/handlers/split_declaration_declarators.rs index 2963bc71..78fb4a10 100644 --- a/crates/ide/src/code_action/handlers/split_declaration_declarators.rs +++ b/crates/ide/src/code_action/handlers/split_declaration_declarators.rs @@ -19,6 +19,19 @@ const ID: CodeActionId = CodeActionId { }; const LABEL: &str = "Split declaration"; +// Assist: split_declaration_declarators +// +// This splits a declaration with multiple declarators into one declaration per +// declarator. +// +// ``` +// logic $0a, b; +// ``` +// -> +// ``` +// logic a; +// logic b; +// ``` pub(super) fn split_declaration_declarators( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/handlers/wrap_statement_in_begin_end.rs b/crates/ide/src/code_action/handlers/wrap_statement_in_begin_end.rs index 0c134a1b..4bfe2402 100644 --- a/crates/ide/src/code_action/handlers/wrap_statement_in_begin_end.rs +++ b/crates/ide/src/code_action/handlers/wrap_statement_in_begin_end.rs @@ -18,6 +18,17 @@ const WRAP_ID: CodeActionId = CodeActionId { }; const WRAP_LABEL: &str = "Wrap statement in begin/end"; +// Assist: wrap_statement_in_begin_end +// +// This wraps a control-flow body statement in a `begin`/`end` block. +// +// ``` +// always_comb if (a) $0y = 1; +// ``` +// -> +// ``` +// always_comb if (a) begin y = 1; end +// ``` pub(super) fn wrap_statement_in_begin_end( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, @@ -47,6 +58,17 @@ const UNWRAP_ID: CodeActionId = CodeActionId { }; const UNWRAP_LABEL: &str = "Unwrap single-statement begin/end"; +// Assist: unwrap_single_statement_block +// +// This unwraps a `begin`/`end` block that contains exactly one statement. +// +// ``` +// always_comb if (a) $0begin y = 1; end +// ``` +// -> +// ``` +// always_comb if (a) y = 1; +// ``` pub(super) fn unwrap_single_statement_block( collector: &mut CodeActionCollector, ctx: &CodeActionCtx, diff --git a/crates/ide/src/code_action/tests.rs b/crates/ide/src/code_action/tests.rs index 2f7161ad..1bd3f82b 100644 --- a/crates/ide/src/code_action/tests.rs +++ b/crates/ide/src/code_action/tests.rs @@ -13,6 +13,11 @@ fn db_with_file(text: &str) -> (RootDb, FileId, TextSize) { let marker = "/*caret*/"; let offset = text.find(marker).expect("missing caret marker"); let text = text.replace(marker, ""); + let (db, file_id) = db_with_text(&text); + (db, file_id, TextSize::from(offset as u32)) +} + +fn db_with_text(text: &str) -> (RootDb, FileId) { let file_id = FileId(0); let mut file_set = FileSet::default(); file_set.insert(file_id, VfsPath::new_virtual_path("/test.sv".to_owned())); @@ -21,12 +26,12 @@ fn db_with_file(text: &str) -> (RootDb, FileId, TextSize) { change.set_roots(vec![SourceRoot::new_local(file_set)]); change.add_changed_file(ChangedFile { file_id, - change_kind: ChangeKind::Create(Arc::from(text.as_str()), LineEnding::Unix), + change_kind: ChangeKind::Create(Arc::from(text), LineEnding::Unix), }); let mut db = RootDb::new(None); db.apply_change(change); - (db, file_id, TextSize::from(offset as u32)) + (db, file_id) } fn apply_action(text: &str, repair: RepairKind) -> Option { @@ -90,6 +95,57 @@ fn apply_action_without_diagnostics_by( Some(text) } +fn apply_action_without_diagnostics_with_selection( + text: &str, + action_name: &str, +) -> Option { + apply_action_without_diagnostics_with_selection_by(text, |action| action.id.name == action_name) +} + +fn apply_action_without_diagnostics_with_selection_by( + text: &str, + pred: impl Fn(&CodeAction) -> bool, +) -> Option { + let (mut text, range) = text_with_selection_range(text); + let (db, file_id) = db_with_text(&text); + let actions = code_action( + &db, + file_id, + range, + CodeActionDiagnostics::default(), + CodeActionResolveStrategy::All, + ); + let action = actions.into_iter().find(pred)?; + let edit = action.source_change?.text_edits.remove(&file_id)?; + edit.apply(&mut text); + Some(text) +} + +fn action_labels_without_diagnostics_with_selection(text: &str) -> Vec { + let (text, range) = text_with_selection_range(text); + let (db, file_id) = db_with_text(&text); + code_action( + &db, + file_id, + range, + CodeActionDiagnostics::default(), + CodeActionResolveStrategy::All, + ) + .into_iter() + .map(|action| action.label) + .collect() +} + +fn text_with_selection_range(text: &str) -> (String, TextRange) { + let marker = "/*selection*/"; + let start = text.find(marker).expect("missing selection start marker"); + let text = text.replacen(marker, "", 1); + let end = text.find(marker).expect("missing selection end marker"); + let text = text.replacen(marker, "", 1); + let range = TextRange::new(TextSize::from(start as u32), TextSize::from(end as u32)); + (text, range) +} + fn diagnostic_for_repair(repair: RepairKind) -> CodeActionDiagnostic { match repair { RepairKind::MissingConnection => CodeActionDiagnostic { @@ -317,6 +373,53 @@ fn literal_base_is_not_available_for_string_literals() { assert!(!labels.iter().any(|label| label.starts_with("Convert literal to "))); } +#[test] +fn reformat_number_literal_adds_decimal_separators() { + let text = "module top; localparam int value = /*caret*/10000; endmodule\n"; + let fixed = apply_action_without_diagnostics_with_label( + text, + "reformat_number_literal", + "Convert 10000 to 10_000", + ) + .unwrap(); + + assert_eq!(fixed, "module top; localparam int value = 10_000; endmodule\n"); +} + +#[test] +fn reformat_number_literal_removes_separators() { + let text = "module top; localparam int value = /*caret*/10_000; endmodule\n"; + let fixed = apply_action_without_diagnostics_with_label( + text, + "reformat_number_literal", + "Remove digit separators", + ) + .unwrap(); + + assert_eq!(fixed, "module top; localparam int value = 10000; endmodule\n"); +} + +#[test] +fn reformat_number_literal_formats_hex_literals() { + let text = "module top; localparam int value = /*caret*/'hff0000; endmodule\n"; + let fixed = apply_action_without_diagnostics_with_label( + text, + "reformat_number_literal", + "Convert 'hff0000 to 'hff_0000", + ) + .unwrap(); + + assert_eq!(fixed, "module top; localparam int value = 'hff_0000; endmodule\n"); +} + +#[test] +fn reformat_number_literal_requires_enough_digits() { + let labels = action_labels_without_diagnostics( + "module top; localparam int value = /*caret*/999; endmodule\n", + ); + assert!(!labels.iter().any(|label| label.starts_with("Convert 999 to "))); +} + #[test] fn missing_connection_repair_fills_named_connections() { let text = "module child(input a, input b); endmodule\nmodule top; child u(/*caret*/.a()); endmodule\n"; @@ -524,6 +627,139 @@ fn implicit_named_port_repair_is_available_without_diagnostics() { assert_eq!(fixed, "module child(input a); endmodule\nmodule top; child u(.a()); endmodule\n"); } +#[test] +fn named_port_shorthand_expands() { + let text = + "module child(input a); endmodule\nmodule top; logic a; child u(/*caret*/.a); endmodule\n"; + let fixed = + apply_action_without_diagnostics(text, "expand_named_port_connection_shorthand").unwrap(); + assert_eq!( + fixed, + "module child(input a); endmodule\nmodule top; logic a; child u(.a(a)); endmodule\n" + ); +} + +#[test] +fn named_port_shorthand_expands_all_named_connections_in_instance() { + let text = "module child(input a, b); endmodule\nmodule top; logic a, b; child u(/*caret*/.a, .b); endmodule\n"; + let fixed = + apply_action_without_diagnostics(text, "expand_named_port_connection_shorthand").unwrap(); + assert_eq!( + fixed, + "module child(input a, b); endmodule\nmodule top; logic a, b; child u(.a(a), .b(b)); endmodule\n" + ); +} + +#[test] +fn named_port_shorthand_collapses() { + let text = "module child(input a); endmodule\nmodule top; logic a; child u(/*caret*/.a(a)); endmodule\n"; + let fixed = + apply_action_without_diagnostics(text, "collapse_named_port_connection_shorthand").unwrap(); + assert_eq!( + fixed, + "module child(input a); endmodule\nmodule top; logic a; child u(.a); endmodule\n" + ); +} + +#[test] +fn named_port_shorthand_collapses_all_named_connections_in_instance() { + let text = "module child(input a, b); endmodule\nmodule top; logic a, b; child u(/*caret*/.a(a), .b(b)); endmodule\n"; + let fixed = + apply_action_without_diagnostics(text, "collapse_named_port_connection_shorthand").unwrap(); + assert_eq!( + fixed, + "module child(input a, b); endmodule\nmodule top; logic a, b; child u(.a, .b); endmodule\n" + ); +} + +#[test] +fn named_port_shorthand_collapses_matching_connections_in_instance() { + let text = "module child(input a, b, c); endmodule\nmodule top; logic sw1, b, gate_out; child u(/*caret*/.a(sw1), .c(c), .b(gate_out)); endmodule\n"; + let fixed = + apply_action_without_diagnostics(text, "collapse_named_port_connection_shorthand").unwrap(); + assert_eq!( + fixed, + "module child(input a, b, c); endmodule\nmodule top; logic sw1, b, gate_out; child u(.a(sw1), .c, .b(gate_out)); endmodule\n" + ); +} + +#[test] +fn named_port_shorthand_collapse_requires_at_least_one_same_name() { + let labels = action_labels_without_diagnostics( + "module child(input a); endmodule\nmodule top; logic b; child u(/*caret*/.a(b)); endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Collapse named port to shorthand")); +} + +#[test] +fn named_port_shorthand_requires_all_connections_named() { + let labels = action_labels_without_diagnostics( + "module child(input a, b); endmodule\nmodule top; logic a, b; child u(/*caret*/.a, b); endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Expand named port shorthand")); +} + +#[test] +fn convert_always_star_to_always_comb() { + let text = "module top; logic a, y; /*caret*/always @(*) begin y = a; end endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_always_to_always_comb").unwrap(); + assert_eq!(fixed, "module top; logic a, y; always_comb begin y = a; end endmodule\n"); +} + +#[test] +fn convert_always_comb_to_always_star() { + let text = "module top; logic a, y; /*caret*/always_comb begin y = a; end endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_always_comb_to_always").unwrap(); + assert_eq!(fixed, "module top; logic a, y; always @(*) begin y = a; end endmodule\n"); +} + +#[test] +fn convert_always_posedge_to_always_ff() { + let text = "module top; logic clk, d, q; /*caret*/always @(posedge clk) q <= d; endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_always_to_always_ff").unwrap(); + assert_eq!(fixed, "module top; logic clk, d, q; always_ff @(posedge clk) q <= d; endmodule\n"); +} + +#[test] +fn convert_always_event_list_to_always_ff() { + let text = "module top; logic clk, d, q; always @(/*caret*/posedge clk) q <= d; endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_always_to_always_ff").unwrap(); + assert_eq!(fixed, "module top; logic clk, d, q; always_ff @(posedge clk) q <= d; endmodule\n"); +} + +#[test] +fn convert_always_ff_to_plain_always() { + let text = "module top; logic clk, d, q; /*caret*/always_ff @(posedge clk) q <= d; endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_always_ff_to_always").unwrap(); + assert_eq!(fixed, "module top; logic clk, d, q; always @(posedge clk) q <= d; endmodule\n"); +} + +#[test] +fn convert_always_block_requires_caret_on_keyword_or_event_list() { + let labels = action_labels_without_diagnostics( + "module top; logic a, y; always @(*) begin /*caret*/y = a; end endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Convert to always_comb")); + + let labels = action_labels_without_diagnostics( + "module top; logic a, y; always_comb begin /*caret*/y = a; end endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Convert to always @(*)")); + + let labels = action_labels_without_diagnostics( + "module top; logic clk, d, q; always_ff @(posedge clk) /*caret*/q <= d; endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Convert to always @(...)")); +} + +#[test] +fn convert_always_to_always_ff_requires_edge_sensitivity() { + let labels = action_labels_without_diagnostics( + "module top; logic clk, d, q; /*caret*/always @(clk) q <= d; endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Convert to always_ff")); +} + #[test] fn instance_missing_parens_repair_adds_port_list() { let text = "module child; endmodule\nmodule top; child u/*caret*/; endmodule\n"; @@ -538,6 +774,82 @@ fn instance_missing_parens_repair_requires_diagnostics() { assert!(!labels.iter().any(|label| label == "Add empty instance port list")); } +#[test] +fn convert_ansi_ports_to_non_ansi() { + let text = "module top(/*caret*/input a, output logic b);\nassign b = a;\nendmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_ansi_ports_to_non_ansi").unwrap(); + assert_eq!( + fixed, + "module top(a, b);\n input a;\n output logic b;\n assign b = a;\nendmodule\n" + ); +} + +#[test] +fn convert_ansi_ports_to_non_ansi_uses_inherited_header() { + let text = "module top(/*caret*/input a, b);\nassign b = a;\nendmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_ansi_ports_to_non_ansi").unwrap(); + assert_eq!( + fixed, + "module top(a, b);\n input a;\n input wire logic b;\n assign b = a;\nendmodule\n" + ); +} + +#[test] +fn convert_non_ansi_ports_to_ansi() { + let text = + "module top(/*caret*/a, b);\ninput wire a;\noutput logic b;\nassign b = a;\nendmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_non_ansi_ports_to_ansi").unwrap(); + assert_eq!(fixed, "module top(input wire a, output logic b);\n assign b = a;\nendmodule\n"); +} + +#[test] +fn convert_non_ansi_ports_to_ansi_merges_data_declaration() { + let text = "module top (\n /*caret*/c,\n led0\n);\n input wire c;\n output led0;\n reg led0;\n\nendmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_non_ansi_ports_to_ansi").unwrap(); + assert_eq!(fixed, "module top (\n input wire c,\n output reg led0\n);\nendmodule\n"); +} + +#[test] +fn convert_ansi_ports_to_non_ansi_preserves_body_comments() { + let text = + "module top(/*caret*/input a, output logic b);\n// keep this\nassign b = a;\nendmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_ansi_ports_to_non_ansi").unwrap(); + assert!(fixed.contains("// keep this"), "{fixed}"); + assert!(fixed.contains("assign b = a;"), "{fixed}"); +} + +#[test] +fn convert_non_ansi_ports_to_ansi_preserves_body_comments() { + let text = "module top(/*caret*/a, b);\n// keep first\ninput wire a;\n// keep second\noutput logic b;\nassign b = a;\nendmodule\n"; + let fixed = apply_action_without_diagnostics(text, "convert_non_ansi_ports_to_ansi").unwrap(); + assert!(fixed.contains("// keep first"), "{fixed}"); + assert!(fixed.contains("// keep second"), "{fixed}"); + assert!(fixed.contains("assign b = a;"), "{fixed}"); +} + +#[test] +fn convert_port_declarations_requires_caret_in_port_list() { + let labels = action_labels_without_diagnostics( + "module /*caret*/top(input a, output logic b);\nassign b = a;\nendmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Convert ANSI port declarations to non-ANSI")); + + let labels = action_labels_without_diagnostics( + "module top(input a, output logic b);\n/*caret*/assign b = a;\nendmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Convert ANSI port declarations to non-ANSI")); + + let labels = action_labels_without_diagnostics( + "module /*caret*/top(a, b);\ninput wire a;\noutput logic b;\nassign b = a;\nendmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Convert non-ANSI port declarations to ANSI")); + + let labels = action_labels_without_diagnostics( + "module top(a, b);\ninput wire a;\n/*caret*/output logic b;\nassign b = a;\nendmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Convert non-ANSI port declarations to ANSI")); +} + #[test] fn split_declaration_declarators_splits_data_declaration() { let text = "module top; /*caret*/logic [3:0] a, b = 4'h0; endmodule\n"; @@ -620,6 +932,232 @@ fn invert_if_else_swaps_branches_and_negates_condition() { assert_eq!(fixed, "module top; always_comb if (!(a)) y = 0; else y = 1; endmodule\n"); } +#[test] +fn remove_parentheses_removes_redundant_binary_parens() { + let text = "module top; assign y = /*caret*/(a + b) + c; endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "remove_parentheses").unwrap(); + assert_eq!(fixed, "module top; assign y = a + b + c; endmodule\n"); +} + +#[test] +fn remove_parentheses_keeps_required_parens() { + let labels = action_labels_without_diagnostics( + "module top; assign y = /*caret*/(a + b) * c; endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Remove redundant parentheses")); +} + +#[test] +fn remove_parentheses_requires_cursor_on_paren() { + let labels = action_labels_without_diagnostics( + "module top; assign y = (a /*caret*/+ b) + c; endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Remove redundant parentheses")); +} + +#[test] +fn merge_nested_if_merges_simple_nested_if() { + let text = "module top; always_comb if (/*caret*/a) begin if (b) y = 1; end endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "merge_nested_if").unwrap(); + assert_eq!(fixed, "module top; always_comb if (a && b) y = 1; endmodule\n"); +} + +#[test] +fn merge_nested_if_wraps_or_conditions() { + let text = + "module top; always_comb if (/*caret*/a || b) begin if (c || d) y = 1; end endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "merge_nested_if").unwrap(); + assert_eq!(fixed, "module top; always_comb if ((a || b) && (c || d)) y = 1; endmodule\n"); +} + +#[test] +fn merge_nested_if_merges_multiple_nested_levels() { + let text = "module top; always_comb if (/*caret*/a) begin if (b) begin if (c) y = 1; end end endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "merge_nested_if").unwrap(); + assert_eq!(fixed, "module top; always_comb if (a && b && c) y = 1; endmodule\n"); +} + +#[test] +fn merge_nested_if_triggers_from_middle_nested_level() { + let text = "module top; always_comb if (a) begin if (/*caret*/b) begin if (c) y = 1; end end endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "merge_nested_if").unwrap(); + assert_eq!(fixed, "module top; always_comb if (a && b && c) y = 1; endmodule\n"); +} + +#[test] +fn merge_nested_if_triggers_from_innermost_nested_level() { + let text = "module top; always_comb if (a) begin if (b) begin if (/*caret*/c) y = 1; end end endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "merge_nested_if").unwrap(); + assert_eq!(fixed, "module top; always_comb if (a && b && c) y = 1; endmodule\n"); +} + +#[test] +fn merge_nested_if_merges_mixed_block_and_unbraced_levels() { + let text = "module top; always_comb if (a) begin if (/*caret*/b) if (c) y = 1; end endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "merge_nested_if").unwrap(); + assert_eq!(fixed, "module top; always_comb if (a && b && c) y = 1; endmodule\n"); +} + +#[test] +fn merge_nested_if_requires_no_else_branches() { + let labels = action_labels_without_diagnostics( + "module top; always_comb if (/*caret*/a) begin if (b) y = 1; else y = 0; end endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Merge nested if")); +} + +#[test] +fn merge_nested_if_rejects_block_with_declarations() { + let labels = action_labels_without_diagnostics( + "module top; always_comb if (/*caret*/a) begin logic tmp; if (b) y = tmp; end endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Merge nested if")); +} + +#[test] +fn extract_variable_inserts_local_before_statement() { + let text = "module top; always_comb begin y = /*selection*/a + b/*selection*/; end endmodule\n"; + let fixed = apply_action_without_diagnostics_with_selection(text, "extract_variable").unwrap(); + assert_eq!( + fixed, + "module top; always_comb begin logic value = a + b;\ny = value; end endmodule\n" + ); +} + +#[test] +fn extract_variable_allows_selection_padding() { + let text = + "module top; always_comb begin y =/*selection*/ a + b /*selection*/; end endmodule\n"; + let fixed = apply_action_without_diagnostics_with_selection(text, "extract_variable").unwrap(); + assert_eq!( + fixed, + "module top; always_comb begin logic value = a + b;\ny = value ; end endmodule\n" + ); +} + +#[test] +fn extract_variable_uses_assignment_lhs_type() { + let text = "module top; logic [7:0] y, a, b; always_comb begin y = /*selection*/a + b/*selection*/; end endmodule\n"; + let fixed = apply_action_without_diagnostics_with_selection(text, "extract_variable").unwrap(); + assert_eq!( + fixed, + "module top; logic [7:0] y, a, b; always_comb begin logic [7:0] value = a + b;\ny = value; end endmodule\n" + ); +} + +#[test] +fn extract_variable_from_continuous_assign() { + let text = "module top; assign y = /*selection*/a + b/*selection*/; endmodule\n"; + let fixed = apply_action_without_diagnostics_with_selection(text, "extract_variable").unwrap(); + assert_eq!(fixed, "module top; wire logic value = a + b;\nassign y = value; endmodule\n"); +} + +#[test] +fn extract_variable_uses_continuous_assign_lhs_type() { + let text = + "module top; logic [7:0] y, a, b; assign y = /*selection*/a + b/*selection*/; endmodule\n"; + let fixed = apply_action_without_diagnostics_with_selection(text, "extract_variable").unwrap(); + assert_eq!( + fixed, + "module top; logic [7:0] y, a, b; wire logic [7:0] value = a + b;\nassign y = value; endmodule\n" + ); +} + +#[test] +fn extract_variable_requires_selection() { + let labels = action_labels_without_diagnostics( + "module top; always_comb begin y = a /*caret*/+ b; end endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Extract into variable")); +} + +#[test] +fn extract_variable_requires_complete_expression_selection() { + let labels = action_labels_without_diagnostics_with_selection( + "module top; always_comb begin y = a /*selection*/+/*selection*/ b; end endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Extract into variable")); +} + +#[test] +fn extract_variable_rejects_continuous_assign_lhs() { + let labels = action_labels_without_diagnostics_with_selection( + "module top; assign /*selection*/y/*selection*/ = a + b; endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Extract into variable")); +} + +#[test] +fn extract_variable_requires_block_scope() { + let labels = action_labels_without_diagnostics_with_selection( + "module top; always_comb if (a) y = /*selection*/b + c/*selection*/; endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Extract into variable")); +} + +#[test] +fn pull_assignment_up_converts_if_else_assignment_to_ternary() { + let text = "module top; always_comb /*caret*/if (a) y = 1; else y = 0; endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "pull_assignment_up").unwrap(); + assert_eq!(fixed, "module top; always_comb y = a ? 1 : 0; endmodule\n"); +} + +#[test] +fn pull_assignment_up_converts_else_if_chain_to_nested_ternary() { + let text = + "module top; always_comb if (/*caret*/a) y = 1; else if (b) y = 2; else y = 3; endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "pull_assignment_up").unwrap(); + assert_eq!(fixed, "module top; always_comb y = a ? 1 : b ? 2 : 3; endmodule\n"); +} + +#[test] +fn pull_assignment_up_triggers_from_else_if_chain_body() { + let text = + "module top; always_comb if (a) y = 1; else if (b) /*caret*/y = 2; else y = 3; endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "pull_assignment_up").unwrap(); + assert_eq!(fixed, "module top; always_comb y = a ? 1 : b ? 2 : 3; endmodule\n"); +} + +#[test] +fn pull_assignment_up_wraps_conditional_predicate() { + let text = "module top; always_comb if (a ? b : c) /*caret*/y = 1; else y = 0; endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "pull_assignment_up").unwrap(); + assert_eq!(fixed, "module top; always_comb y = (a ? b : c) ? 1 : 0; endmodule\n"); +} + +#[test] +fn pull_assignment_up_requires_single_assignment_branches() { + let labels = action_labels_without_diagnostics( + "module top; always_comb if (a) begin /*caret*/y = 1; z = 0; end else y = 2; endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Pull assignment up")); +} + +#[test] +fn pull_assignment_up_rejects_block_with_declarations() { + let labels = action_labels_without_diagnostics( + "module top; always_comb if (a) begin logic tmp; /*caret*/y = tmp; end else y = 0; endmodule\n", + ); + assert!(!labels.iter().any(|label| label == "Pull assignment up")); +} + +#[test] +fn pull_assignment_down_converts_ternary_assignment_to_if_else() { + let text = "module top; always_comb /*caret*/y = a ? 1 : 0; endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "pull_assignment_down").unwrap(); + assert_eq!(fixed, "module top; always_comb if (a) y = 1; else y = 0; endmodule\n"); +} + +#[test] +fn pull_assignment_down_converts_nested_ternary_to_else_if_chain() { + let text = "module top; always_comb /*caret*/y = a ? 1 : b ? 2 : 3; endmodule\n"; + let fixed = apply_action_without_diagnostics(text, "pull_assignment_down").unwrap(); + assert_eq!( + fixed, + "module top; always_comb if (a) y = 1; else if (b) y = 2; else y = 3; endmodule\n" + ); +} + #[test] fn unwrap_single_statement_block_unwraps_single_statement() { let text = "module top; always_comb if (a) /*caret*/begin y = 1; end endmodule\n"; diff --git a/crates/ide/src/document_symbols.rs b/crates/ide/src/document_symbols.rs index 3a251c54..92dfb220 100644 --- a/crates/ide/src/document_symbols.rs +++ b/crates/ide/src/document_symbols.rs @@ -7,6 +7,7 @@ use hir::{ file::HirFileId, hir_def::{ DEFAULT_NAME, + aggregate::{StructDef, StructId, StructKind, StructSrc}, block::{BlockId, BlockInfo, BlockItem, BlockSrc, LocalBlockId}, declaration::{Declaration, DeclarationId, DeclarationSrc}, expr::declarator::{DeclId, Declarator, DeclaratorSrc, DeclsRange}, @@ -224,9 +225,7 @@ pub(crate) fn document_symbols(db: &dyn HirDb, file_id: FileId) -> Vec { build_subroutine(&mut collector, subroutine_id, file, src_map) } - FileItem::StructId(_) => { - // TODO: implement document symbols for these items - } + FileItem::StructId(struct_id) => build_struct(&mut collector, struct_id, file, src_map), FileItem::ConfigDeclId(config_id) => { build_config_decl(&mut collector, config_id, file, src_map) } @@ -328,9 +327,7 @@ fn collect_module_items( ModuleItem::SubroutineId(subroutine_id) => { build_subroutine(collector, subroutine_id, module, src_map) } - ModuleItem::StructId(_) => { - // TODO: implement document symbols for these items - } + ModuleItem::StructId(struct_id) => build_struct(collector, struct_id, module, src_map), } } collector.pop(); @@ -365,9 +362,7 @@ fn collect_block_items( BlockItem::TypedefId(typedef_id) => { build_typedef(collector, typedef_id, block, src_map) } - BlockItem::StructId(_) => { - // TODO: implement document symbols for these items - } + BlockItem::StructId(struct_id) => build_struct(collector, struct_id, block, src_map), } } collector.pop(); @@ -489,6 +484,7 @@ fn build_generate_region( + GetRef + GetRef + GetRef + + GetRef + GetRef, SrcMap: Get> + Get> @@ -497,6 +493,7 @@ fn build_generate_region( + Get> + Get> + Get> + + Get> + Get>, { let hir = arena.get(generate_region_id); @@ -527,7 +524,9 @@ fn build_generate_region( let proc = arena.get(proc_id); build_stmt(db, collector, proc.stmt, arena, src_map); } - GenerateItem::StructId(_) => {} + GenerateItem::StructId(struct_id) => { + build_struct(collector, struct_id, arena, src_map); + } GenerateItem::SubroutineId(subroutine_id) => { build_subroutine(collector, subroutine_id, arena, src_map); } @@ -578,14 +577,43 @@ fn build_generate_block( } } } - GenerateBlockItem::ContAssignId(_) - | GenerateBlockItem::DefParamId(_) - | GenerateBlockItem::StructId(_) => {} + GenerateBlockItem::ContAssignId(_) | GenerateBlockItem::DefParamId(_) => {} + GenerateBlockItem::StructId(struct_id) => { + build_struct(collector, struct_id, generate_block, src_map); + } } } collector.pop(); } +#[inline] +fn build_struct( + collector: &mut SymbolCollecter, + struct_id: Idx, + arena: &Arn, + src_map: &SrcMap, +) where + Arn: GetRef, + SrcMap: Get>, +{ + let hir = arena.get(struct_id); + let Some(src) = src_map.get(struct_id) else { + return; + }; + + let name = hir.name.clone().or_else(|| Some(struct_kind_name(hir.kind))); + collector.push_symbol_with_kind(&name, src, SymbolKind::Struct); + collector.pop(); +} + +#[inline] +fn struct_kind_name(kind: StructKind) -> SmolStr { + match kind { + StructKind::Struct => SmolStr::new_static("struct"), + StructKind::Union => SmolStr::new_static("union"), + } +} + #[inline] fn build_specify_block( collector: &mut SymbolCollecter, @@ -666,7 +694,11 @@ fn build_typedef( let Some(src) = src_map.get(typedef_id) else { return; }; - collector.push_symbol_with_kind(&hir.name, src, SymbolKind::Typedef); + let kind = match hir.ty { + Some(hir::hir_def::expr::data_ty::DataTy::Struct(_)) => SymbolKind::Struct, + _ => SymbolKind::Typedef, + }; + collector.push_symbol_with_kind(&hir.name, src, kind); collector.pop(); } diff --git a/crates/ide/src/lib.rs b/crates/ide/src/lib.rs index cdac0865..a2f12299 100644 --- a/crates/ide/src/lib.rs +++ b/crates/ide/src/lib.rs @@ -58,6 +58,7 @@ pub enum SymbolKind { Genvar, Specparam, Typedef, + Struct, Instance, Block, Stmt, diff --git a/crates/ide/src/module_resolution.rs b/crates/ide/src/module_resolution.rs index e535f235..721c599e 100644 --- a/crates/ide/src/module_resolution.rs +++ b/crates/ide/src/module_resolution.rs @@ -5,8 +5,11 @@ use hir::{ container::InModule, db::HirDb, hir_def::{ - Ident, declaration::Declaration, expr::declarator::DeclaratorParent, lower_ident_opt, - module::ModuleId, + Ident, + declaration::Declaration, + expr::declarator::DeclaratorParent, + lower_ident_opt, + module::{ModuleId, instantiation::Instantiation}, }, scope::{ModuleEntry, ScopeResolution}, semantics::pathres::PathResolution, @@ -55,6 +58,14 @@ pub(crate) fn resolve_instantiation_target( resolve_module_name(db, from_file, &name) } +pub(crate) fn resolve_hir_instantiation_target( + db: &RootDb, + from_file: FileId, + instantiation: &Instantiation, +) -> Option { + resolve_module_name(db, from_file, instantiation.module_name.as_ref()?).unique() +} + pub(crate) fn resolve_module_name( db: &RootDb, from_file: FileId, diff --git a/crates/ide/src/rename.rs b/crates/ide/src/rename.rs index 04cef1b4..948700af 100644 --- a/crates/ide/src/rename.rs +++ b/crates/ide/src/rename.rs @@ -1,6 +1,4 @@ -use hir::{ - base_db::source_db::SourceDb, container::InFile, hir_def::lower_ident, semantics::Semantics, -}; +use hir::{base_db::source_db::SourceDb, container::InFile, semantics::Semantics}; use nohash_hasher::IntMap; use rustc_hash::FxHashMap; use smol_str::SmolStr; @@ -427,7 +425,7 @@ fn check_same_name_conn( DefinitionClass::PortConnShorthand { port, .. } => port, DefinitionClass::Ambiguous(_) => return None, }; - let port_name = lower_ident(Some(name_token))?; + let port_name = name_token.value_text().to_string(); let expr = conn.expr()?.as_simple_property_expr()?.expr().as_simple_sequence_expr()?.expr(); let actual_token = match expr { Expression::Name(Name::IdentifierName(ident)) => ident.identifier()?, @@ -438,7 +436,7 @@ fn check_same_name_conn( } _ => return None, }; - if lower_ident(Some(actual_token))?.as_str() != port_name.as_str() { + if actual_token.value_text().to_string() != port_name { return None; } let actual_token = SyntaxTokenWithParent { parent: expr.syntax(), tok: actual_token }; @@ -548,7 +546,7 @@ fn edits_from_refs( && conn_data_range(port_conn).is_some_and(|r| r == range) && let Some(port_name) = port_conn .name() - .filter(|n| lower_ident(Some(*n)).is_some_and(|name| name == new_name)) { + .filter(|n| n.value_text().to_string() == new_name) { // .new(data) => .new let Some(start) = port_name.text_range_in(port_conn.syntax()).map(|range| range.start()) else { diff --git a/crates/ide/src/verilog_2005.rs b/crates/ide/src/verilog_2005.rs index 0ad63130..07b58e8a 100644 --- a/crates/ide/src/verilog_2005.rs +++ b/crates/ide/src/verilog_2005.rs @@ -2091,3 +2091,39 @@ fn verilog_2005_lsp_snapshots() { assert_snapshot!("verilog_2005_lsp_snapshots", report); } + +#[test] +fn document_symbols_include_typedef_structs_and_nested_generate_structs() { + let text = r#" +module top; + typedef struct packed { + logic ready; + } packet_t; + + generate + if (1) begin : g + typedef union packed { + logic [7:0] raw; + logic flag; + } state_t; + end + endgenerate +endmodule +"#; + let (host, file_id) = setup(text); + let analysis = host.make_analysis(); + + let symbols = analysis.document_symbol(file_id).unwrap(); + let mut lines = Vec::new(); + collect_symbol_lines(&symbols, 0, &mut lines); + let dump = lines.join("\n"); + + assert!( + dump.contains("packet_t Struct"), + "typedef struct should surface as a struct symbol: {dump}" + ); + assert!( + dump.contains("state_t Struct"), + "nested generate typedef union should surface as a struct symbol: {dump}" + ); +} diff --git a/crates/slang/bindings/rust/ast.rs b/crates/slang/bindings/rust/ast.rs old mode 100644 new mode 100755 index fdad1b27..5e8328e8 --- a/crates/slang/bindings/rust/ast.rs +++ b/crates/slang/bindings/rust/ast.rs @@ -51,6 +51,15 @@ impl<'a, T: AstNode<'a>> SyntaxList<'a, T> { pub fn children(&self) -> impl Iterator + 'a { SyntaxChildren::new(self.syntax).map(|elem| T::cast(elem.as_node().unwrap()).unwrap()) } + + pub fn only_children(&self) -> Option { + let mut children = SyntaxChildren::new(self.syntax); + let first = children.next()?; + if children.next().is_some() { + return None; + } + T::cast(first.as_node().unwrap()) + } } impl<'a, T: AstNode<'a>> AstNode<'a> for SyntaxList<'a, T> { @@ -79,6 +88,15 @@ impl<'a, T: AstNode<'a>> SeparatedList<'a, T> { .step_by(2) .map(|elem| T::cast(elem.as_node().unwrap()).unwrap()) } + + pub fn only_children(&self) -> Option { + let mut children = SyntaxChildren::new(self.syntax); + let first = children.next()?; + if children.next().is_some() { + return None; + } + T::cast(first.as_node().unwrap()) + } } impl<'a, T: AstNode<'a>> AstNode<'a> for SeparatedList<'a, T> { diff --git a/src/i18n.rs b/src/i18n.rs index 28fac9cc..66b6efc8 100644 --- a/src/i18n.rs +++ b/src/i18n.rs @@ -95,10 +95,32 @@ pub(crate) mod keys { "code_action.sort_named_port_connections"; pub(crate) const CODE_ACTION_ADD_DEFAULT_CASE_ITEM: &str = "code_action.add_default_case_item"; pub(crate) const CODE_ACTION_INVERT_IF_ELSE: &str = "code_action.invert_if_else"; + pub(crate) const CODE_ACTION_EXTRACT_VARIABLE: &str = "code_action.extract_variable"; + pub(crate) const CODE_ACTION_REMOVE_REDUNDANT_PARENTHESES: &str = + "code_action.remove_redundant_parentheses"; pub(crate) const CODE_ACTION_UNWRAP_SINGLE_STATEMENT_BLOCK: &str = "code_action.unwrap_single_statement_block"; pub(crate) const CODE_ACTION_WRAP_STATEMENT_IN_BEGIN_END: &str = "code_action.wrap_statement_in_begin_end"; + pub(crate) const CODE_ACTION_EXPAND_NAMED_PORT_CONNECTION_SHORTHAND: &str = + "code_action.expand_named_port_connection_shorthand"; + pub(crate) const CODE_ACTION_COLLAPSE_NAMED_PORT_CONNECTION_SHORTHAND: &str = + "code_action.collapse_named_port_connection_shorthand"; + pub(crate) const CODE_ACTION_CONVERT_ANSI_PORTS_TO_NON_ANSI: &str = + "code_action.convert_ansi_ports_to_non_ansi"; + pub(crate) const CODE_ACTION_CONVERT_NON_ANSI_PORTS_TO_ANSI: &str = + "code_action.convert_non_ansi_ports_to_ansi"; + pub(crate) const CODE_ACTION_CONVERT_ALWAYS_TO_ALWAYS_COMB: &str = + "code_action.convert_always_to_always_comb"; + pub(crate) const CODE_ACTION_CONVERT_ALWAYS_TO_ALWAYS_FF: &str = + "code_action.convert_always_to_always_ff"; + pub(crate) const CODE_ACTION_CONVERT_ALWAYS_COMB_TO_ALWAYS: &str = + "code_action.convert_always_comb_to_always"; + pub(crate) const CODE_ACTION_CONVERT_ALWAYS_FF_TO_ALWAYS: &str = + "code_action.convert_always_ff_to_always"; + pub(crate) const CODE_ACTION_MERGE_NESTED_IF: &str = "code_action.merge_nested_if"; + pub(crate) const CODE_ACTION_PULL_ASSIGNMENT_UP: &str = "code_action.pull_assignment_up"; + pub(crate) const CODE_ACTION_PULL_ASSIGNMENT_DOWN: &str = "code_action.pull_assignment_down"; pub(crate) const CODE_ACTION_EXPAND_POSTFIX_INC_DEC: &str = "code_action.expand_postfix_inc_dec"; pub(crate) const CODE_ACTION_EXPAND_PREFIX_INC_DEC: &str = "code_action.expand_prefix_inc_dec"; @@ -124,6 +146,8 @@ pub(crate) mod keys { "code_action.collapse_compound_assignment"; pub(crate) const CODE_ACTION_APPLY_DE_MORGAN: &str = "code_action.apply_de_morgan"; pub(crate) const CODE_ACTION_FACTOR_DE_MORGAN: &str = "code_action.factor_de_morgan"; + pub(crate) const CODE_ACTION_REMOVE_DIGIT_SEPARATORS: &str = + "code_action.remove_digit_separators"; pub(crate) const CODE_ACTION_INSERT_MISSING_TOKEN: &str = "code_action.insert_missing_token"; pub(crate) const CODE_ACTION_CONVERT_LITERAL_TO_BINARY: &str = "code_action.convert_literal_to_binary"; diff --git a/src/i18n/en.toml b/src/i18n/en.toml index fd942584..d695beb7 100644 --- a/src/i18n/en.toml +++ b/src/i18n/en.toml @@ -53,8 +53,21 @@ sort_named_parameter_assignments = "Sort named parameter assignments" sort_named_port_connections = "Sort named port connections" add_default_case_item = "Add default case item" invert_if_else = "Invert if/else" +extract_variable = "Extract into variable" +remove_redundant_parentheses = "Remove redundant parentheses" unwrap_single_statement_block = "Unwrap single-statement begin/end" wrap_statement_in_begin_end = "Wrap statement in begin/end" +expand_named_port_connection_shorthand = "Expand named port shorthand" +collapse_named_port_connection_shorthand = "Collapse named port to shorthand" +convert_ansi_ports_to_non_ansi = "Convert ANSI port declarations to non-ANSI" +convert_non_ansi_ports_to_ansi = "Convert non-ANSI port declarations to ANSI" +convert_always_to_always_comb = "Convert to always_comb" +convert_always_to_always_ff = "Convert to always_ff" +convert_always_comb_to_always = "Convert to always @(*)" +convert_always_ff_to_always = "Convert to always @(...)" +merge_nested_if = "Merge nested if" +pull_assignment_up = "Pull assignment up" +pull_assignment_down = "Pull assignment down" expand_postfix_inc_dec = "Expand postfix expression" expand_prefix_inc_dec = "Expand prefix expression" convert_postfix_to_prefix_inc_dec = "Convert postfix to prefix expression" @@ -69,6 +82,7 @@ expand_compound_assignment = "Expand compound assignment" collapse_compound_assignment = "Collapse compound assignment" apply_de_morgan = "Apply De Morgan's law" factor_de_morgan = "Factor De Morgan's law" +remove_digit_separators = "Remove digit separators" insert_missing_token = "Insert missing '{token}'" convert_literal_to_binary = "Convert literal to binary" convert_literal_to_octal = "Convert literal to octal" diff --git a/src/i18n/zh-CN.toml b/src/i18n/zh-CN.toml index c25d855b..30e33df6 100644 --- a/src/i18n/zh-CN.toml +++ b/src/i18n/zh-CN.toml @@ -53,8 +53,21 @@ sort_named_parameter_assignments = "排序命名参数赋值" sort_named_port_connections = "排序命名端口连接" add_default_case_item = "添加 default case 分支项" invert_if_else = "反转 if/else" +extract_variable = "提取为变量" +remove_redundant_parentheses = "移除冗余括号" unwrap_single_statement_block = "展开单语句 begin/end" wrap_statement_in_begin_end = "用 begin/end 包裹语句" +expand_named_port_connection_shorthand = "展开命名端口简写" +collapse_named_port_connection_shorthand = "折叠命名端口为简写" +convert_ansi_ports_to_non_ansi = "将 ANSI 端口声明转换为非 ANSI" +convert_non_ansi_ports_to_ansi = "将非 ANSI 端口声明转换为 ANSI" +convert_always_to_always_comb = "转换为 always_comb" +convert_always_to_always_ff = "转换为 always_ff" +convert_always_comb_to_always = "转换为 always @(*)" +convert_always_ff_to_always = "转换为 always @(...)" +merge_nested_if = "合并嵌套 if" +pull_assignment_up = "转换为嵌套三元表达式 ?:" +pull_assignment_down = "转换为 if/else 赋值" expand_postfix_inc_dec = "展开后缀表达式" expand_prefix_inc_dec = "展开前缀表达式" convert_postfix_to_prefix_inc_dec = "将后缀表达式转换为前缀表达式" @@ -67,8 +80,9 @@ convert_assignment_to_postfix_inc_dec = "将赋值转换为后缀表达式" convert_assignment_to_prefix_inc_dec = "将赋值转换为前缀表达式" expand_compound_assignment = "展开复合赋值" collapse_compound_assignment = "折叠复合赋值" -apply_de_morgan = "应用德摩根律" -factor_de_morgan = "提取德摩根律" +apply_de_morgan = "应用德摩根律,将取反操作分配到内层" +factor_de_morgan = "应用德摩根律,将内层取反提取到外层" +remove_digit_separators = "移除数字分隔符" insert_missing_token = "插入缺失的 '{token}'" convert_literal_to_binary = "将字面量转换为二进制" convert_literal_to_octal = "将字面量转换为八进制" diff --git a/src/lsp_ext/to_proto.rs b/src/lsp_ext/to_proto.rs index 1b5e53ee..98fac21a 100644 --- a/src/lsp_ext/to_proto.rs +++ b/src/lsp_ext/to_proto.rs @@ -266,6 +266,7 @@ fn symbol_kind(symbol_kind: SymbolKind) -> lsp_types::SymbolKind { SymbolKind::Genvar => LspSymbolKind::VARIABLE, SymbolKind::Specparam => LspSymbolKind::TYPE_PARAMETER, SymbolKind::Typedef => LspSymbolKind::TYPE_PARAMETER, + SymbolKind::Struct => LspSymbolKind::STRUCT, SymbolKind::Instance => LspSymbolKind::OBJECT, SymbolKind::Block => LspSymbolKind::NAMESPACE, SymbolKind::Stmt => LspSymbolKind::NAMESPACE, @@ -945,8 +946,25 @@ fn code_action_title_key(id: &str, label: &str) -> Option<&'static str> { "sort_named_port_connections" => keys::CODE_ACTION_SORT_NAMED_PORT_CONNECTIONS, "add_default_case_item" => keys::CODE_ACTION_ADD_DEFAULT_CASE_ITEM, "invert_if_else" => keys::CODE_ACTION_INVERT_IF_ELSE, + "extract_variable" => keys::CODE_ACTION_EXTRACT_VARIABLE, + "remove_parentheses" => keys::CODE_ACTION_REMOVE_REDUNDANT_PARENTHESES, "unwrap_single_statement_block" => keys::CODE_ACTION_UNWRAP_SINGLE_STATEMENT_BLOCK, "wrap_statement_in_begin_end" => keys::CODE_ACTION_WRAP_STATEMENT_IN_BEGIN_END, + "expand_named_port_connection_shorthand" => { + keys::CODE_ACTION_EXPAND_NAMED_PORT_CONNECTION_SHORTHAND + } + "collapse_named_port_connection_shorthand" => { + keys::CODE_ACTION_COLLAPSE_NAMED_PORT_CONNECTION_SHORTHAND + } + "convert_ansi_ports_to_non_ansi" => keys::CODE_ACTION_CONVERT_ANSI_PORTS_TO_NON_ANSI, + "convert_non_ansi_ports_to_ansi" => keys::CODE_ACTION_CONVERT_NON_ANSI_PORTS_TO_ANSI, + "convert_always_to_always_comb" => keys::CODE_ACTION_CONVERT_ALWAYS_TO_ALWAYS_COMB, + "convert_always_to_always_ff" => keys::CODE_ACTION_CONVERT_ALWAYS_TO_ALWAYS_FF, + "convert_always_comb_to_always" => keys::CODE_ACTION_CONVERT_ALWAYS_COMB_TO_ALWAYS, + "convert_always_ff_to_always" => keys::CODE_ACTION_CONVERT_ALWAYS_FF_TO_ALWAYS, + "merge_nested_if" => keys::CODE_ACTION_MERGE_NESTED_IF, + "pull_assignment_up" => keys::CODE_ACTION_PULL_ASSIGNMENT_UP, + "pull_assignment_down" => keys::CODE_ACTION_PULL_ASSIGNMENT_DOWN, "expand_postfix_inc_dec" => keys::CODE_ACTION_EXPAND_POSTFIX_INC_DEC, "expand_prefix_inc_dec" => keys::CODE_ACTION_EXPAND_PREFIX_INC_DEC, "convert_postfix_to_prefix_inc_dec" => keys::CODE_ACTION_CONVERT_POSTFIX_TO_PREFIX_INC_DEC, @@ -973,6 +991,9 @@ fn code_action_title_key(id: &str, label: &str) -> Option<&'static str> { "collapse_compound_assignment" => keys::CODE_ACTION_COLLAPSE_COMPOUND_ASSIGNMENT, "apply_de_morgan" => keys::CODE_ACTION_APPLY_DE_MORGAN, "factor_de_morgan" => keys::CODE_ACTION_FACTOR_DE_MORGAN, + "reformat_number_literal" if label == "Remove digit separators" => { + keys::CODE_ACTION_REMOVE_DIGIT_SEPARATORS + } "convert_literal_base" => match label { "Convert literal to binary" => keys::CODE_ACTION_CONVERT_LITERAL_TO_BINARY, "Convert literal to octal" => keys::CODE_ACTION_CONVERT_LITERAL_TO_OCTAL, diff --git a/src/tests.rs b/src/tests.rs index ac8ab1f2..e9f14354 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -673,11 +673,20 @@ fn goto_definition_response_uris(response: GotoDefinitionResponse) -> Vec { fn position_of(text: &str, needle: &str) -> Position { let offset = text.find(needle).unwrap_or_else(|| panic!("missing {needle:?}")); + position_at_offset(text, offset) +} + +fn position_at_offset(text: &str, offset: usize) -> Position { let line = text[..offset].bytes().filter(|byte| *byte == b'\n').count() as u32; let line_start = text[..offset].rfind('\n').map(|idx| idx + 1).unwrap_or(0); Position { line, character: (offset - line_start) as u32 } } +fn range_of(text: &str, needle: &str) -> Range { + let start = text.find(needle).unwrap_or_else(|| panic!("missing {needle:?}")); + Range::new(position_at_offset(text, start), position_at_offset(text, start + needle.len())) +} + fn code_action_client_caps() -> ClientCapabilities { ClientCapabilities { text_document: Some(TextDocumentClientCapabilities { @@ -688,6 +697,7 @@ fn code_action_client_caps() -> ClientCapabilities { CodeActionKind::EMPTY, CodeActionKind::QUICKFIX, CodeActionKind::REFACTOR, + CodeActionKind::REFACTOR_EXTRACT, CodeActionKind::REFACTOR_REWRITE, ] .into_iter() @@ -751,6 +761,48 @@ fn request_code_actions( unreachable!("codeAction retries should either return or panic") } +fn request_code_actions_with_range( + client: &Connection, + uri: Url, + range: Range, + context: CodeActionContext, + request_id: i32, +) -> Vec { + const CONTENT_MODIFIED_RETRIES: i32 = 5; + + for attempt in 0..=CONTENT_MODIFIED_RETRIES { + let request_id = lsp_server::RequestId::from(request_id + attempt); + client + .sender + .send(Message::Request(Request::new( + request_id.clone(), + CodeActionRequest::METHOD.to_string(), + CodeActionParams { + text_document: TextDocumentIdentifier { uri: uri.clone() }, + range, + context: context.clone(), + work_done_progress_params: WorkDoneProgressParams::default(), + partial_result_params: Default::default(), + }, + ))) + .unwrap(); + + let response = recv_raw_response(client, request_id, "codeAction"); + if response.error.is_none() { + return serde_json::from_value(response.result.unwrap_or(serde_json::Value::Null)) + .unwrap_or_else(|err| panic!("failed to decode codeAction response: {err}")); + } + + if is_content_modified(&response) && attempt < CONTENT_MODIFIED_RETRIES { + continue; + } + + panic!("codeAction returned error: {:?}", response.error); + } + + unreachable!("codeAction retries should either return or panic") +} + fn is_content_modified(response: &lsp_server::Response) -> bool { response .error @@ -1158,6 +1210,77 @@ endmodule shutdown_test_server(&client, server_thread); } +#[test] +fn code_action_request_returns_extract_variable_for_selected_expression() { + let text = "\ +module top; + always_comb begin + y = a + b; + end +endmodule +"; + let (_temp_dir, client, server_thread, uri) = + setup_diagnostics_test(code_action_client_caps(), UserConfig::default(), text); + + let actions = request_code_actions_with_range( + &client, + uri, + range_of(text, "a + b"), + CodeActionContext { + diagnostics: Vec::new(), + only: Some(vec![CodeActionKind::REFACTOR_EXTRACT]), + trigger_kind: None, + }, + 201, + ); + let titles = code_action_titles(&actions); + + assert!( + titles.iter().any(|title| title == "Extract into variable"), + "expected extract variable refactor, got {titles:?}" + ); + + shutdown_test_server(&client, server_thread); +} + +#[test] +fn code_action_request_returns_extract_variable_for_selected_continuous_assign_rhs() { + let text = "\ +module top ( + c, + led0 +); + input wire c; + output led0; + reg led0; + + assign led0 = c * 2 + c; +endmodule +"; + let (_temp_dir, client, server_thread, uri) = + setup_diagnostics_test(code_action_client_caps(), UserConfig::default(), text); + + let actions = request_code_actions_with_range( + &client, + uri, + range_of(text, "c * 2 + c"), + CodeActionContext { + diagnostics: Vec::new(), + only: Some(vec![CodeActionKind::REFACTOR_EXTRACT]), + trigger_kind: None, + }, + 202, + ); + let titles = code_action_titles(&actions); + + assert!( + titles.iter().any(|title| title == "Extract into variable"), + "expected extract variable refactor, got {titles:?}" + ); + + shutdown_test_server(&client, server_thread); +} + #[test] fn code_action_request_uses_server_diagnostics_when_client_diagnostic_has_no_data() { let text = "\