From 534f9078e461a3c4346f5e95417ef558e10c8401 Mon Sep 17 00:00:00 2001 From: Alex Rocha Date: Wed, 17 Dec 2025 16:21:26 -0800 Subject: [PATCH] Generate location() accessor for each node type Each node already has location data in its C struct, but it wasn't exposed through the Rust API. This adds a generated `location()` method to every node type, making it easy to get source ranges for any part of the AST. Also reorders RBSLocation and RBSLocationList struct fields to consistently put parser first, matching the pattern used elsewhere. --- rust/ruby-rbs/build.rs | 68 ++++++++++++++++++++++++++++++++++++++-- rust/ruby-rbs/src/lib.rs | 47 ++++++++++++++++++++------- 2 files changed, 101 insertions(+), 14 deletions(-) diff --git a/rust/ruby-rbs/build.rs b/rust/ruby-rbs/build.rs index b2142fa86..950338f26 100644 --- a/rust/ruby-rbs/build.rs +++ b/rust/ruby-rbs/build.rs @@ -350,6 +350,16 @@ fn generate(config: &Config) -> Result<(), Box> { writeln!(file, " pub fn as_node(self) -> Node {{")?; writeln!(file, " Node::{}(self)", node.variant_name())?; writeln!(file, " }}")?; + writeln!(file)?; + writeln!(file, " /// Returns the location of this node.")?; + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn location(&self) -> RBSLocation {{")?; + writeln!( + file, + " RBSLocation::new(unsafe {{ (*self.pointer).base.location }})" + )?; + writeln!(file, " }}")?; + writeln!(file)?; if let Some(fields) = &node.fields { for field in fields { @@ -379,10 +389,64 @@ fn generate(config: &Config) -> Result<(), Box> { write_node_field_accessor(&mut file, field, "RBSHash")?; } "rbs_location" => { - write_node_field_accessor(&mut file, field, "RBSLocation")?; + if field.optional { + writeln!( + file, + " pub fn {}(&self) -> Option {{", + field.name + )?; + writeln!( + file, + " let ptr = unsafe {{ (*self.pointer).{} }};", + field.c_name() + )?; + writeln!(file, " if ptr.is_null() {{")?; + writeln!(file, " None")?; + writeln!(file, " }} else {{")?; + writeln!(file, " Some(RBSLocation {{ pointer: ptr }})")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + } else { + writeln!(file, " pub fn {}(&self) -> RBSLocation {{", field.name)?; + writeln!( + file, + " RBSLocation {{ pointer: unsafe {{ (*self.pointer).{} }} }}", + field.c_name() + )?; + writeln!(file, " }}")?; + } } "rbs_location_list" => { - write_node_field_accessor(&mut file, field, "RBSLocationList")?; + if field.optional { + writeln!( + file, + " pub fn {}(&self) -> Option {{", + field.name + )?; + writeln!( + file, + " let ptr = unsafe {{ (*self.pointer).{} }};", + field.c_name() + )?; + writeln!(file, " if ptr.is_null() {{")?; + writeln!(file, " None")?; + writeln!(file, " }} else {{")?; + writeln!(file, " Some(RBSLocationList {{ pointer: ptr }})")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + } else { + writeln!( + file, + " pub fn {}(&self) -> RBSLocationList {{", + field.name + )?; + writeln!( + file, + " RBSLocationList {{ pointer: unsafe {{ (*self.pointer).{} }} }}", + field.c_name() + )?; + writeln!(file, " }}")?; + } } "rbs_namespace" => { write_node_field_accessor(&mut file, field, "NamespaceNode")?; diff --git a/rust/ruby-rbs/src/lib.rs b/rust/ruby-rbs/src/lib.rs index 45df7145f..0dc116263 100644 --- a/rust/ruby-rbs/src/lib.rs +++ b/rust/ruby-rbs/src/lib.rs @@ -134,27 +134,24 @@ impl Iterator for RBSHashIter { pub struct RBSLocation { pointer: *const rbs_location_t, - #[allow(dead_code)] - parser: *mut rbs_parser_t, } impl RBSLocation { - pub fn new(pointer: *const rbs_location_t, parser: *mut rbs_parser_t) -> Self { - Self { pointer, parser } + pub fn new(pointer: *const rbs_location_t) -> Self { + Self { pointer } } - pub fn start_loc(&self) -> i32 { + pub fn start(&self) -> i32 { unsafe { (*self.pointer).rg.start.byte_pos } } - pub fn end_loc(&self) -> i32 { + pub fn end(&self) -> i32 { unsafe { (*self.pointer).rg.end.byte_pos } } } pub struct RBSLocationListIter { current: *mut rbs_location_list_node_t, - parser: *mut rbs_parser_t, } impl Iterator for RBSLocationListIter { @@ -165,7 +162,7 @@ impl Iterator for RBSLocationListIter { None } else { let pointer_data = unsafe { *self.current }; - let loc = RBSLocation::new(pointer_data.loc, self.parser); + let loc = RBSLocation::new(pointer_data.loc); self.current = pointer_data.next; Some(loc) } @@ -174,12 +171,11 @@ impl Iterator for RBSLocationListIter { pub struct RBSLocationList { pointer: *mut rbs_location_list, - parser: *mut rbs_parser_t, } impl RBSLocationList { - pub fn new(pointer: *mut rbs_location_list, parser: *mut rbs_parser_t) -> Self { - Self { pointer, parser } + pub fn new(pointer: *mut rbs_location_list) -> Self { + Self { pointer } } /// Returns an iterator over the locations. @@ -187,7 +183,6 @@ impl RBSLocationList { pub fn iter(&self) -> RBSLocationListIter { RBSLocationListIter { current: unsafe { (*self.pointer).head }, - parser: self.parser, } } } @@ -435,4 +430,32 @@ mod tests { visitor.visited ); } + + #[test] + fn test_node_location_ranges() { + let rbs_code = r#"type foo = 1"#; + let signature = parse(rbs_code.as_bytes()).unwrap(); + + let declaration = signature.declarations().iter().next().unwrap(); + let Node::TypeAlias(type_alias) = declaration else { + panic!("Expected TypeAlias"); + }; + + // TypeAlias spans the entire declaration + let loc = type_alias.location(); + assert_eq!(0, loc.start()); + assert_eq!(12, loc.end()); + + // The literal "1" is at position 11-12 + let Node::LiteralType(literal) = type_alias.type_() else { + panic!("Expected LiteralType"); + }; + let Node::Integer(integer) = literal.literal() else { + panic!("Expected Integer"); + }; + + let int_loc = integer.location(); + assert_eq!(11, int_loc.start()); + assert_eq!(12, int_loc.end()); + } }