From 2ebf3f4e068fe069b7b949d74f7f272a148bd596 Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Thu, 22 Jan 2026 15:29:11 -0500 Subject: [PATCH 1/6] Support baseline-only execution in Scenario This change allows scenarios to be initialized with an empty strategies list when include_baseline=True. Previously, empty strategies would raise a ValueError even when baseline was requested. Changes: - Add allow_empty parameter to prepare_scenario_strategies() in ScenarioStrategy. When True and an empty sequence is explicitly provided, returns an empty list instead of raising ValueError. - Update Scenario.initialize_async() to pass allow_empty=include_baseline so baseline-only execution is allowed when baseline is requested. - Add _create_standalone_baseline() method to Scenario that creates a baseline attack directly from dataset_config when no other atomic attacks exist to derive from. - Add unit tests for baseline-only execution scenarios. This enables use cases where users want to run only baseline attacks without any additional attack strategies, such as for establishing baseline metrics before applying attack techniques. --- pyrit/scenario/core/scenario.py | 60 +++++++- pyrit/scenario/core/scenario_strategy.py | 12 +- tests/unit/scenarios/test_scenario.py | 187 +++++++++++++++++++++++ 3 files changed, 256 insertions(+), 3 deletions(-) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index fdead33e9..1e038ff92 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -228,14 +228,22 @@ async def initialize_async( self._memory_labels = memory_labels or {} # Prepare scenario strategies using the stored configuration + # Allow empty strategies when include_baseline is True (baseline-only execution) self._scenario_composites = self._strategy_class.prepare_scenario_strategies( - scenario_strategies, default_aggregate=self.get_default_strategy() + scenario_strategies, + default_aggregate=self.get_default_strategy(), + allow_empty=self._include_baseline, ) self._atomic_attacks = await self._get_atomic_attacks_async() if self._include_baseline: - baseline_attack = self._get_baseline_from_first_attack() + if self._atomic_attacks: + # Derive baseline from first attack + baseline_attack = self._get_baseline_from_first_attack() + else: + # No atomic attacks - create standalone baseline from dataset + baseline_attack = self._create_standalone_baseline() self._atomic_attacks.insert(0, baseline_attack) # Store original objectives for each atomic attack (before any mutations during execution) @@ -323,6 +331,54 @@ def _get_baseline_from_first_attack(self) -> AtomicAttack: memory_labels=self._memory_labels, ) + def _create_standalone_baseline(self) -> AtomicAttack: + """ + Create a standalone baseline AtomicAttack when no other atomic attacks exist. + + This method is used for baseline-only execution where no attack strategies are specified + but include_baseline=True. It creates the baseline directly from the dataset configuration + and scenario-level settings. + + Returns: + AtomicAttack: The baseline AtomicAttack instance. + + Raises: + ValueError: If objective_target, dataset_config, or objective_scorer is not set. + """ + if not self._objective_target: + raise ValueError("Objective target is required to create standalone baseline attack.") + + if not self._dataset_config: + raise ValueError("Dataset config is required to create standalone baseline attack.") + + if not self._objective_scorer: + raise ValueError("Objective scorer is required to create standalone baseline attack.") + + # Get seed groups from the dataset configuration + seed_groups = self._dataset_config.get_all_seed_attack_groups() + + if not seed_groups or len(seed_groups) == 0: + raise ValueError("Dataset config must have seed groups to create baseline.") + + # Import here to avoid circular imports + from pyrit.executor.attack.core.attack_config import AttackScoringConfig + + # Create scoring config from the scenario's objective scorer + attack_scoring_config = AttackScoringConfig(objective_scorer=self._objective_scorer) + + # Create baseline attack with no converters + attack = PromptSendingAttack( + objective_target=self._objective_target, + attack_scoring_config=attack_scoring_config, + ) + + return AtomicAttack( + atomic_attack_name="baseline", + attack=attack, + seed_groups=seed_groups, + memory_labels=self._memory_labels, + ) + def _raise_dataset_exception(self) -> None: error_msg = textwrap.dedent( f""" diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index 362be2c56..964c73a3e 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -197,6 +197,7 @@ def prepare_scenario_strategies( strategies: Sequence[T | "ScenarioCompositeStrategy"] | None = None, *, default_aggregate: T | None = None, + allow_empty: bool = False, ) -> List["ScenarioCompositeStrategy"]: """ Prepare and normalize scenario strategies for use in a scenario. @@ -213,16 +214,22 @@ def prepare_scenario_strategies( strategies (Sequence[T | ScenarioCompositeStrategy] | None): The strategies to prepare. Can be a mix of bare strategy enums and composite strategies. If None, uses default_aggregate to determine defaults. + If an empty sequence, behavior depends on allow_empty parameter. default_aggregate (T | None): The aggregate strategy to use when strategies is None. Common values: MyStrategy.ALL, MyStrategy.EASY. If None when strategies is None, raises ValueError. + allow_empty (bool): If True, allows an empty strategies list to be returned when + an empty sequence is explicitly provided. This is useful for baseline-only + execution where no attack strategies are needed. Defaults to False. Returns: List[ScenarioCompositeStrategy]: Normalized list of composite strategies ready for use. + May be empty if allow_empty=True and an empty sequence was provided. Raises: ValueError: If strategies is None and default_aggregate is None, or if compositions - are invalid according to validate_composition(). + are invalid according to validate_composition(), or if strategies is empty + and allow_empty is False. """ # Handle None input with default aggregate if strategies is None: @@ -251,7 +258,10 @@ def prepare_scenario_strategies( # For now, skip to allow flexibility pass + # Allow empty list if explicitly requested (for baseline-only execution) if not composite_strategies: + if allow_empty and strategies is not None and len(strategies) == 0: + return [] raise ValueError( f"No valid {cls.__name__} strategies provided. " f"Provide at least one {cls.__name__} enum or ScenarioCompositeStrategy." diff --git a/tests/unit/scenarios/test_scenario.py b/tests/unit/scenarios/test_scenario.py index 266e85530..330eca58a 100644 --- a/tests/unit/scenarios/test_scenario.py +++ b/tests/unit/scenarios/test_scenario.py @@ -601,3 +601,190 @@ def test_scenario_identifier_with_init_data(self): identifier = ScenarioIdentifier(name="TestScenario", scenario_version=1, init_data=init_data) assert identifier.init_data == init_data + + +def create_mock_truefalse_scorer(): + """Create a mock TrueFalseScorer for testing baseline-only execution.""" + from pyrit.score import TrueFalseScorer + + mock_scorer = MagicMock(spec=TrueFalseScorer) + mock_scorer.get_identifier.return_value = {"__type__": "MockTrueFalseScorer", "__module__": "test"} + mock_scorer.get_scorer_metrics.return_value = None + # Make isinstance check work + mock_scorer.__class__ = TrueFalseScorer + return mock_scorer + + +class ConcreteScenarioWithTrueFalseScorer(Scenario): + """Concrete implementation of Scenario for testing baseline-only execution.""" + + def __init__(self, atomic_attacks_to_return=None, **kwargs): + # Add required strategy_class if not provided + + class TestStrategy(ScenarioStrategy): + TEST = ("test", {"concrete"}) + ALL = ("all", {"all"}) + + @classmethod + def get_aggregate_tags(cls) -> set[str]: + return {"all"} + + kwargs.setdefault("strategy_class", TestStrategy) + + # Use TrueFalseScorer mock if not provided + if "objective_scorer" not in kwargs: + kwargs["objective_scorer"] = create_mock_truefalse_scorer() + + super().__init__(**kwargs) + self._atomic_attacks_to_return = atomic_attacks_to_return or [] + + @classmethod + def get_strategy_class(cls): + """Return a mock strategy class for testing.""" + + from pyrit.scenario.core.scenario_strategy import ScenarioStrategy + + class TestStrategy(ScenarioStrategy): + TEST = ("test", {"concrete"}) + ALL = ("all", {"all"}) + + @classmethod + def get_aggregate_tags(cls) -> set[str]: + return {"all"} + + return TestStrategy + + @classmethod + def get_default_strategy(cls): + """Return the default strategy for testing.""" + return cls.get_strategy_class().ALL + + @classmethod + def default_dataset_config(cls) -> DatasetConfiguration: + """Return the default dataset configuration for testing.""" + return DatasetConfiguration() + + async def _get_atomic_attacks_async(self): + return self._atomic_attacks_to_return + + +@pytest.mark.usefixtures("patch_central_database") +class TestScenarioBaselineOnlyExecution: + """Tests for baseline-only execution (empty strategies with include_baseline=True).""" + + @pytest.mark.asyncio + async def test_initialize_async_with_empty_strategies_and_baseline(self, mock_objective_target): + """Test that baseline-only execution works when include_baseline=True and strategies is empty.""" + from pyrit.models import SeedAttackGroup, SeedObjective + + # Create a scenario with include_default_baseline=True and TrueFalseScorer + scenario = ConcreteScenarioWithTrueFalseScorer( + name="Baseline Only Test", + version=1, + include_default_baseline=True, + ) + + # Create a mock dataset config with seed groups + mock_dataset_config = MagicMock(spec=DatasetConfiguration) + mock_dataset_config.get_all_seed_attack_groups.return_value = [ + SeedAttackGroup(seeds=[SeedObjective(value="test objective 1")]), + SeedAttackGroup(seeds=[SeedObjective(value="test objective 2")]), + ] + + # Initialize with empty strategies + await scenario.initialize_async( + objective_target=mock_objective_target, + scenario_strategies=[], # Empty list - baseline only + dataset_config=mock_dataset_config, + ) + + # Should have exactly one attack - the baseline + assert scenario.atomic_attack_count == 1 + assert scenario._atomic_attacks[0].atomic_attack_name == "baseline" + + @pytest.mark.asyncio + async def test_baseline_only_execution_runs_successfully(self, mock_objective_target, sample_attack_results): + """Test that baseline-only scenario can run successfully.""" + from pyrit.models import SeedAttackGroup, SeedObjective + + # Create a scenario with include_default_baseline=True and TrueFalseScorer + scenario = ConcreteScenarioWithTrueFalseScorer( + name="Baseline Only Test", + version=1, + include_default_baseline=True, + ) + + # Create a mock dataset config with seed groups + mock_dataset_config = MagicMock(spec=DatasetConfiguration) + mock_dataset_config.get_all_seed_attack_groups.return_value = [ + SeedAttackGroup(seeds=[SeedObjective(value="test objective 1")]), + ] + + # Initialize with empty strategies + await scenario.initialize_async( + objective_target=mock_objective_target, + scenario_strategies=[], # Empty list - baseline only + dataset_config=mock_dataset_config, + ) + + # Mock the baseline attack's run_async + scenario._atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) + + # Run the scenario + result = await scenario.run_async() + + # Verify the result + assert isinstance(result, ScenarioResult) + assert "baseline" in result.attack_results + assert len(result.attack_results["baseline"]) == 1 + + @pytest.mark.asyncio + async def test_empty_strategies_without_baseline_raises_error(self, mock_objective_target): + """Test that empty strategies without include_baseline raises ValueError.""" + scenario = ConcreteScenario( + name="No Baseline Test", + version=1, + include_default_baseline=False, # No baseline + ) + + mock_dataset_config = MagicMock(spec=DatasetConfiguration) + + # Should raise ValueError because empty strategies without baseline is not allowed + with pytest.raises(ValueError, match="No valid .* strategies provided"): + await scenario.initialize_async( + objective_target=mock_objective_target, + scenario_strategies=[], # Empty list without baseline + dataset_config=mock_dataset_config, + ) + + @pytest.mark.asyncio + async def test_standalone_baseline_uses_dataset_config_seeds(self, mock_objective_target): + """Test that standalone baseline uses seed groups from dataset_config.""" + from pyrit.models import SeedAttackGroup, SeedObjective + + scenario = ConcreteScenarioWithTrueFalseScorer( + name="Baseline Seeds Test", + version=1, + include_default_baseline=True, + ) + + # Create specific seed groups to verify they're used + expected_seeds = [ + SeedAttackGroup(seeds=[SeedObjective(value="objective_a")]), + SeedAttackGroup(seeds=[SeedObjective(value="objective_b")]), + SeedAttackGroup(seeds=[SeedObjective(value="objective_c")]), + ] + + mock_dataset_config = MagicMock(spec=DatasetConfiguration) + mock_dataset_config.get_all_seed_attack_groups.return_value = expected_seeds + + await scenario.initialize_async( + objective_target=mock_objective_target, + scenario_strategies=[], + dataset_config=mock_dataset_config, + ) + + # Verify the baseline attack has the expected seed groups + baseline_attack = scenario._atomic_attacks[0] + assert baseline_attack.atomic_attack_name == "baseline" + assert baseline_attack.seed_groups == expected_seeds From 65e0b02423e28841b0a07b13cbd4eee3c6e97eac Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Thu, 22 Jan 2026 15:47:49 -0500 Subject: [PATCH 2/6] Fix mypy type error: cast objective_scorer to TrueFalseScorer --- pyrit/scenario/core/scenario.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 1e038ff92..bd5f09dbd 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -361,10 +361,15 @@ def _create_standalone_baseline(self) -> AtomicAttack: raise ValueError("Dataset config must have seed groups to create baseline.") # Import here to avoid circular imports + from typing import cast from pyrit.executor.attack.core.attack_config import AttackScoringConfig + from pyrit.score import TrueFalseScorer # Create scoring config from the scenario's objective scorer - attack_scoring_config = AttackScoringConfig(objective_scorer=self._objective_scorer) + # Note: Scenarios require TrueFalseScorer for attack scoring + attack_scoring_config = AttackScoringConfig( + objective_scorer=cast(TrueFalseScorer, self._objective_scorer) + ) # Create baseline attack with no converters attack = PromptSendingAttack( From 73e516331578296cbbf5c32cd812e6b2fe5dfbb5 Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Thu, 22 Jan 2026 16:32:29 -0500 Subject: [PATCH 3/6] Fix formatting with black --- pyrit/scenario/core/scenario.py | 149 ++++++++++++++----- pyrit/scenario/core/scenario_strategy.py | 48 +++++-- tests/unit/scenarios/test_scenario.py | 174 ++++++++++++++++++----- 3 files changed, 283 insertions(+), 88 deletions(-) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index bd5f09dbd..220167677 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -79,7 +79,9 @@ def __init__( with whitespace normalized for display. """ # Use the class docstring with normalized whitespace as description - description = " ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "" + description = ( + " ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "" + ) self._identifier = ScenarioIdentifier( name=type(self).__name__, scenario_version=version, description=description @@ -101,7 +103,9 @@ def __init__( self._name = name self._memory = CentralMemory.get_memory_instance() self._atomic_attacks: List[AtomicAttack] = [] - self._scenario_result_id: Optional[str] = str(scenario_result_id) if scenario_result_id else None + self._scenario_result_id: Optional[str] = ( + str(scenario_result_id) if scenario_result_id else None + ) self._result_lock = asyncio.Lock() self._include_baseline = include_default_baseline @@ -173,7 +177,9 @@ async def initialize_async( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore - scenario_strategies: Optional[Sequence[ScenarioStrategy | ScenarioCompositeStrategy]] = None, + scenario_strategies: Optional[ + Sequence[ScenarioStrategy | ScenarioCompositeStrategy] + ] = None, dataset_config: Optional[DatasetConfiguration] = None, max_concurrency: int = 10, max_retries: int = 0, @@ -222,7 +228,9 @@ async def initialize_async( self._objective_target = objective_target self._objective_target_identifier = objective_target.get_identifier() self._dataset_config_provided = dataset_config is not None - self._dataset_config = dataset_config if dataset_config else self.default_dataset_config() + self._dataset_config = ( + dataset_config if dataset_config else self.default_dataset_config() + ) self._max_concurrency = max_concurrency self._max_retries = max_retries self._memory_labels = memory_labels or {} @@ -248,12 +256,15 @@ async def initialize_async( # Store original objectives for each atomic attack (before any mutations during execution) self._original_objectives_map = { - atomic_attack.atomic_attack_name: tuple(atomic_attack.objectives) for atomic_attack in self._atomic_attacks + atomic_attack.atomic_attack_name: tuple(atomic_attack.objectives) + for atomic_attack in self._atomic_attacks } # Check if we're resuming an existing scenario if self._scenario_result_id: - existing_results = self._memory.get_scenario_results(scenario_result_ids=[self._scenario_result_id]) + existing_results = self._memory.get_scenario_results( + scenario_result_ids=[self._scenario_result_id] + ) if existing_results: existing_result = existing_results[0] @@ -272,7 +283,8 @@ async def initialize_async( # Create new scenario result attack_results: Dict[str, List[AttackResult]] = { - atomic_attack.atomic_attack_name: [] for atomic_attack in self._atomic_attacks + atomic_attack.atomic_attack_name: [] + for atomic_attack in self._atomic_attacks } result = ScenarioResult( @@ -310,13 +322,17 @@ def _get_baseline_from_first_attack(self) -> AtomicAttack: objective_target = first_attack._attack.get_objective_target() if not seed_groups or len(seed_groups) == 0: - raise ValueError("First atomic attack must have seed_groups to create baseline.") + raise ValueError( + "First atomic attack must have seed_groups to create baseline." + ) if not objective_target: raise ValueError("Objective target is required to create baseline attack.") if not attack_scoring_config: - raise ValueError("Attack scoring config is required to create baseline attack.") + raise ValueError( + "Attack scoring config is required to create baseline attack." + ) # Create baseline attack with no converters attack = PromptSendingAttack( @@ -346,13 +362,19 @@ def _create_standalone_baseline(self) -> AtomicAttack: ValueError: If objective_target, dataset_config, or objective_scorer is not set. """ if not self._objective_target: - raise ValueError("Objective target is required to create standalone baseline attack.") + raise ValueError( + "Objective target is required to create standalone baseline attack." + ) if not self._dataset_config: - raise ValueError("Dataset config is required to create standalone baseline attack.") + raise ValueError( + "Dataset config is required to create standalone baseline attack." + ) if not self._objective_scorer: - raise ValueError("Objective scorer is required to create standalone baseline attack.") + raise ValueError( + "Objective scorer is required to create standalone baseline attack." + ) # Get seed groups from the dataset configuration seed_groups = self._dataset_config.get_all_seed_attack_groups() @@ -433,7 +455,9 @@ def _validate_stored_scenario(self, *, stored_result: ScenarioResult) -> bool: ) return True - def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> Set[str]: + def _get_completed_objectives_for_attack( + self, *, atomic_attack_name: str + ) -> Set[str]: """ Get the set of objectives that have already been completed for a specific atomic attack. @@ -450,14 +474,17 @@ def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> Se try: # Retrieve the scenario result from memory - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[self._scenario_result_id]) + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[self._scenario_result_id] + ) if scenario_results: scenario_result = scenario_results[0] # Get completed objectives for this atomic attack name if atomic_attack_name in scenario_result.attack_results: completed_objectives = { - result.objective for result in scenario_result.attack_results[atomic_attack_name] + result.objective + for result in scenario_result.attack_results[atomic_attack_name] } except Exception as e: logger.warning( @@ -489,10 +516,14 @@ async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: ) # Get ORIGINAL objectives (before any mutations) from stored map - original_objectives = self._original_objectives_map.get(atomic_attack.atomic_attack_name, ()) + original_objectives = self._original_objectives_map.get( + atomic_attack.atomic_attack_name, () + ) # Calculate remaining objectives - remaining_objectives = [obj for obj in original_objectives if obj not in completed_objectives] + remaining_objectives = [ + obj for obj in original_objectives if obj not in completed_objectives + ] if remaining_objectives: # If there are remaining objectives, update the atomic attack @@ -502,7 +533,9 @@ async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: f"{len(remaining_objectives)}/{len(original_objectives)} objectives remaining" ) # Update the objectives for this atomic attack to only include remaining ones - atomic_attack.filter_seed_groups_by_objectives(remaining_objectives=remaining_objectives) + atomic_attack.filter_seed_groups_by_objectives( + remaining_objectives=remaining_objectives + ) remaining_attacks.append(atomic_attack) else: @@ -525,7 +558,9 @@ async def _update_scenario_result_async( attack_results (List[AttackResult]): The list of new attack results to add. """ if not self._scenario_result_id: - logger.warning("Cannot update scenario result: no scenario result ID available") + logger.warning( + "Cannot update scenario result: no scenario result ID available" + ) return async with self._result_lock: @@ -589,7 +624,9 @@ async def run_async(self) -> ScenarioResult: ) if not self._scenario_result_id: - raise ValueError("Scenario not properly initialized. Call await scenario.initialize_async() first.") + raise ValueError( + "Scenario not properly initialized. Call await scenario.initialize_async() first." + ) # Type narrowing: create local variable that type checker knows is non-None scenario_result_id: str = self._scenario_result_id @@ -604,8 +641,14 @@ async def run_async(self) -> ScenarioResult: last_exception = e # Get current scenario to check number of tries - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) - current_tries = scenario_results[0].number_tries if scenario_results else retry_attempt + 1 + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[scenario_result_id] + ) + current_tries = ( + scenario_results[0].number_tries + if scenario_results + else retry_attempt + 1 + ) # Check if we have more retries available remaining_retries = self._max_retries - retry_attempt @@ -630,7 +673,9 @@ async def run_async(self) -> ScenarioResult: # This should never be reached, but just in case if last_exception: raise last_exception - raise RuntimeError(f"Scenario '{self._name}' completed unexpectedly without result") + raise RuntimeError( + f"Scenario '{self._name}' completed unexpectedly without result" + ) async def _execute_scenario_async(self) -> ScenarioResult: """ @@ -648,7 +693,9 @@ async def _execute_scenario_async(self) -> ScenarioResult: ValueError: If a lookup for a scenario for a given ID fails. ValueError: If atomic attack execution fails. """ - logger.info(f"Starting scenario '{self._name}' execution with {len(self._atomic_attacks)} atomic attacks") + logger.info( + f"Starting scenario '{self._name}' execution with {len(self._atomic_attacks)} atomic attacks" + ) # Type narrowing: _scenario_result_id is guaranteed to be non-None at this point # (verified in run_async before calling this method) @@ -656,13 +703,17 @@ async def _execute_scenario_async(self) -> ScenarioResult: scenario_result_id: str = self._scenario_result_id # Increment number_tries at the start of each run - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[scenario_result_id] + ) if scenario_results: current_scenario = scenario_results[0] current_scenario.number_tries += 1 entry = ScenarioResultEntry(entry=current_scenario) self._memory._update_entry(entry) - logger.info(f"Scenario '{self._name}' attempt #{current_scenario.number_tries}") + logger.info( + f"Scenario '{self._name}' attempt #{current_scenario.number_tries}" + ) else: raise ValueError(f"Scenario result with ID {scenario_result_id} not found") @@ -670,17 +721,23 @@ async def _execute_scenario_async(self) -> ScenarioResult: remaining_attacks = await self._get_remaining_atomic_attacks_async() if not remaining_attacks: - logger.info(f"Scenario '{self._name}' has no remaining objectives to execute") + logger.info( + f"Scenario '{self._name}' has no remaining objectives to execute" + ) # Mark scenario as completed self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="COMPLETED" ) # Retrieve and return the current scenario result - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[scenario_result_id] + ) if scenario_results: return scenario_results[0] else: - raise ValueError(f"Scenario result with ID {scenario_result_id} not found") + raise ValueError( + f"Scenario result with ID {scenario_result_id} not found" + ) logger.info( f"Scenario '{self._name}' has {len(remaining_attacks)} atomic attacks " @@ -688,7 +745,9 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Mark scenario as in progress - self._memory.update_scenario_run_state(scenario_result_id=scenario_result_id, scenario_run_state="IN_PROGRESS") + self._memory.update_scenario_run_state( + scenario_result_id=scenario_result_id, scenario_run_state="IN_PROGRESS" + ) # Calculate starting index based on completed attacks completed_count = len(self._atomic_attacks) - len(remaining_attacks) @@ -711,7 +770,8 @@ async def _execute_scenario_async(self) -> ScenarioResult: try: atomic_results = await atomic_attack.run_async( - max_concurrency=self._max_concurrency, return_partial_on_failure=True + max_concurrency=self._max_concurrency, + return_partial_on_failure=True, ) # Always save completed results, even if some objectives didn't complete @@ -734,11 +794,14 @@ async def _execute_scenario_async(self) -> ScenarioResult: # Log details of each incomplete objective for obj, exc in atomic_results.incomplete_objectives: - logger.error(f" Incomplete objective '{obj[:50]}...': {str(exc)}") + logger.error( + f" Incomplete objective '{obj[:50]}...': {str(exc)}" + ) # Mark scenario as failed self._memory.update_scenario_run_state( - scenario_result_id=scenario_result_id, scenario_run_state="FAILED" + scenario_result_id=scenario_result_id, + scenario_run_state="FAILED", ) # Raise exception with detailed information @@ -761,10 +824,16 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Mark scenario as failed if not already done - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) - if scenario_results and scenario_results[0].scenario_run_state != "FAILED": + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[scenario_result_id] + ) + if ( + scenario_results + and scenario_results[0].scenario_run_state != "FAILED" + ): self._memory.update_scenario_run_state( - scenario_result_id=scenario_result_id, scenario_run_state="FAILED" + scenario_result_id=scenario_result_id, + scenario_run_state="FAILED", ) raise @@ -777,9 +846,13 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Retrieve and return final scenario result - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[scenario_result_id] + ) if not scenario_results: - raise ValueError(f"Scenario result with ID {self._scenario_result_id} not found") + raise ValueError( + f"Scenario result with ID {self._scenario_result_id} not found" + ) return scenario_results[0] diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index 964c73a3e..580de5b37 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -108,7 +108,11 @@ def get_strategies_by_tag(cls: type[T], tag: str) -> Set[T]: any aggregate markers. """ aggregate_tags = cls.get_aggregate_tags() - return {strategy for strategy in cls if tag in strategy.tags and strategy.value not in aggregate_tags} + return { + strategy + for strategy in cls + if tag in strategy.tags and strategy.value not in aggregate_tags + } @classmethod def get_all_strategies(cls: type[T]) -> list[T]: @@ -173,12 +177,17 @@ def normalize_strategies(cls: type[T], strategies: Set[T]) -> Set[T]: # Find aggregate tags in the input and expand them aggregate_tags = cls.get_aggregate_tags() aggregates_to_expand = { - tag for strategy in strategies if strategy.value in aggregate_tags for tag in strategy.tags + tag + for strategy in strategies + if strategy.value in aggregate_tags + for tag in strategy.tags } for aggregate_tag in aggregates_to_expand: # Remove the aggregate marker itself - aggregate_marker = next((s for s in normalized_strategies if s.value == aggregate_tag), None) + aggregate_marker = next( + (s for s in normalized_strategies if s.value == aggregate_tag), None + ) if aggregate_marker: normalized_strategies.remove(aggregate_marker) @@ -242,7 +251,10 @@ def prepare_scenario_strategies( # Expand the default aggregate into concrete strategies expanded = cls.normalize_strategies({default_aggregate}) # Wrap each in a ScenarioCompositeStrategy - composite_strategies = [ScenarioCompositeStrategy(strategies=[strategy]) for strategy in expanded] + composite_strategies = [ + ScenarioCompositeStrategy(strategies=[strategy]) + for strategy in expanded + ] else: # Process the provided strategies composite_strategies = [] @@ -252,7 +264,9 @@ def prepare_scenario_strategies( composite_strategies.append(item) elif isinstance(item, cls): # Bare strategy enum - wrap it in a composite - composite_strategies.append(ScenarioCompositeStrategy(strategies=[item])) + composite_strategies.append( + ScenarioCompositeStrategy(strategies=[item]) + ) else: # Not our strategy type - skip or could raise error # For now, skip to allow flexibility @@ -268,7 +282,9 @@ def prepare_scenario_strategies( ) # Normalize compositions (expands aggregates, validates compositions) - normalized = ScenarioCompositeStrategy.normalize_compositions(composite_strategies, strategy_type=cls) + normalized = ScenarioCompositeStrategy.normalize_compositions( + composite_strategies, strategy_type=cls + ) return normalized @@ -425,7 +441,9 @@ def extract_single_strategy_values( ValueError: If any composite contains multiple strategies. """ # Check that all composites are single-strategy - multi_strategy_composites = [comp for comp in composites if not comp.is_single_strategy] + multi_strategy_composites = [ + comp for comp in composites if not comp.is_single_strategy + ] if multi_strategy_composites: composite_names = [comp.name for comp in multi_strategy_composites] raise ValueError( @@ -528,14 +546,20 @@ def normalize_compositions( raise ValueError("Empty compositions are not allowed") # Filter to only strategies of the specified type - typed_strategies = [s for s in composite.strategies if isinstance(s, strategy_type)] + typed_strategies = [ + s for s in composite.strategies if isinstance(s, strategy_type) + ] if not typed_strategies: # No strategies of this type - skip continue # Check if composition contains any aggregates - aggregates_in_composition = [s for s in typed_strategies if s.value in aggregate_tags] - concretes_in_composition = [s for s in typed_strategies if s.value not in aggregate_tags] + aggregates_in_composition = [ + s for s in typed_strategies if s.value in aggregate_tags + ] + concretes_in_composition = [ + s for s in typed_strategies if s.value not in aggregate_tags + ] # Error if mixing aggregates with concrete strategies if aggregates_in_composition and concretes_in_composition: @@ -559,7 +583,9 @@ def normalize_compositions( expanded = strategy_type.normalize_strategies({aggregate}) # Each expanded strategy becomes its own composition for strategy in expanded: - normalized_compositions.append(ScenarioCompositeStrategy(strategies=[strategy])) + normalized_compositions.append( + ScenarioCompositeStrategy(strategies=[strategy]) + ) else: # Concrete composition - validate and preserve as-is strategy_type.validate_composition(typed_strategies) diff --git a/tests/unit/scenarios/test_scenario.py b/tests/unit/scenarios/test_scenario.py index 330eca58a..63df939da 100644 --- a/tests/unit/scenarios/test_scenario.py +++ b/tests/unit/scenarios/test_scenario.py @@ -27,7 +27,9 @@ def create_mock_run_async(attack_results): async def mock_run_async(*args, **kwargs): # Save results to memory (mimics what real attacks do) save_attack_results_to_memory(attack_results) - return AttackExecutorResult(completed_results=attack_results, incomplete_objectives=[]) + return AttackExecutorResult( + completed_results=attack_results, incomplete_objectives=[] + ) return AsyncMock(side_effect=mock_run_async) @@ -35,7 +37,10 @@ async def mock_run_async(*args, **kwargs): def create_mock_scorer(): """Create a mock scorer for testing ScenarioResult.""" mock_scorer = MagicMock(spec=Scorer) - mock_scorer.get_identifier.return_value = {"__type__": "MockScorer", "__module__": "test"} + mock_scorer.get_identifier.return_value = { + "__type__": "MockScorer", + "__module__": "test", + } mock_scorer.get_scorer_metrics.return_value = None return mock_scorer @@ -70,7 +75,10 @@ def mock_atomic_attacks(): def mock_objective_target(): """Create a mock objective target for testing.""" target = MagicMock() - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test"} + target.get_identifier.return_value = { + "__type__": "MockTarget", + "__module__": "test", + } return target @@ -81,7 +89,11 @@ def sample_attack_results(): AttackResult( conversation_id=f"conv-{i}", objective=f"objective{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, + attack_identifier={ + "__type__": "TestAttack", + "__module__": "test", + "id": str(i), + }, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -111,7 +123,10 @@ def get_aggregate_tags(cls) -> set[str]: # Add a mock scorer if not provided if "objective_scorer" not in kwargs: mock_scorer = MagicMock(spec=Scorer) - mock_scorer.get_identifier.return_value = {"__type__": "MockScorer", "__module__": "test"} + mock_scorer.get_identifier.return_value = { + "__type__": "MockScorer", + "__module__": "test", + } mock_scorer.get_scorer_metrics.return_value = None kwargs["objective_scorer"] = mock_scorer @@ -196,7 +211,9 @@ class TestScenarioInitialization2: """Tests for Scenario initialize_async method.""" @pytest.mark.asyncio - async def test_initialize_async_populates_atomic_attacks(self, mock_atomic_attacks, mock_objective_target): + async def test_initialize_async_populates_atomic_attacks( + self, mock_atomic_attacks, mock_objective_target + ): """Test that initialize_async populates atomic attacks.""" scenario = ConcreteScenario( name="Test Scenario", @@ -222,7 +239,10 @@ async def test_initialize_async_sets_objective_target(self, mock_objective_targe await scenario.initialize_async(objective_target=mock_objective_target) assert scenario._objective_target == mock_objective_target - assert scenario._objective_target_identifier == {"__type__": "MockTarget", "__module__": "test"} + assert scenario._objective_target_identifier == { + "__type__": "MockTarget", + "__module__": "test", + } @pytest.mark.asyncio async def test_initialize_async_requires_objective_target(self): @@ -243,7 +263,9 @@ async def test_initialize_async_sets_max_retries(self, mock_objective_target): version=1, ) - await scenario.initialize_async(objective_target=mock_objective_target, max_retries=3) + await scenario.initialize_async( + objective_target=mock_objective_target, max_retries=3 + ) assert scenario._max_retries == 3 @@ -255,7 +277,9 @@ async def test_initialize_async_sets_max_concurrency(self, mock_objective_target version=1, ) - await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=5) + await scenario.initialize_async( + objective_target=mock_objective_target, max_concurrency=5 + ) assert scenario._max_concurrency == 5 @@ -268,7 +292,9 @@ async def test_initialize_async_sets_memory_labels(self, mock_objective_target): version=1, ) - await scenario.initialize_async(objective_target=mock_objective_target, memory_labels=labels) + await scenario.initialize_async( + objective_target=mock_objective_target, memory_labels=labels + ) assert scenario._memory_labels == labels @@ -292,7 +318,9 @@ class TestScenarioExecution: """Tests for Scenario execution methods.""" @pytest.mark.asyncio - async def test_run_async_executes_all_runs(self, mock_atomic_attacks, sample_attack_results, mock_objective_target): + async def test_run_async_executes_all_runs( + self, mock_atomic_attacks, sample_attack_results, mock_objective_target + ): """Test that run_async executes all atomic attacks sequentially.""" # Configure each run to return different results for i, run in enumerate(mock_atomic_attacks): @@ -313,7 +341,9 @@ async def test_run_async_executes_all_runs(self, mock_atomic_attacks, sample_att # Verify all runs were executed with correct concurrency assert len(result.attack_results) == 3 for run in mock_atomic_attacks: - run.run_async.assert_called_once_with(max_concurrency=10, return_partial_on_failure=True) + run.run_async.assert_called_once_with( + max_concurrency=10, return_partial_on_failure=True + ) # Verify results are aggregated correctly by atomic attack name assert "attack_run_1" in result.attack_results @@ -336,13 +366,17 @@ async def test_run_async_with_custom_concurrency( version=1, atomic_attacks_to_return=mock_atomic_attacks, ) - await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=5) + await scenario.initialize_async( + objective_target=mock_objective_target, max_concurrency=5 + ) result = await scenario.run_async() # Verify max_concurrency was passed to each run for run in mock_atomic_attacks: - run.run_async.assert_called_once_with(max_concurrency=5, return_partial_on_failure=True) + run.run_async.assert_called_once_with( + max_concurrency=5, return_partial_on_failure=True + ) # Verify result structure assert isinstance(result, ScenarioResult) @@ -354,9 +388,15 @@ async def test_run_async_aggregates_multiple_results( ): """Test that results from multiple atomic attacks are properly aggregated.""" # Configure runs to return different numbers of results - mock_atomic_attacks[0].run_async = create_mock_run_async(sample_attack_results[0:2]) - mock_atomic_attacks[1].run_async = create_mock_run_async(sample_attack_results[2:4]) - mock_atomic_attacks[2].run_async = create_mock_run_async(sample_attack_results[4:5]) + mock_atomic_attacks[0].run_async = create_mock_run_async( + sample_attack_results[0:2] + ) + mock_atomic_attacks[1].run_async = create_mock_run_async( + sample_attack_results[2:4] + ) + mock_atomic_attacks[2].run_async = create_mock_run_async( + sample_attack_results[4:5] + ) scenario = ConcreteScenario( name="Test Scenario", @@ -375,11 +415,19 @@ async def test_run_async_aggregates_multiple_results( assert len(result.attack_results["attack_run_3"]) == 1 @pytest.mark.asyncio - async def test_run_async_stops_on_error(self, mock_atomic_attacks, sample_attack_results, mock_objective_target): + async def test_run_async_stops_on_error( + self, mock_atomic_attacks, sample_attack_results, mock_objective_target + ): """Test that execution stops when an atomic attack fails.""" - mock_atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) - mock_atomic_attacks[1].run_async = AsyncMock(side_effect=Exception("Test error")) - mock_atomic_attacks[2].run_async = create_mock_run_async([sample_attack_results[2]]) + mock_atomic_attacks[0].run_async = create_mock_run_async( + [sample_attack_results[0]] + ) + mock_atomic_attacks[1].run_async = AsyncMock( + side_effect=Exception("Test error") + ) + mock_atomic_attacks[2].run_async = create_mock_run_async( + [sample_attack_results[2]] + ) scenario = ConcreteScenario( name="Test Scenario", @@ -406,7 +454,9 @@ async def test_run_async_fails_without_initialization(self, mock_objective_targe version=1, ) - with pytest.raises(ValueError, match="Cannot run scenario with no atomic attacks"): + with pytest.raises( + ValueError, match="Cannot run scenario with no atomic attacks" + ): await scenario.run_async() @pytest.mark.asyncio @@ -431,7 +481,11 @@ async def test_run_async_returns_scenario_result_with_identifier( assert result.scenario_identifier.name == "ConcreteScenario" assert result.scenario_identifier.version == 5 assert result.scenario_identifier.pyrit_version is not None - assert result.get_strategies_used() == ["attack_run_1", "attack_run_2", "attack_run_3"] + assert result.get_strategies_used() == [ + "attack_run_1", + "attack_run_2", + "attack_run_3", + ] @pytest.mark.usefixtures("patch_central_database") @@ -448,7 +502,9 @@ def test_name_property(self, mock_objective_target): assert scenario.name == "My Test Scenario" @pytest.mark.asyncio - async def test_atomic_attack_count_property(self, mock_atomic_attacks, mock_objective_target): + async def test_atomic_attack_count_property( + self, mock_atomic_attacks, mock_objective_target + ): """Test that atomic_attack_count returns the correct count.""" scenario = ConcreteScenario( name="Test Scenario", @@ -463,7 +519,9 @@ async def test_atomic_attack_count_property(self, mock_atomic_attacks, mock_obje assert scenario.atomic_attack_count == 3 @pytest.mark.asyncio - async def test_atomic_attack_count_with_different_sizes(self, mock_objective_target): + async def test_atomic_attack_count_with_different_sizes( + self, mock_objective_target + ): """Test atomic_attack_count with different numbers of atomic attacks.""" # Create mock attack strategy mock_attack = MagicMock() @@ -511,8 +569,14 @@ def test_scenario_result_initialization(self, sample_attack_results): mock_scorer = create_mock_scorer() result = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={"__type__": "TestTarget", "__module__": "test"}, - attack_results={"base64": sample_attack_results[:3], "rot13": sample_attack_results[3:]}, + objective_target_identifier={ + "__type__": "TestTarget", + "__module__": "test", + }, + attack_results={ + "base64": sample_attack_results[:3], + "rot13": sample_attack_results[3:], + }, objective_scorer=mock_scorer, ) @@ -528,7 +592,10 @@ def test_scenario_result_with_empty_results(self): mock_scorer = create_mock_scorer() result = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={"__type__": "TestTarget", "__module__": "test"}, + objective_target_identifier={ + "__type__": "TestTarget", + "__module__": "test", + }, attack_results={"base64": []}, objective_scorer=mock_scorer, ) @@ -544,7 +611,10 @@ def test_scenario_result_objective_achieved_rate(self, sample_attack_results): # All successful result = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={"__type__": "TestTarget", "__module__": "test"}, + objective_target_identifier={ + "__type__": "TestTarget", + "__module__": "test", + }, attack_results={"base64": sample_attack_results}, objective_scorer=mock_scorer, ) @@ -555,21 +625,32 @@ def test_scenario_result_objective_achieved_rate(self, sample_attack_results): AttackResult( conversation_id="conv-fail", objective="objective", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "1"}, + attack_identifier={ + "__type__": "TestAttack", + "__module__": "test", + "id": "1", + }, outcome=AttackOutcome.FAILURE, executed_turns=1, ), AttackResult( conversation_id="conv-fail2", objective="objective", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "2"}, + attack_identifier={ + "__type__": "TestAttack", + "__module__": "test", + "id": "2", + }, outcome=AttackOutcome.FAILURE, executed_turns=1, ), ] result2 = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={"__type__": "TestTarget", "__module__": "test"}, + objective_target_identifier={ + "__type__": "TestTarget", + "__module__": "test", + }, attack_results={"base64": mixed_results}, objective_scorer=mock_scorer, ) @@ -598,7 +679,9 @@ def test_scenario_identifier_with_custom_pyrit_version(self): def test_scenario_identifier_with_init_data(self): """Test ScenarioIdentifier with init_data.""" init_data = {"param1": "value1", "param2": 42} - identifier = ScenarioIdentifier(name="TestScenario", scenario_version=1, init_data=init_data) + identifier = ScenarioIdentifier( + name="TestScenario", scenario_version=1, init_data=init_data + ) assert identifier.init_data == init_data @@ -608,7 +691,10 @@ def create_mock_truefalse_scorer(): from pyrit.score import TrueFalseScorer mock_scorer = MagicMock(spec=TrueFalseScorer) - mock_scorer.get_identifier.return_value = {"__type__": "MockTrueFalseScorer", "__module__": "test"} + mock_scorer.get_identifier.return_value = { + "__type__": "MockTrueFalseScorer", + "__module__": "test", + } mock_scorer.get_scorer_metrics.return_value = None # Make isinstance check work mock_scorer.__class__ = TrueFalseScorer @@ -673,7 +759,9 @@ class TestScenarioBaselineOnlyExecution: """Tests for baseline-only execution (empty strategies with include_baseline=True).""" @pytest.mark.asyncio - async def test_initialize_async_with_empty_strategies_and_baseline(self, mock_objective_target): + async def test_initialize_async_with_empty_strategies_and_baseline( + self, mock_objective_target + ): """Test that baseline-only execution works when include_baseline=True and strategies is empty.""" from pyrit.models import SeedAttackGroup, SeedObjective @@ -703,7 +791,9 @@ async def test_initialize_async_with_empty_strategies_and_baseline(self, mock_ob assert scenario._atomic_attacks[0].atomic_attack_name == "baseline" @pytest.mark.asyncio - async def test_baseline_only_execution_runs_successfully(self, mock_objective_target, sample_attack_results): + async def test_baseline_only_execution_runs_successfully( + self, mock_objective_target, sample_attack_results + ): """Test that baseline-only scenario can run successfully.""" from pyrit.models import SeedAttackGroup, SeedObjective @@ -728,7 +818,9 @@ async def test_baseline_only_execution_runs_successfully(self, mock_objective_ta ) # Mock the baseline attack's run_async - scenario._atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) + scenario._atomic_attacks[0].run_async = create_mock_run_async( + [sample_attack_results[0]] + ) # Run the scenario result = await scenario.run_async() @@ -739,7 +831,9 @@ async def test_baseline_only_execution_runs_successfully(self, mock_objective_ta assert len(result.attack_results["baseline"]) == 1 @pytest.mark.asyncio - async def test_empty_strategies_without_baseline_raises_error(self, mock_objective_target): + async def test_empty_strategies_without_baseline_raises_error( + self, mock_objective_target + ): """Test that empty strategies without include_baseline raises ValueError.""" scenario = ConcreteScenario( name="No Baseline Test", @@ -758,7 +852,9 @@ async def test_empty_strategies_without_baseline_raises_error(self, mock_objecti ) @pytest.mark.asyncio - async def test_standalone_baseline_uses_dataset_config_seeds(self, mock_objective_target): + async def test_standalone_baseline_uses_dataset_config_seeds( + self, mock_objective_target + ): """Test that standalone baseline uses seed groups from dataset_config.""" from pyrit.models import SeedAttackGroup, SeedObjective From 18b6f960684d97af8f5a3ce38bb027348bd8584c Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Thu, 22 Jan 2026 16:49:04 -0500 Subject: [PATCH 4/6] Apply ruff formatting fixes --- pyrit/scenario/core/scenario.py | 145 ++++++----------------- pyrit/scenario/core/scenario_strategy.py | 48 ++------ tests/unit/scenarios/test_scenario.py | 100 ++++------------ 3 files changed, 73 insertions(+), 220 deletions(-) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 220167677..a4a3dccfc 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -79,9 +79,7 @@ def __init__( with whitespace normalized for display. """ # Use the class docstring with normalized whitespace as description - description = ( - " ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "" - ) + description = " ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "" self._identifier = ScenarioIdentifier( name=type(self).__name__, scenario_version=version, description=description @@ -103,9 +101,7 @@ def __init__( self._name = name self._memory = CentralMemory.get_memory_instance() self._atomic_attacks: List[AtomicAttack] = [] - self._scenario_result_id: Optional[str] = ( - str(scenario_result_id) if scenario_result_id else None - ) + self._scenario_result_id: Optional[str] = str(scenario_result_id) if scenario_result_id else None self._result_lock = asyncio.Lock() self._include_baseline = include_default_baseline @@ -177,9 +173,7 @@ async def initialize_async( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore - scenario_strategies: Optional[ - Sequence[ScenarioStrategy | ScenarioCompositeStrategy] - ] = None, + scenario_strategies: Optional[Sequence[ScenarioStrategy | ScenarioCompositeStrategy]] = None, dataset_config: Optional[DatasetConfiguration] = None, max_concurrency: int = 10, max_retries: int = 0, @@ -228,9 +222,7 @@ async def initialize_async( self._objective_target = objective_target self._objective_target_identifier = objective_target.get_identifier() self._dataset_config_provided = dataset_config is not None - self._dataset_config = ( - dataset_config if dataset_config else self.default_dataset_config() - ) + self._dataset_config = dataset_config if dataset_config else self.default_dataset_config() self._max_concurrency = max_concurrency self._max_retries = max_retries self._memory_labels = memory_labels or {} @@ -256,15 +248,12 @@ async def initialize_async( # Store original objectives for each atomic attack (before any mutations during execution) self._original_objectives_map = { - atomic_attack.atomic_attack_name: tuple(atomic_attack.objectives) - for atomic_attack in self._atomic_attacks + atomic_attack.atomic_attack_name: tuple(atomic_attack.objectives) for atomic_attack in self._atomic_attacks } # Check if we're resuming an existing scenario if self._scenario_result_id: - existing_results = self._memory.get_scenario_results( - scenario_result_ids=[self._scenario_result_id] - ) + existing_results = self._memory.get_scenario_results(scenario_result_ids=[self._scenario_result_id]) if existing_results: existing_result = existing_results[0] @@ -283,8 +272,7 @@ async def initialize_async( # Create new scenario result attack_results: Dict[str, List[AttackResult]] = { - atomic_attack.atomic_attack_name: [] - for atomic_attack in self._atomic_attacks + atomic_attack.atomic_attack_name: [] for atomic_attack in self._atomic_attacks } result = ScenarioResult( @@ -322,17 +310,13 @@ def _get_baseline_from_first_attack(self) -> AtomicAttack: objective_target = first_attack._attack.get_objective_target() if not seed_groups or len(seed_groups) == 0: - raise ValueError( - "First atomic attack must have seed_groups to create baseline." - ) + raise ValueError("First atomic attack must have seed_groups to create baseline.") if not objective_target: raise ValueError("Objective target is required to create baseline attack.") if not attack_scoring_config: - raise ValueError( - "Attack scoring config is required to create baseline attack." - ) + raise ValueError("Attack scoring config is required to create baseline attack.") # Create baseline attack with no converters attack = PromptSendingAttack( @@ -362,19 +346,13 @@ def _create_standalone_baseline(self) -> AtomicAttack: ValueError: If objective_target, dataset_config, or objective_scorer is not set. """ if not self._objective_target: - raise ValueError( - "Objective target is required to create standalone baseline attack." - ) + raise ValueError("Objective target is required to create standalone baseline attack.") if not self._dataset_config: - raise ValueError( - "Dataset config is required to create standalone baseline attack." - ) + raise ValueError("Dataset config is required to create standalone baseline attack.") if not self._objective_scorer: - raise ValueError( - "Objective scorer is required to create standalone baseline attack." - ) + raise ValueError("Objective scorer is required to create standalone baseline attack.") # Get seed groups from the dataset configuration seed_groups = self._dataset_config.get_all_seed_attack_groups() @@ -384,14 +362,13 @@ def _create_standalone_baseline(self) -> AtomicAttack: # Import here to avoid circular imports from typing import cast + from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.score import TrueFalseScorer # Create scoring config from the scenario's objective scorer # Note: Scenarios require TrueFalseScorer for attack scoring - attack_scoring_config = AttackScoringConfig( - objective_scorer=cast(TrueFalseScorer, self._objective_scorer) - ) + attack_scoring_config = AttackScoringConfig(objective_scorer=cast(TrueFalseScorer, self._objective_scorer)) # Create baseline attack with no converters attack = PromptSendingAttack( @@ -455,9 +432,7 @@ def _validate_stored_scenario(self, *, stored_result: ScenarioResult) -> bool: ) return True - def _get_completed_objectives_for_attack( - self, *, atomic_attack_name: str - ) -> Set[str]: + def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> Set[str]: """ Get the set of objectives that have already been completed for a specific atomic attack. @@ -474,17 +449,14 @@ def _get_completed_objectives_for_attack( try: # Retrieve the scenario result from memory - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[self._scenario_result_id] - ) + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[self._scenario_result_id]) if scenario_results: scenario_result = scenario_results[0] # Get completed objectives for this atomic attack name if atomic_attack_name in scenario_result.attack_results: completed_objectives = { - result.objective - for result in scenario_result.attack_results[atomic_attack_name] + result.objective for result in scenario_result.attack_results[atomic_attack_name] } except Exception as e: logger.warning( @@ -516,14 +488,10 @@ async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: ) # Get ORIGINAL objectives (before any mutations) from stored map - original_objectives = self._original_objectives_map.get( - atomic_attack.atomic_attack_name, () - ) + original_objectives = self._original_objectives_map.get(atomic_attack.atomic_attack_name, ()) # Calculate remaining objectives - remaining_objectives = [ - obj for obj in original_objectives if obj not in completed_objectives - ] + remaining_objectives = [obj for obj in original_objectives if obj not in completed_objectives] if remaining_objectives: # If there are remaining objectives, update the atomic attack @@ -533,9 +501,7 @@ async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: f"{len(remaining_objectives)}/{len(original_objectives)} objectives remaining" ) # Update the objectives for this atomic attack to only include remaining ones - atomic_attack.filter_seed_groups_by_objectives( - remaining_objectives=remaining_objectives - ) + atomic_attack.filter_seed_groups_by_objectives(remaining_objectives=remaining_objectives) remaining_attacks.append(atomic_attack) else: @@ -558,9 +524,7 @@ async def _update_scenario_result_async( attack_results (List[AttackResult]): The list of new attack results to add. """ if not self._scenario_result_id: - logger.warning( - "Cannot update scenario result: no scenario result ID available" - ) + logger.warning("Cannot update scenario result: no scenario result ID available") return async with self._result_lock: @@ -624,9 +588,7 @@ async def run_async(self) -> ScenarioResult: ) if not self._scenario_result_id: - raise ValueError( - "Scenario not properly initialized. Call await scenario.initialize_async() first." - ) + raise ValueError("Scenario not properly initialized. Call await scenario.initialize_async() first.") # Type narrowing: create local variable that type checker knows is non-None scenario_result_id: str = self._scenario_result_id @@ -641,14 +603,8 @@ async def run_async(self) -> ScenarioResult: last_exception = e # Get current scenario to check number of tries - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[scenario_result_id] - ) - current_tries = ( - scenario_results[0].number_tries - if scenario_results - else retry_attempt + 1 - ) + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + current_tries = scenario_results[0].number_tries if scenario_results else retry_attempt + 1 # Check if we have more retries available remaining_retries = self._max_retries - retry_attempt @@ -673,9 +629,7 @@ async def run_async(self) -> ScenarioResult: # This should never be reached, but just in case if last_exception: raise last_exception - raise RuntimeError( - f"Scenario '{self._name}' completed unexpectedly without result" - ) + raise RuntimeError(f"Scenario '{self._name}' completed unexpectedly without result") async def _execute_scenario_async(self) -> ScenarioResult: """ @@ -693,9 +647,7 @@ async def _execute_scenario_async(self) -> ScenarioResult: ValueError: If a lookup for a scenario for a given ID fails. ValueError: If atomic attack execution fails. """ - logger.info( - f"Starting scenario '{self._name}' execution with {len(self._atomic_attacks)} atomic attacks" - ) + logger.info(f"Starting scenario '{self._name}' execution with {len(self._atomic_attacks)} atomic attacks") # Type narrowing: _scenario_result_id is guaranteed to be non-None at this point # (verified in run_async before calling this method) @@ -703,17 +655,13 @@ async def _execute_scenario_async(self) -> ScenarioResult: scenario_result_id: str = self._scenario_result_id # Increment number_tries at the start of each run - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[scenario_result_id] - ) + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) if scenario_results: current_scenario = scenario_results[0] current_scenario.number_tries += 1 entry = ScenarioResultEntry(entry=current_scenario) self._memory._update_entry(entry) - logger.info( - f"Scenario '{self._name}' attempt #{current_scenario.number_tries}" - ) + logger.info(f"Scenario '{self._name}' attempt #{current_scenario.number_tries}") else: raise ValueError(f"Scenario result with ID {scenario_result_id} not found") @@ -721,23 +669,17 @@ async def _execute_scenario_async(self) -> ScenarioResult: remaining_attacks = await self._get_remaining_atomic_attacks_async() if not remaining_attacks: - logger.info( - f"Scenario '{self._name}' has no remaining objectives to execute" - ) + logger.info(f"Scenario '{self._name}' has no remaining objectives to execute") # Mark scenario as completed self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="COMPLETED" ) # Retrieve and return the current scenario result - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[scenario_result_id] - ) + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) if scenario_results: return scenario_results[0] else: - raise ValueError( - f"Scenario result with ID {scenario_result_id} not found" - ) + raise ValueError(f"Scenario result with ID {scenario_result_id} not found") logger.info( f"Scenario '{self._name}' has {len(remaining_attacks)} atomic attacks " @@ -745,9 +687,7 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Mark scenario as in progress - self._memory.update_scenario_run_state( - scenario_result_id=scenario_result_id, scenario_run_state="IN_PROGRESS" - ) + self._memory.update_scenario_run_state(scenario_result_id=scenario_result_id, scenario_run_state="IN_PROGRESS") # Calculate starting index based on completed attacks completed_count = len(self._atomic_attacks) - len(remaining_attacks) @@ -794,9 +734,7 @@ async def _execute_scenario_async(self) -> ScenarioResult: # Log details of each incomplete objective for obj, exc in atomic_results.incomplete_objectives: - logger.error( - f" Incomplete objective '{obj[:50]}...': {str(exc)}" - ) + logger.error(f" Incomplete objective '{obj[:50]}...': {str(exc)}") # Mark scenario as failed self._memory.update_scenario_run_state( @@ -824,13 +762,8 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Mark scenario as failed if not already done - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[scenario_result_id] - ) - if ( - scenario_results - and scenario_results[0].scenario_run_state != "FAILED" - ): + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + if scenario_results and scenario_results[0].scenario_run_state != "FAILED": self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="FAILED", @@ -846,13 +779,9 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Retrieve and return final scenario result - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[scenario_result_id] - ) + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) if not scenario_results: - raise ValueError( - f"Scenario result with ID {self._scenario_result_id} not found" - ) + raise ValueError(f"Scenario result with ID {self._scenario_result_id} not found") return scenario_results[0] diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index 580de5b37..964c73a3e 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -108,11 +108,7 @@ def get_strategies_by_tag(cls: type[T], tag: str) -> Set[T]: any aggregate markers. """ aggregate_tags = cls.get_aggregate_tags() - return { - strategy - for strategy in cls - if tag in strategy.tags and strategy.value not in aggregate_tags - } + return {strategy for strategy in cls if tag in strategy.tags and strategy.value not in aggregate_tags} @classmethod def get_all_strategies(cls: type[T]) -> list[T]: @@ -177,17 +173,12 @@ def normalize_strategies(cls: type[T], strategies: Set[T]) -> Set[T]: # Find aggregate tags in the input and expand them aggregate_tags = cls.get_aggregate_tags() aggregates_to_expand = { - tag - for strategy in strategies - if strategy.value in aggregate_tags - for tag in strategy.tags + tag for strategy in strategies if strategy.value in aggregate_tags for tag in strategy.tags } for aggregate_tag in aggregates_to_expand: # Remove the aggregate marker itself - aggregate_marker = next( - (s for s in normalized_strategies if s.value == aggregate_tag), None - ) + aggregate_marker = next((s for s in normalized_strategies if s.value == aggregate_tag), None) if aggregate_marker: normalized_strategies.remove(aggregate_marker) @@ -251,10 +242,7 @@ def prepare_scenario_strategies( # Expand the default aggregate into concrete strategies expanded = cls.normalize_strategies({default_aggregate}) # Wrap each in a ScenarioCompositeStrategy - composite_strategies = [ - ScenarioCompositeStrategy(strategies=[strategy]) - for strategy in expanded - ] + composite_strategies = [ScenarioCompositeStrategy(strategies=[strategy]) for strategy in expanded] else: # Process the provided strategies composite_strategies = [] @@ -264,9 +252,7 @@ def prepare_scenario_strategies( composite_strategies.append(item) elif isinstance(item, cls): # Bare strategy enum - wrap it in a composite - composite_strategies.append( - ScenarioCompositeStrategy(strategies=[item]) - ) + composite_strategies.append(ScenarioCompositeStrategy(strategies=[item])) else: # Not our strategy type - skip or could raise error # For now, skip to allow flexibility @@ -282,9 +268,7 @@ def prepare_scenario_strategies( ) # Normalize compositions (expands aggregates, validates compositions) - normalized = ScenarioCompositeStrategy.normalize_compositions( - composite_strategies, strategy_type=cls - ) + normalized = ScenarioCompositeStrategy.normalize_compositions(composite_strategies, strategy_type=cls) return normalized @@ -441,9 +425,7 @@ def extract_single_strategy_values( ValueError: If any composite contains multiple strategies. """ # Check that all composites are single-strategy - multi_strategy_composites = [ - comp for comp in composites if not comp.is_single_strategy - ] + multi_strategy_composites = [comp for comp in composites if not comp.is_single_strategy] if multi_strategy_composites: composite_names = [comp.name for comp in multi_strategy_composites] raise ValueError( @@ -546,20 +528,14 @@ def normalize_compositions( raise ValueError("Empty compositions are not allowed") # Filter to only strategies of the specified type - typed_strategies = [ - s for s in composite.strategies if isinstance(s, strategy_type) - ] + typed_strategies = [s for s in composite.strategies if isinstance(s, strategy_type)] if not typed_strategies: # No strategies of this type - skip continue # Check if composition contains any aggregates - aggregates_in_composition = [ - s for s in typed_strategies if s.value in aggregate_tags - ] - concretes_in_composition = [ - s for s in typed_strategies if s.value not in aggregate_tags - ] + aggregates_in_composition = [s for s in typed_strategies if s.value in aggregate_tags] + concretes_in_composition = [s for s in typed_strategies if s.value not in aggregate_tags] # Error if mixing aggregates with concrete strategies if aggregates_in_composition and concretes_in_composition: @@ -583,9 +559,7 @@ def normalize_compositions( expanded = strategy_type.normalize_strategies({aggregate}) # Each expanded strategy becomes its own composition for strategy in expanded: - normalized_compositions.append( - ScenarioCompositeStrategy(strategies=[strategy]) - ) + normalized_compositions.append(ScenarioCompositeStrategy(strategies=[strategy])) else: # Concrete composition - validate and preserve as-is strategy_type.validate_composition(typed_strategies) diff --git a/tests/unit/scenarios/test_scenario.py b/tests/unit/scenarios/test_scenario.py index 63df939da..3d0ff7823 100644 --- a/tests/unit/scenarios/test_scenario.py +++ b/tests/unit/scenarios/test_scenario.py @@ -27,9 +27,7 @@ def create_mock_run_async(attack_results): async def mock_run_async(*args, **kwargs): # Save results to memory (mimics what real attacks do) save_attack_results_to_memory(attack_results) - return AttackExecutorResult( - completed_results=attack_results, incomplete_objectives=[] - ) + return AttackExecutorResult(completed_results=attack_results, incomplete_objectives=[]) return AsyncMock(side_effect=mock_run_async) @@ -211,9 +209,7 @@ class TestScenarioInitialization2: """Tests for Scenario initialize_async method.""" @pytest.mark.asyncio - async def test_initialize_async_populates_atomic_attacks( - self, mock_atomic_attacks, mock_objective_target - ): + async def test_initialize_async_populates_atomic_attacks(self, mock_atomic_attacks, mock_objective_target): """Test that initialize_async populates atomic attacks.""" scenario = ConcreteScenario( name="Test Scenario", @@ -263,9 +259,7 @@ async def test_initialize_async_sets_max_retries(self, mock_objective_target): version=1, ) - await scenario.initialize_async( - objective_target=mock_objective_target, max_retries=3 - ) + await scenario.initialize_async(objective_target=mock_objective_target, max_retries=3) assert scenario._max_retries == 3 @@ -277,9 +271,7 @@ async def test_initialize_async_sets_max_concurrency(self, mock_objective_target version=1, ) - await scenario.initialize_async( - objective_target=mock_objective_target, max_concurrency=5 - ) + await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=5) assert scenario._max_concurrency == 5 @@ -292,9 +284,7 @@ async def test_initialize_async_sets_memory_labels(self, mock_objective_target): version=1, ) - await scenario.initialize_async( - objective_target=mock_objective_target, memory_labels=labels - ) + await scenario.initialize_async(objective_target=mock_objective_target, memory_labels=labels) assert scenario._memory_labels == labels @@ -318,9 +308,7 @@ class TestScenarioExecution: """Tests for Scenario execution methods.""" @pytest.mark.asyncio - async def test_run_async_executes_all_runs( - self, mock_atomic_attacks, sample_attack_results, mock_objective_target - ): + async def test_run_async_executes_all_runs(self, mock_atomic_attacks, sample_attack_results, mock_objective_target): """Test that run_async executes all atomic attacks sequentially.""" # Configure each run to return different results for i, run in enumerate(mock_atomic_attacks): @@ -341,9 +329,7 @@ async def test_run_async_executes_all_runs( # Verify all runs were executed with correct concurrency assert len(result.attack_results) == 3 for run in mock_atomic_attacks: - run.run_async.assert_called_once_with( - max_concurrency=10, return_partial_on_failure=True - ) + run.run_async.assert_called_once_with(max_concurrency=10, return_partial_on_failure=True) # Verify results are aggregated correctly by atomic attack name assert "attack_run_1" in result.attack_results @@ -366,17 +352,13 @@ async def test_run_async_with_custom_concurrency( version=1, atomic_attacks_to_return=mock_atomic_attacks, ) - await scenario.initialize_async( - objective_target=mock_objective_target, max_concurrency=5 - ) + await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=5) result = await scenario.run_async() # Verify max_concurrency was passed to each run for run in mock_atomic_attacks: - run.run_async.assert_called_once_with( - max_concurrency=5, return_partial_on_failure=True - ) + run.run_async.assert_called_once_with(max_concurrency=5, return_partial_on_failure=True) # Verify result structure assert isinstance(result, ScenarioResult) @@ -388,15 +370,9 @@ async def test_run_async_aggregates_multiple_results( ): """Test that results from multiple atomic attacks are properly aggregated.""" # Configure runs to return different numbers of results - mock_atomic_attacks[0].run_async = create_mock_run_async( - sample_attack_results[0:2] - ) - mock_atomic_attacks[1].run_async = create_mock_run_async( - sample_attack_results[2:4] - ) - mock_atomic_attacks[2].run_async = create_mock_run_async( - sample_attack_results[4:5] - ) + mock_atomic_attacks[0].run_async = create_mock_run_async(sample_attack_results[0:2]) + mock_atomic_attacks[1].run_async = create_mock_run_async(sample_attack_results[2:4]) + mock_atomic_attacks[2].run_async = create_mock_run_async(sample_attack_results[4:5]) scenario = ConcreteScenario( name="Test Scenario", @@ -415,19 +391,11 @@ async def test_run_async_aggregates_multiple_results( assert len(result.attack_results["attack_run_3"]) == 1 @pytest.mark.asyncio - async def test_run_async_stops_on_error( - self, mock_atomic_attacks, sample_attack_results, mock_objective_target - ): + async def test_run_async_stops_on_error(self, mock_atomic_attacks, sample_attack_results, mock_objective_target): """Test that execution stops when an atomic attack fails.""" - mock_atomic_attacks[0].run_async = create_mock_run_async( - [sample_attack_results[0]] - ) - mock_atomic_attacks[1].run_async = AsyncMock( - side_effect=Exception("Test error") - ) - mock_atomic_attacks[2].run_async = create_mock_run_async( - [sample_attack_results[2]] - ) + mock_atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) + mock_atomic_attacks[1].run_async = AsyncMock(side_effect=Exception("Test error")) + mock_atomic_attacks[2].run_async = create_mock_run_async([sample_attack_results[2]]) scenario = ConcreteScenario( name="Test Scenario", @@ -454,9 +422,7 @@ async def test_run_async_fails_without_initialization(self, mock_objective_targe version=1, ) - with pytest.raises( - ValueError, match="Cannot run scenario with no atomic attacks" - ): + with pytest.raises(ValueError, match="Cannot run scenario with no atomic attacks"): await scenario.run_async() @pytest.mark.asyncio @@ -502,9 +468,7 @@ def test_name_property(self, mock_objective_target): assert scenario.name == "My Test Scenario" @pytest.mark.asyncio - async def test_atomic_attack_count_property( - self, mock_atomic_attacks, mock_objective_target - ): + async def test_atomic_attack_count_property(self, mock_atomic_attacks, mock_objective_target): """Test that atomic_attack_count returns the correct count.""" scenario = ConcreteScenario( name="Test Scenario", @@ -519,9 +483,7 @@ async def test_atomic_attack_count_property( assert scenario.atomic_attack_count == 3 @pytest.mark.asyncio - async def test_atomic_attack_count_with_different_sizes( - self, mock_objective_target - ): + async def test_atomic_attack_count_with_different_sizes(self, mock_objective_target): """Test atomic_attack_count with different numbers of atomic attacks.""" # Create mock attack strategy mock_attack = MagicMock() @@ -679,9 +641,7 @@ def test_scenario_identifier_with_custom_pyrit_version(self): def test_scenario_identifier_with_init_data(self): """Test ScenarioIdentifier with init_data.""" init_data = {"param1": "value1", "param2": 42} - identifier = ScenarioIdentifier( - name="TestScenario", scenario_version=1, init_data=init_data - ) + identifier = ScenarioIdentifier(name="TestScenario", scenario_version=1, init_data=init_data) assert identifier.init_data == init_data @@ -759,9 +719,7 @@ class TestScenarioBaselineOnlyExecution: """Tests for baseline-only execution (empty strategies with include_baseline=True).""" @pytest.mark.asyncio - async def test_initialize_async_with_empty_strategies_and_baseline( - self, mock_objective_target - ): + async def test_initialize_async_with_empty_strategies_and_baseline(self, mock_objective_target): """Test that baseline-only execution works when include_baseline=True and strategies is empty.""" from pyrit.models import SeedAttackGroup, SeedObjective @@ -791,9 +749,7 @@ async def test_initialize_async_with_empty_strategies_and_baseline( assert scenario._atomic_attacks[0].atomic_attack_name == "baseline" @pytest.mark.asyncio - async def test_baseline_only_execution_runs_successfully( - self, mock_objective_target, sample_attack_results - ): + async def test_baseline_only_execution_runs_successfully(self, mock_objective_target, sample_attack_results): """Test that baseline-only scenario can run successfully.""" from pyrit.models import SeedAttackGroup, SeedObjective @@ -818,9 +774,7 @@ async def test_baseline_only_execution_runs_successfully( ) # Mock the baseline attack's run_async - scenario._atomic_attacks[0].run_async = create_mock_run_async( - [sample_attack_results[0]] - ) + scenario._atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) # Run the scenario result = await scenario.run_async() @@ -831,9 +785,7 @@ async def test_baseline_only_execution_runs_successfully( assert len(result.attack_results["baseline"]) == 1 @pytest.mark.asyncio - async def test_empty_strategies_without_baseline_raises_error( - self, mock_objective_target - ): + async def test_empty_strategies_without_baseline_raises_error(self, mock_objective_target): """Test that empty strategies without include_baseline raises ValueError.""" scenario = ConcreteScenario( name="No Baseline Test", @@ -852,9 +804,7 @@ async def test_empty_strategies_without_baseline_raises_error( ) @pytest.mark.asyncio - async def test_standalone_baseline_uses_dataset_config_seeds( - self, mock_objective_target - ): + async def test_standalone_baseline_uses_dataset_config_seeds(self, mock_objective_target): """Test that standalone baseline uses seed groups from dataset_config.""" from pyrit.models import SeedAttackGroup, SeedObjective From 0bcde47e477be58912af1076aa2e1274da268d8d Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Thu, 22 Jan 2026 17:55:41 -0500 Subject: [PATCH 5/6] Refactor baseline support per PR feedback: remove allow_empty param and consolidate methods --- pyrit/scenario/core/scenario.py | 238 ++++++++++++++--------- pyrit/scenario/core/scenario_strategy.py | 63 ++++-- tests/unit/scenarios/test_scenario.py | 122 ++++++++---- 3 files changed, 274 insertions(+), 149 deletions(-) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index a4a3dccfc..7cc2b1f96 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -79,7 +79,9 @@ def __init__( with whitespace normalized for display. """ # Use the class docstring with normalized whitespace as description - description = " ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "" + description = ( + " ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "" + ) self._identifier = ScenarioIdentifier( name=type(self).__name__, scenario_version=version, description=description @@ -101,7 +103,9 @@ def __init__( self._name = name self._memory = CentralMemory.get_memory_instance() self._atomic_attacks: List[AtomicAttack] = [] - self._scenario_result_id: Optional[str] = str(scenario_result_id) if scenario_result_id else None + self._scenario_result_id: Optional[str] = ( + str(scenario_result_id) if scenario_result_id else None + ) self._result_lock = asyncio.Lock() self._include_baseline = include_default_baseline @@ -173,7 +177,9 @@ async def initialize_async( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore - scenario_strategies: Optional[Sequence[ScenarioStrategy | ScenarioCompositeStrategy]] = None, + scenario_strategies: Optional[ + Sequence[ScenarioStrategy | ScenarioCompositeStrategy] + ] = None, dataset_config: Optional[DatasetConfiguration] = None, max_concurrency: int = 10, max_retries: int = 0, @@ -222,7 +228,9 @@ async def initialize_async( self._objective_target = objective_target self._objective_target_identifier = objective_target.get_identifier() self._dataset_config_provided = dataset_config is not None - self._dataset_config = dataset_config if dataset_config else self.default_dataset_config() + self._dataset_config = ( + dataset_config if dataset_config else self.default_dataset_config() + ) self._max_concurrency = max_concurrency self._max_retries = max_retries self._memory_labels = memory_labels or {} @@ -232,28 +240,25 @@ async def initialize_async( self._scenario_composites = self._strategy_class.prepare_scenario_strategies( scenario_strategies, default_aggregate=self.get_default_strategy(), - allow_empty=self._include_baseline, ) self._atomic_attacks = await self._get_atomic_attacks_async() if self._include_baseline: - if self._atomic_attacks: - # Derive baseline from first attack - baseline_attack = self._get_baseline_from_first_attack() - else: - # No atomic attacks - create standalone baseline from dataset - baseline_attack = self._create_standalone_baseline() + baseline_attack = self._get_baseline() self._atomic_attacks.insert(0, baseline_attack) # Store original objectives for each atomic attack (before any mutations during execution) self._original_objectives_map = { - atomic_attack.atomic_attack_name: tuple(atomic_attack.objectives) for atomic_attack in self._atomic_attacks + atomic_attack.atomic_attack_name: tuple(atomic_attack.objectives) + for atomic_attack in self._atomic_attacks } # Check if we're resuming an existing scenario if self._scenario_result_id: - existing_results = self._memory.get_scenario_results(scenario_result_ids=[self._scenario_result_id]) + existing_results = self._memory.get_scenario_results( + scenario_result_ids=[self._scenario_result_id] + ) if existing_results: existing_result = existing_results[0] @@ -272,7 +277,8 @@ async def initialize_async( # Create new scenario result attack_results: Dict[str, List[AttackResult]] = { - atomic_attack.atomic_attack_name: [] for atomic_attack in self._atomic_attacks + atomic_attack.atomic_attack_name: [] + for atomic_attack in self._atomic_attacks } result = ScenarioResult( @@ -289,34 +295,21 @@ async def initialize_async( self._scenario_result_id = str(result.id) logger.info(f"Created new scenario result with ID: {self._scenario_result_id}") - def _get_baseline_from_first_attack(self) -> AtomicAttack: + def _get_baseline(self) -> AtomicAttack: """ Get a baseline AtomicAttack, which simply sends all the objectives without any modifications. + If other atomic attacks exist, derives baseline data from the first attack. + Otherwise, creates a standalone baseline from the dataset configuration and scenario settings. + Returns: AtomicAttack: The baseline AtomicAttack instance. Raises: - ValueError: If no atomic attacks are available to derive baseline from. + ValueError: If required data (seed_groups, objective_target, attack_scoring_config) + is not available. """ - if not self._atomic_attacks or len(self._atomic_attacks) == 0: - raise ValueError("No atomic attacks available to derive baseline from.") - - first_attack = self._atomic_attacks[0] - - # Copy seed_groups, scoring, target from the first attack - seed_groups = first_attack.seed_groups - attack_scoring_config = first_attack._attack.get_attack_scoring_config() - objective_target = first_attack._attack.get_objective_target() - - if not seed_groups or len(seed_groups) == 0: - raise ValueError("First atomic attack must have seed_groups to create baseline.") - - if not objective_target: - raise ValueError("Objective target is required to create baseline attack.") - - if not attack_scoring_config: - raise ValueError("Attack scoring config is required to create baseline attack.") + seed_groups, attack_scoring_config, objective_target = self._get_baseline_data() # Create baseline attack with no converters attack = PromptSendingAttack( @@ -331,57 +324,64 @@ def _get_baseline_from_first_attack(self) -> AtomicAttack: memory_labels=self._memory_labels, ) - def _create_standalone_baseline(self) -> AtomicAttack: + def _get_baseline_data(self): """ - Create a standalone baseline AtomicAttack when no other atomic attacks exist. + Get the data needed to create a baseline attack. - This method is used for baseline-only execution where no attack strategies are specified - but include_baseline=True. It creates the baseline directly from the dataset configuration - and scenario-level settings. + Returns either the first attack's data or the scenario-level data + depending on whether other atomic attacks exist. Returns: - AtomicAttack: The baseline AtomicAttack instance. + Tuple containing (seed_groups, attack_scoring_config, objective_target) Raises: - ValueError: If objective_target, dataset_config, or objective_scorer is not set. + ValueError: If required data is not available. """ - if not self._objective_target: - raise ValueError("Objective target is required to create standalone baseline attack.") - - if not self._dataset_config: - raise ValueError("Dataset config is required to create standalone baseline attack.") - - if not self._objective_scorer: - raise ValueError("Objective scorer is required to create standalone baseline attack.") - - # Get seed groups from the dataset configuration - seed_groups = self._dataset_config.get_all_seed_attack_groups() + if self._atomic_attacks and len(self._atomic_attacks) > 0: + # Derive from first attack + first_attack = self._atomic_attacks[0] + seed_groups = first_attack.seed_groups + attack_scoring_config = first_attack._attack.get_attack_scoring_config() + objective_target = first_attack._attack.get_objective_target() + else: + # Create from scenario-level settings + if not self._objective_target: + raise ValueError( + "Objective target is required to create baseline attack." + ) + if not self._dataset_config: + raise ValueError( + "Dataset config is required to create baseline attack." + ) + if not self._objective_scorer: + raise ValueError( + "Objective scorer is required to create baseline attack." + ) - if not seed_groups or len(seed_groups) == 0: - raise ValueError("Dataset config must have seed groups to create baseline.") + seed_groups = self._dataset_config.get_all_seed_attack_groups() + objective_target = self._objective_target - # Import here to avoid circular imports - from typing import cast + # Import here to avoid circular imports + from typing import cast - from pyrit.executor.attack.core.attack_config import AttackScoringConfig - from pyrit.score import TrueFalseScorer + from pyrit.executor.attack.core.attack_config import AttackScoringConfig + from pyrit.score import TrueFalseScorer - # Create scoring config from the scenario's objective scorer - # Note: Scenarios require TrueFalseScorer for attack scoring - attack_scoring_config = AttackScoringConfig(objective_scorer=cast(TrueFalseScorer, self._objective_scorer)) + attack_scoring_config = AttackScoringConfig( + objective_scorer=cast(TrueFalseScorer, self._objective_scorer) + ) - # Create baseline attack with no converters - attack = PromptSendingAttack( - objective_target=self._objective_target, - attack_scoring_config=attack_scoring_config, - ) + # Validate required data + if not seed_groups or len(seed_groups) == 0: + raise ValueError("Seed groups are required to create baseline attack.") + if not objective_target: + raise ValueError("Objective target is required to create baseline attack.") + if not attack_scoring_config: + raise ValueError( + "Attack scoring config is required to create baseline attack." + ) - return AtomicAttack( - atomic_attack_name="baseline", - attack=attack, - seed_groups=seed_groups, - memory_labels=self._memory_labels, - ) + return seed_groups, attack_scoring_config, objective_target def _raise_dataset_exception(self) -> None: error_msg = textwrap.dedent( @@ -432,7 +432,9 @@ def _validate_stored_scenario(self, *, stored_result: ScenarioResult) -> bool: ) return True - def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> Set[str]: + def _get_completed_objectives_for_attack( + self, *, atomic_attack_name: str + ) -> Set[str]: """ Get the set of objectives that have already been completed for a specific atomic attack. @@ -449,14 +451,17 @@ def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> Se try: # Retrieve the scenario result from memory - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[self._scenario_result_id]) + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[self._scenario_result_id] + ) if scenario_results: scenario_result = scenario_results[0] # Get completed objectives for this atomic attack name if atomic_attack_name in scenario_result.attack_results: completed_objectives = { - result.objective for result in scenario_result.attack_results[atomic_attack_name] + result.objective + for result in scenario_result.attack_results[atomic_attack_name] } except Exception as e: logger.warning( @@ -488,10 +493,14 @@ async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: ) # Get ORIGINAL objectives (before any mutations) from stored map - original_objectives = self._original_objectives_map.get(atomic_attack.atomic_attack_name, ()) + original_objectives = self._original_objectives_map.get( + atomic_attack.atomic_attack_name, () + ) # Calculate remaining objectives - remaining_objectives = [obj for obj in original_objectives if obj not in completed_objectives] + remaining_objectives = [ + obj for obj in original_objectives if obj not in completed_objectives + ] if remaining_objectives: # If there are remaining objectives, update the atomic attack @@ -501,7 +510,9 @@ async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: f"{len(remaining_objectives)}/{len(original_objectives)} objectives remaining" ) # Update the objectives for this atomic attack to only include remaining ones - atomic_attack.filter_seed_groups_by_objectives(remaining_objectives=remaining_objectives) + atomic_attack.filter_seed_groups_by_objectives( + remaining_objectives=remaining_objectives + ) remaining_attacks.append(atomic_attack) else: @@ -524,7 +535,9 @@ async def _update_scenario_result_async( attack_results (List[AttackResult]): The list of new attack results to add. """ if not self._scenario_result_id: - logger.warning("Cannot update scenario result: no scenario result ID available") + logger.warning( + "Cannot update scenario result: no scenario result ID available" + ) return async with self._result_lock: @@ -588,7 +601,9 @@ async def run_async(self) -> ScenarioResult: ) if not self._scenario_result_id: - raise ValueError("Scenario not properly initialized. Call await scenario.initialize_async() first.") + raise ValueError( + "Scenario not properly initialized. Call await scenario.initialize_async() first." + ) # Type narrowing: create local variable that type checker knows is non-None scenario_result_id: str = self._scenario_result_id @@ -603,8 +618,14 @@ async def run_async(self) -> ScenarioResult: last_exception = e # Get current scenario to check number of tries - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) - current_tries = scenario_results[0].number_tries if scenario_results else retry_attempt + 1 + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[scenario_result_id] + ) + current_tries = ( + scenario_results[0].number_tries + if scenario_results + else retry_attempt + 1 + ) # Check if we have more retries available remaining_retries = self._max_retries - retry_attempt @@ -629,7 +650,9 @@ async def run_async(self) -> ScenarioResult: # This should never be reached, but just in case if last_exception: raise last_exception - raise RuntimeError(f"Scenario '{self._name}' completed unexpectedly without result") + raise RuntimeError( + f"Scenario '{self._name}' completed unexpectedly without result" + ) async def _execute_scenario_async(self) -> ScenarioResult: """ @@ -647,7 +670,9 @@ async def _execute_scenario_async(self) -> ScenarioResult: ValueError: If a lookup for a scenario for a given ID fails. ValueError: If atomic attack execution fails. """ - logger.info(f"Starting scenario '{self._name}' execution with {len(self._atomic_attacks)} atomic attacks") + logger.info( + f"Starting scenario '{self._name}' execution with {len(self._atomic_attacks)} atomic attacks" + ) # Type narrowing: _scenario_result_id is guaranteed to be non-None at this point # (verified in run_async before calling this method) @@ -655,13 +680,17 @@ async def _execute_scenario_async(self) -> ScenarioResult: scenario_result_id: str = self._scenario_result_id # Increment number_tries at the start of each run - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[scenario_result_id] + ) if scenario_results: current_scenario = scenario_results[0] current_scenario.number_tries += 1 entry = ScenarioResultEntry(entry=current_scenario) self._memory._update_entry(entry) - logger.info(f"Scenario '{self._name}' attempt #{current_scenario.number_tries}") + logger.info( + f"Scenario '{self._name}' attempt #{current_scenario.number_tries}" + ) else: raise ValueError(f"Scenario result with ID {scenario_result_id} not found") @@ -669,17 +698,23 @@ async def _execute_scenario_async(self) -> ScenarioResult: remaining_attacks = await self._get_remaining_atomic_attacks_async() if not remaining_attacks: - logger.info(f"Scenario '{self._name}' has no remaining objectives to execute") + logger.info( + f"Scenario '{self._name}' has no remaining objectives to execute" + ) # Mark scenario as completed self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="COMPLETED" ) # Retrieve and return the current scenario result - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[scenario_result_id] + ) if scenario_results: return scenario_results[0] else: - raise ValueError(f"Scenario result with ID {scenario_result_id} not found") + raise ValueError( + f"Scenario result with ID {scenario_result_id} not found" + ) logger.info( f"Scenario '{self._name}' has {len(remaining_attacks)} atomic attacks " @@ -687,7 +722,9 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Mark scenario as in progress - self._memory.update_scenario_run_state(scenario_result_id=scenario_result_id, scenario_run_state="IN_PROGRESS") + self._memory.update_scenario_run_state( + scenario_result_id=scenario_result_id, scenario_run_state="IN_PROGRESS" + ) # Calculate starting index based on completed attacks completed_count = len(self._atomic_attacks) - len(remaining_attacks) @@ -734,7 +771,9 @@ async def _execute_scenario_async(self) -> ScenarioResult: # Log details of each incomplete objective for obj, exc in atomic_results.incomplete_objectives: - logger.error(f" Incomplete objective '{obj[:50]}...': {str(exc)}") + logger.error( + f" Incomplete objective '{obj[:50]}...': {str(exc)}" + ) # Mark scenario as failed self._memory.update_scenario_run_state( @@ -762,8 +801,13 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Mark scenario as failed if not already done - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) - if scenario_results and scenario_results[0].scenario_run_state != "FAILED": + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[scenario_result_id] + ) + if ( + scenario_results + and scenario_results[0].scenario_run_state != "FAILED" + ): self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="FAILED", @@ -779,9 +823,13 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Retrieve and return final scenario result - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + scenario_results = self._memory.get_scenario_results( + scenario_result_ids=[scenario_result_id] + ) if not scenario_results: - raise ValueError(f"Scenario result with ID {self._scenario_result_id} not found") + raise ValueError( + f"Scenario result with ID {self._scenario_result_id} not found" + ) return scenario_results[0] diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index 964c73a3e..769d9ea88 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -108,7 +108,11 @@ def get_strategies_by_tag(cls: type[T], tag: str) -> Set[T]: any aggregate markers. """ aggregate_tags = cls.get_aggregate_tags() - return {strategy for strategy in cls if tag in strategy.tags and strategy.value not in aggregate_tags} + return { + strategy + for strategy in cls + if tag in strategy.tags and strategy.value not in aggregate_tags + } @classmethod def get_all_strategies(cls: type[T]) -> list[T]: @@ -173,12 +177,17 @@ def normalize_strategies(cls: type[T], strategies: Set[T]) -> Set[T]: # Find aggregate tags in the input and expand them aggregate_tags = cls.get_aggregate_tags() aggregates_to_expand = { - tag for strategy in strategies if strategy.value in aggregate_tags for tag in strategy.tags + tag + for strategy in strategies + if strategy.value in aggregate_tags + for tag in strategy.tags } for aggregate_tag in aggregates_to_expand: # Remove the aggregate marker itself - aggregate_marker = next((s for s in normalized_strategies if s.value == aggregate_tag), None) + aggregate_marker = next( + (s for s in normalized_strategies if s.value == aggregate_tag), None + ) if aggregate_marker: normalized_strategies.remove(aggregate_marker) @@ -197,7 +206,6 @@ def prepare_scenario_strategies( strategies: Sequence[T | "ScenarioCompositeStrategy"] | None = None, *, default_aggregate: T | None = None, - allow_empty: bool = False, ) -> List["ScenarioCompositeStrategy"]: """ Prepare and normalize scenario strategies for use in a scenario. @@ -214,22 +222,18 @@ def prepare_scenario_strategies( strategies (Sequence[T | ScenarioCompositeStrategy] | None): The strategies to prepare. Can be a mix of bare strategy enums and composite strategies. If None, uses default_aggregate to determine defaults. - If an empty sequence, behavior depends on allow_empty parameter. + If an empty sequence, returns an empty list (useful for baseline-only execution). default_aggregate (T | None): The aggregate strategy to use when strategies is None. Common values: MyStrategy.ALL, MyStrategy.EASY. If None when strategies is None, raises ValueError. - allow_empty (bool): If True, allows an empty strategies list to be returned when - an empty sequence is explicitly provided. This is useful for baseline-only - execution where no attack strategies are needed. Defaults to False. Returns: List[ScenarioCompositeStrategy]: Normalized list of composite strategies ready for use. - May be empty if allow_empty=True and an empty sequence was provided. + May be empty if an empty sequence was explicitly provided. Raises: ValueError: If strategies is None and default_aggregate is None, or if compositions - are invalid according to validate_composition(), or if strategies is empty - and allow_empty is False. + are invalid according to validate_composition(). """ # Handle None input with default aggregate if strategies is None: @@ -242,7 +246,10 @@ def prepare_scenario_strategies( # Expand the default aggregate into concrete strategies expanded = cls.normalize_strategies({default_aggregate}) # Wrap each in a ScenarioCompositeStrategy - composite_strategies = [ScenarioCompositeStrategy(strategies=[strategy]) for strategy in expanded] + composite_strategies = [ + ScenarioCompositeStrategy(strategies=[strategy]) + for strategy in expanded + ] else: # Process the provided strategies composite_strategies = [] @@ -252,15 +259,17 @@ def prepare_scenario_strategies( composite_strategies.append(item) elif isinstance(item, cls): # Bare strategy enum - wrap it in a composite - composite_strategies.append(ScenarioCompositeStrategy(strategies=[item])) + composite_strategies.append( + ScenarioCompositeStrategy(strategies=[item]) + ) else: # Not our strategy type - skip or could raise error # For now, skip to allow flexibility pass - # Allow empty list if explicitly requested (for baseline-only execution) + # Allow empty list if explicitly provided (for baseline-only execution) if not composite_strategies: - if allow_empty and strategies is not None and len(strategies) == 0: + if strategies is not None and len(strategies) == 0: return [] raise ValueError( f"No valid {cls.__name__} strategies provided. " @@ -268,7 +277,9 @@ def prepare_scenario_strategies( ) # Normalize compositions (expands aggregates, validates compositions) - normalized = ScenarioCompositeStrategy.normalize_compositions(composite_strategies, strategy_type=cls) + normalized = ScenarioCompositeStrategy.normalize_compositions( + composite_strategies, strategy_type=cls + ) return normalized @@ -425,7 +436,9 @@ def extract_single_strategy_values( ValueError: If any composite contains multiple strategies. """ # Check that all composites are single-strategy - multi_strategy_composites = [comp for comp in composites if not comp.is_single_strategy] + multi_strategy_composites = [ + comp for comp in composites if not comp.is_single_strategy + ] if multi_strategy_composites: composite_names = [comp.name for comp in multi_strategy_composites] raise ValueError( @@ -528,14 +541,20 @@ def normalize_compositions( raise ValueError("Empty compositions are not allowed") # Filter to only strategies of the specified type - typed_strategies = [s for s in composite.strategies if isinstance(s, strategy_type)] + typed_strategies = [ + s for s in composite.strategies if isinstance(s, strategy_type) + ] if not typed_strategies: # No strategies of this type - skip continue # Check if composition contains any aggregates - aggregates_in_composition = [s for s in typed_strategies if s.value in aggregate_tags] - concretes_in_composition = [s for s in typed_strategies if s.value not in aggregate_tags] + aggregates_in_composition = [ + s for s in typed_strategies if s.value in aggregate_tags + ] + concretes_in_composition = [ + s for s in typed_strategies if s.value not in aggregate_tags + ] # Error if mixing aggregates with concrete strategies if aggregates_in_composition and concretes_in_composition: @@ -559,7 +578,9 @@ def normalize_compositions( expanded = strategy_type.normalize_strategies({aggregate}) # Each expanded strategy becomes its own composition for strategy in expanded: - normalized_compositions.append(ScenarioCompositeStrategy(strategies=[strategy])) + normalized_compositions.append( + ScenarioCompositeStrategy(strategies=[strategy]) + ) else: # Concrete composition - validate and preserve as-is strategy_type.validate_composition(typed_strategies) diff --git a/tests/unit/scenarios/test_scenario.py b/tests/unit/scenarios/test_scenario.py index 3d0ff7823..fe87135bd 100644 --- a/tests/unit/scenarios/test_scenario.py +++ b/tests/unit/scenarios/test_scenario.py @@ -27,7 +27,9 @@ def create_mock_run_async(attack_results): async def mock_run_async(*args, **kwargs): # Save results to memory (mimics what real attacks do) save_attack_results_to_memory(attack_results) - return AttackExecutorResult(completed_results=attack_results, incomplete_objectives=[]) + return AttackExecutorResult( + completed_results=attack_results, incomplete_objectives=[] + ) return AsyncMock(side_effect=mock_run_async) @@ -209,7 +211,9 @@ class TestScenarioInitialization2: """Tests for Scenario initialize_async method.""" @pytest.mark.asyncio - async def test_initialize_async_populates_atomic_attacks(self, mock_atomic_attacks, mock_objective_target): + async def test_initialize_async_populates_atomic_attacks( + self, mock_atomic_attacks, mock_objective_target + ): """Test that initialize_async populates atomic attacks.""" scenario = ConcreteScenario( name="Test Scenario", @@ -259,7 +263,9 @@ async def test_initialize_async_sets_max_retries(self, mock_objective_target): version=1, ) - await scenario.initialize_async(objective_target=mock_objective_target, max_retries=3) + await scenario.initialize_async( + objective_target=mock_objective_target, max_retries=3 + ) assert scenario._max_retries == 3 @@ -271,7 +277,9 @@ async def test_initialize_async_sets_max_concurrency(self, mock_objective_target version=1, ) - await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=5) + await scenario.initialize_async( + objective_target=mock_objective_target, max_concurrency=5 + ) assert scenario._max_concurrency == 5 @@ -284,7 +292,9 @@ async def test_initialize_async_sets_memory_labels(self, mock_objective_target): version=1, ) - await scenario.initialize_async(objective_target=mock_objective_target, memory_labels=labels) + await scenario.initialize_async( + objective_target=mock_objective_target, memory_labels=labels + ) assert scenario._memory_labels == labels @@ -308,7 +318,9 @@ class TestScenarioExecution: """Tests for Scenario execution methods.""" @pytest.mark.asyncio - async def test_run_async_executes_all_runs(self, mock_atomic_attacks, sample_attack_results, mock_objective_target): + async def test_run_async_executes_all_runs( + self, mock_atomic_attacks, sample_attack_results, mock_objective_target + ): """Test that run_async executes all atomic attacks sequentially.""" # Configure each run to return different results for i, run in enumerate(mock_atomic_attacks): @@ -329,7 +341,9 @@ async def test_run_async_executes_all_runs(self, mock_atomic_attacks, sample_att # Verify all runs were executed with correct concurrency assert len(result.attack_results) == 3 for run in mock_atomic_attacks: - run.run_async.assert_called_once_with(max_concurrency=10, return_partial_on_failure=True) + run.run_async.assert_called_once_with( + max_concurrency=10, return_partial_on_failure=True + ) # Verify results are aggregated correctly by atomic attack name assert "attack_run_1" in result.attack_results @@ -352,13 +366,17 @@ async def test_run_async_with_custom_concurrency( version=1, atomic_attacks_to_return=mock_atomic_attacks, ) - await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=5) + await scenario.initialize_async( + objective_target=mock_objective_target, max_concurrency=5 + ) result = await scenario.run_async() # Verify max_concurrency was passed to each run for run in mock_atomic_attacks: - run.run_async.assert_called_once_with(max_concurrency=5, return_partial_on_failure=True) + run.run_async.assert_called_once_with( + max_concurrency=5, return_partial_on_failure=True + ) # Verify result structure assert isinstance(result, ScenarioResult) @@ -370,9 +388,15 @@ async def test_run_async_aggregates_multiple_results( ): """Test that results from multiple atomic attacks are properly aggregated.""" # Configure runs to return different numbers of results - mock_atomic_attacks[0].run_async = create_mock_run_async(sample_attack_results[0:2]) - mock_atomic_attacks[1].run_async = create_mock_run_async(sample_attack_results[2:4]) - mock_atomic_attacks[2].run_async = create_mock_run_async(sample_attack_results[4:5]) + mock_atomic_attacks[0].run_async = create_mock_run_async( + sample_attack_results[0:2] + ) + mock_atomic_attacks[1].run_async = create_mock_run_async( + sample_attack_results[2:4] + ) + mock_atomic_attacks[2].run_async = create_mock_run_async( + sample_attack_results[4:5] + ) scenario = ConcreteScenario( name="Test Scenario", @@ -391,11 +415,19 @@ async def test_run_async_aggregates_multiple_results( assert len(result.attack_results["attack_run_3"]) == 1 @pytest.mark.asyncio - async def test_run_async_stops_on_error(self, mock_atomic_attacks, sample_attack_results, mock_objective_target): + async def test_run_async_stops_on_error( + self, mock_atomic_attacks, sample_attack_results, mock_objective_target + ): """Test that execution stops when an atomic attack fails.""" - mock_atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) - mock_atomic_attacks[1].run_async = AsyncMock(side_effect=Exception("Test error")) - mock_atomic_attacks[2].run_async = create_mock_run_async([sample_attack_results[2]]) + mock_atomic_attacks[0].run_async = create_mock_run_async( + [sample_attack_results[0]] + ) + mock_atomic_attacks[1].run_async = AsyncMock( + side_effect=Exception("Test error") + ) + mock_atomic_attacks[2].run_async = create_mock_run_async( + [sample_attack_results[2]] + ) scenario = ConcreteScenario( name="Test Scenario", @@ -422,7 +454,9 @@ async def test_run_async_fails_without_initialization(self, mock_objective_targe version=1, ) - with pytest.raises(ValueError, match="Cannot run scenario with no atomic attacks"): + with pytest.raises( + ValueError, match="Cannot run scenario with no atomic attacks" + ): await scenario.run_async() @pytest.mark.asyncio @@ -468,7 +502,9 @@ def test_name_property(self, mock_objective_target): assert scenario.name == "My Test Scenario" @pytest.mark.asyncio - async def test_atomic_attack_count_property(self, mock_atomic_attacks, mock_objective_target): + async def test_atomic_attack_count_property( + self, mock_atomic_attacks, mock_objective_target + ): """Test that atomic_attack_count returns the correct count.""" scenario = ConcreteScenario( name="Test Scenario", @@ -483,7 +519,9 @@ async def test_atomic_attack_count_property(self, mock_atomic_attacks, mock_obje assert scenario.atomic_attack_count == 3 @pytest.mark.asyncio - async def test_atomic_attack_count_with_different_sizes(self, mock_objective_target): + async def test_atomic_attack_count_with_different_sizes( + self, mock_objective_target + ): """Test atomic_attack_count with different numbers of atomic attacks.""" # Create mock attack strategy mock_attack = MagicMock() @@ -641,7 +679,9 @@ def test_scenario_identifier_with_custom_pyrit_version(self): def test_scenario_identifier_with_init_data(self): """Test ScenarioIdentifier with init_data.""" init_data = {"param1": "value1", "param2": 42} - identifier = ScenarioIdentifier(name="TestScenario", scenario_version=1, init_data=init_data) + identifier = ScenarioIdentifier( + name="TestScenario", scenario_version=1, init_data=init_data + ) assert identifier.init_data == init_data @@ -719,7 +759,9 @@ class TestScenarioBaselineOnlyExecution: """Tests for baseline-only execution (empty strategies with include_baseline=True).""" @pytest.mark.asyncio - async def test_initialize_async_with_empty_strategies_and_baseline(self, mock_objective_target): + async def test_initialize_async_with_empty_strategies_and_baseline( + self, mock_objective_target + ): """Test that baseline-only execution works when include_baseline=True and strategies is empty.""" from pyrit.models import SeedAttackGroup, SeedObjective @@ -749,7 +791,9 @@ async def test_initialize_async_with_empty_strategies_and_baseline(self, mock_ob assert scenario._atomic_attacks[0].atomic_attack_name == "baseline" @pytest.mark.asyncio - async def test_baseline_only_execution_runs_successfully(self, mock_objective_target, sample_attack_results): + async def test_baseline_only_execution_runs_successfully( + self, mock_objective_target, sample_attack_results + ): """Test that baseline-only scenario can run successfully.""" from pyrit.models import SeedAttackGroup, SeedObjective @@ -774,7 +818,9 @@ async def test_baseline_only_execution_runs_successfully(self, mock_objective_ta ) # Mock the baseline attack's run_async - scenario._atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) + scenario._atomic_attacks[0].run_async = create_mock_run_async( + [sample_attack_results[0]] + ) # Run the scenario result = await scenario.run_async() @@ -785,8 +831,10 @@ async def test_baseline_only_execution_runs_successfully(self, mock_objective_ta assert len(result.attack_results["baseline"]) == 1 @pytest.mark.asyncio - async def test_empty_strategies_without_baseline_raises_error(self, mock_objective_target): - """Test that empty strategies without include_baseline raises ValueError.""" + async def test_empty_strategies_without_baseline_allows_initialization( + self, mock_objective_target + ): + """Test that empty strategies without include_baseline allows initialization but fails at run time.""" scenario = ConcreteScenario( name="No Baseline Test", version=1, @@ -795,16 +843,24 @@ async def test_empty_strategies_without_baseline_raises_error(self, mock_objecti mock_dataset_config = MagicMock(spec=DatasetConfiguration) - # Should raise ValueError because empty strategies without baseline is not allowed - with pytest.raises(ValueError, match="No valid .* strategies provided"): - await scenario.initialize_async( - objective_target=mock_objective_target, - scenario_strategies=[], # Empty list without baseline - dataset_config=mock_dataset_config, - ) + # Empty strategies are now always allowed during initialization + # (no allow_empty parameter required) + await scenario.initialize_async( + objective_target=mock_objective_target, + scenario_strategies=[], # Empty list without baseline + dataset_config=mock_dataset_config, + ) + + # But running should fail because there are no atomic attacks + with pytest.raises( + ValueError, match="Cannot run scenario with no atomic attacks" + ): + await scenario.run_async() @pytest.mark.asyncio - async def test_standalone_baseline_uses_dataset_config_seeds(self, mock_objective_target): + async def test_standalone_baseline_uses_dataset_config_seeds( + self, mock_objective_target + ): """Test that standalone baseline uses seed groups from dataset_config.""" from pyrit.models import SeedAttackGroup, SeedObjective From ba78d861d11a84173c6d3cd1c0a67d85a749e7ef Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Thu, 22 Jan 2026 22:52:21 -0500 Subject: [PATCH 6/6] fix: Add type annotation and apply ruff formatting - Add return type annotation to _get_baseline_data() method - Apply ruff formatting to scenario files --- pyrit/scenario/core/scenario.py | 148 +++++++---------------- pyrit/scenario/core/scenario_strategy.py | 48 ++------ tests/unit/scenarios/test_scenario.py | 104 ++++------------ 3 files changed, 78 insertions(+), 222 deletions(-) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 7cc2b1f96..773fc3454 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -13,7 +13,7 @@ import textwrap import uuid from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Sequence, Set, Type, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Type, Union from tqdm.auto import tqdm @@ -32,6 +32,10 @@ ) from pyrit.score import Scorer +if TYPE_CHECKING: + from pyrit.executor.attack.core.attack_config import AttackScoringConfig + from pyrit.models import SeedAttackGroup + logger = logging.getLogger(__name__) @@ -79,9 +83,7 @@ def __init__( with whitespace normalized for display. """ # Use the class docstring with normalized whitespace as description - description = ( - " ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "" - ) + description = " ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "" self._identifier = ScenarioIdentifier( name=type(self).__name__, scenario_version=version, description=description @@ -103,9 +105,7 @@ def __init__( self._name = name self._memory = CentralMemory.get_memory_instance() self._atomic_attacks: List[AtomicAttack] = [] - self._scenario_result_id: Optional[str] = ( - str(scenario_result_id) if scenario_result_id else None - ) + self._scenario_result_id: Optional[str] = str(scenario_result_id) if scenario_result_id else None self._result_lock = asyncio.Lock() self._include_baseline = include_default_baseline @@ -177,9 +177,7 @@ async def initialize_async( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore - scenario_strategies: Optional[ - Sequence[ScenarioStrategy | ScenarioCompositeStrategy] - ] = None, + scenario_strategies: Optional[Sequence[ScenarioStrategy | ScenarioCompositeStrategy]] = None, dataset_config: Optional[DatasetConfiguration] = None, max_concurrency: int = 10, max_retries: int = 0, @@ -228,9 +226,7 @@ async def initialize_async( self._objective_target = objective_target self._objective_target_identifier = objective_target.get_identifier() self._dataset_config_provided = dataset_config is not None - self._dataset_config = ( - dataset_config if dataset_config else self.default_dataset_config() - ) + self._dataset_config = dataset_config if dataset_config else self.default_dataset_config() self._max_concurrency = max_concurrency self._max_retries = max_retries self._memory_labels = memory_labels or {} @@ -250,15 +246,12 @@ async def initialize_async( # Store original objectives for each atomic attack (before any mutations during execution) self._original_objectives_map = { - atomic_attack.atomic_attack_name: tuple(atomic_attack.objectives) - for atomic_attack in self._atomic_attacks + atomic_attack.atomic_attack_name: tuple(atomic_attack.objectives) for atomic_attack in self._atomic_attacks } # Check if we're resuming an existing scenario if self._scenario_result_id: - existing_results = self._memory.get_scenario_results( - scenario_result_ids=[self._scenario_result_id] - ) + existing_results = self._memory.get_scenario_results(scenario_result_ids=[self._scenario_result_id]) if existing_results: existing_result = existing_results[0] @@ -277,8 +270,7 @@ async def initialize_async( # Create new scenario result attack_results: Dict[str, List[AttackResult]] = { - atomic_attack.atomic_attack_name: [] - for atomic_attack in self._atomic_attacks + atomic_attack.atomic_attack_name: [] for atomic_attack in self._atomic_attacks } result = ScenarioResult( @@ -324,7 +316,7 @@ def _get_baseline(self) -> AtomicAttack: memory_labels=self._memory_labels, ) - def _get_baseline_data(self): + def _get_baseline_data(self) -> Tuple[List["SeedAttackGroup"], "AttackScoringConfig", PromptTarget]: """ Get the data needed to create a baseline attack. @@ -346,17 +338,11 @@ def _get_baseline_data(self): else: # Create from scenario-level settings if not self._objective_target: - raise ValueError( - "Objective target is required to create baseline attack." - ) + raise ValueError("Objective target is required to create baseline attack.") if not self._dataset_config: - raise ValueError( - "Dataset config is required to create baseline attack." - ) + raise ValueError("Dataset config is required to create baseline attack.") if not self._objective_scorer: - raise ValueError( - "Objective scorer is required to create baseline attack." - ) + raise ValueError("Objective scorer is required to create baseline attack.") seed_groups = self._dataset_config.get_all_seed_attack_groups() objective_target = self._objective_target @@ -367,9 +353,7 @@ def _get_baseline_data(self): from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.score import TrueFalseScorer - attack_scoring_config = AttackScoringConfig( - objective_scorer=cast(TrueFalseScorer, self._objective_scorer) - ) + attack_scoring_config = AttackScoringConfig(objective_scorer=cast(TrueFalseScorer, self._objective_scorer)) # Validate required data if not seed_groups or len(seed_groups) == 0: @@ -377,9 +361,7 @@ def _get_baseline_data(self): if not objective_target: raise ValueError("Objective target is required to create baseline attack.") if not attack_scoring_config: - raise ValueError( - "Attack scoring config is required to create baseline attack." - ) + raise ValueError("Attack scoring config is required to create baseline attack.") return seed_groups, attack_scoring_config, objective_target @@ -432,9 +414,7 @@ def _validate_stored_scenario(self, *, stored_result: ScenarioResult) -> bool: ) return True - def _get_completed_objectives_for_attack( - self, *, atomic_attack_name: str - ) -> Set[str]: + def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> Set[str]: """ Get the set of objectives that have already been completed for a specific atomic attack. @@ -451,17 +431,14 @@ def _get_completed_objectives_for_attack( try: # Retrieve the scenario result from memory - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[self._scenario_result_id] - ) + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[self._scenario_result_id]) if scenario_results: scenario_result = scenario_results[0] # Get completed objectives for this atomic attack name if atomic_attack_name in scenario_result.attack_results: completed_objectives = { - result.objective - for result in scenario_result.attack_results[atomic_attack_name] + result.objective for result in scenario_result.attack_results[atomic_attack_name] } except Exception as e: logger.warning( @@ -493,14 +470,10 @@ async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: ) # Get ORIGINAL objectives (before any mutations) from stored map - original_objectives = self._original_objectives_map.get( - atomic_attack.atomic_attack_name, () - ) + original_objectives = self._original_objectives_map.get(atomic_attack.atomic_attack_name, ()) # Calculate remaining objectives - remaining_objectives = [ - obj for obj in original_objectives if obj not in completed_objectives - ] + remaining_objectives = [obj for obj in original_objectives if obj not in completed_objectives] if remaining_objectives: # If there are remaining objectives, update the atomic attack @@ -510,9 +483,7 @@ async def _get_remaining_atomic_attacks_async(self) -> List[AtomicAttack]: f"{len(remaining_objectives)}/{len(original_objectives)} objectives remaining" ) # Update the objectives for this atomic attack to only include remaining ones - atomic_attack.filter_seed_groups_by_objectives( - remaining_objectives=remaining_objectives - ) + atomic_attack.filter_seed_groups_by_objectives(remaining_objectives=remaining_objectives) remaining_attacks.append(atomic_attack) else: @@ -535,9 +506,7 @@ async def _update_scenario_result_async( attack_results (List[AttackResult]): The list of new attack results to add. """ if not self._scenario_result_id: - logger.warning( - "Cannot update scenario result: no scenario result ID available" - ) + logger.warning("Cannot update scenario result: no scenario result ID available") return async with self._result_lock: @@ -601,9 +570,7 @@ async def run_async(self) -> ScenarioResult: ) if not self._scenario_result_id: - raise ValueError( - "Scenario not properly initialized. Call await scenario.initialize_async() first." - ) + raise ValueError("Scenario not properly initialized. Call await scenario.initialize_async() first.") # Type narrowing: create local variable that type checker knows is non-None scenario_result_id: str = self._scenario_result_id @@ -618,14 +585,8 @@ async def run_async(self) -> ScenarioResult: last_exception = e # Get current scenario to check number of tries - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[scenario_result_id] - ) - current_tries = ( - scenario_results[0].number_tries - if scenario_results - else retry_attempt + 1 - ) + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + current_tries = scenario_results[0].number_tries if scenario_results else retry_attempt + 1 # Check if we have more retries available remaining_retries = self._max_retries - retry_attempt @@ -650,9 +611,7 @@ async def run_async(self) -> ScenarioResult: # This should never be reached, but just in case if last_exception: raise last_exception - raise RuntimeError( - f"Scenario '{self._name}' completed unexpectedly without result" - ) + raise RuntimeError(f"Scenario '{self._name}' completed unexpectedly without result") async def _execute_scenario_async(self) -> ScenarioResult: """ @@ -670,9 +629,7 @@ async def _execute_scenario_async(self) -> ScenarioResult: ValueError: If a lookup for a scenario for a given ID fails. ValueError: If atomic attack execution fails. """ - logger.info( - f"Starting scenario '{self._name}' execution with {len(self._atomic_attacks)} atomic attacks" - ) + logger.info(f"Starting scenario '{self._name}' execution with {len(self._atomic_attacks)} atomic attacks") # Type narrowing: _scenario_result_id is guaranteed to be non-None at this point # (verified in run_async before calling this method) @@ -680,17 +637,13 @@ async def _execute_scenario_async(self) -> ScenarioResult: scenario_result_id: str = self._scenario_result_id # Increment number_tries at the start of each run - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[scenario_result_id] - ) + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) if scenario_results: current_scenario = scenario_results[0] current_scenario.number_tries += 1 entry = ScenarioResultEntry(entry=current_scenario) self._memory._update_entry(entry) - logger.info( - f"Scenario '{self._name}' attempt #{current_scenario.number_tries}" - ) + logger.info(f"Scenario '{self._name}' attempt #{current_scenario.number_tries}") else: raise ValueError(f"Scenario result with ID {scenario_result_id} not found") @@ -698,23 +651,17 @@ async def _execute_scenario_async(self) -> ScenarioResult: remaining_attacks = await self._get_remaining_atomic_attacks_async() if not remaining_attacks: - logger.info( - f"Scenario '{self._name}' has no remaining objectives to execute" - ) + logger.info(f"Scenario '{self._name}' has no remaining objectives to execute") # Mark scenario as completed self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="COMPLETED" ) # Retrieve and return the current scenario result - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[scenario_result_id] - ) + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) if scenario_results: return scenario_results[0] else: - raise ValueError( - f"Scenario result with ID {scenario_result_id} not found" - ) + raise ValueError(f"Scenario result with ID {scenario_result_id} not found") logger.info( f"Scenario '{self._name}' has {len(remaining_attacks)} atomic attacks " @@ -722,9 +669,7 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Mark scenario as in progress - self._memory.update_scenario_run_state( - scenario_result_id=scenario_result_id, scenario_run_state="IN_PROGRESS" - ) + self._memory.update_scenario_run_state(scenario_result_id=scenario_result_id, scenario_run_state="IN_PROGRESS") # Calculate starting index based on completed attacks completed_count = len(self._atomic_attacks) - len(remaining_attacks) @@ -771,9 +716,7 @@ async def _execute_scenario_async(self) -> ScenarioResult: # Log details of each incomplete objective for obj, exc in atomic_results.incomplete_objectives: - logger.error( - f" Incomplete objective '{obj[:50]}...': {str(exc)}" - ) + logger.error(f" Incomplete objective '{obj[:50]}...': {str(exc)}") # Mark scenario as failed self._memory.update_scenario_run_state( @@ -801,13 +744,8 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Mark scenario as failed if not already done - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[scenario_result_id] - ) - if ( - scenario_results - and scenario_results[0].scenario_run_state != "FAILED" - ): + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + if scenario_results and scenario_results[0].scenario_run_state != "FAILED": self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="FAILED", @@ -823,13 +761,9 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Retrieve and return final scenario result - scenario_results = self._memory.get_scenario_results( - scenario_result_ids=[scenario_result_id] - ) + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) if not scenario_results: - raise ValueError( - f"Scenario result with ID {self._scenario_result_id} not found" - ) + raise ValueError(f"Scenario result with ID {self._scenario_result_id} not found") return scenario_results[0] diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index 769d9ea88..d1f1cdceb 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -108,11 +108,7 @@ def get_strategies_by_tag(cls: type[T], tag: str) -> Set[T]: any aggregate markers. """ aggregate_tags = cls.get_aggregate_tags() - return { - strategy - for strategy in cls - if tag in strategy.tags and strategy.value not in aggregate_tags - } + return {strategy for strategy in cls if tag in strategy.tags and strategy.value not in aggregate_tags} @classmethod def get_all_strategies(cls: type[T]) -> list[T]: @@ -177,17 +173,12 @@ def normalize_strategies(cls: type[T], strategies: Set[T]) -> Set[T]: # Find aggregate tags in the input and expand them aggregate_tags = cls.get_aggregate_tags() aggregates_to_expand = { - tag - for strategy in strategies - if strategy.value in aggregate_tags - for tag in strategy.tags + tag for strategy in strategies if strategy.value in aggregate_tags for tag in strategy.tags } for aggregate_tag in aggregates_to_expand: # Remove the aggregate marker itself - aggregate_marker = next( - (s for s in normalized_strategies if s.value == aggregate_tag), None - ) + aggregate_marker = next((s for s in normalized_strategies if s.value == aggregate_tag), None) if aggregate_marker: normalized_strategies.remove(aggregate_marker) @@ -246,10 +237,7 @@ def prepare_scenario_strategies( # Expand the default aggregate into concrete strategies expanded = cls.normalize_strategies({default_aggregate}) # Wrap each in a ScenarioCompositeStrategy - composite_strategies = [ - ScenarioCompositeStrategy(strategies=[strategy]) - for strategy in expanded - ] + composite_strategies = [ScenarioCompositeStrategy(strategies=[strategy]) for strategy in expanded] else: # Process the provided strategies composite_strategies = [] @@ -259,9 +247,7 @@ def prepare_scenario_strategies( composite_strategies.append(item) elif isinstance(item, cls): # Bare strategy enum - wrap it in a composite - composite_strategies.append( - ScenarioCompositeStrategy(strategies=[item]) - ) + composite_strategies.append(ScenarioCompositeStrategy(strategies=[item])) else: # Not our strategy type - skip or could raise error # For now, skip to allow flexibility @@ -277,9 +263,7 @@ def prepare_scenario_strategies( ) # Normalize compositions (expands aggregates, validates compositions) - normalized = ScenarioCompositeStrategy.normalize_compositions( - composite_strategies, strategy_type=cls - ) + normalized = ScenarioCompositeStrategy.normalize_compositions(composite_strategies, strategy_type=cls) return normalized @@ -436,9 +420,7 @@ def extract_single_strategy_values( ValueError: If any composite contains multiple strategies. """ # Check that all composites are single-strategy - multi_strategy_composites = [ - comp for comp in composites if not comp.is_single_strategy - ] + multi_strategy_composites = [comp for comp in composites if not comp.is_single_strategy] if multi_strategy_composites: composite_names = [comp.name for comp in multi_strategy_composites] raise ValueError( @@ -541,20 +523,14 @@ def normalize_compositions( raise ValueError("Empty compositions are not allowed") # Filter to only strategies of the specified type - typed_strategies = [ - s for s in composite.strategies if isinstance(s, strategy_type) - ] + typed_strategies = [s for s in composite.strategies if isinstance(s, strategy_type)] if not typed_strategies: # No strategies of this type - skip continue # Check if composition contains any aggregates - aggregates_in_composition = [ - s for s in typed_strategies if s.value in aggregate_tags - ] - concretes_in_composition = [ - s for s in typed_strategies if s.value not in aggregate_tags - ] + aggregates_in_composition = [s for s in typed_strategies if s.value in aggregate_tags] + concretes_in_composition = [s for s in typed_strategies if s.value not in aggregate_tags] # Error if mixing aggregates with concrete strategies if aggregates_in_composition and concretes_in_composition: @@ -578,9 +554,7 @@ def normalize_compositions( expanded = strategy_type.normalize_strategies({aggregate}) # Each expanded strategy becomes its own composition for strategy in expanded: - normalized_compositions.append( - ScenarioCompositeStrategy(strategies=[strategy]) - ) + normalized_compositions.append(ScenarioCompositeStrategy(strategies=[strategy])) else: # Concrete composition - validate and preserve as-is strategy_type.validate_composition(typed_strategies) diff --git a/tests/unit/scenarios/test_scenario.py b/tests/unit/scenarios/test_scenario.py index fe87135bd..d81101a6c 100644 --- a/tests/unit/scenarios/test_scenario.py +++ b/tests/unit/scenarios/test_scenario.py @@ -27,9 +27,7 @@ def create_mock_run_async(attack_results): async def mock_run_async(*args, **kwargs): # Save results to memory (mimics what real attacks do) save_attack_results_to_memory(attack_results) - return AttackExecutorResult( - completed_results=attack_results, incomplete_objectives=[] - ) + return AttackExecutorResult(completed_results=attack_results, incomplete_objectives=[]) return AsyncMock(side_effect=mock_run_async) @@ -211,9 +209,7 @@ class TestScenarioInitialization2: """Tests for Scenario initialize_async method.""" @pytest.mark.asyncio - async def test_initialize_async_populates_atomic_attacks( - self, mock_atomic_attacks, mock_objective_target - ): + async def test_initialize_async_populates_atomic_attacks(self, mock_atomic_attacks, mock_objective_target): """Test that initialize_async populates atomic attacks.""" scenario = ConcreteScenario( name="Test Scenario", @@ -263,9 +259,7 @@ async def test_initialize_async_sets_max_retries(self, mock_objective_target): version=1, ) - await scenario.initialize_async( - objective_target=mock_objective_target, max_retries=3 - ) + await scenario.initialize_async(objective_target=mock_objective_target, max_retries=3) assert scenario._max_retries == 3 @@ -277,9 +271,7 @@ async def test_initialize_async_sets_max_concurrency(self, mock_objective_target version=1, ) - await scenario.initialize_async( - objective_target=mock_objective_target, max_concurrency=5 - ) + await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=5) assert scenario._max_concurrency == 5 @@ -292,9 +284,7 @@ async def test_initialize_async_sets_memory_labels(self, mock_objective_target): version=1, ) - await scenario.initialize_async( - objective_target=mock_objective_target, memory_labels=labels - ) + await scenario.initialize_async(objective_target=mock_objective_target, memory_labels=labels) assert scenario._memory_labels == labels @@ -318,9 +308,7 @@ class TestScenarioExecution: """Tests for Scenario execution methods.""" @pytest.mark.asyncio - async def test_run_async_executes_all_runs( - self, mock_atomic_attacks, sample_attack_results, mock_objective_target - ): + async def test_run_async_executes_all_runs(self, mock_atomic_attacks, sample_attack_results, mock_objective_target): """Test that run_async executes all atomic attacks sequentially.""" # Configure each run to return different results for i, run in enumerate(mock_atomic_attacks): @@ -341,9 +329,7 @@ async def test_run_async_executes_all_runs( # Verify all runs were executed with correct concurrency assert len(result.attack_results) == 3 for run in mock_atomic_attacks: - run.run_async.assert_called_once_with( - max_concurrency=10, return_partial_on_failure=True - ) + run.run_async.assert_called_once_with(max_concurrency=10, return_partial_on_failure=True) # Verify results are aggregated correctly by atomic attack name assert "attack_run_1" in result.attack_results @@ -366,17 +352,13 @@ async def test_run_async_with_custom_concurrency( version=1, atomic_attacks_to_return=mock_atomic_attacks, ) - await scenario.initialize_async( - objective_target=mock_objective_target, max_concurrency=5 - ) + await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=5) result = await scenario.run_async() # Verify max_concurrency was passed to each run for run in mock_atomic_attacks: - run.run_async.assert_called_once_with( - max_concurrency=5, return_partial_on_failure=True - ) + run.run_async.assert_called_once_with(max_concurrency=5, return_partial_on_failure=True) # Verify result structure assert isinstance(result, ScenarioResult) @@ -388,15 +370,9 @@ async def test_run_async_aggregates_multiple_results( ): """Test that results from multiple atomic attacks are properly aggregated.""" # Configure runs to return different numbers of results - mock_atomic_attacks[0].run_async = create_mock_run_async( - sample_attack_results[0:2] - ) - mock_atomic_attacks[1].run_async = create_mock_run_async( - sample_attack_results[2:4] - ) - mock_atomic_attacks[2].run_async = create_mock_run_async( - sample_attack_results[4:5] - ) + mock_atomic_attacks[0].run_async = create_mock_run_async(sample_attack_results[0:2]) + mock_atomic_attacks[1].run_async = create_mock_run_async(sample_attack_results[2:4]) + mock_atomic_attacks[2].run_async = create_mock_run_async(sample_attack_results[4:5]) scenario = ConcreteScenario( name="Test Scenario", @@ -415,19 +391,11 @@ async def test_run_async_aggregates_multiple_results( assert len(result.attack_results["attack_run_3"]) == 1 @pytest.mark.asyncio - async def test_run_async_stops_on_error( - self, mock_atomic_attacks, sample_attack_results, mock_objective_target - ): + async def test_run_async_stops_on_error(self, mock_atomic_attacks, sample_attack_results, mock_objective_target): """Test that execution stops when an atomic attack fails.""" - mock_atomic_attacks[0].run_async = create_mock_run_async( - [sample_attack_results[0]] - ) - mock_atomic_attacks[1].run_async = AsyncMock( - side_effect=Exception("Test error") - ) - mock_atomic_attacks[2].run_async = create_mock_run_async( - [sample_attack_results[2]] - ) + mock_atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) + mock_atomic_attacks[1].run_async = AsyncMock(side_effect=Exception("Test error")) + mock_atomic_attacks[2].run_async = create_mock_run_async([sample_attack_results[2]]) scenario = ConcreteScenario( name="Test Scenario", @@ -454,9 +422,7 @@ async def test_run_async_fails_without_initialization(self, mock_objective_targe version=1, ) - with pytest.raises( - ValueError, match="Cannot run scenario with no atomic attacks" - ): + with pytest.raises(ValueError, match="Cannot run scenario with no atomic attacks"): await scenario.run_async() @pytest.mark.asyncio @@ -502,9 +468,7 @@ def test_name_property(self, mock_objective_target): assert scenario.name == "My Test Scenario" @pytest.mark.asyncio - async def test_atomic_attack_count_property( - self, mock_atomic_attacks, mock_objective_target - ): + async def test_atomic_attack_count_property(self, mock_atomic_attacks, mock_objective_target): """Test that atomic_attack_count returns the correct count.""" scenario = ConcreteScenario( name="Test Scenario", @@ -519,9 +483,7 @@ async def test_atomic_attack_count_property( assert scenario.atomic_attack_count == 3 @pytest.mark.asyncio - async def test_atomic_attack_count_with_different_sizes( - self, mock_objective_target - ): + async def test_atomic_attack_count_with_different_sizes(self, mock_objective_target): """Test atomic_attack_count with different numbers of atomic attacks.""" # Create mock attack strategy mock_attack = MagicMock() @@ -679,9 +641,7 @@ def test_scenario_identifier_with_custom_pyrit_version(self): def test_scenario_identifier_with_init_data(self): """Test ScenarioIdentifier with init_data.""" init_data = {"param1": "value1", "param2": 42} - identifier = ScenarioIdentifier( - name="TestScenario", scenario_version=1, init_data=init_data - ) + identifier = ScenarioIdentifier(name="TestScenario", scenario_version=1, init_data=init_data) assert identifier.init_data == init_data @@ -759,9 +719,7 @@ class TestScenarioBaselineOnlyExecution: """Tests for baseline-only execution (empty strategies with include_baseline=True).""" @pytest.mark.asyncio - async def test_initialize_async_with_empty_strategies_and_baseline( - self, mock_objective_target - ): + async def test_initialize_async_with_empty_strategies_and_baseline(self, mock_objective_target): """Test that baseline-only execution works when include_baseline=True and strategies is empty.""" from pyrit.models import SeedAttackGroup, SeedObjective @@ -791,9 +749,7 @@ async def test_initialize_async_with_empty_strategies_and_baseline( assert scenario._atomic_attacks[0].atomic_attack_name == "baseline" @pytest.mark.asyncio - async def test_baseline_only_execution_runs_successfully( - self, mock_objective_target, sample_attack_results - ): + async def test_baseline_only_execution_runs_successfully(self, mock_objective_target, sample_attack_results): """Test that baseline-only scenario can run successfully.""" from pyrit.models import SeedAttackGroup, SeedObjective @@ -818,9 +774,7 @@ async def test_baseline_only_execution_runs_successfully( ) # Mock the baseline attack's run_async - scenario._atomic_attacks[0].run_async = create_mock_run_async( - [sample_attack_results[0]] - ) + scenario._atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) # Run the scenario result = await scenario.run_async() @@ -831,9 +785,7 @@ async def test_baseline_only_execution_runs_successfully( assert len(result.attack_results["baseline"]) == 1 @pytest.mark.asyncio - async def test_empty_strategies_without_baseline_allows_initialization( - self, mock_objective_target - ): + async def test_empty_strategies_without_baseline_allows_initialization(self, mock_objective_target): """Test that empty strategies without include_baseline allows initialization but fails at run time.""" scenario = ConcreteScenario( name="No Baseline Test", @@ -852,15 +804,11 @@ async def test_empty_strategies_without_baseline_allows_initialization( ) # But running should fail because there are no atomic attacks - with pytest.raises( - ValueError, match="Cannot run scenario with no atomic attacks" - ): + with pytest.raises(ValueError, match="Cannot run scenario with no atomic attacks"): await scenario.run_async() @pytest.mark.asyncio - async def test_standalone_baseline_uses_dataset_config_seeds( - self, mock_objective_target - ): + async def test_standalone_baseline_uses_dataset_config_seeds(self, mock_objective_target): """Test that standalone baseline uses seed groups from dataset_config.""" from pyrit.models import SeedAttackGroup, SeedObjective