diff --git a/tools/analysis_3d/analysis_runner.py b/tools/analysis_3d/analysis_runner.py index cc93d3adb..d970d75d2 100644 --- a/tools/analysis_3d/analysis_runner.py +++ b/tools/analysis_3d/analysis_runner.py @@ -31,6 +31,7 @@ class AnalysisRunner: def __init__( self, data_root_path: str, + dataset_version_config_root: str, config_path: str, out_path: str, max_sweeps: int = 2, @@ -41,6 +42,7 @@ def __init__( :param out_path: Path where to save output. """ self.data_root_path = data_root_path + self.dataset_version_config_root = dataset_version_config_root self.config_path = config_path self.out_path = Path(out_path) # Initialization @@ -98,7 +100,7 @@ def _get_dataset_scenario_names(self, dataset_version: str) -> Dict[str, List[st Get list of scenarios names for different splits in a dataset. :return: A dict of {split name: [scenario names in a split]}. """ - dataset_yaml_file = Path(self.config.dataset_version_config_root) / (dataset_version + ".yaml") + dataset_yaml_file = Path(self.dataset_version_config_root) / (dataset_version + ".yaml") with open(dataset_yaml_file, "r") as f: dataset_list_dict: Dict[str, List[str]] = yaml.safe_load(f) return dataset_list_dict @@ -163,7 +165,7 @@ def _extra_scenario_data( """ scenario_data = {} for scene_token_with_version in scene_tokens: - scene_token, version = scene_token_with_version.split(" ") + scene_token, version = scene_token_with_version.split("/") print_log(f"Creating scenario data for the scene: {scene_token}, version: {version}") scene_root_dir_path = get_scene_root_dir_path( root_path=self.data_root_path, diff --git a/tools/analysis_3d/run.py b/tools/analysis_3d/run.py index 208cb9eba..0cb2ac87f 100644 --- a/tools/analysis_3d/run.py +++ b/tools/analysis_3d/run.py @@ -22,6 +22,12 @@ def parse_args(): required=True, help="specify the root path of dataset", ) + parser.add_argument( + "--dataset_version_config_root", + type=str, + required=True, + help="specify the root path of dataset version config yaml files", + ) parser.add_argument( "-o", "--out_dir", @@ -40,6 +46,7 @@ def main(): print_log("Building AnalysisRunner...", logger="current") analysis_runner = AnalysisRunner( data_root_path=args.data_root_path, + dataset_version_config_root=args.dataset_version_config_root, config_path=args.config_path, out_path=args.out_dir, )