Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 71 additions & 26 deletions pyrit/scenario/core/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import textwrap
import uuid
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Sequence, Set, Type, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Type, Union

from tqdm.auto import tqdm

Expand All @@ -32,6 +32,10 @@
)
from pyrit.score import Scorer

if TYPE_CHECKING:
from pyrit.executor.attack.core.attack_config import AttackScoringConfig
from pyrit.models import SeedAttackGroup

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -228,14 +232,16 @@ async def initialize_async(
self._memory_labels = memory_labels or {}

# Prepare scenario strategies using the stored configuration
# Allow empty strategies when include_baseline is True (baseline-only execution)
self._scenario_composites = self._strategy_class.prepare_scenario_strategies(
scenario_strategies, default_aggregate=self.get_default_strategy()
scenario_strategies,
default_aggregate=self.get_default_strategy(),
)

self._atomic_attacks = await self._get_atomic_attacks_async()

if self._include_baseline:
baseline_attack = self._get_baseline_from_first_attack()
baseline_attack = self._get_baseline()
self._atomic_attacks.insert(0, baseline_attack)

# Store original objectives for each atomic attack (before any mutations during execution)
Expand Down Expand Up @@ -281,34 +287,21 @@ async def initialize_async(
self._scenario_result_id = str(result.id)
logger.info(f"Created new scenario result with ID: {self._scenario_result_id}")

def _get_baseline_from_first_attack(self) -> AtomicAttack:
def _get_baseline(self) -> AtomicAttack:
"""
Get a baseline AtomicAttack, which simply sends all the objectives without any modifications.

If other atomic attacks exist, derives baseline data from the first attack.
Otherwise, creates a standalone baseline from the dataset configuration and scenario settings.

Returns:
AtomicAttack: The baseline AtomicAttack instance.

Raises:
ValueError: If no atomic attacks are available to derive baseline from.
ValueError: If required data (seed_groups, objective_target, attack_scoring_config)
is not available.
"""
if not self._atomic_attacks or len(self._atomic_attacks) == 0:
raise ValueError("No atomic attacks available to derive baseline from.")

first_attack = self._atomic_attacks[0]

# Copy seed_groups, scoring, target from the first attack
seed_groups = first_attack.seed_groups
attack_scoring_config = first_attack._attack.get_attack_scoring_config()
objective_target = first_attack._attack.get_objective_target()

if not seed_groups or len(seed_groups) == 0:
raise ValueError("First atomic attack must have seed_groups to create baseline.")

if not objective_target:
raise ValueError("Objective target is required to create baseline attack.")

if not attack_scoring_config:
raise ValueError("Attack scoring config is required to create baseline attack.")
seed_groups, attack_scoring_config, objective_target = self._get_baseline_data()

# Create baseline attack with no converters
attack = PromptSendingAttack(
Expand All @@ -323,6 +316,55 @@ def _get_baseline_from_first_attack(self) -> AtomicAttack:
memory_labels=self._memory_labels,
)

def _get_baseline_data(self) -> Tuple[List["SeedAttackGroup"], "AttackScoringConfig", PromptTarget]:
"""
Get the data needed to create a baseline attack.

Returns either the first attack's data or the scenario-level data
depending on whether other atomic attacks exist.

Returns:
Tuple containing (seed_groups, attack_scoring_config, objective_target)

Raises:
ValueError: If required data is not available.
"""
if self._atomic_attacks and len(self._atomic_attacks) > 0:
# Derive from first attack
first_attack = self._atomic_attacks[0]
seed_groups = first_attack.seed_groups
attack_scoring_config = first_attack._attack.get_attack_scoring_config()
objective_target = first_attack._attack.get_objective_target()
else:
# Create from scenario-level settings
if not self._objective_target:
raise ValueError("Objective target is required to create baseline attack.")
if not self._dataset_config:
raise ValueError("Dataset config is required to create baseline attack.")
if not self._objective_scorer:
raise ValueError("Objective scorer is required to create baseline attack.")

seed_groups = self._dataset_config.get_all_seed_attack_groups()
objective_target = self._objective_target

# Import here to avoid circular imports
from typing import cast

from pyrit.executor.attack.core.attack_config import AttackScoringConfig
from pyrit.score import TrueFalseScorer

attack_scoring_config = AttackScoringConfig(objective_scorer=cast(TrueFalseScorer, self._objective_scorer))

# Validate required data
if not seed_groups or len(seed_groups) == 0:
raise ValueError("Seed groups are required to create baseline attack.")
if not objective_target:
raise ValueError("Objective target is required to create baseline attack.")
if not attack_scoring_config:
raise ValueError("Attack scoring config is required to create baseline attack.")

return seed_groups, attack_scoring_config, objective_target

def _raise_dataset_exception(self) -> None:
error_msg = textwrap.dedent(
f"""
Expand Down Expand Up @@ -650,7 +692,8 @@ async def _execute_scenario_async(self) -> ScenarioResult:

try:
atomic_results = await atomic_attack.run_async(
max_concurrency=self._max_concurrency, return_partial_on_failure=True
max_concurrency=self._max_concurrency,
return_partial_on_failure=True,
)

# Always save completed results, even if some objectives didn't complete
Expand All @@ -677,7 +720,8 @@ async def _execute_scenario_async(self) -> ScenarioResult:

# Mark scenario as failed
self._memory.update_scenario_run_state(
scenario_result_id=scenario_result_id, scenario_run_state="FAILED"
scenario_result_id=scenario_result_id,
scenario_run_state="FAILED",
)

# Raise exception with detailed information
Expand All @@ -703,7 +747,8 @@ async def _execute_scenario_async(self) -> ScenarioResult:
scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id])
if scenario_results and scenario_results[0].scenario_run_state != "FAILED":
self._memory.update_scenario_run_state(
scenario_result_id=scenario_result_id, scenario_run_state="FAILED"
scenario_result_id=scenario_result_id,
scenario_run_state="FAILED",
)

raise
Expand Down
5 changes: 5 additions & 0 deletions pyrit/scenario/core/scenario_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,14 @@ def prepare_scenario_strategies(
strategies (Sequence[T | ScenarioCompositeStrategy] | None): The strategies to prepare.
Can be a mix of bare strategy enums and composite strategies.
If None, uses default_aggregate to determine defaults.
If an empty sequence, returns an empty list (useful for baseline-only execution).
default_aggregate (T | None): The aggregate strategy to use when strategies is None.
Common values: MyStrategy.ALL, MyStrategy.EASY. If None when strategies is None,
raises ValueError.

Returns:
List[ScenarioCompositeStrategy]: Normalized list of composite strategies ready for use.
May be empty if an empty sequence was explicitly provided.

Raises:
ValueError: If strategies is None and default_aggregate is None, or if compositions
Expand Down Expand Up @@ -251,7 +253,10 @@ def prepare_scenario_strategies(
# For now, skip to allow flexibility
pass

# Allow empty list if explicitly provided (for baseline-only execution)
if not composite_strategies:
if strategies is not None and len(strategies) == 0:
return []
raise ValueError(
f"No valid {cls.__name__} strategies provided. "
f"Provide at least one {cls.__name__} enum or ScenarioCompositeStrategy."
Expand Down
Loading