diff --git a/README.md b/README.md index 7d35e0bc..655bc72a 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ challenge_final_results: bool = False judge_intervention: Optional[str] = None judge_metric: Optional[str] = None judge_endpoint_url: Optional[str] = None +judge_model_name: Optional[str] = None judge_api_key: str = "-" judge_always_intervene: bool = False ``` diff --git a/mallm/coordinator.py b/mallm/coordinator.py index e5499d00..8fe7fd8c 100644 --- a/mallm/coordinator.py +++ b/mallm/coordinator.py @@ -13,8 +13,8 @@ from mallm.agents.draftProposer import DraftProposer from mallm.agents.judge import Judge from mallm.agents.panelist import Panelist -from mallm.decision_protocol.protocol import DecisionProtocol -from mallm.discourse_policy.policy import DiscoursePolicy +from mallm.decision_protocols.protocol import DecisionProtocol +from mallm.discussion_paradigms.paradigm import DiscussionParadigm from mallm.models.Chat import Chat from mallm.models.discussion.ResponseGenerator import ResponseGenerator from mallm.models.discussion.SimpleResponseGenerator import SimpleResponseGenerator @@ -273,7 +273,7 @@ def discuss( raise Exception( f"No valid discourse policy for paradigm {config.discussion_paradigm}" ) - policy: DiscoursePolicy = DISCUSSION_PARADIGMS[config.discussion_paradigm]() + policy: DiscussionParadigm = DISCUSSION_PARADIGMS[config.discussion_paradigm]() logger.info( f"""Starting discussion with coordinator {self.id}... diff --git a/mallm/decision_protocol/__init__.py b/mallm/decision_protocols/__init__.py similarity index 100% rename from mallm/decision_protocol/__init__.py rename to mallm/decision_protocols/__init__.py diff --git a/mallm/decision_protocol/approval_voting.py b/mallm/decision_protocols/approval_voting.py similarity index 98% rename from mallm/decision_protocol/approval_voting.py rename to mallm/decision_protocols/approval_voting.py index 1659fd93..3f0ed559 100644 --- a/mallm/decision_protocol/approval_voting.py +++ b/mallm/decision_protocols/approval_voting.py @@ -3,7 +3,7 @@ from typing import Any, Optional from mallm.agents.panelist import Panelist -from mallm.decision_protocol.protocol import DecisionProtocol +from mallm.decision_protocols.protocol import DecisionProtocol from mallm.models.discussion.ResponseGenerator import ResponseGenerator from mallm.utils.config import Config from mallm.utils.enums import DecisionAlteration diff --git a/mallm/decision_protocol/consensus.py b/mallm/decision_protocols/consensus.py similarity index 98% rename from mallm/decision_protocol/consensus.py rename to mallm/decision_protocols/consensus.py index d85d0162..58fbcc47 100644 --- a/mallm/decision_protocol/consensus.py +++ b/mallm/decision_protocols/consensus.py @@ -2,7 +2,7 @@ from typing import Any, Optional from mallm.agents.panelist import Panelist -from mallm.decision_protocol.protocol import DecisionProtocol +from mallm.decision_protocols.protocol import DecisionProtocol from mallm.utils.config import Config from mallm.utils.enums import DecisionAlteration from mallm.utils.types import Agreement, VotingResult, VotingResultList, WorkerFunctions diff --git a/mallm/decision_protocol/consensus_voting.py b/mallm/decision_protocols/consensus_voting.py similarity index 99% rename from mallm/decision_protocol/consensus_voting.py rename to mallm/decision_protocols/consensus_voting.py index 924b72ab..38f786ae 100644 --- a/mallm/decision_protocol/consensus_voting.py +++ b/mallm/decision_protocols/consensus_voting.py @@ -5,7 +5,7 @@ from typing import Any, Optional from mallm.agents.panelist import Panelist -from mallm.decision_protocol.protocol import DecisionProtocol +from mallm.decision_protocols.protocol import DecisionProtocol from mallm.models.discussion.ResponseGenerator import ResponseGenerator from mallm.utils.config import Config from mallm.utils.enums import DecisionAlteration diff --git a/mallm/decision_protocol/cumulative_voting.py b/mallm/decision_protocols/cumulative_voting.py similarity index 98% rename from mallm/decision_protocol/cumulative_voting.py rename to mallm/decision_protocols/cumulative_voting.py index 1e1ebf0a..f92c7cac 100644 --- a/mallm/decision_protocol/cumulative_voting.py +++ b/mallm/decision_protocols/cumulative_voting.py @@ -3,7 +3,7 @@ from typing import Any, Optional from mallm.agents.panelist import Panelist -from mallm.decision_protocol.protocol import DecisionProtocol +from mallm.decision_protocols.protocol import DecisionProtocol from mallm.models.discussion.ResponseGenerator import ResponseGenerator from mallm.utils.config import Config from mallm.utils.enums import DecisionAlteration diff --git a/mallm/decision_protocol/summary.py b/mallm/decision_protocols/judge.py similarity index 93% rename from mallm/decision_protocol/summary.py rename to mallm/decision_protocols/judge.py index d7918f1e..4e74d033 100644 --- a/mallm/decision_protocol/summary.py +++ b/mallm/decision_protocols/judge.py @@ -2,7 +2,7 @@ from typing import Any, Optional from mallm.agents.panelist import Panelist -from mallm.decision_protocol.protocol import DecisionProtocol +from mallm.decision_protocols.protocol import DecisionProtocol from mallm.models.discussion.ResponseGenerator import ResponseGenerator from mallm.utils.config import Config from mallm.utils.enums import DecisionAlteration @@ -11,9 +11,9 @@ logger = logging.getLogger("mallm") -class Summary(DecisionProtocol): +class Judge(DecisionProtocol): """ - The Summary decision protocol creates a summary of all answers after a certain number of turns. + The Judge decision protocol creates a summary of all answers after a certain number of turns. """ _name = "summary" diff --git a/mallm/decision_protocol/protocol.py b/mallm/decision_protocols/protocol.py similarity index 100% rename from mallm/decision_protocol/protocol.py rename to mallm/decision_protocols/protocol.py diff --git a/mallm/decision_protocol/ranked_voting.py b/mallm/decision_protocols/ranked_voting.py similarity index 98% rename from mallm/decision_protocol/ranked_voting.py rename to mallm/decision_protocols/ranked_voting.py index 0174a8be..792c77c5 100644 --- a/mallm/decision_protocol/ranked_voting.py +++ b/mallm/decision_protocols/ranked_voting.py @@ -2,7 +2,7 @@ from typing import Any, Optional from mallm.agents.panelist import Panelist -from mallm.decision_protocol.protocol import DecisionProtocol +from mallm.decision_protocols.protocol import DecisionProtocol from mallm.models.discussion.ResponseGenerator import ResponseGenerator from mallm.utils.config import Config from mallm.utils.enums import DecisionAlteration diff --git a/mallm/decision_protocol/simple_voting.py b/mallm/decision_protocols/simple_voting.py similarity index 98% rename from mallm/decision_protocol/simple_voting.py rename to mallm/decision_protocols/simple_voting.py index e71a096d..a057cb17 100644 --- a/mallm/decision_protocol/simple_voting.py +++ b/mallm/decision_protocols/simple_voting.py @@ -3,7 +3,7 @@ from typing import Any, Optional from mallm.agents.panelist import Panelist -from mallm.decision_protocol.protocol import DecisionProtocol +from mallm.decision_protocols.protocol import DecisionProtocol from mallm.models.discussion.ResponseGenerator import ResponseGenerator from mallm.utils.config import Config from mallm.utils.enums import DecisionAlteration diff --git a/mallm/discourse_policy/__init__.py b/mallm/discussion_paradigms/__init__.py similarity index 100% rename from mallm/discourse_policy/__init__.py rename to mallm/discussion_paradigms/__init__.py diff --git a/mallm/discourse_policy/collective_refinement.py b/mallm/discussion_paradigms/collective_refinement.py similarity index 98% rename from mallm/discourse_policy/collective_refinement.py rename to mallm/discussion_paradigms/collective_refinement.py index a6886d26..b13ab972 100644 --- a/mallm/discourse_policy/collective_refinement.py +++ b/mallm/discussion_paradigms/collective_refinement.py @@ -8,7 +8,7 @@ from mallm.agents.draftProposer import DraftProposer from mallm.agents.panelist import Panelist -from mallm.discourse_policy.policy import DiscoursePolicy +from mallm.discussion_paradigms.paradigm import DiscussionParadigm from mallm.utils.types import Agreement, Memory, TemplateFilling, VotingResultList if TYPE_CHECKING: @@ -18,7 +18,7 @@ logger = logging.getLogger("mallm") -class CollectiveRefinement(DiscoursePolicy): +class CollectiveRefinement(DiscussionParadigm): """ A discussion protocol where agents improve their answers through multiple rounds of feedback. diff --git a/mallm/discourse_policy/debate.py b/mallm/discussion_paradigms/debate.py similarity index 98% rename from mallm/discourse_policy/debate.py rename to mallm/discussion_paradigms/debate.py index a0ed5ddf..4e0243a7 100644 --- a/mallm/discourse_policy/debate.py +++ b/mallm/discussion_paradigms/debate.py @@ -8,7 +8,7 @@ from mallm.agents.draftProposer import DraftProposer from mallm.agents.judge import Judge from mallm.agents.panelist import Panelist -from mallm.discourse_policy.policy import DiscoursePolicy +from mallm.discussion_paradigms.paradigm import DiscussionParadigm from mallm.utils.types import Agreement, TemplateFilling, VotingResultList if TYPE_CHECKING: @@ -18,7 +18,7 @@ logger = logging.getLogger("mallm") -class DiscourseDebate(DiscoursePolicy): +class DiscussionDebate(DiscussionParadigm): def draft_proposer_call( self, draft_proposer: DraftProposer, @@ -120,7 +120,8 @@ def discuss( logger.debug( f"Discussion {coordinator.id} goes into debate round: {r!s}" ) - debate_agreements: list[Agreement] = [] + if r == 0: + debate_agreements = self.agreements for i, a in enumerate( coordinator.agents[1:] ): # similar to relay paradigm diff --git a/mallm/discourse_policy/memory.py b/mallm/discussion_paradigms/memory.py similarity index 94% rename from mallm/discourse_policy/memory.py rename to mallm/discussion_paradigms/memory.py index 4922739f..0a21aec1 100644 --- a/mallm/discourse_policy/memory.py +++ b/mallm/discussion_paradigms/memory.py @@ -5,7 +5,7 @@ from mallm.agents.draftProposer import DraftProposer from mallm.agents.panelist import Panelist -from mallm.discourse_policy.policy import DiscoursePolicy +from mallm.discussion_paradigms.paradigm import DiscussionParadigm from mallm.utils.types import TemplateFilling if TYPE_CHECKING: @@ -13,7 +13,7 @@ logger = logging.getLogger("mallm") -class DiscourseMemory(DiscoursePolicy): +class DiscussionMemory(DiscussionParadigm): def __init__(self) -> None: super().__init__( """Paradigm: Memory diff --git a/mallm/discourse_policy/policy.py b/mallm/discussion_paradigms/paradigm.py similarity index 99% rename from mallm/discourse_policy/policy.py rename to mallm/discussion_paradigms/paradigm.py index 43a19a2b..94fec7a6 100644 --- a/mallm/discourse_policy/policy.py +++ b/mallm/discussion_paradigms/paradigm.py @@ -19,7 +19,7 @@ logger = logging.getLogger("mallm") -class DiscoursePolicy(ABC): +class DiscussionParadigm(ABC): def __init__(self, paradigm_str: str = "") -> None: self.paradigm_str = paradigm_str self.decision = False diff --git a/mallm/discourse_policy/relay.py b/mallm/discussion_paradigms/relay.py similarity index 95% rename from mallm/discourse_policy/relay.py rename to mallm/discussion_paradigms/relay.py index 1bde0a72..ddf3a3be 100644 --- a/mallm/discourse_policy/relay.py +++ b/mallm/discussion_paradigms/relay.py @@ -5,7 +5,7 @@ from mallm.agents.draftProposer import DraftProposer from mallm.agents.panelist import Panelist -from mallm.discourse_policy.policy import DiscoursePolicy +from mallm.discussion_paradigms.paradigm import DiscussionParadigm from mallm.utils.types import TemplateFilling if TYPE_CHECKING: @@ -13,7 +13,7 @@ logger = logging.getLogger("mallm") -class DiscourseRelay(DiscoursePolicy): +class DiscussionRelay(DiscussionParadigm): def __init__(self) -> None: super().__init__( """Paradigm: Relay diff --git a/mallm/discourse_policy/report.py b/mallm/discussion_paradigms/report.py similarity index 96% rename from mallm/discourse_policy/report.py rename to mallm/discussion_paradigms/report.py index 968a81ab..9a6e4e4b 100644 --- a/mallm/discourse_policy/report.py +++ b/mallm/discussion_paradigms/report.py @@ -5,7 +5,7 @@ from mallm.agents.draftProposer import DraftProposer from mallm.agents.panelist import Panelist -from mallm.discourse_policy.policy import DiscoursePolicy +from mallm.discussion_paradigms.paradigm import DiscussionParadigm from mallm.utils.types import TemplateFilling if TYPE_CHECKING: @@ -13,7 +13,7 @@ logger = logging.getLogger("mallm") -class DiscourseReport(DiscoursePolicy): +class DiscussionReport(DiscussionParadigm): def __init__(self) -> None: super().__init__( """Paradigm: Report diff --git a/mallm/utils/dicts.py b/mallm/utils/dicts.py index 17dfb1b9..3bdcfef9 100644 --- a/mallm/utils/dicts.py +++ b/mallm/utils/dicts.py @@ -1,22 +1,22 @@ -from mallm.decision_protocol.approval_voting import ApprovalVoting -from mallm.decision_protocol.consensus import ( +from mallm.decision_protocols.approval_voting import ApprovalVoting +from mallm.decision_protocols.consensus import ( HybridMajorityConsensus, MajorityConsensus, SupermajorityConsensus, UnanimityConsensus, ) -from mallm.decision_protocol.consensus_voting import ConsensusVoting -from mallm.decision_protocol.cumulative_voting import CumulativeVoting -from mallm.decision_protocol.protocol import DecisionProtocol -from mallm.decision_protocol.ranked_voting import RankedVoting -from mallm.decision_protocol.simple_voting import SimpleVoting -from mallm.decision_protocol.summary import Summary -from mallm.discourse_policy.collective_refinement import CollectiveRefinement -from mallm.discourse_policy.debate import DiscourseDebate -from mallm.discourse_policy.memory import DiscourseMemory -from mallm.discourse_policy.policy import DiscoursePolicy -from mallm.discourse_policy.relay import DiscourseRelay -from mallm.discourse_policy.report import DiscourseReport +from mallm.decision_protocols.consensus_voting import ConsensusVoting +from mallm.decision_protocols.cumulative_voting import CumulativeVoting +from mallm.decision_protocols.protocol import DecisionProtocol +from mallm.decision_protocols.ranked_voting import RankedVoting +from mallm.decision_protocols.simple_voting import SimpleVoting +from mallm.decision_protocols.judge import Judge +from mallm.discussion_paradigms.collective_refinement import CollectiveRefinement +from mallm.discussion_paradigms.debate import DiscussionDebate +from mallm.discussion_paradigms.memory import DiscussionMemory +from mallm.discussion_paradigms.paradigm import DiscussionParadigm +from mallm.discussion_paradigms.relay import DiscussionRelay +from mallm.discussion_paradigms.report import DiscussionReport from mallm.models.discussion.CriticalResponseGenerator import CriticalResponseGenerator from mallm.models.discussion.FreeTextResponseGenerator import FreeTextResponseGenerator from mallm.models.discussion.ReasoningResponseGenerator import ( @@ -43,14 +43,14 @@ "cumulative_voting": CumulativeVoting, "ranked_voting": RankedVoting, "consensus_voting": ConsensusVoting, - "summary": Summary, + "summary": Judge, } -DISCUSSION_PARADIGMS: dict[str, type[DiscoursePolicy]] = { - "memory": DiscourseMemory, - "report": DiscourseReport, - "relay": DiscourseRelay, - "debate": DiscourseDebate, +DISCUSSION_PARADIGMS: dict[str, type[DiscussionParadigm]] = { + "memory": DiscussionMemory, + "report": DiscussionReport, + "relay": DiscussionRelay, + "debate": DiscussionDebate, "collective_refinement": CollectiveRefinement, } diff --git a/test/decision_protocol/test_majority_consensus.py b/test/decision_protocol/test_majority_consensus.py index 3fc48bea..6822672c 100644 --- a/test/decision_protocol/test_majority_consensus.py +++ b/test/decision_protocol/test_majority_consensus.py @@ -1,7 +1,7 @@ from mallm.agents.draftProposer import DraftProposer from mallm.agents.panelist import Panelist from mallm.coordinator import Coordinator -from mallm.decision_protocol.consensus import ( +from mallm.decision_protocols.consensus import ( HybridMajorityConsensus, UnanimityConsensus, )