diff --git a/rust/ruby-rbs/build.rs b/rust/ruby-rbs/build.rs index b2142fa86..950338f26 100644 --- a/rust/ruby-rbs/build.rs +++ b/rust/ruby-rbs/build.rs @@ -350,6 +350,16 @@ fn generate(config: &Config) -> Result<(), Box> { writeln!(file, " pub fn as_node(self) -> Node {{")?; writeln!(file, " Node::{}(self)", node.variant_name())?; writeln!(file, " }}")?; + writeln!(file)?; + writeln!(file, " /// Returns the location of this node.")?; + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn location(&self) -> RBSLocation {{")?; + writeln!( + file, + " RBSLocation::new(unsafe {{ (*self.pointer).base.location }})" + )?; + writeln!(file, " }}")?; + writeln!(file)?; if let Some(fields) = &node.fields { for field in fields { @@ -379,10 +389,64 @@ fn generate(config: &Config) -> Result<(), Box> { write_node_field_accessor(&mut file, field, "RBSHash")?; } "rbs_location" => { - write_node_field_accessor(&mut file, field, "RBSLocation")?; + if field.optional { + writeln!( + file, + " pub fn {}(&self) -> Option {{", + field.name + )?; + writeln!( + file, + " let ptr = unsafe {{ (*self.pointer).{} }};", + field.c_name() + )?; + writeln!(file, " if ptr.is_null() {{")?; + writeln!(file, " None")?; + writeln!(file, " }} else {{")?; + writeln!(file, " Some(RBSLocation {{ pointer: ptr }})")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + } else { + writeln!(file, " pub fn {}(&self) -> RBSLocation {{", field.name)?; + writeln!( + file, + " RBSLocation {{ pointer: unsafe {{ (*self.pointer).{} }} }}", + field.c_name() + )?; + writeln!(file, " }}")?; + } } "rbs_location_list" => { - write_node_field_accessor(&mut file, field, "RBSLocationList")?; + if field.optional { + writeln!( + file, + " pub fn {}(&self) -> Option {{", + field.name + )?; + writeln!( + file, + " let ptr = unsafe {{ (*self.pointer).{} }};", + field.c_name() + )?; + writeln!(file, " if ptr.is_null() {{")?; + writeln!(file, " None")?; + writeln!(file, " }} else {{")?; + writeln!(file, " Some(RBSLocationList {{ pointer: ptr }})")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + } else { + writeln!( + file, + " pub fn {}(&self) -> RBSLocationList {{", + field.name + )?; + writeln!( + file, + " RBSLocationList {{ pointer: unsafe {{ (*self.pointer).{} }} }}", + field.c_name() + )?; + writeln!(file, " }}")?; + } } "rbs_namespace" => { write_node_field_accessor(&mut file, field, "NamespaceNode")?; diff --git a/rust/ruby-rbs/src/lib.rs b/rust/ruby-rbs/src/lib.rs index 45df7145f..0dc116263 100644 --- a/rust/ruby-rbs/src/lib.rs +++ b/rust/ruby-rbs/src/lib.rs @@ -134,27 +134,24 @@ impl Iterator for RBSHashIter { pub struct RBSLocation { pointer: *const rbs_location_t, - #[allow(dead_code)] - parser: *mut rbs_parser_t, } impl RBSLocation { - pub fn new(pointer: *const rbs_location_t, parser: *mut rbs_parser_t) -> Self { - Self { pointer, parser } + pub fn new(pointer: *const rbs_location_t) -> Self { + Self { pointer } } - pub fn start_loc(&self) -> i32 { + pub fn start(&self) -> i32 { unsafe { (*self.pointer).rg.start.byte_pos } } - pub fn end_loc(&self) -> i32 { + pub fn end(&self) -> i32 { unsafe { (*self.pointer).rg.end.byte_pos } } } pub struct RBSLocationListIter { current: *mut rbs_location_list_node_t, - parser: *mut rbs_parser_t, } impl Iterator for RBSLocationListIter { @@ -165,7 +162,7 @@ impl Iterator for RBSLocationListIter { None } else { let pointer_data = unsafe { *self.current }; - let loc = RBSLocation::new(pointer_data.loc, self.parser); + let loc = RBSLocation::new(pointer_data.loc); self.current = pointer_data.next; Some(loc) } @@ -174,12 +171,11 @@ impl Iterator for RBSLocationListIter { pub struct RBSLocationList { pointer: *mut rbs_location_list, - parser: *mut rbs_parser_t, } impl RBSLocationList { - pub fn new(pointer: *mut rbs_location_list, parser: *mut rbs_parser_t) -> Self { - Self { pointer, parser } + pub fn new(pointer: *mut rbs_location_list) -> Self { + Self { pointer } } /// Returns an iterator over the locations. @@ -187,7 +183,6 @@ impl RBSLocationList { pub fn iter(&self) -> RBSLocationListIter { RBSLocationListIter { current: unsafe { (*self.pointer).head }, - parser: self.parser, } } } @@ -435,4 +430,32 @@ mod tests { visitor.visited ); } + + #[test] + fn test_node_location_ranges() { + let rbs_code = r#"type foo = 1"#; + let signature = parse(rbs_code.as_bytes()).unwrap(); + + let declaration = signature.declarations().iter().next().unwrap(); + let Node::TypeAlias(type_alias) = declaration else { + panic!("Expected TypeAlias"); + }; + + // TypeAlias spans the entire declaration + let loc = type_alias.location(); + assert_eq!(0, loc.start()); + assert_eq!(12, loc.end()); + + // The literal "1" is at position 11-12 + let Node::LiteralType(literal) = type_alias.type_() else { + panic!("Expected LiteralType"); + }; + let Node::Integer(integer) = literal.literal() else { + panic!("Expected Integer"); + }; + + let int_loc = integer.location(); + assert_eq!(11, int_loc.start()); + assert_eq!(12, int_loc.end()); + } }