From 49af57b059ef8c6452144232f9e83c720b1b9cf4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 07:41:05 +0000 Subject: [PATCH 01/13] Add 3 new tasks, opponent simulation, enhanced grading, and comprehensive tests - Add confidentiality/NDA (medium+), termination (hard++), data protection (expert) tasks - Add opponent simulation with contextual counterparty responses per action type - Enhance grading with semantic similarity (cosine+Jaccard) and clause completeness scoring - Add 3 new task-specific graders: grade_medium_plus, grade_hard_plus2, grade_expert - Add required_elements field for completeness scoring per task - Update reward formula: 35% correctness + 25% improvement + 25% risk_alignment + 10% semantic + 5% completeness - Improve inference script with adaptive multi-turn strategy for all 8 tasks - Add 21 new tests (42 total, all passing) - Update openenv.yaml, README, and all exports Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/e3c399ed-eac8-40b6-9c61-196e32f44385 Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- README.md | 98 ++++++--- contract_env/env/__init__.py | 14 ++ contract_env/env/environment.py | 20 ++ contract_env/env/graders.py | 153 +++++++++++++- contract_env/env/tasks.py | 311 +++++++++++++++++++++++++++++ contract_env/tests/test_api.py | 4 +- contract_env/tests/test_graders.py | 92 +++++++++ contract_env/tests/test_smoke.py | 74 +++++++ inference.py | 92 +++++++-- openenv.yaml | 6 + 10 files changed, 813 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index 4e642b7..df63717 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ lawyers, procurement teams, and founders. Key challenges for an AI agent: - **Partial-progress rewards**: improving a clause partially (e.g., adding a liability cap without addressing IP ownership) deserves more reward than doing nothing — but less than resolving every risk. +- **Multi-turn dynamics**: the counterparty pushes back on proposals, requiring + adaptive negotiation strategies across multiple rounds. --- @@ -39,6 +41,9 @@ lawyers, procurement teams, and founders. Key challenges for an AI agent: | `hard_conflicting_obligations` | Hard (4/5) | Performance/Changes | HIGH | Yes | | `easy_compliance_agreement` | Easy+ (2/5) | Compliance | LOW | No | | `hard_intellectual_property` | Hard+ (5/5) | IP Ownership | HIGH | Yes | +| `medium_confidentiality_nda` | Medium+ (3/5) | Confidentiality | MODERATE | Yes | +| `hard_termination_convenience` | Hard++ (4/5) | Termination | HIGH | Yes | +| `expert_data_protection` | Expert (5/5) | Data Protection | HIGH | Yes | ### Task descriptions @@ -63,6 +68,43 @@ improvement: adding explicit breach-notification obligations ("+6% bonus for the customer provides specifications. The agent must rewrite to assign IP to the customer and limit the supplier to a scoped license. +**medium_confidentiality_nda** — An overbroad NDA with perpetual obligations and +no carve-outs for public information. The agent must narrow the scope, add a +time limit (3 years), and carve out publicly available and independently +developed information. + +**hard_termination_convenience** — A one-sided termination clause allowing only +the Supplier to terminate at will with 5-day notice, while the Customer has no +termination rights and waives all remedies. The agent must establish mutual +termination, add a 30-day cure period, and include transition/wind-down +provisions. + +**expert_data_protection** — A clause giving the Supplier blanket authority to +process personal data, transfer it to any jurisdiction, engage sub-processors +without notice, and waive data-subject rights. The agent must add DPA +requirements, 72-hour breach notification, sub-processor consent, data-subject +rights assistance, and data deletion obligations. + +--- + +## Opponent Simulation + +Each task includes **opponent responses** keyed by action type. When the agent +takes an action (e.g., `FLAG_RISK`, `EDIT_CLAUSE`), the counterparty replies +with a contextually appropriate pushback or counter-proposal, creating realistic +multi-turn negotiation dynamics: + +``` +agent → FLAG_RISK +opponent → "Our legal team considers this standard. What specific cap do you propose?" +agent → EDIT_CLAUSE (with cap at 12 months) +opponent → "We can accept a cap but consequential damages must remain." +agent → PROPOSE_COUNTER (addressing consequential damages) +... +``` + +Opponent replies appear in the `negotiation_history` and in `info.opponent_reply`. + --- ## Observation Space @@ -72,12 +114,13 @@ Every call to `/reset` or `/step` returns an `Observation`: ```json { "contract_text": "string — the current clause text (may be rewritten after EDIT/PROPOSE)", - "clause_type": "string — e.g. liability, term_renewal, intellectual_property", + "clause_type": "string — e.g. liability, term_renewal, intellectual_property, confidentiality, termination, data_protection", "risk_level": "float ∈ (0, 1) — observed risk density (0=safe, 1=highly risky)", "step_count": "int — steps taken so far (0 = just reset)", "negotiation_history": [ "opponent|[Counterparty] Unlimited indemnity is standard.", "agent|step=1 action=FLAG_RISK content_len=0", + "opponent|[Counterparty] Our legal team considers this standard.", "..." ] } @@ -109,16 +152,20 @@ Sending empty content returns a validation error and a near-zero reward. Every step returns a scalar `reward ∈ (0.001, 0.999)`, computed as: ``` -reward = 0.40 × correctness - + 0.30 × improvement - + 0.30 × risk_alignment +reward = 0.35 × correctness + + 0.25 × improvement + + 0.25 × risk_alignment + + 0.10 × semantic_similarity + + 0.05 × completeness ``` | Component | What it measures | |-----------|-----------------| -| **Correctness** (40%) | For EDIT/PROPOSE: how much risky language was *removed* from the original. For FLAG/REJECT/ACCEPT: how many risk keywords are identified in context. | -| **Improvement** (30%) | How well the proposed edit matches safe keywords and the expected safe rewrite. | -| **Risk Alignment** (30%) | Whether the chosen action is appropriate for the current risk level (e.g., editing a HIGH-risk clause scores 0.92×; accepting it scores 0.20×). | +| **Correctness** (35%) | For EDIT/PROPOSE: how much risky language was *removed* from the original. For FLAG/REJECT/ACCEPT: how many risk keywords are identified in context. | +| **Improvement** (25%) | How well the proposed edit matches safe keywords and the expected safe rewrite. | +| **Risk Alignment** (25%) | Whether the chosen action is appropriate for the current risk level (e.g., editing a HIGH-risk clause scores 0.92×; accepting it scores 0.20×). | +| **Semantic Similarity** (10%) | Combined Jaccard + cosine similarity between the rewrite and the expected safe edit. | +| **Completeness** (5%) | Fraction of required legal elements present in the rewritten clause (e.g., liability cap, notice period, cure clause). | ### Task-specific adjustments @@ -129,6 +176,9 @@ reward = 0.40 × correctness | Hard | −50% penalty when hidden trap markers remain in the proposed text | | Easy+ | +6% bonus for including breach-notification language | | Hard+ | −45% penalty for unresolved IP traps; +7% bonus for explicit customer ownership | +| Medium+ | +8% bonus for well-scoped NDA; −30% for accepting overbroad terms | +| Hard++ | −45% penalty for unresolved one-sided termination; +9% for cure-period language | +| Expert | −50% penalty for missing data-protection safeguards; +10% for GDPR language (requires ≥2 indicators) | Blocked accepts (accepting HIGH-risk text) are clamped to `0.001`. @@ -139,22 +189,6 @@ An episode is considered successful if `score ≥ 0.50`. --- -## Reference Baseline Scores - -Measured over 1 episode per task with `Qwen/Qwen2.5-72B-Instruct`: - -| Task | Avg reward/step | Episode score | -|------|----------------|---------------| -| `easy_unlimited_liability` | 0.64 | 0.64 | -| `medium_auto_renewal` | 0.58 | 0.58 | -| `hard_conflicting_obligations` | 0.45 | 0.45 | -| `easy_compliance_agreement` | 0.61 | 0.61 | -| `hard_intellectual_property` | 0.42 | 0.42 | - -A random agent achieves approximately 0.28 average per step across all tasks. - ---- - ## API Endpoints | Method | Path | Description | @@ -175,7 +209,7 @@ A random agent achieves approximately 0.28 average per step across all tasks. ```bash pip install -e ".[dev]" -python -m unittest discover contract_env/tests/ -v +python -m pytest contract_env/tests/ -v # 42 tests ``` ### Run the server @@ -188,7 +222,7 @@ uvicorn contract_env.server.app:app --host 0.0.0.0 --port 7860 ```bash export HF_TOKEN="your-huggingface-token" -python inference.py --benchmark # one episode per task (5 total) +python inference.py --benchmark # one episode per task (8 total) python inference.py --episodes 3 # run 3 episodes cycling through tasks ``` @@ -221,19 +255,19 @@ docker run -p 7860:7860 \ ``` contract_env/ ├── env/ -│ ├── environment.py # ContractEnv — reset/step/state, 7-step episodes -│ ├── graders.py # evaluate_action() + 5 task-specific grader functions +│ ├── environment.py # ContractEnv — reset/step/state, 7-step episodes, opponent simulation +│ ├── graders.py # evaluate_action() + 8 task-specific grader functions + semantic/completeness scoring │ ├── models.py # Pydantic v2 models: Action, Observation, Reward -│ └── tasks.py # 5 NegotiationTask definitions with metadata +│ └── tasks.py # 8 NegotiationTask definitions with metadata + opponent responses ├── server/ │ └── app.py # FastAPI server (port 7860) ├── tests/ -│ ├── test_graders.py # 13 unit tests covering all grader edge cases +│ ├── test_graders.py # 28 unit tests covering all grader edge cases + new metrics │ ├── test_api.py # API endpoint tests -│ └── test_smoke.py # Smoke tests +│ └── test_smoke.py # 12 smoke tests including opponent simulation + new tasks └── client.py # HTTP client helper -inference.py # LLM-driven baseline agent -openenv.yaml # OpenEnv manifest (spec_version: 1) +inference.py # LLM-driven baseline agent with adaptive multi-turn strategy +openenv.yaml # OpenEnv manifest (spec_version: 1, 8 graded tasks) Dockerfile # Python 3.10-slim container, port 7860 verify_graders.py # Pre-submission grader validation script ``` diff --git a/contract_env/env/__init__.py b/contract_env/env/__init__.py index 86c331c..d47b22f 100644 --- a/contract_env/env/__init__.py +++ b/contract_env/env/__init__.py @@ -5,6 +5,13 @@ grade_easy, grade_medium, grade_hard, + grade_easy_plus, + grade_hard_plus, + grade_medium_plus, + grade_hard_plus2, + grade_expert, + clause_completeness_score, + semantic_similarity, TASK_GRADERS, GRADED_TASKS, NUM_GRADED_TASKS, @@ -33,6 +40,13 @@ "grade_easy", "grade_medium", "grade_hard", + "grade_easy_plus", + "grade_hard_plus", + "grade_medium_plus", + "grade_hard_plus2", + "grade_expert", + "clause_completeness_score", + "semantic_similarity", "TASK_GRADERS", "GRADED_TASKS", "NUM_GRADED_TASKS", diff --git a/contract_env/env/environment.py b/contract_env/env/environment.py index 2988e66..0bd5079 100644 --- a/contract_env/env/environment.py +++ b/contract_env/env/environment.py @@ -23,6 +23,7 @@ def __init__(self) -> None: self.current_step: int = 0 self.done: bool = False self.state_data: dict[str, Any] = {} + self._rng = random.Random(42) @property def tasks(self) -> list[str]: @@ -34,6 +35,19 @@ def graders(self) -> dict: from contract_env.env.graders import TASK_GRADERS return TASK_GRADERS + def _opponent_reply(self, action_type: str) -> Optional[str]: + """Generate an opponent response based on the action taken. + + If the current task defines opponent_responses for this action_type, + pick one at random. Otherwise return None. + """ + if self.current_task is None: + return None + responses = self.current_task.opponent_responses.get(action_type, []) + if not responses: + return None + return self._rng.choice(responses) + def reset(self) -> Observation: self.done = False self.current_step = 0 @@ -123,6 +137,12 @@ def step(self, action: Action) -> Tuple[Observation, float, bool, dict[str, Any] ) self.state_data["negotiation_history"].append(entry) + # Opponent simulation: add a counterparty reply to the history + opp_reply = self._opponent_reply(action.action_type) + if opp_reply: + self.state_data["negotiation_history"].append(f"opponent|{opp_reply}") + info["opponent_reply"] = opp_reply + if action.action_type == "EDIT_CLAUSE": self.state_data["contract_text"] = (action.content or "").strip() diff --git a/contract_env/env/graders.py b/contract_env/env/graders.py index ef2416f..89f7def 100644 --- a/contract_env/env/graders.py +++ b/contract_env/env/graders.py @@ -1,6 +1,8 @@ from __future__ import annotations +import math import re +from collections import Counter from typing import Any, Tuple, TYPE_CHECKING from contract_env.env.models import Action, Reward @@ -9,6 +11,8 @@ from contract_env.env.tasks import NegotiationTask +# ── TEXT ANALYSIS UTILITIES ────────────────────────────────────────────── + def tokenize(text: str) -> list[str]: return re.findall(r"[a-z0-9]+", text.lower()) @@ -25,6 +29,21 @@ def token_overlap_ratio(a: str, b: str) -> float: return len(sa & sb) / len(sa | sb) +def _cosine_similarity(a: str, b: str) -> float: + """Compute cosine similarity between two texts using term frequency vectors.""" + ta, tb = tokenize(a), tokenize(b) + if not ta or not tb: + return 0.0 + ca, cb = Counter(ta), Counter(tb) + all_tokens = set(ca) | set(cb) + dot = sum(ca.get(t, 0) * cb.get(t, 0) for t in all_tokens) + mag_a = math.sqrt(sum(v * v for v in ca.values())) + mag_b = math.sqrt(sum(v * v for v in cb.values())) + if mag_a == 0 or mag_b == 0: + return 0.0 + return dot / (mag_a * mag_b) + + def _weighted_risk_hits(text: str, risk_keywords: list[str]) -> float: low = text.lower() if not risk_keywords: @@ -52,6 +71,30 @@ def _safe_overlap(text: str, safe_keywords: list[str], expected_safe: str) -> fl return min(1.0, 0.5 * kw_score + 0.5 * exp_score) +def clause_completeness_score(text: str, required_elements: list[str]) -> float: + """Score how many required legal elements are present in the clause text. + + Args: + text: The clause text to evaluate. + required_elements: Lowercase phrases that a well-drafted clause should contain. + + Returns: + Float in [0, 1] — fraction of required elements found. + """ + if not required_elements: + return 1.0 + low = text.lower() + found = sum(1 for elem in required_elements if elem in low) + return found / len(required_elements) + + +def semantic_similarity(text: str, reference: str) -> float: + """Combined semantic similarity using Jaccard + cosine similarity.""" + jaccard = token_overlap_ratio(text, reference) + cosine = _cosine_similarity(text, reference) + return 0.5 * jaccard + 0.5 * cosine + + def trap_unresolved(task: NegotiationTask, contract_text: str) -> bool: low = contract_text.lower() return any(m in low for m in task.trap_markers) @@ -128,8 +171,29 @@ def evaluate_action( eff_high = effective_risk_high(task, proposed_contract_text) risk_al = action_risk_alignment(action.action_type, eff_high, task) + # Semantic quality bonus: reward rewrites that are semantically close to the + # expected safe edit (uses combined Jaccard + cosine similarity). + sem_bonus = 0.0 + if action.action_type in ("EDIT_CLAUSE", "PROPOSE_COUNTER") and content: + sem_bonus = semantic_similarity(content, task.expected_safe_edit) + + # Completeness bonus: reward rewrites that include required legal elements + # defined on the task (if any). + completeness = 0.0 + required_elems: list[str] = getattr(task, "required_elements", []) + if action.action_type in ("EDIT_CLAUSE", "PROPOSE_COUNTER") and content and required_elems: + completeness = clause_completeness_score(content, required_elems) + # ---------- FINAL SCORE ---------- - score = 0.4 * correctness + 0.3 * improvement + 0.3 * risk_al + # Weights: 35% correctness + 25% improvement + 25% risk_alignment + # + 10% semantic similarity + 5% completeness + score = ( + 0.35 * correctness + + 0.25 * improvement + + 0.25 * risk_al + + 0.10 * sem_bonus + + 0.05 * completeness + ) # ✅ STRICT RANGE FIX: strictly between 0 and 1, clamped to [0.001, 0.999] score = max(0.001, min(0.999, score)) @@ -142,6 +206,8 @@ def evaluate_action( "correctness": round(correctness, 4), "improvement": round(improvement, 4), "risk_alignment": round(risk_al, 4), + "semantic_similarity": round(sem_bonus, 4), + "completeness": round(completeness, 4), }, } @@ -206,6 +272,12 @@ def grade_action( _EASY_PLUS_NOTIFICATION_BONUS = 1.06 # +6 % for breach-notification language _HARD_PLUS_TRAP_PENALTY = 0.55 # −45 % for unresolved IP traps _HARD_PLUS_OWNERSHIP_BONUS = 1.07 # +7 % for explicit customer-ownership +_MEDIUM_PLUS_CONFIDENTIALITY_BONUS = 1.08 # +8 % for strong confidentiality terms +_MEDIUM_PLUS_OVERBROAD_PENALTY = 0.70 # −30 % for accepting overbroad NDA +_HARD_PLUS2_TERMINATION_PENALTY = 0.55 # −45 % for accepting one-sided termination +_HARD_PLUS2_CURE_BONUS = 1.09 # +9 % for proper cure-period language +_EXPERT_DATA_PENALTY = 0.50 # −50 % for missing data-protection safeguards +_EXPERT_DATA_BONUS = 1.10 # +10 % for strong GDPR/privacy language def grade_easy(task: NegotiationTask, contract_before: str, action: Action, proposed_contract_text: str) -> Reward: @@ -260,6 +332,82 @@ def grade_hard_plus(task: NegotiationTask, contract_before: str, action: Action, return reward +def grade_medium_plus(task: NegotiationTask, contract_before: str, action: Action, proposed_contract_text: str) -> Reward: + """Grade medium-plus confidentiality/NDA tasks. + + Rewards narrowly scoped confidentiality; penalises accepting overbroad NDAs + that lack time limits or carve-outs for public information. + """ + reward, _ = evaluate_action(task, contract_before, action, proposed_contract_text) + content = (action.content or "").strip().lower() + + if action.action_type == "ACCEPT": + risk = _weighted_risk_hits(proposed_contract_text, task.risk_keywords) + if risk >= 0.3: + reward.score = max(0.001, min(0.999, reward.score * _MEDIUM_PLUS_OVERBROAD_PENALTY)) + + if action.action_type in ("EDIT_CLAUSE", "PROPOSE_COUNTER"): + # Reward well-scoped confidentiality rewrites + conf_indicators = ("time limit", "expir", "carve-out", "public information", "exclude") + if any(kw in content for kw in conf_indicators): + reward.score = max(0.001, min(0.999, reward.score * _MEDIUM_PLUS_CONFIDENTIALITY_BONUS)) + + return reward + + +def grade_hard_plus2(task: NegotiationTask, contract_before: str, action: Action, proposed_contract_text: str) -> Reward: + """Grade hard-plus-2 termination-for-convenience tasks. + + Penalises accepting one-sided termination; rewards proper cure periods, + wind-down provisions, and mutual termination rights. + """ + reward, _ = evaluate_action(task, contract_before, action, proposed_contract_text) + content = (action.content or "").strip().lower() + + if action.action_type == "ACCEPT": + if trap_unresolved(task, proposed_contract_text): + reward.score = max(0.001, min(0.999, reward.score * _HARD_PLUS2_TERMINATION_PENALTY)) + + if action.action_type in ("EDIT_CLAUSE", "PROPOSE_COUNTER"): + # Reward proper cure-period and mutual-termination language + cure_indicators = ("cure period", "cure-period", "30 days", "thirty", "wind-down", "mutual") + if any(kw in content for kw in cure_indicators): + reward.score = max(0.001, min(0.999, reward.score * _HARD_PLUS2_CURE_BONUS)) + if trap_unresolved(task, proposed_contract_text): + reward.score = max(0.001, min(0.999, reward.score * _HARD_PLUS2_TERMINATION_PENALTY)) + + return reward + + +def grade_expert(task: NegotiationTask, contract_before: str, action: Action, proposed_contract_text: str) -> Reward: + """Grade expert-level data-protection / GDPR tasks. + + Penalises missing data-protection safeguards heavily; rewards clauses that + include DPA references, data-subject rights, breach notification timelines, + and data-minimisation language. + """ + reward, _ = evaluate_action(task, contract_before, action, proposed_contract_text) + content = (action.content or "").strip().lower() + + if action.action_type == "ACCEPT": + if trap_unresolved(task, proposed_contract_text): + reward.score = max(0.001, min(0.999, reward.score * _EXPERT_DATA_PENALTY)) + + if action.action_type in ("EDIT_CLAUSE", "PROPOSE_COUNTER"): + gdpr_indicators = ( + "data processing agreement", "dpa", "data subject", + "72 hours", "breach notification", "data minimisation", + "data minimization", "sub-processor", "supervisory authority", + ) + hits = sum(1 for kw in gdpr_indicators if kw in content) + if hits >= 2: + reward.score = max(0.001, min(0.999, reward.score * _EXPERT_DATA_BONUS)) + if trap_unresolved(task, proposed_contract_text): + reward.score = max(0.001, min(0.999, reward.score * _EXPERT_DATA_PENALTY)) + + return reward + + # ============ GRADER REGISTRY ============ # Explicit mapping of task IDs to their grader functions # This ensures the validator can detect that all graded tasks have graders @@ -269,6 +417,9 @@ def grade_hard_plus(task: NegotiationTask, contract_before: str, action: Action, "hard_conflicting_obligations": grade_hard, "easy_compliance_agreement": grade_easy_plus, "hard_intellectual_property": grade_hard_plus, + "medium_confidentiality_nda": grade_medium_plus, + "hard_termination_convenience": grade_hard_plus2, + "expert_data_protection": grade_expert, } # List of graded task IDs for validator inspection diff --git a/contract_env/env/tasks.py b/contract_env/env/tasks.py index 1433dca..c6d3356 100644 --- a/contract_env/env/tasks.py +++ b/contract_env/env/tasks.py @@ -28,6 +28,17 @@ class NegotiationTask(BaseModel): ) industry_context: str = Field(default="saas_b2b") opponent_opening: List[str] = Field(default_factory=list) + opponent_responses: dict[str, List[str]] = Field( + default_factory=dict, + description=( + "Mapping from action_type to possible opponent replies. " + "The environment selects one at random during multi-turn negotiation." + ), + ) + required_elements: List[str] = Field( + default_factory=list, + description="Lowercase phrases a well-drafted rewrite should include for completeness scoring.", + ) grader_func: Any = Field(exclude=True) grader: str = Field(default="") grader_name: str @@ -52,6 +63,9 @@ def grade(self, contract_before: str, action: "Action", proposed_contract_text: grade_hard, grade_easy_plus, grade_hard_plus, + grade_medium_plus, + grade_hard_plus2, + grade_expert, ) # ---------------- TASKS ---------------- @@ -94,6 +108,23 @@ def grade(self, contract_before: str, action: "Action", proposed_contract_text: opponent_opening=[ "[Counterparty] Unlimited indemnity is standard and non-negotiable." ], + opponent_responses={ + "FLAG_RISK": [ + "[Counterparty] We understand your concern but unlimited liability protects both parties.", + "[Counterparty] Our legal team considers this standard. What specific cap do you propose?", + ], + "EDIT_CLAUSE": [ + "[Counterparty] A 12-month fee cap is too restrictive. We could consider 24 months.", + "[Counterparty] We can accept a cap but consequential damages must remain.", + ], + "REJECT": [ + "[Counterparty] Rejecting outright is not constructive. Please propose an alternative.", + ], + "PROPOSE_COUNTER": [ + "[Counterparty] We'll review your counter-proposal with our legal team.", + ], + }, + required_elements=["capped", "twelve", "consequential", "punitive"], grader_func=grade_easy, grader="contract_env.env.graders:grade_easy", grader_name="grade_easy", @@ -134,6 +165,18 @@ def grade(self, contract_before: str, action: "Action", proposed_contract_text: opponent_opening=[ "[Counterparty] One-day notice is sufficient since pricing is shared earlier." ], + opponent_responses={ + "FLAG_RISK": [ + "[Counterparty] Our billing systems require advance commitment. 30 days is our maximum.", + ], + "EDIT_CLAUSE": [ + "[Counterparty] 60 days notice is too long. We can agree to 30 days.", + ], + "PROPOSE_COUNTER": [ + "[Counterparty] We can consider a longer notice period if you commit to a 2-year minimum.", + ], + }, + required_elements=["sixty", "written notice", "renewal"], grader_func=grade_medium, grader="contract_env.env.graders:grade_medium", grader_name="grade_medium", @@ -180,6 +223,18 @@ def grade(self, contract_before: str, action: "Action", proposed_contract_text: opponent_opening=[ "[Counterparty] Unlimited changes are standard in agile delivery." ], + opponent_responses={ + "FLAG_RISK": [ + "[Counterparty] These are industry-standard agile terms. Our other clients accept them.", + ], + "EDIT_CLAUSE": [ + "[Counterparty] We need flexibility for scope changes. A formal change-control process slows delivery.", + ], + "REJECT": [ + "[Counterparty] We cannot proceed without change-order flexibility.", + ], + }, + required_elements=["mutual approval", "timeline", "fees"], grader_func=grade_hard, grader="contract_env.env.graders:grade_hard", grader_name="grade_hard", @@ -218,6 +273,12 @@ def grade(self, contract_before: str, action: "Action", proposed_contract_text: opponent_opening=[ "[Counterparty] Compliance wording is boilerplate and not negotiable." ], + opponent_responses={ + "EDIT_CLAUSE": [ + "[Counterparty] Notification obligations are already implied under applicable law.", + ], + }, + required_elements=["notify", "material breach", "applicable laws"], grader_func=grade_easy_plus, grader="contract_env.env.graders:grade_easy_plus", grader_name="grade_easy_plus", @@ -257,10 +318,260 @@ def grade(self, contract_before: str, action: "Action", proposed_contract_text: opponent_opening=[ "[Counterparty] IP ownership is standard vendor-owned language." ], + opponent_responses={ + "FLAG_RISK": [ + "[Counterparty] Supplier-owned IP is our standard position for all engagements.", + ], + "EDIT_CLAUSE": [ + "[Counterparty] We can grant a perpetual license but cannot transfer ownership.", + ], + "PROPOSE_COUNTER": [ + "[Counterparty] Joint ownership may be possible if you fund the development fully.", + ], + }, + required_elements=["customer", "own", "license", "supplier"], grader_func=grade_hard_plus, grader="contract_env.env.graders:grade_hard_plus", grader_name="grade_hard_plus", ), + + # ── NEW TASK 6: Confidentiality / NDA (Medium+) ────────────────────── + NegotiationTask( + id="medium_confidentiality_nda", + name="MEDIUM_PLUS", + contract_text=( + "9. CONFIDENTIALITY. Receiving Party shall hold in strict confidence all " + "information disclosed by Disclosing Party, including oral, written, and " + "electronic communications, trade secrets, business plans, financial data, " + "and any information that a reasonable person would consider confidential. " + "This obligation shall survive in perpetuity and applies to all information " + "without exception or carve-out. Receiving Party shall not disclose " + "Confidential Information to any third party for any reason." + ), + clause_type="confidentiality", + risk_keywords=[ + "in perpetuity", + "without exception", + "all information", + "any third party for any reason", + "trade secrets", + ], + safe_keywords=[ + "time limit", + "expiration", + "carve-out", + "public information", + "independently developed", + "prior written consent", + "reasonable", + ], + expected_safe_edit=( + "Receiving Party shall maintain confidentiality of Disclosing Party's " + "proprietary information for a period of three (3) years from disclosure. " + "Confidential Information excludes information that: (a) is or becomes publicly " + "available through no fault of Receiving Party; (b) was independently developed; " + "or (c) is required to be disclosed by law or court order. Receiving Party may " + "disclose to employees and advisors who need to know, subject to written " + "confidentiality obligations." + ), + risk_level="MODERATE", + hidden_trap="", + trap_markers=[ + "in perpetuity", + "without exception or carve-out", + ], + clause_type_weight=1.1, + industry_context="saas_b2b", + opponent_opening=[ + "[Counterparty] Perpetual confidentiality is standard for trade secrets." + ], + opponent_responses={ + "FLAG_RISK": [ + "[Counterparty] Our trade secrets require indefinite protection.", + "[Counterparty] We cannot risk our proprietary information being disclosed after a time limit.", + ], + "EDIT_CLAUSE": [ + "[Counterparty] Three years is too short. We require at least five years.", + "[Counterparty] We can accept carve-outs for publicly available information only.", + ], + "PROPOSE_COUNTER": [ + "[Counterparty] We will consider a time-limited obligation if trade secrets are carved out.", + ], + "REJECT": [ + "[Counterparty] An NDA is essential. We cannot proceed without confidentiality protections.", + ], + }, + required_elements=[ + "three", "years", "publicly available", "independently developed", + "required", "law", "employees", + ], + grader_func=grade_medium_plus, + grader="contract_env.env.graders:grade_medium_plus", + grader_name="grade_medium_plus", + ), + + # ── NEW TASK 7: Termination for Convenience (Hard+2) ───────────────── + NegotiationTask( + id="hard_termination_convenience", + name="HARD_PLUS2", + contract_text=( + "14. TERMINATION. Supplier may terminate this Agreement at any time, for any " + "reason or no reason, upon five (5) calendar days' written notice to Customer. " + "Upon termination, Customer shall pay all outstanding fees and return all " + "Supplier materials within 24 hours. Customer shall have no right to terminate " + "this Agreement for convenience. In case of Supplier's breach, Customer's sole " + "remedy is a pro-rata refund of unused prepaid fees, waiving all other claims." + ), + clause_type="termination", + risk_keywords=[ + "for any reason or no reason", + "five calendar days", + "no right to terminate", + "waiving all other claims", + "sole remedy", + "24 hours", + ], + safe_keywords=[ + "mutual", + "cure period", + "thirty days", + "wind-down", + "transition assistance", + "material breach", + "right to terminate", + ], + expected_safe_edit=( + "Either party may terminate this Agreement for convenience upon sixty (60) days' " + "prior written notice. Either party may terminate for material breach if the " + "breaching party fails to cure within thirty (30) days of written notice. Upon " + "termination, Supplier shall provide reasonable transition assistance for a " + "wind-down period of up to ninety (90) days. Customer retains all rights and " + "remedies available at law or in equity." + ), + risk_level="HIGH", + hidden_trap=( + "One-sided termination with waiver of remedies must be replaced with mutual rights." + ), + trap_markers=[ + "no right to terminate", + "waiving all other claims", + "for any reason or no reason", + ], + clause_type_weight=1.2, + industry_context="saas_b2b", + opponent_opening=[ + "[Counterparty] Our standard terms require flexibility to exit engagements quickly." + ], + opponent_responses={ + "FLAG_RISK": [ + "[Counterparty] Supplier termination flexibility is essential for our business model.", + "[Counterparty] We offer competitive pricing in exchange for this flexibility.", + ], + "EDIT_CLAUSE": [ + "[Counterparty] 60 days notice is too long for termination. We can offer 30 days.", + "[Counterparty] Transition assistance is not included in our standard pricing.", + ], + "PROPOSE_COUNTER": [ + "[Counterparty] We can consider mutual termination if the notice period stays short.", + ], + "REJECT": [ + "[Counterparty] Without termination flexibility, we need to re-evaluate pricing.", + ], + }, + required_elements=[ + "either party", "sixty", "material breach", "cure", + "thirty", "transition", "wind-down", + ], + grader_func=grade_hard_plus2, + grader="contract_env.env.graders:grade_hard_plus2", + grader_name="grade_hard_plus2", + ), + + # ── NEW TASK 8: Data Protection / GDPR (Expert) ────────────────────── + NegotiationTask( + id="expert_data_protection", + name="EXPERT", + contract_text=( + "20. DATA HANDLING. Supplier may process Customer's personal data as it sees " + "fit to perform the Services. Supplier may transfer data to any jurisdiction " + "and engage any sub-processors without notice. Supplier shall have no obligation " + "to notify Customer of any data breach or security incident. Customer waives " + "all rights related to data subject access requests. Supplier retains a perpetual, " + "irrevocable license to use anonymised derivatives of Customer data for any purpose." + ), + clause_type="data_protection", + risk_keywords=[ + "as it sees fit", + "any jurisdiction", + "without notice", + "no obligation to notify", + "waives all rights", + "perpetual", + "irrevocable license", + "any sub-processors", + ], + safe_keywords=[ + "data processing agreement", + "dpa", + "data subject rights", + "breach notification", + "72 hours", + "sub-processor", + "data minimisation", + "adequate jurisdiction", + "supervisory authority", + "deletion", + ], + expected_safe_edit=( + "Supplier shall process Customer personal data only as necessary to perform the " + "Services and in accordance with a Data Processing Agreement (DPA) to be executed " + "between the parties. Supplier shall: (a) notify Customer of any data breach within " + "72 hours of discovery; (b) engage sub-processors only with prior written consent " + "and equivalent data-protection obligations; (c) transfer data only to jurisdictions " + "with adequate data-protection standards; (d) assist Customer in responding to data " + "subject access requests; and (e) delete or return all personal data upon termination. " + "Supplier may use anonymised, aggregated data for service improvement only, subject " + "to data minimisation principles." + ), + risk_level="HIGH", + hidden_trap=( + "Blanket data-processing authority and waiver of data-subject rights violate GDPR principles." + ), + trap_markers=[ + "as it sees fit", + "waives all rights", + "no obligation to notify", + "without notice", + ], + clause_type_weight=1.3, + industry_context="saas_b2b", + opponent_opening=[ + "[Counterparty] Our data-handling terms are optimised for operational efficiency." + ], + opponent_responses={ + "FLAG_RISK": [ + "[Counterparty] Our security practices exceed industry standards. A DPA is unnecessary overhead.", + "[Counterparty] We already follow best practices internally.", + ], + "EDIT_CLAUSE": [ + "[Counterparty] 72-hour breach notification is too aggressive. We prefer 'without undue delay'.", + "[Counterparty] Sub-processor approval would slow our operations significantly.", + ], + "PROPOSE_COUNTER": [ + "[Counterparty] We can agree to a DPA framework if it follows our template.", + ], + "REJECT": [ + "[Counterparty] Data handling terms are non-negotiable for our platform.", + ], + }, + required_elements=[ + "data processing agreement", "72 hours", "sub-processor", + "data subject", "delete", "data minimisation", + ], + grader_func=grade_expert, + grader="contract_env.env.graders:grade_expert", + grader_name="grade_expert", + ), ] # ============ GRADED TASK VALIDATION ============ diff --git a/contract_env/tests/test_api.py b/contract_env/tests/test_api.py index 222ed0f..11a5e78 100644 --- a/contract_env/tests/test_api.py +++ b/contract_env/tests/test_api.py @@ -47,8 +47,8 @@ def test_tasks_endpoint(self) -> None: r = self.client.get("/tasks") self.assertEqual(r.status_code, 200) data = r.json() - self.assertGreaterEqual(data["total"], 5) - self.assertGreaterEqual(data["graded"], 3) + self.assertGreaterEqual(data["total"], 8) + self.assertGreaterEqual(data["graded"], 8) self.assertEqual(len(data["tasks"]), data["total"]) for t in data["tasks"]: self.assertIn("id", t) diff --git a/contract_env/tests/test_graders.py b/contract_env/tests/test_graders.py index c4a1464..0ba7a9e 100644 --- a/contract_env/tests/test_graders.py +++ b/contract_env/tests/test_graders.py @@ -12,6 +12,11 @@ grade_hard, grade_easy_plus, grade_hard_plus, + grade_medium_plus, + grade_hard_plus2, + grade_expert, + clause_completeness_score, + semantic_similarity, token_overlap_ratio, ) from contract_env.env.models import Action @@ -108,6 +113,93 @@ def test_all_tasks_have_graders(self) -> None: for task in TASKS: self.assertTrue(task.has_grader(), f"Task {task.id} missing grader") + # ── NEW: Tests for new graders ────────────────────────────────────── + def test_grade_medium_plus_rewards_scoped_nda(self) -> None: + task = next(t for t in TASKS if t.name == "MEDIUM_PLUS") + action = Action(action_type="EDIT_CLAUSE", content=task.expected_safe_edit) + r = grade_medium_plus(task, task.contract_text, action, task.expected_safe_edit) + self.assertGreater(r.score, 0.0) + self.assertLess(r.score, 1.0) + + def test_grade_medium_plus_penalises_overbroad_accept(self) -> None: + task = next(t for t in TASKS if t.name == "MEDIUM_PLUS") + r = grade_medium_plus(task, task.contract_text, + Action(action_type="ACCEPT"), task.contract_text) + # Accepting overbroad NDA should be penalised + r_edit = grade_medium_plus(task, task.contract_text, + Action(action_type="EDIT_CLAUSE", content=task.expected_safe_edit), + task.expected_safe_edit) + self.assertGreater(r_edit.score, r.score) + + def test_grade_hard_plus2_rewards_cure_period(self) -> None: + task = next(t for t in TASKS if t.name == "HARD_PLUS2") + action = Action(action_type="EDIT_CLAUSE", content=task.expected_safe_edit) + r = grade_hard_plus2(task, task.contract_text, action, task.expected_safe_edit) + self.assertGreater(r.score, 0.0) + self.assertLess(r.score, 1.0) + + def test_grade_hard_plus2_penalises_unresolved(self) -> None: + task = next(t for t in TASKS if t.name == "HARD_PLUS2") + action = Action(action_type="EDIT_CLAUSE", content=task.contract_text) + r_bad = grade_hard_plus2(task, task.contract_text, action, task.contract_text) + r_good = grade_hard_plus2(task, task.contract_text, + Action(action_type="EDIT_CLAUSE", content=task.expected_safe_edit), + task.expected_safe_edit) + self.assertGreater(r_good.score, r_bad.score) + + def test_grade_expert_rewards_gdpr_language(self) -> None: + task = next(t for t in TASKS if t.name == "EXPERT") + action = Action(action_type="EDIT_CLAUSE", content=task.expected_safe_edit) + r = grade_expert(task, task.contract_text, action, task.expected_safe_edit) + self.assertGreater(r.score, 0.0) + self.assertLess(r.score, 1.0) + + def test_grade_expert_penalises_unresolved_data_traps(self) -> None: + task = next(t for t in TASKS if t.name == "EXPERT") + action = Action(action_type="EDIT_CLAUSE", content=task.contract_text) + r_bad = grade_expert(task, task.contract_text, action, task.contract_text) + r_good = grade_expert(task, task.contract_text, + Action(action_type="EDIT_CLAUSE", content=task.expected_safe_edit), + task.expected_safe_edit) + self.assertGreater(r_good.score, r_bad.score) + + # ── NEW: Tests for enhanced scoring metrics ───────────────────────── + def test_clause_completeness_score_full(self) -> None: + score = clause_completeness_score("capped at twelve months, no consequential or punitive", + ["capped", "twelve", "consequential", "punitive"]) + self.assertEqual(score, 1.0) + + def test_clause_completeness_score_partial(self) -> None: + score = clause_completeness_score("capped at twelve months", + ["capped", "twelve", "consequential", "punitive"]) + self.assertEqual(score, 0.5) + + def test_clause_completeness_score_empty_requirements(self) -> None: + score = clause_completeness_score("any text", []) + self.assertEqual(score, 1.0) + + def test_semantic_similarity_identical(self) -> None: + sim = semantic_similarity("hello world test", "hello world test") + self.assertAlmostEqual(sim, 1.0, places=2) + + def test_semantic_similarity_different(self) -> None: + sim = semantic_similarity("hello world test", "completely unrelated xyz") + self.assertLess(sim, 0.5) + + def test_evaluate_action_returns_new_grade_fields(self) -> None: + """Verify evaluate_action returns semantic_similarity and completeness in grade info.""" + task = TASKS[0] + action = Action(action_type="EDIT_CLAUSE", content=task.expected_safe_edit) + _, info = evaluate_action(task, task.contract_text, action, task.expected_safe_edit) + grade = info["grade"] + self.assertIn("semantic_similarity", grade) + self.assertIn("completeness", grade) + + def test_eight_graded_tasks(self) -> None: + """Ensure we have at least 8 graded tasks.""" + graded = [t for t in TASKS if t.has_grader()] + self.assertGreaterEqual(len(graded), 8) + if __name__ == "__main__": unittest.main() diff --git a/contract_env/tests/test_smoke.py b/contract_env/tests/test_smoke.py index 90df891..a22bfdc 100644 --- a/contract_env/tests/test_smoke.py +++ b/contract_env/tests/test_smoke.py @@ -35,6 +35,80 @@ def test_accept_high_risk_zero_reward(self) -> None: self.assertTrue(done) self.assertEqual(r, 0.001) + # ── NEW: Tests for opponent simulation ────────────────────────────── + def test_opponent_reply_added_to_history(self) -> None: + """After an action, the opponent should reply and appear in history.""" + env = ContractEnv() + env.reset() + obs, _, _, info = env.step(Action(action_type="FLAG_RISK", content="risk identified")) + # At least one opponent reply should be in the history + opponent_entries = [h for h in obs.negotiation_history if h.startswith("opponent|")] + # There's at least the opening + possibly a reply + self.assertGreaterEqual(len(opponent_entries), 1) + + def test_opponent_reply_in_info(self) -> None: + """Opponent replies should appear in step info when available.""" + env = ContractEnv() + env.reset() + task = env.current_task + # If this task has opponent_responses for FLAG_RISK, info should have the reply + if task.opponent_responses.get("FLAG_RISK"): + _, _, _, info = env.step(Action(action_type="FLAG_RISK", content="risk found")) + self.assertIn("opponent_reply", info) + + # ── NEW: Tests for new tasks ──────────────────────────────────────── + def test_eight_tasks_exist(self) -> None: + """Verify all 8 tasks are defined.""" + self.assertEqual(len(TASKS), 8) + + def test_all_tasks_cycle(self) -> None: + """All 8 tasks should be visited when cycling through resets.""" + env = ContractEnv() + visited_ids = set() + for _ in range(len(TASKS)): + env.reset() + visited_ids.add(env.current_task.id) + self.assertEqual(len(visited_ids), len(TASKS)) + + def test_new_task_confidentiality(self) -> None: + """Verify the confidentiality task exists and has correct properties.""" + task = next(t for t in TASKS if t.id == "medium_confidentiality_nda") + self.assertEqual(task.clause_type, "confidentiality") + self.assertEqual(task.risk_level, "MODERATE") + self.assertTrue(task.has_grader()) + self.assertTrue(len(task.opponent_responses) > 0) + self.assertTrue(len(task.required_elements) > 0) + + def test_new_task_termination(self) -> None: + """Verify the termination task exists and has correct properties.""" + task = next(t for t in TASKS if t.id == "hard_termination_convenience") + self.assertEqual(task.clause_type, "termination") + self.assertEqual(task.risk_level, "HIGH") + self.assertTrue(task.has_grader()) + self.assertTrue(len(task.trap_markers) > 0) + + def test_new_task_data_protection(self) -> None: + """Verify the data protection task exists and has correct properties.""" + task = next(t for t in TASKS if t.id == "expert_data_protection") + self.assertEqual(task.clause_type, "data_protection") + self.assertEqual(task.risk_level, "HIGH") + self.assertTrue(task.has_grader()) + self.assertTrue(len(task.trap_markers) > 0) + self.assertTrue(len(task.required_elements) > 0) + + def test_edit_new_task_safe_edit_scores_well(self) -> None: + """Editing with the expected safe edit should produce a decent reward for new tasks.""" + env = ContractEnv() + # Cycle to the confidentiality task (index 5) + for _ in range(6): + env.reset() + task = env.current_task + self.assertEqual(task.id, "medium_confidentiality_nda") + _, r, _, _ = env.step( + Action(action_type="EDIT_CLAUSE", content=task.expected_safe_edit) + ) + self.assertGreater(r, 0.1) + if __name__ == "__main__": unittest.main() diff --git a/inference.py b/inference.py index 785b4ee..1001386 100644 --- a/inference.py +++ b/inference.py @@ -76,10 +76,10 @@ def _get_client() -> OpenAI: } Action selection rules: -- HIGH risk clause (unlimited liability, IP trap, conflicting obligations): +- HIGH risk clause (unlimited liability, IP trap, conflicting obligations, one-sided termination, data misuse): Step 1: FLAG_RISK. Step 2+: EDIT_CLAUSE with a concrete safe rewrite. -- MODERATE risk (short notice periods, auto-renewal traps): - Use PROPOSE_COUNTER with balanced language (e.g., 60-day notice). +- MODERATE risk (short notice periods, auto-renewal traps, overbroad NDAs): + Use PROPOSE_COUNTER with balanced language (e.g., 60-day notice, time-limited NDA). - LOW risk (compliance, boilerplate): EDIT_CLAUSE to add notification/reporting obligations, then ACCEPT. - REJECT only for terms so extreme they cannot be salvaged. @@ -92,6 +92,13 @@ def _get_client() -> OpenAI: preceding twelve (12) months; no consequential or punitive damages." For auto-renewal tasks, include: "sixty (60) days prior written notice." For compliance tasks, include: "promptly notify Customer of any material breach." +For confidentiality/NDA tasks, include: time limit (e.g., 3 years), carve-outs for +publicly available information, and scope limitations. +For termination tasks, include: mutual termination rights, cure period of at least +30 days, and transition/wind-down provisions. +For data protection tasks, include: Data Processing Agreement reference, 72-hour +breach notification, sub-processor consent requirements, data subject rights +assistance, and data deletion upon termination. Return ONLY the JSON object — no markdown fences, no commentary, no extra text. """ @@ -142,7 +149,7 @@ def _parse_llm_json(text: str) -> Optional[dict]: def _risk_score(task: NegotiationTask, contract_text: str) -> float: hits = keyword_match_score(contract_text, task.risk_keywords) rs = min(1.0, hits * task.clause_type_weight / 1.15) - if task.name in ("HARD", "HARD_PLUS") and trap_unresolved(task, contract_text): + if task.name in ("HARD", "HARD_PLUS", "HARD_PLUS2", "EXPERT") and trap_unresolved(task, contract_text): rs = min(1.0, rs + 0.25) return round(rs, 6) @@ -221,6 +228,32 @@ def _build_rewrite_prompt( "- Grant Supplier only a limited license to use Customer materials\n" "- Remove any supplier-ownership language\n" ) + elif task.clause_type == "confidentiality": + user_msg += ( + "- Limit confidentiality obligation to three (3) years from disclosure\n" + "- Add carve-outs for publicly available information\n" + "- Add carve-out for independently developed information\n" + "- Allow disclosure required by law or court order\n" + "- Permit sharing with employees and advisors under NDA\n" + ) + elif task.clause_type == "termination": + user_msg += ( + "- Make termination rights mutual (either party)\n" + "- Require sixty (60) days' prior written notice for convenience termination\n" + "- Add thirty (30) day cure period for material breach\n" + "- Include transition/wind-down assistance provision\n" + "- Preserve all legal rights and remedies\n" + ) + elif task.clause_type == "data_protection": + user_msg += ( + "- Require execution of a Data Processing Agreement (DPA)\n" + "- Add 72-hour data breach notification requirement\n" + "- Require prior written consent for sub-processors\n" + "- Mandate assistance with data subject access requests\n" + "- Require deletion or return of personal data upon termination\n" + "- Include data minimisation principles\n" + "- Restrict data transfers to adequate jurisdictions\n" + ) user_msg += "\nReturn ONLY the rewritten clause text, nothing else." return [ {"role": "system", "content": SYSTEM_PROMPT}, @@ -230,6 +263,11 @@ def _build_rewrite_prompt( _VALID_ACTIONS = {"FLAG_RISK", "EDIT_CLAUSE", "ACCEPT", "REJECT", "PROPOSE_COUNTER"} +# Optimal action sequences per intent level for rule-based fallback +_STRATEGY_HIGH = ["FLAG_RISK", "EDIT_CLAUSE", "EDIT_CLAUSE", "PROPOSE_COUNTER", "REJECT", "EDIT_CLAUSE", "ACCEPT"] +_STRATEGY_MODERATE = ["FLAG_RISK", "PROPOSE_COUNTER", "EDIT_CLAUSE", "EDIT_CLAUSE", "ACCEPT"] +_STRATEGY_LOW = ["EDIT_CLAUSE", "EDIT_CLAUSE", "ACCEPT"] + def _choose( task: NegotiationTask, @@ -240,7 +278,7 @@ def _choose( """LLM-driven action selection with rule-based fallback.""" contract_text = state_data["contract_text"] history = state_data.get("negotiation_history", []) - history_summary = "\n".join(history[-6:]) if history else "" + history_summary = "\n".join(history[-8:]) if history else "" # ── 1. Ask the LLM for structured analysis ────────────────────────── parsed: Optional[dict] = None @@ -263,15 +301,15 @@ def _choose( content = parsed.get("rewritten_clause") or None risk_assessment = parsed.get("risk_assessment", "") - # ── 3. Rule-based fallback ──────────────────────────────────────────── + # ── 3. Rule-based fallback with improved strategy ───────────────────── if action_type is None: intent = _rule_based_intent(task, contract_text) if intent == "HIGH": - seq = ["FLAG_RISK", "EDIT_CLAUSE", "PROPOSE_COUNTER", "REJECT", "ACCEPT"] + seq = _STRATEGY_HIGH elif intent == "MODERATE": - seq = ["FLAG_RISK", "PROPOSE_COUNTER", "EDIT_CLAUSE", "ACCEPT"] + seq = _STRATEGY_MODERATE else: - seq = ["EDIT_CLAUSE", "ACCEPT"] + seq = _STRATEGY_LOW action_type = seq[min(step, len(seq) - 1)] # ── 4. Adaptive: switch to EDIT if previous score was poor ─────────── @@ -279,13 +317,22 @@ def _choose( if action_type in ("FLAG_RISK", "REJECT"): action_type = "EDIT_CLAUSE" + # ── 4b. Adaptive: if scores are improving and risk resolved, accept ── + if ( + len(prev_rewards) >= 3 + and all(r > 0.45 for r in prev_rewards[-2:]) + and not effective_risk_high(task, contract_text) + and not trap_unresolved(task, contract_text) + ): + action_type = "ACCEPT" + # ── 5. Generate content for EDIT / PROPOSE if missing ──────────────── if action_type in ("EDIT_CLAUSE", "PROPOSE_COUNTER") and not content: try: msgs = _build_rewrite_prompt( task, contract_text, risk_assessment or "High legal risk identified" ) - content = _llm_chat(msgs, max_tokens=400) + content = _llm_chat(msgs, max_tokens=600) if content.startswith('"') and content.endswith('"'): content = content[1:-1] except Exception as exc: @@ -300,8 +347,11 @@ def _choose( # ── EPISODE EXECUTION ──────────────────────────────────────────────────── -def run_episode(env: ContractEnv) -> None: - """Run one full episode using env.reset() → loop env.step() → log.""" +def run_episode(env: ContractEnv) -> float: + """Run one full episode using env.reset() → loop env.step() → log. + + Returns the mean episode score. + """ obs_obj = env.reset() task = env.current_task @@ -362,6 +412,8 @@ def run_episode(env: ContractEnv) -> None: flush=True, ) + return score + # ── MAIN ───────────────────────────────────────────────────────────────── def main() -> None: @@ -374,13 +426,13 @@ def main() -> None: parser.add_argument( "--episodes", type=int, - default=5, - help="Number of episodes to run (default: 5)", + default=8, + help="Number of episodes to run (default: 8 — one per task)", ) parser.add_argument( "--benchmark", action="store_true", - help="Run exactly one episode per task (covers all 5 tasks)", + help="Run exactly one episode per task (covers all 8 tasks)", ) args = parser.parse_args() @@ -389,8 +441,16 @@ def main() -> None: episodes_to_run = len(TASKS) if args.benchmark else args.episodes + total_score = 0.0 for _ in range(episodes_to_run): - run_episode(env) + total_score += run_episode(env) + + mean_score = total_score / max(episodes_to_run, 1) + print( + f"\n[SUMMARY] episodes={episodes_to_run} mean_score={mean_score:.3f} " + f"threshold={SUCCESS_SCORE_THRESHOLD}", + flush=True, + ) if __name__ == "__main__": diff --git a/openenv.yaml b/openenv.yaml index f7efd64..03a7c54 100644 --- a/openenv.yaml +++ b/openenv.yaml @@ -16,6 +16,12 @@ tasks: grader: contract_env.env.graders:grade_easy_plus - id: hard_intellectual_property grader: contract_env.env.graders:grade_hard_plus + - id: medium_confidentiality_nda + grader: contract_env.env.graders:grade_medium_plus + - id: hard_termination_convenience + grader: contract_env.env.graders:grade_hard_plus2 + - id: expert_data_protection + grader: contract_env.env.graders:grade_expert observation_space: fields: From 7c92c775f8d03e480a90742c36ff84cb8704f30e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 07:42:16 +0000 Subject: [PATCH 02/13] Address code review feedback: fix help text and comment notation Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/e3c399ed-eac8-40b6-9c61-196e32f44385 Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- contract_env/env/graders.py | 5 ++--- inference.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/contract_env/env/graders.py b/contract_env/env/graders.py index 89f7def..316fe34 100644 --- a/contract_env/env/graders.py +++ b/contract_env/env/graders.py @@ -184,9 +184,8 @@ def evaluate_action( if action.action_type in ("EDIT_CLAUSE", "PROPOSE_COUNTER") and content and required_elems: completeness = clause_completeness_score(content, required_elems) - # ---------- FINAL SCORE ---------- - # Weights: 35% correctness + 25% improvement + 25% risk_alignment - # + 10% semantic similarity + 5% completeness + # Weights: 0.35 correctness + 0.25 improvement + 0.25 risk_alignment + # + 0.10 semantic similarity + 0.05 completeness (total = 1.0) score = ( 0.35 * correctness + 0.25 * improvement diff --git a/inference.py b/inference.py index 1001386..5b4c771 100644 --- a/inference.py +++ b/inference.py @@ -427,7 +427,7 @@ def main() -> None: "--episodes", type=int, default=8, - help="Number of episodes to run (default: 8 — one per task)", + help="Number of episodes to run (default: 8). Tasks cycle sequentially.", ) parser.add_argument( "--benchmark", From b420683528927fcce25f193e9a5b05c3199b0174 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 07:53:02 +0000 Subject: [PATCH 03/13] Fix grader bugs, add opponent-aware negotiation, HTTP-client mode, and action_space - Fix effective_risk_high() to cover all tasks with trap_markers (was only HARD/HARD_PLUS) - Fix observation_risk_float() to boost risk for all trap-bearing tasks (was only HARD) - Add opponent stance parsing (_parse_opponent_stance) for concession/firmness detection - Add opponent-aware action adjustment in inference _choose() function - Add HTTP-client mode (--mode api) for Docker API evaluation - Add per-task scoring summary in benchmark output - Add action_space section to openenv.yaml - Remove vestigial server/app.py and test_endpoints.py - Add 7 new tests (49 total): trap coverage, accept blocking, opponent parsing - Update README with new features and ENV_SERVER_URL Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/c8dd2642-d749-42d1-9fad-91bc78d8d379 Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- README.md | 17 +++- contract_env/env/graders.py | 11 +- contract_env/tests/test_graders.py | 47 +++++++++ contract_env/tests/test_smoke.py | 25 +++++ inference.py | 156 +++++++++++++++++++++++++++-- openenv.yaml | 14 +++ server/app.py | 14 --- test_endpoints.py | 45 --------- 8 files changed, 252 insertions(+), 77 deletions(-) delete mode 100644 server/app.py delete mode 100644 test_endpoints.py diff --git a/README.md b/README.md index df63717..bc7a311 100644 --- a/README.md +++ b/README.md @@ -224,6 +224,9 @@ uvicorn contract_env.server.app:app --host 0.0.0.0 --port 7860 export HF_TOKEN="your-huggingface-token" python inference.py --benchmark # one episode per task (8 total) python inference.py --episodes 3 # run 3 episodes cycling through tasks + +# Against the Docker API server (for competition evaluation): +python inference.py --benchmark --mode api ``` ### Docker @@ -234,6 +237,9 @@ docker run -p 7860:7860 \ -e HF_TOKEN=your-token \ -e MODEL_NAME=Qwen/Qwen2.5-72B-Instruct \ contract-negotiation-env + +# Then run inference against the Docker server: +python inference.py --benchmark --mode api ``` --- @@ -246,6 +252,7 @@ docker run -p 7860:7860 \ | `API_BASE_URL` | No | `https://router.huggingface.co/v1` | LLM API endpoint | | `MODEL_NAME` | No | `Qwen/Qwen2.5-72B-Instruct` | Model identifier | | `BENCHMARK` | No | `contract_negotiation` | Benchmark name in [START] log line | +| `ENV_SERVER_URL` | No | `http://localhost:7860` | Docker server URL (for `--mode api`) | | `PORT` | No | `7860` | Server port | --- @@ -262,12 +269,12 @@ contract_env/ ├── server/ │ └── app.py # FastAPI server (port 7860) ├── tests/ -│ ├── test_graders.py # 28 unit tests covering all grader edge cases + new metrics +│ ├── test_graders.py # Grader unit tests covering all edge cases + new metrics │ ├── test_api.py # API endpoint tests -│ └── test_smoke.py # 12 smoke tests including opponent simulation + new tasks -└── client.py # HTTP client helper -inference.py # LLM-driven baseline agent with adaptive multi-turn strategy -openenv.yaml # OpenEnv manifest (spec_version: 1, 8 graded tasks) +│ └── test_smoke.py # Smoke tests including opponent simulation + opponent stance parsing +└── client.py # HTTP client helper with from_docker_image() support +inference.py # LLM-driven agent with opponent-aware multi-turn strategy + HTTP mode +openenv.yaml # OpenEnv manifest (spec_version: 1, 8 graded tasks, action_space) Dockerfile # Python 3.10-slim container, port 7860 verify_graders.py # Pre-submission grader validation script ``` diff --git a/contract_env/env/graders.py b/contract_env/env/graders.py index 316fe34..685ca13 100644 --- a/contract_env/env/graders.py +++ b/contract_env/env/graders.py @@ -101,10 +101,10 @@ def trap_unresolved(task: NegotiationTask, contract_text: str) -> bool: def effective_risk_high(task: NegotiationTask, contract_text: str) -> bool: - # HARD and HARD_PLUS tasks define explicit trap markers; a task is still - # "effectively high risk" as long as any trap marker remains in the text. - if task.name in ("HARD", "HARD_PLUS"): - return trap_unresolved(task, contract_text) + # Any task that defines explicit trap markers is still "effectively high + # risk" as long as any trap marker remains in the text. + if task.trap_markers and trap_unresolved(task, contract_text): + return True hits = _weighted_risk_hits(contract_text, task.risk_keywords) @@ -239,7 +239,8 @@ def build_proposed_contract_for_step(contract_before: str, action: Action) -> st def observation_risk_float(task: NegotiationTask, contract_text: str) -> float: base = _weighted_risk_hits(contract_text, task.risk_keywords) - if task.name == "HARD" and trap_unresolved(task, contract_text): + # Boost risk observation when any task's trap markers remain unresolved + if task.trap_markers and trap_unresolved(task, contract_text): base = min(1.0, base + 0.25) # ✅ STRICT RANGE FIX: strictly between 0 and 1, clamped to [0.001, 0.999] diff --git a/contract_env/tests/test_graders.py b/contract_env/tests/test_graders.py index 0ba7a9e..66c7507 100644 --- a/contract_env/tests/test_graders.py +++ b/contract_env/tests/test_graders.py @@ -69,6 +69,27 @@ def test_effective_high_hard_trap(self) -> None: effective_risk_high(task, task.expected_safe_edit), ) + def test_effective_high_covers_hard_plus2(self) -> None: + """HARD_PLUS2 tasks with unresolved trap markers should be effectively high risk.""" + task = next(t for t in TASKS if t.name == "HARD_PLUS2") + self.assertTrue(len(task.trap_markers) > 0, "HARD_PLUS2 must have trap markers") + self.assertTrue(effective_risk_high(task, task.contract_text)) + self.assertFalse(effective_risk_high(task, task.expected_safe_edit)) + + def test_effective_high_covers_expert(self) -> None: + """EXPERT tasks with unresolved trap markers should be effectively high risk.""" + task = next(t for t in TASKS if t.name == "EXPERT") + self.assertTrue(len(task.trap_markers) > 0, "EXPERT must have trap markers") + self.assertTrue(effective_risk_high(task, task.contract_text)) + self.assertFalse(effective_risk_high(task, task.expected_safe_edit)) + + def test_effective_high_covers_medium_plus(self) -> None: + """MEDIUM_PLUS tasks with unresolved trap markers should be effectively high risk.""" + task = next(t for t in TASKS if t.name == "MEDIUM_PLUS") + self.assertTrue(len(task.trap_markers) > 0, "MEDIUM_PLUS must have trap markers") + self.assertTrue(effective_risk_high(task, task.contract_text)) + self.assertFalse(effective_risk_high(task, task.expected_safe_edit)) + # ── Differentiated grader tests ───────────────────────────────────── def test_grade_easy_rewards_safe_edit(self) -> None: task = next(t for t in TASKS if t.name == "EASY") @@ -200,6 +221,32 @@ def test_eight_graded_tasks(self) -> None: graded = [t for t in TASKS if t.has_grader()] self.assertGreaterEqual(len(graded), 8) + def test_observation_risk_float_trap_bonus_all_tasks(self) -> None: + """All tasks with trap_markers should get a risk boost in observation_risk_float.""" + from contract_env.env.graders import observation_risk_float + for task in TASKS: + if task.trap_markers: + risk_with_trap = observation_risk_float(task, task.contract_text) + # Contract text with trap markers should have elevated risk + self.assertGreater(risk_with_trap, 0.1, + f"Task {task.id} trap-bearing text should have elevated risk") + + def test_accept_blocked_on_expert_unresolved(self) -> None: + """Accepting EXPERT task with unresolved traps should be blocked.""" + task = next(t for t in TASKS if t.name == "EXPERT") + r, info = evaluate_action(task, task.contract_text, + Action(action_type="ACCEPT"), task.contract_text) + self.assertEqual(r.score, 0.001) + self.assertTrue(info.get("accept_blocked")) + + def test_accept_blocked_on_hard_plus2_unresolved(self) -> None: + """Accepting HARD_PLUS2 task with unresolved traps should be blocked.""" + task = next(t for t in TASKS if t.name == "HARD_PLUS2") + r, info = evaluate_action(task, task.contract_text, + Action(action_type="ACCEPT"), task.contract_text) + self.assertEqual(r.score, 0.001) + self.assertTrue(info.get("accept_blocked")) + if __name__ == "__main__": unittest.main() diff --git a/contract_env/tests/test_smoke.py b/contract_env/tests/test_smoke.py index a22bfdc..0153378 100644 --- a/contract_env/tests/test_smoke.py +++ b/contract_env/tests/test_smoke.py @@ -109,6 +109,31 @@ def test_edit_new_task_safe_edit_scores_well(self) -> None: ) self.assertGreater(r, 0.1) + def test_opponent_stance_parsing(self) -> None: + """Opponent concession/firmness signals should be detected correctly.""" + from inference import _parse_opponent_stance + + # Conceding + history_concede = [ + "opponent|[Counterparty] We can accept a cap but consequential damages must remain." + ] + self.assertEqual(_parse_opponent_stance(history_concede), "conceding") + + # Firm + history_firm = [ + "opponent|[Counterparty] This is non-negotiable and standard." + ] + self.assertEqual(_parse_opponent_stance(history_firm), "firm") + + # Neutral + history_neutral = [ + "opponent|[Counterparty] Our legal team considers this standard." + ] + self.assertEqual(_parse_opponent_stance(history_neutral), "neutral") + + # Empty + self.assertEqual(_parse_opponent_stance([]), "neutral") + if __name__ == "__main__": unittest.main() diff --git a/inference.py b/inference.py index 5b4c771..f38d85a 100644 --- a/inference.py +++ b/inference.py @@ -13,6 +13,10 @@ [START] task= env= model= [STEP] step= action= reward=<0.00> done= error= [END] success= steps= score= rewards= + +MODES: + --mode local Use ContractEnv directly (default, for local development). + --mode api Connect to Docker API at ENV_SERVER_URL (for competition evaluation). """ from __future__ import annotations @@ -43,6 +47,7 @@ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY") BENCHMARK = os.getenv("BENCHMARK", "contract_negotiation") +ENV_SERVER_URL = os.getenv("ENV_SERVER_URL", "http://localhost:7860") MAX_STEPS = 10 SUCCESS_SCORE_THRESHOLD = 0.5 @@ -268,6 +273,41 @@ def _build_rewrite_prompt( _STRATEGY_MODERATE = ["FLAG_RISK", "PROPOSE_COUNTER", "EDIT_CLAUSE", "EDIT_CLAUSE", "ACCEPT"] _STRATEGY_LOW = ["EDIT_CLAUSE", "EDIT_CLAUSE", "ACCEPT"] +# ── OPPONENT-RESPONSE PARSING ─────────────────────────────────────────── +# Concession signals from the counterparty that indicate willingness to negotiate +_CONCESSION_SIGNALS = ( + "we can accept", "we can agree", "we could consider", "we could accept", + "we can consider", "may be possible", "we are willing", + "we'll review", "we will review", + "agree to", "open to", +) +_FIRMNESS_SIGNALS = ( + "non-negotiable", "cannot proceed", "not possible", + "standard and non-negotiable", "cannot accept", "is not included", +) + + +def _parse_opponent_stance(history: list[str]) -> str: + """Analyse the latest opponent reply to determine their negotiation stance. + + Returns: + 'conceding' — opponent shows willingness; escalate to EDIT_CLAUSE / ACCEPT. + 'firm' — opponent is holding position; keep pushing with PROPOSE_COUNTER. + 'neutral' — no strong signal; follow normal strategy. + """ + # Find the most recent opponent entry + opp_entries = [h for h in history if h.startswith("opponent|")] + if not opp_entries: + return "neutral" + + latest = opp_entries[-1].lower() + + if any(signal in latest for signal in _CONCESSION_SIGNALS): + return "conceding" + if any(signal in latest for signal in _FIRMNESS_SIGNALS): + return "firm" + return "neutral" + def _choose( task: NegotiationTask, @@ -275,11 +315,14 @@ def _choose( step: int, prev_rewards: list[float], ) -> Action: - """LLM-driven action selection with rule-based fallback.""" + """LLM-driven action selection with rule-based fallback and opponent awareness.""" contract_text = state_data["contract_text"] history = state_data.get("negotiation_history", []) history_summary = "\n".join(history[-8:]) if history else "" + # ── 0. Parse opponent stance from negotiation history ──────────────── + opponent_stance = _parse_opponent_stance(history) + # ── 1. Ask the LLM for structured analysis ────────────────────────── parsed: Optional[dict] = None try: @@ -317,7 +360,21 @@ def _choose( if action_type in ("FLAG_RISK", "REJECT"): action_type = "EDIT_CLAUSE" - # ── 4b. Adaptive: if scores are improving and risk resolved, accept ── + # ── 4b. Opponent-aware adjustment ──────────────────────────────────── + # If opponent is conceding, escalate toward resolution faster. + if opponent_stance == "conceding" and step > 0: + if action_type == "FLAG_RISK": + action_type = "EDIT_CLAUSE" + elif action_type == "REJECT": + action_type = "PROPOSE_COUNTER" + # If opponent is firm, use PROPOSE_COUNTER to keep negotiating. + elif opponent_stance == "firm" and step > 0: + if action_type in ("ACCEPT",): + # Don't accept while opponent is still pushing back on risky terms + if effective_risk_high(task, contract_text) or trap_unresolved(task, contract_text): + action_type = "PROPOSE_COUNTER" + + # ── 4c. Adaptive: if scores are improving and risk resolved, accept ── if ( len(prev_rewards) >= 3 and all(r > 0.45 for r in prev_rewards[-2:]) @@ -346,11 +403,69 @@ def _choose( return Action(action_type=action_type, content=content) +# ── HTTP-CLIENT WRAPPER ────────────────────────────────────────────────── +# Provides the same reset()/step() interface as ContractEnv but talks to +# the Docker API server via HTTP, matching the competition evaluation flow. + +class _HTTPEnvClient: + """Thin HTTP wrapper with the same interface as ContractEnv for inference.""" + + def __init__(self, base_url: str) -> None: + import requests + self.base_url = base_url.rstrip("/") + self._session = requests.Session() + self._task_idx = 0 + self.current_task: Optional[NegotiationTask] = None + + def reset(self): + resp = self._session.post(f"{self.base_url}/reset") + resp.raise_for_status() + data = resp.json() + obs = data["observation"] + # Map to a NegotiationTask if possible (for _choose() to use) + task_id = None + try: + state = self._session.get(f"{self.base_url}/state").json() + task_id = state.get("task_id") + except Exception: + pass + if task_id: + self.current_task = next((t for t in TASKS if t.id == task_id), None) + if self.current_task is None: + self.current_task = TASKS[self._task_idx % len(TASKS)] + self._task_idx += 1 + return _DictObservation(obs) + + def step(self, action: Action): + payload = {"action_type": action.action_type} + if action.content: + payload["content"] = action.content + resp = self._session.post(f"{self.base_url}/step", json=payload) + resp.raise_for_status() + data = resp.json() + obs = _DictObservation(data["observation"]) + reward = data["reward"]["score"] + done = data["done"] + info = data.get("info", {}) + return obs, reward, done, info + + +class _DictObservation: + """Lightweight wrapper that exposes dict fields as attributes.""" + + def __init__(self, d: dict) -> None: + self.contract_text: str = d.get("contract_text", "") + self.clause_type: str = d.get("clause_type", "") + self.risk_level: float = d.get("risk_level", 0.5) + self.step_count: int = d.get("step_count", 0) + self.negotiation_history: list[str] = d.get("negotiation_history", []) + + # ── EPISODE EXECUTION ──────────────────────────────────────────────────── -def run_episode(env: ContractEnv) -> float: +def run_episode(env) -> tuple[float, str]: """Run one full episode using env.reset() → loop env.step() → log. - Returns the mean episode score. + Returns (mean_episode_score, task_id). """ obs_obj = env.reset() task = env.current_task @@ -412,7 +527,7 @@ def run_episode(env: ContractEnv) -> float: flush=True, ) - return score + return score, task.id # ── MAIN ───────────────────────────────────────────────────────────────── @@ -434,18 +549,43 @@ def main() -> None: action="store_true", help="Run exactly one episode per task (covers all 8 tasks)", ) + parser.add_argument( + "--mode", + choices=["local", "api"], + default="local", + help=( + "Execution mode. 'local' uses ContractEnv directly (default). " + "'api' connects to the Docker server via HTTP at ENV_SERVER_URL." + ), + ) args = parser.parse_args() - # Single env instance so reset() cycles through tasks in order - env = ContractEnv() + # Select environment backend + if args.mode == "api": + env = _HTTPEnvClient(ENV_SERVER_URL) + print(f"[CONFIG] mode=api server={ENV_SERVER_URL}", flush=True) + else: + env = ContractEnv() + print("[CONFIG] mode=local", flush=True) episodes_to_run = len(TASKS) if args.benchmark else args.episodes total_score = 0.0 + task_scores: dict[str, list[float]] = {} + for _ in range(episodes_to_run): - total_score += run_episode(env) + ep_score, task_id = run_episode(env) + total_score += ep_score + task_scores.setdefault(task_id, []).append(ep_score) mean_score = total_score / max(episodes_to_run, 1) + + # Per-task summary + print("\n[TASK SCORES]", flush=True) + for tid, scores in task_scores.items(): + avg = sum(scores) / len(scores) + print(f" {tid}: mean={avg:.3f} runs={len(scores)}", flush=True) + print( f"\n[SUMMARY] episodes={episodes_to_run} mean_score={mean_score:.3f} " f"threshold={SUCCESS_SCORE_THRESHOLD}", diff --git a/openenv.yaml b/openenv.yaml index 03a7c54..55542fb 100644 --- a/openenv.yaml +++ b/openenv.yaml @@ -31,5 +31,19 @@ observation_space: - step_count - negotiation_history +action_space: + type: discrete + actions: + - name: FLAG_RISK + description: Flag a risk without proposing changes. No content required. + - name: EDIT_CLAUSE + description: Rewrite the clause with safer language. Requires non-empty content. + - name: PROPOSE_COUNTER + description: Submit a formal counter-proposal. Requires non-empty content. + - name: REJECT + description: Reject egregiously one-sided terms. No content required. + - name: ACCEPT + description: Accept clause when all material risks are resolved. No content required. + reward: range: [0.001, 0.999] diff --git a/server/app.py b/server/app.py deleted file mode 100644 index cd9a898..0000000 --- a/server/app.py +++ /dev/null @@ -1,14 +0,0 @@ -from contract_env.server.app import app - - -def main(): - import uvicorn - uvicorn.run( - "contract_env.server.app:app", - host="0.0.0.0", - port=7860, - ) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_endpoints.py b/test_endpoints.py deleted file mode 100644 index dfa0bcc..0000000 --- a/test_endpoints.py +++ /dev/null @@ -1,45 +0,0 @@ -import requests - -s = requests.Session() -base = "http://localhost:7860" - -# Test health -print("--- /health ---") -r = s.get(f"{base}/health") -print(f" Status: {r.status_code}") -print(f" Body: {r.json()}") - -# Test reset -print("\n--- /reset ---") -r = s.post(f"{base}/reset", json={"task_id": "easy_unlimited_liability"}) -print(f" Status: {r.status_code}") -data = r.json() -print(f" Keys: {list(data.keys())}") -obs = data.get("observation", {}) -print(f" clause_type: {obs.get('clause_type')}") -print(f" risk_level: {obs.get('risk_level')}") -print(f" done: {obs.get('done')}") - -# Test step -print("\n--- /step ---") -action = {"action_type": "FLAG_RISK"} -r = s.post(f"{base}/step", json={"action": action}) -print(f" Status: {r.status_code}") -data = r.json() -print(f" Keys: {list(data.keys())}") -print(f" reward: {data.get('reward')}") -print(f" done: {data.get('done')}") - -# Test schema -print("\n--- /schema ---") -r = s.get(f"{base}/schema") -print(f" Status: {r.status_code}") -print(f" Keys: {list(r.json().keys())}") - -# Test state -print("\n--- /state ---") -r = s.get(f"{base}/state") -print(f" Status: {r.status_code}") -print(f" Body: {r.json()}") - -print("\nAll HTTP endpoints OK!") From c3a1156ec71b2a4b48f44d3b41caeaa1897d5cf7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 07:54:26 +0000 Subject: [PATCH 04/13] Address code review: use task.trap_markers generically in _risk_score, clarify comments Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/c8dd2642-d749-42d1-9fad-91bc78d8d379 Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- contract_env/env/graders.py | 5 +++-- inference.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/contract_env/env/graders.py b/contract_env/env/graders.py index 685ca13..31a4f64 100644 --- a/contract_env/env/graders.py +++ b/contract_env/env/graders.py @@ -101,8 +101,9 @@ def trap_unresolved(task: NegotiationTask, contract_text: str) -> bool: def effective_risk_high(task: NegotiationTask, contract_text: str) -> bool: - # Any task that defines explicit trap markers is still "effectively high - # risk" as long as any trap marker remains in the text. + # Any task that defines explicit trap markers (HARD, HARD_PLUS, HARD_PLUS2, + # EXPERT, MEDIUM_PLUS) is still "effectively high risk" as long as any trap + # marker remains in the text. if task.trap_markers and trap_unresolved(task, contract_text): return True diff --git a/inference.py b/inference.py index f38d85a..f1b288e 100644 --- a/inference.py +++ b/inference.py @@ -154,7 +154,7 @@ def _parse_llm_json(text: str) -> Optional[dict]: def _risk_score(task: NegotiationTask, contract_text: str) -> float: hits = keyword_match_score(contract_text, task.risk_keywords) rs = min(1.0, hits * task.clause_type_weight / 1.15) - if task.name in ("HARD", "HARD_PLUS", "HARD_PLUS2", "EXPERT") and trap_unresolved(task, contract_text): + if task.trap_markers and trap_unresolved(task, contract_text): rs = min(1.0, rs + 0.25) return round(rs, 6) From 8a12816e5bc69cf21193651c97869f37b93ca7cd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 08:05:36 +0000 Subject: [PATCH 05/13] Add concession tracking, --retry-low flag, smarter ACCEPT gate, silence pydantic warnings - Track specific opponent concessions per topic (cap, liability, IP, etc.) - Feed concession summary to LLM for richer negotiation context - Add --retry-low THRESHOLD flag to re-run low-scoring tasks - Add smart ACCEPT gate: block acceptance when contract hasn't improved - Improve MODERATE strategy to front-load PROPOSE_COUNTER - Sort per-task summary by score (worst first) with best/mean stats - Silence FastAPI/pydantic internal deprecation warnings in pytest config - Add 2 new tests for concession tracking (51 total, all passing) Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/8526d5cb-221b-4b88-9ca5-4c8679367d2f Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- contract_env/tests/test_smoke.py | 26 ++++++ inference.py | 131 +++++++++++++++++++++++++++++-- pyproject.toml | 9 ++- 3 files changed, 158 insertions(+), 8 deletions(-) diff --git a/contract_env/tests/test_smoke.py b/contract_env/tests/test_smoke.py index 0153378..5a87176 100644 --- a/contract_env/tests/test_smoke.py +++ b/contract_env/tests/test_smoke.py @@ -134,6 +134,32 @@ def test_opponent_stance_parsing(self) -> None: # Empty self.assertEqual(_parse_opponent_stance([]), "neutral") + def test_concession_tracking(self) -> None: + """Track which specific topics the opponent has conceded on.""" + from inference import _track_concessions + + history = [ + "opponent|[Counterparty] We can accept a cap on liability.", + "agent|step=1 action=FLAG_RISK content_len=10", + "opponent|[Counterparty] Termination flexibility is non-negotiable.", + ] + concessions = _track_concessions(history) + # "cap" should be conceded, "termination" should be firm + self.assertEqual(concessions.get("cap"), "conceded") + self.assertEqual(concessions.get("termination"), "firm") + + def test_concession_summary_format(self) -> None: + """Concession summary should produce a readable string.""" + from inference import _concession_summary + + concessions = {"cap": "conceded", "termination": "firm"} + summary = _concession_summary(concessions) + self.assertIn("WILLING", summary) + self.assertIn("HOLDING FIRM", summary) + + # Empty case + self.assertEqual(_concession_summary({}), "") + if __name__ == "__main__": unittest.main() diff --git a/inference.py b/inference.py index f1b288e..37c3e28 100644 --- a/inference.py +++ b/inference.py @@ -268,9 +268,11 @@ def _build_rewrite_prompt( _VALID_ACTIONS = {"FLAG_RISK", "EDIT_CLAUSE", "ACCEPT", "REJECT", "PROPOSE_COUNTER"} -# Optimal action sequences per intent level for rule-based fallback +# Optimal action sequences per intent level for rule-based fallback. +# MODERATE tasks front-load PROPOSE_COUNTER since it's the ideal action for +# balanced negotiation (the environment appends [COUNTERPROPOSAL] tags). _STRATEGY_HIGH = ["FLAG_RISK", "EDIT_CLAUSE", "EDIT_CLAUSE", "PROPOSE_COUNTER", "REJECT", "EDIT_CLAUSE", "ACCEPT"] -_STRATEGY_MODERATE = ["FLAG_RISK", "PROPOSE_COUNTER", "EDIT_CLAUSE", "EDIT_CLAUSE", "ACCEPT"] +_STRATEGY_MODERATE = ["FLAG_RISK", "PROPOSE_COUNTER", "PROPOSE_COUNTER", "EDIT_CLAUSE", "EDIT_CLAUSE", "ACCEPT"] _STRATEGY_LOW = ["EDIT_CLAUSE", "EDIT_CLAUSE", "ACCEPT"] # ── OPPONENT-RESPONSE PARSING ─────────────────────────────────────────── @@ -286,6 +288,19 @@ def _build_rewrite_prompt( "standard and non-negotiable", "cannot accept", "is not included", ) +# Topic keywords to detect what specifically the opponent is conceding or +# holding firm on. Used for fine-grained concession tracking. +_TOPIC_KEYWORDS = { + "cap": ("cap", "capped", "limitation", "limit"), + "notice_period": ("notice", "days", "notice period"), + "ip_ownership": ("ownership", "ip", "intellectual property", "customer owns"), + "termination": ("termination", "terminate", "mutual", "cure"), + "liability": ("liability", "indemnify", "consequential", "punitive"), + "confidentiality": ("confidentiality", "nda", "perpetuity", "time limit"), + "data_protection": ("dpa", "data", "breach notification", "sub-processor", "gdpr"), + "change_control": ("change", "scope", "approval", "timeline"), +} + def _parse_opponent_stance(history: list[str]) -> str: """Analyse the latest opponent reply to determine their negotiation stance. @@ -309,6 +324,50 @@ def _parse_opponent_stance(history: list[str]) -> str: return "neutral" +def _track_concessions(history: list[str]) -> dict[str, str]: + """Track which negotiation topics the opponent has conceded on vs. held firm. + + Returns a dict like {"cap": "conceded", "liability": "firm", "notice_period": "unknown"}. + This enables the agent to focus edits on unresolved issues. + """ + concessions: dict[str, str] = {} + + for entry in history: + if not entry.startswith("opponent|"): + continue + low = entry.lower() + + is_conceding = any(s in low for s in _CONCESSION_SIGNALS) + is_firm = any(s in low for s in _FIRMNESS_SIGNALS) + + for topic, keywords in _TOPIC_KEYWORDS.items(): + if any(kw in low for kw in keywords): + if is_conceding: + concessions[topic] = "conceded" + elif is_firm: + concessions[topic] = "firm" + elif topic not in concessions: + concessions[topic] = "discussed" + + return concessions + + +def _concession_summary(concessions: dict[str, str]) -> str: + """Build a human-readable summary of opponent concessions for the LLM.""" + if not concessions: + return "" + parts: list[str] = [] + for topic, status in concessions.items(): + label = topic.replace("_", " ") + if status == "conceded": + parts.append(f" - {label}: opponent is WILLING to negotiate") + elif status == "firm": + parts.append(f" - {label}: opponent is HOLDING FIRM") + else: + parts.append(f" - {label}: discussed (no clear position)") + return "Opponent concession tracker:\n" + "\n".join(parts) + + def _choose( task: NegotiationTask, state_data: dict, @@ -320,8 +379,14 @@ def _choose( history = state_data.get("negotiation_history", []) history_summary = "\n".join(history[-8:]) if history else "" - # ── 0. Parse opponent stance from negotiation history ──────────────── + # ── 0. Parse opponent stance and track concessions ─────────────────── opponent_stance = _parse_opponent_stance(history) + concessions = _track_concessions(history) + conc_summary = _concession_summary(concessions) + + # Enrich history summary with concession tracking for the LLM + if conc_summary: + history_summary = history_summary + "\n\n" + conc_summary # ── 1. Ask the LLM for structured analysis ────────────────────────── parsed: Optional[dict] = None @@ -374,7 +439,24 @@ def _choose( if effective_risk_high(task, contract_text) or trap_unresolved(task, contract_text): action_type = "PROPOSE_COUNTER" - # ── 4c. Adaptive: if scores are improving and risk resolved, accept ── + # ── 4c. Concession-aware: if opponent conceded on key issues, lean EDIT ─ + conceded_topics = [t for t, s in concessions.items() if s == "conceded"] + if conceded_topics and step > 1: + # Opponent has given ground — capitalise with a concrete edit + if action_type == "FLAG_RISK": + action_type = "EDIT_CLAUSE" + + # ── 4d. Smart ACCEPT gate: only accept when quality actually improved ─ + if action_type == "ACCEPT": + from contract_env.env.graders import observation_risk_float + current_risk = observation_risk_float(task, contract_text) + original_risk = observation_risk_float(task, task.contract_text) + # Block acceptance if the contract hasn't improved meaningfully + if current_risk >= original_risk - 0.05: + if effective_risk_high(task, contract_text) or trap_unresolved(task, contract_text): + action_type = "EDIT_CLAUSE" + + # ── 4e. Adaptive: if scores are improving and risk resolved, accept ── if ( len(prev_rewards) >= 3 and all(r > 0.45 for r in prev_rewards[-2:]) @@ -558,6 +640,16 @@ def main() -> None: "'api' connects to the Docker server via HTTP at ENV_SERVER_URL." ), ) + parser.add_argument( + "--retry-low", + type=float, + default=0.0, + metavar="THRESHOLD", + help=( + "Re-run tasks that scored below THRESHOLD (e.g. --retry-low 0.4). " + "Each low-scoring task is retried once. 0 = disabled (default)." + ), + ) args = parser.parse_args() # Select environment backend @@ -578,13 +670,38 @@ def main() -> None: total_score += ep_score task_scores.setdefault(task_id, []).append(ep_score) + # ── Retry low-scoring tasks ────────────────────────────────────────── + retry_threshold = args.retry_low + if retry_threshold > 0: + low_tasks = { + tid: scores + for tid, scores in task_scores.items() + if (sum(scores) / len(scores)) < retry_threshold + } + if low_tasks: + print( + f"\n[RETRY] {len(low_tasks)} task(s) scored below {retry_threshold:.2f}, retrying...", + flush=True, + ) + for tid in low_tasks: + ep_score, _ = run_episode(env) + total_score += ep_score + task_scores[tid].append(ep_score) + episodes_to_run += 1 + mean_score = total_score / max(episodes_to_run, 1) - # Per-task summary + # Per-task summary (sorted by score, worst first) print("\n[TASK SCORES]", flush=True) - for tid, scores in task_scores.items(): + sorted_tasks = sorted(task_scores.items(), key=lambda x: sum(x[1]) / len(x[1])) + for tid, scores in sorted_tasks: avg = sum(scores) / len(scores) - print(f" {tid}: mean={avg:.3f} runs={len(scores)}", flush=True) + best = max(scores) + status = "✓" if avg >= SUCCESS_SCORE_THRESHOLD else "✗" + print( + f" {status} {tid}: mean={avg:.3f} best={best:.3f} runs={len(scores)}", + flush=True, + ) print( f"\n[SUMMARY] episodes={episodes_to_run} mean_score={mean_score:.3f} " diff --git a/pyproject.toml b/pyproject.toml index 2275543..cc655a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,4 +26,11 @@ server = "contract_env.server.app:main" [tool.setuptools.packages.find] where = ["."] -include = ["contract_env*"] \ No newline at end of file +include = ["contract_env*"] + +[tool.pytest.ini_options] +filterwarnings = [ + # FastAPI internally passes 'deprecated' as an extra keyword to pydantic Field(). + # This is a known issue between FastAPI and Pydantic v2 — not our code. + "ignore::pydantic.warnings.PydanticDeprecatedSince20", +] \ No newline at end of file From a8290fe92bec0dfb113257d9c9e0385715dba31d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 08:07:41 +0000 Subject: [PATCH 06/13] Fix stale README test count (51), extract magic numbers to named constants Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/8526d5cb-221b-4b88-9ca5-4c8679367d2f Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- README.md | 2 +- inference.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index bc7a311..8a32b46 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ An episode is considered successful if `score ≥ 0.50`. ```bash pip install -e ".[dev]" -python -m pytest contract_env/tests/ -v # 42 tests +python -m pytest contract_env/tests/ -v # 51 tests ``` ### Run the server diff --git a/inference.py b/inference.py index 37c3e28..4dbb0c8 100644 --- a/inference.py +++ b/inference.py @@ -50,6 +50,8 @@ ENV_SERVER_URL = os.getenv("ENV_SERVER_URL", "http://localhost:7860") MAX_STEPS = 10 SUCCESS_SCORE_THRESHOLD = 0.5 +HISTORY_WINDOW = 8 # How many recent history entries to show the LLM +ACCEPT_IMPROVEMENT_THRESHOLD = 0.05 # Minimum risk reduction before acceptance allowed # ── LLM CLIENT (lazy singleton) ───────────────────────────────────────── _client: Optional[OpenAI] = None @@ -377,7 +379,7 @@ def _choose( """LLM-driven action selection with rule-based fallback and opponent awareness.""" contract_text = state_data["contract_text"] history = state_data.get("negotiation_history", []) - history_summary = "\n".join(history[-8:]) if history else "" + history_summary = "\n".join(history[-HISTORY_WINDOW:]) if history else "" # ── 0. Parse opponent stance and track concessions ─────────────────── opponent_stance = _parse_opponent_stance(history) @@ -452,7 +454,7 @@ def _choose( current_risk = observation_risk_float(task, contract_text) original_risk = observation_risk_float(task, task.contract_text) # Block acceptance if the contract hasn't improved meaningfully - if current_risk >= original_risk - 0.05: + if current_risk >= original_risk - ACCEPT_IMPROVEMENT_THRESHOLD: if effective_risk_high(task, contract_text) or trap_unresolved(task, contract_text): action_type = "EDIT_CLAUSE" From d0e62b275d95d5bad65c7916c5f0aac9b69d0a4e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 08:17:16 +0000 Subject: [PATCH 07/13] Harden validation, fix 3 failing tests, add Literal types, edge-case tests, input length guards - Make python-dotenv import optional so tests pass without LLM deps - Add Literal types for risk_level (HIGH/MODERATE/LOW) and clause_type - Add Pydantic validator for opponent_responses keys (must be valid ActionType) - Add content length validation in ContractEnv (max_content_length=50_000) - Add /evaluate-quality input length guard (100_000 chars max) - Consolidate NUM_GRADED_TASKS to single definition in graders.py - Add 8 new edge-case tests: content length, max steps, done state, unicode, empty risk_keywords, opponent_response key validation, API max length - Update README: --retry-low docs, CORS_ORIGINS env var, test count (59) - All 59 tests pass Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/028275d3-079a-4945-8a77-3e3dcdf5d12a Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- README.md | 6 +++- contract_env/env/__init__.py | 1 - contract_env/env/environment.py | 5 +++ contract_env/env/tasks.py | 46 ++++++++++++++++++++++----- contract_env/server/app.py | 8 +++++ contract_env/tests/test_graders.py | 50 ++++++++++++++++++++++++++++++ contract_env/tests/test_smoke.py | 48 ++++++++++++++++++++++++++++ inference.py | 7 ++++- 8 files changed, 160 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 8a32b46..5d8d626 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ An episode is considered successful if `score ≥ 0.50`. ```bash pip install -e ".[dev]" -python -m pytest contract_env/tests/ -v # 51 tests +python -m pytest contract_env/tests/ -v # 59 tests ``` ### Run the server @@ -225,6 +225,9 @@ export HF_TOKEN="your-huggingface-token" python inference.py --benchmark # one episode per task (8 total) python inference.py --episodes 3 # run 3 episodes cycling through tasks +# Retry any task that scores below 0.4: +python inference.py --benchmark --retry-low 0.4 + # Against the Docker API server (for competition evaluation): python inference.py --benchmark --mode api ``` @@ -254,6 +257,7 @@ python inference.py --benchmark --mode api | `BENCHMARK` | No | `contract_negotiation` | Benchmark name in [START] log line | | `ENV_SERVER_URL` | No | `http://localhost:7860` | Docker server URL (for `--mode api`) | | `PORT` | No | `7860` | Server port | +| `CORS_ORIGINS` | No | `*` | Comma-separated allowed CORS origins | --- diff --git a/contract_env/env/__init__.py b/contract_env/env/__init__.py index d47b22f..3d77f9c 100644 --- a/contract_env/env/__init__.py +++ b/contract_env/env/__init__.py @@ -25,7 +25,6 @@ validate_all_tasks_have_graders, GRADED_TASK_IDS, GRADED_TASK_NAMES, - NUM_GRADED_TASKS as TASKS_NUM_GRADED, ) __all__ = [ diff --git a/contract_env/env/environment.py b/contract_env/env/environment.py index 0bd5079..3d37978 100644 --- a/contract_env/env/environment.py +++ b/contract_env/env/environment.py @@ -16,6 +16,7 @@ class ContractEnv: max_steps: int = 7 + max_content_length: int = 50_000 # guard against oversized action content def __init__(self) -> None: self._reset_count: int = 0 @@ -92,6 +93,10 @@ def _validate_action(self, action: Action) -> Optional[str]: if action.action_type in ("EDIT_CLAUSE", "PROPOSE_COUNTER"): if not c: return "EDIT_CLAUSE and PROPOSE_COUNTER require non-empty content" + if len(c) > self.max_content_length: + return ( + f"content exceeds maximum length of {self.max_content_length} characters" + ) return None diff --git a/contract_env/env/tasks.py b/contract_env/env/tasks.py index c6d3356..77b73ac 100644 --- a/contract_env/env/tasks.py +++ b/contract_env/env/tasks.py @@ -1,23 +1,42 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, List +from typing import TYPE_CHECKING, Any, Callable, List, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator if TYPE_CHECKING: from contract_env.env.models import Reward from contract_env.env.models import Action +# Allowed values – kept in sync with openenv.yaml / models.py +RiskLevel = Literal["HIGH", "MODERATE", "LOW"] +ClauseType = Literal[ + "liability", + "term_renewal", + "performance_changes", + "compliance", + "intellectual_property", + "confidentiality", + "termination", + "data_protection", +] + +_VALID_ACTION_TYPES = frozenset( + {"FLAG_RISK", "EDIT_CLAUSE", "ACCEPT", "REJECT", "PROPOSE_COUNTER"} +) + class NegotiationTask(BaseModel): + """A single contract-negotiation task with clause text, metadata, and grading info.""" + id: str name: str contract_text: str - clause_type: str - risk_keywords: List[str] # ✅ FIXED (was tuple before) + clause_type: ClauseType + risk_keywords: List[str] safe_keywords: List[str] expected_safe_edit: str - risk_level: str + risk_level: RiskLevel hidden_trap: str trap_markers: List[str] = Field( default_factory=list, @@ -43,6 +62,18 @@ class NegotiationTask(BaseModel): grader: str = Field(default="") grader_name: str + @field_validator("opponent_responses") + @classmethod + def _validate_opponent_response_keys( + cls, v: dict[str, List[str]] + ) -> dict[str, List[str]]: + bad = set(v.keys()) - _VALID_ACTION_TYPES + if bad: + raise ValueError( + f"opponent_responses contains invalid action types: {bad}" + ) + return v + def get_grader(self) -> Callable[["NegotiationTask", str, "Action", str], "Reward"]: """Return the grader function assigned to this task.""" return self.grader_func @@ -596,7 +627,6 @@ def validate_all_tasks_have_graders() -> bool: return False return True -# Metadata - just count tasks with grader field set +# Metadata — derived from TASKS (single source of truth for task-side counts) GRADED_TASK_IDS = [task.id for task in TASKS if task.has_grader()] -GRADED_TASK_NAMES = [task.name for task in TASKS if task.has_grader()] -NUM_GRADED_TASKS = len(GRADED_TASK_IDS) \ No newline at end of file +GRADED_TASK_NAMES = [task.name for task in TASKS if task.has_grader()] \ No newline at end of file diff --git a/contract_env/server/app.py b/contract_env/server/app.py index c4a206e..c2d442e 100644 --- a/contract_env/server/app.py +++ b/contract_env/server/app.py @@ -138,6 +138,9 @@ def get_schema(): } +_MAX_EVALUATE_TEXT_LEN = 100_000 + + # ── EVALUATE QUALITY ───────────────────────────────────────────────────── @app.post("/evaluate-quality") def evaluate_quality(body: dict): @@ -152,6 +155,11 @@ def evaluate_quality(body: dict): contract_text = body.get("contract_text", "") if not contract_text: raise HTTPException(status_code=422, detail="contract_text must be non-empty.") + if len(contract_text) > _MAX_EVALUATE_TEXT_LEN: + raise HTTPException( + status_code=422, + detail=f"contract_text exceeds maximum length of {_MAX_EVALUATE_TEXT_LEN}.", + ) quality = contract_quality_score(_env.current_task, contract_text) return {"quality_score": round(quality, 4), "risk_score": round(1.0 - quality, 4)} diff --git a/contract_env/tests/test_graders.py b/contract_env/tests/test_graders.py index 66c7507..9398bc9 100644 --- a/contract_env/tests/test_graders.py +++ b/contract_env/tests/test_graders.py @@ -248,5 +248,55 @@ def test_accept_blocked_on_hard_plus2_unresolved(self) -> None: self.assertTrue(info.get("accept_blocked")) + def test_empty_risk_keywords_handled(self) -> None: + """Tasks with empty risk_keywords should not crash scoring.""" + from contract_env.env.graders import keyword_match_score + score = keyword_match_score("any text here", []) + self.assertEqual(score, 0.0) + + def test_unicode_in_contract_text(self) -> None: + """Non-ASCII contract text should be scored without errors.""" + task = TASKS[0] + action = Action( + action_type="EDIT_CLAUSE", + content="Haftungsbeschränkung: Begrenzung auf gezahlte Gebühren der letzten 12 Monate.", + ) + r = grade_action(task, task.contract_text, action, action.content) + self.assertGreater(r.score, 0.0) + self.assertLess(r.score, 1.0) + + def test_opponent_response_key_validation(self) -> None: + """Invalid action type keys in opponent_responses should be rejected.""" + from contract_env.env.tasks import NegotiationTask + from pydantic import ValidationError + with self.assertRaises(ValidationError): + NegotiationTask( + id="test", + name="TEST", + contract_text="test", + clause_type="liability", + risk_keywords=["test"], + safe_keywords=["test"], + expected_safe_edit="test", + risk_level="HIGH", + hidden_trap="", + opponent_responses={"INVALID_ACTION": ["reply"]}, + grader_func=grade_easy, + grader_name="grade_easy", + ) + + def test_evaluate_quality_endpoint_max_length(self) -> None: + """API should reject excessively long contract_text.""" + from fastapi.testclient import TestClient + from contract_env.server.app import app, _env + client = TestClient(app) + _env.reset() + r = client.post( + "/evaluate-quality", + json={"contract_text": "x" * 100_001}, + ) + self.assertEqual(r.status_code, 422) + + if __name__ == "__main__": unittest.main() diff --git a/contract_env/tests/test_smoke.py b/contract_env/tests/test_smoke.py index 5a87176..2806514 100644 --- a/contract_env/tests/test_smoke.py +++ b/contract_env/tests/test_smoke.py @@ -161,5 +161,53 @@ def test_concession_summary_format(self) -> None: self.assertEqual(_concession_summary({}), "") + def test_content_length_validation(self) -> None: + """Oversized content should be rejected gracefully.""" + env = ContractEnv() + env.reset() + huge_content = "x" * (env.max_content_length + 1) + obs, r, done, info = env.step( + Action(action_type="EDIT_CLAUSE", content=huge_content) + ) + self.assertEqual(r, 0.001) + self.assertIn("error", info) + self.assertIn("exceeds maximum length", info["error"]) + + def test_episode_runs_to_max_steps(self) -> None: + """Episode should terminate at max_steps if agent never accepts.""" + env = ContractEnv() + env.reset() + for i in range(env.max_steps): + obs, r, done, info = env.step( + Action(action_type="FLAG_RISK", content="risk") + ) + if i < env.max_steps - 1: + self.assertFalse(done) + self.assertTrue(done) + self.assertEqual(info.get("termination_reason"), "max_steps_reached") + + def test_step_after_done_returns_error(self) -> None: + """Stepping after episode is done should return an error.""" + env = ContractEnv() + env.reset() + env.step(Action(action_type="ACCEPT")) + obs, r, done, info = env.step(Action(action_type="FLAG_RISK", content="test")) + self.assertTrue(done) + self.assertEqual(info.get("error"), "already_done") + + def test_unicode_content_handled(self) -> None: + """Non-ASCII characters should not crash the environment.""" + env = ContractEnv() + env.reset() + obs, r, done, info = env.step( + Action( + action_type="EDIT_CLAUSE", + content="Les parties s'engagent à limiter la responsabilité — §12 Haftung" + ) + ) + self.assertGreater(r, 0.0) + self.assertLess(r, 1.0) + + if __name__ == "__main__": unittest.main() diff --git a/inference.py b/inference.py index 4dbb0c8..18d7e5c 100644 --- a/inference.py +++ b/inference.py @@ -28,7 +28,12 @@ import re from typing import Any, Optional -from dotenv import load_dotenv +try: + from dotenv import load_dotenv +except ImportError: # pragma: no cover – optional for test environments + def load_dotenv(*_a: Any, **_kw: Any) -> None: # type: ignore[misc] + pass + from openai import OpenAI from contract_env.env.environment import ContractEnv From 676717a18df628140b9b78e64b9761511823cf50 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 08:18:39 +0000 Subject: [PATCH 08/13] Consolidate _VALID_ACTION_TYPES to derive from canonical ActionType Literal in models.py Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/028275d3-079a-4945-8a77-3e3dcdf5d12a Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- contract_env/env/tasks.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/contract_env/env/tasks.py b/contract_env/env/tasks.py index 77b73ac..bedc25e 100644 --- a/contract_env/env/tasks.py +++ b/contract_env/env/tasks.py @@ -1,9 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, List, Literal +from typing import TYPE_CHECKING, Any, Callable, List, Literal, get_args from pydantic import BaseModel, Field, field_validator +from contract_env.env.models import ActionType # single source of truth + if TYPE_CHECKING: from contract_env.env.models import Reward from contract_env.env.models import Action @@ -21,9 +23,8 @@ "data_protection", ] -_VALID_ACTION_TYPES = frozenset( - {"FLAG_RISK", "EDIT_CLAUSE", "ACCEPT", "REJECT", "PROPOSE_COUNTER"} -) +# Derive valid action type strings from the canonical Literal in models.py +_VALID_ACTION_TYPES: frozenset[str] = frozenset(get_args(ActionType)) class NegotiationTask(BaseModel): From 085ff9fa30f576d7f73a37219fd52e0ff12b7645 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 09:30:54 +0000 Subject: [PATCH 09/13] Fix critical accept-blocking bug, retry logic, stack trace leak, session leak, pin deps - Fix negation-aware keyword matching: risk keywords in negation context (e.g. "no party is liable for consequential damages") no longer falsely trigger effective_risk_high(), which was blocking ACCEPT on the easy task even after the agent submitted the correct safe edit - Fix retry logic: --retry-low now passes task_id to reset() so it actually retries the same low-scoring task instead of cycling to the next one - Add reset(task_id=) parameter to ContractEnv for targeted task selection - Remove stack trace exposure in production error handler (app.py) - Add close()/context manager to _HTTPEnvClient for proper session cleanup - Pin requirements.txt versions to match pyproject.toml - Add 7 new tests: negation matching, safe edit acceptance, full edit+accept flow, reset with task_id, original contracts flagged correctly - 66 tests pass Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/5b50d4d5-393e-4df1-ad7e-50ef984ec409 Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- README.md | 2 +- contract_env/env/environment.py | 13 +++++-- contract_env/env/graders.py | 33 +++++++++++++++- contract_env/server/app.py | 2 - contract_env/tests/test_graders.py | 61 ++++++++++++++++++++++++++++++ contract_env/tests/test_smoke.py | 30 +++++++++++++++ inference.py | 22 +++++++++-- requirements.txt | 9 +++-- 8 files changed, 158 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 5d8d626..57fa72a 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ An episode is considered successful if `score ≥ 0.50`. ```bash pip install -e ".[dev]" -python -m pytest contract_env/tests/ -v # 59 tests +python -m pytest contract_env/tests/ -v # 66 tests ``` ### Run the server diff --git a/contract_env/env/environment.py b/contract_env/env/environment.py index 3d37978..0298785 100644 --- a/contract_env/env/environment.py +++ b/contract_env/env/environment.py @@ -49,14 +49,21 @@ def _opponent_reply(self, action_type: str) -> Optional[str]: return None return self._rng.choice(responses) - def reset(self) -> Observation: + def reset(self, task_id: Optional[str] = None) -> Observation: self.done = False self.current_step = 0 - idx = self._reset_count % len(TASKS) + if task_id is not None: + # Reset to a specific task (used by retry logic) + match = next((t for t in TASKS if t.id == task_id), None) + if match is None: + raise ValueError(f"Unknown task_id: {task_id!r}") + self.current_task = match + else: + idx = self._reset_count % len(TASKS) + self.current_task = TASKS[idx] self._reset_count += 1 - self.current_task = TASKS[idx] assert self.current_task is not None t = self.current_task diff --git a/contract_env/env/graders.py b/contract_env/env/graders.py index 31a4f64..b499e49 100644 --- a/contract_env/env/graders.py +++ b/contract_env/env/graders.py @@ -44,6 +44,36 @@ def _cosine_similarity(a: str, b: str) -> float: return dot / (mag_a * mag_b) +# Negation prefixes that invert the meaning of a risk keyword. +# E.g. "no consequential damages" is safe, not risky. +_NEGATION_PREFIXES = ( + "no ", "not ", "neither ", "without any ", "exclud", "except for ", + "not liable for ", "no party is liable for ", "shall not include ", + "does not cover ", "not responsible for ", +) + + +def _is_negated(text_lower: str, keyword_lower: str) -> bool: + """Return True if *every* occurrence of keyword_lower in text_lower is preceded + by a negation phrase, meaning the keyword appears only in a 'safe' context.""" + idx = 0 + all_negated = True + found_any = False + while True: + pos = text_lower.find(keyword_lower, idx) + if pos == -1: + break + found_any = True + # Check the 60-character window before the match for negation cues + window_start = max(0, pos - 60) + preceding = text_lower[window_start:pos] + if not any(neg in preceding for neg in _NEGATION_PREFIXES): + all_negated = False + break + idx = pos + len(keyword_lower) + return found_any and all_negated + + def _weighted_risk_hits(text: str, risk_keywords: list[str]) -> float: low = text.lower() if not risk_keywords: @@ -51,7 +81,8 @@ def _weighted_risk_hits(text: str, risk_keywords: list[str]) -> float: hits = 0 for phrase in risk_keywords: - if phrase.lower() in low: + kw = phrase.lower() + if kw in low and not _is_negated(low, kw): hits += 1 return min(1.0, hits / max(len(risk_keywords), 1)) diff --git a/contract_env/server/app.py b/contract_env/server/app.py index c2d442e..1f1df99 100644 --- a/contract_env/server/app.py +++ b/contract_env/server/app.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import traceback from typing import Any from fastapi import FastAPI, HTTPException, Request @@ -61,7 +60,6 @@ async def global_exception_handler(request: Request, exc: Exception): status_code=500, content={ "detail": str(exc), - "trace": traceback.format_exc(), }, ) diff --git a/contract_env/tests/test_graders.py b/contract_env/tests/test_graders.py index 9398bc9..5eeaeea 100644 --- a/contract_env/tests/test_graders.py +++ b/contract_env/tests/test_graders.py @@ -297,6 +297,67 @@ def test_evaluate_quality_endpoint_max_length(self) -> None: ) self.assertEqual(r.status_code, 422) + def test_accept_not_blocked_after_safe_edit(self) -> None: + """ACCEPT should NOT be blocked when the contract has been rewritten to the safe edit.""" + for task in TASKS: + r, info = evaluate_action( + task, task.expected_safe_edit, + Action(action_type="ACCEPT"), task.expected_safe_edit, + ) + self.assertFalse( + info.get("accept_blocked", False), + f"Task {task.id}: ACCEPT blocked on expected_safe_edit text", + ) + self.assertGreater( + r.score, 0.01, + f"Task {task.id}: ACCEPT reward too low after safe edit", + ) + + def test_negation_aware_keyword_matching(self) -> None: + """Risk keywords in negation context should not count as risk hits.""" + from contract_env.env.graders import _weighted_risk_hits, _is_negated + # "no consequential damages" — negated + self.assertTrue(_is_negated( + "no party is liable for consequential damages", "consequential" + )) + # "consequential damages apply" — NOT negated + self.assertFalse(_is_negated( + "consequential damages apply to all claims", "consequential" + )) + # Risk hits should be 0 when negated + self.assertEqual( + _weighted_risk_hits( + "no party is liable for consequential or punitive damages", + ["consequential", "punitive"], + ), + 0.0, + ) + # Risk hits should be >0 when NOT negated + self.assertGreater( + _weighted_risk_hits( + "vendor has consequential and punitive liability", + ["consequential", "punitive"], + ), + 0.0, + ) + + def test_safe_edits_not_flagged_as_high_risk(self) -> None: + """All expected_safe_edits should NOT be classified as effectively high risk.""" + for task in TASKS: + self.assertFalse( + effective_risk_high(task, task.expected_safe_edit), + f"Task {task.id}: expected_safe_edit incorrectly flagged as high risk", + ) + + def test_original_contracts_flagged_as_high_risk(self) -> None: + """All original contract texts with HIGH risk_level should be effectively high risk.""" + for task in TASKS: + if task.risk_level == "HIGH": + self.assertTrue( + effective_risk_high(task, task.contract_text), + f"Task {task.id}: original contract not flagged as high risk", + ) + if __name__ == "__main__": unittest.main() diff --git a/contract_env/tests/test_smoke.py b/contract_env/tests/test_smoke.py index 2806514..7ac6464 100644 --- a/contract_env/tests/test_smoke.py +++ b/contract_env/tests/test_smoke.py @@ -208,6 +208,36 @@ def test_unicode_content_handled(self) -> None: self.assertGreater(r, 0.0) self.assertLess(r, 1.0) + def test_reset_with_task_id(self) -> None: + """reset(task_id=...) should force a specific task.""" + env = ContractEnv() + obs = env.reset(task_id="expert_data_protection") + self.assertEqual(env.current_task.id, "expert_data_protection") + self.assertIn("data", obs.clause_type) + + def test_reset_with_invalid_task_id(self) -> None: + """reset(task_id=...) with an unknown ID should raise ValueError.""" + env = ContractEnv() + with self.assertRaises(ValueError): + env.reset(task_id="nonexistent_task") + + def test_edit_then_accept_full_flow(self) -> None: + """Full flow: EDIT_CLAUSE with safe edit, then ACCEPT should not be blocked.""" + env = ContractEnv() + env.reset() + task = env.current_task + # Step 1: submit the safe edit + obs, r1, done1, info1 = env.step( + Action(action_type="EDIT_CLAUSE", content=task.expected_safe_edit) + ) + self.assertFalse(done1) + self.assertGreater(r1, 0.1) + # Step 2: accept the edited contract + obs, r2, done2, info2 = env.step(Action(action_type="ACCEPT")) + self.assertTrue(done2) + self.assertFalse(info2.get("accept_blocked", False)) + self.assertGreater(r2, 0.01) + if __name__ == "__main__": unittest.main() diff --git a/inference.py b/inference.py index 18d7e5c..abf5688 100644 --- a/inference.py +++ b/inference.py @@ -506,6 +506,15 @@ def __init__(self, base_url: str) -> None: self._task_idx = 0 self.current_task: Optional[NegotiationTask] = None + def close(self) -> None: + self._session.close() + + def __enter__(self): + return self + + def __exit__(self, *exc: Any) -> None: + self.close() + def reset(self): resp = self._session.post(f"{self.base_url}/reset") resp.raise_for_status() @@ -551,12 +560,19 @@ def __init__(self, d: dict) -> None: # ── EPISODE EXECUTION ──────────────────────────────────────────────────── -def run_episode(env) -> tuple[float, str]: +def run_episode(env, task_id: Optional[str] = None) -> tuple[float, str]: """Run one full episode using env.reset() → loop env.step() → log. + Args: + env: Environment instance (ContractEnv or _HTTPEnvClient). + task_id: If given, reset to this specific task (local mode only). + Returns (mean_episode_score, task_id). """ - obs_obj = env.reset() + if task_id is not None and hasattr(env, "reset") and "task_id" in env.reset.__code__.co_varnames: + obs_obj = env.reset(task_id=task_id) + else: + obs_obj = env.reset() task = env.current_task state_data: dict[str, Any] = { @@ -691,7 +707,7 @@ def main() -> None: flush=True, ) for tid in low_tasks: - ep_score, _ = run_episode(env) + ep_score, _ = run_episode(env, task_id=tid) total_score += ep_score task_scores[tid].append(ep_score) episodes_to_run += 1 diff --git a/requirements.txt b/requirements.txt index a217b6d..0280db9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ pydantic==2.6.0 -python-dotenv -fastapi -uvicorn -openai +python-dotenv>=1.0.0 +fastapi>=0.115.0 +uvicorn[standard]>=0.32.0 +openai>=1.50.0 +openenv>=0.1.13 From 0d1222d7c9c63a65976d0e79cb7e951d693ddd1e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 09:32:09 +0000 Subject: [PATCH 10/13] Improve _is_negated docstring for clarity Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/5b50d4d5-393e-4df1-ad7e-50ef984ec409 Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- contract_env/env/graders.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/contract_env/env/graders.py b/contract_env/env/graders.py index b499e49..900e029 100644 --- a/contract_env/env/graders.py +++ b/contract_env/env/graders.py @@ -55,7 +55,10 @@ def _cosine_similarity(a: str, b: str) -> float: def _is_negated(text_lower: str, keyword_lower: str) -> bool: """Return True if *every* occurrence of keyword_lower in text_lower is preceded - by a negation phrase, meaning the keyword appears only in a 'safe' context.""" + by a negation phrase, meaning the keyword appears only in a 'safe' context. + + Returns False if the keyword is not found at all (no occurrence to negate). + """ idx = 0 all_negated = True found_any = False From 20be8130db06c251efe3bcb81035b098e156a081 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 10:00:52 +0000 Subject: [PATCH 11/13] Replace assertions with proper validation, add Pydantic request model, docstrings, new API tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace 3 assertions in environment.py with RuntimeError guards that survive python -O - Replace assert in graders.py with explicit ValueError - Add Pydantic EvaluateQualityRequest model to app.py replacing raw dict - Remove dead _MAX_EVALUATE_TEXT_LEN constant and unused Any import - Add comprehensive docstrings to ContractEnv class and methods - Add docstring to evaluate_action() explaining 5-dimension rubric - Fix fragile __code__.co_varnames introspection in inference.py with try/except - Remove redundant TASKS import in environment.py tasks property - Normalize "belongs to Supplier" → "belongs to supplier" in tasks.py - Add 6 new API tests: schema, root, evaluate-quality missing/empty/success, invalid action - Update README test count from 66 → 72 Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/5ac284f3-516c-4d77-ba64-14f8ee06778c Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- README.md | 2 +- contract_env/env/environment.py | 51 ++++++++++++++++++++++++++++++--- contract_env/env/graders.py | 16 ++++++++++- contract_env/env/tasks.py | 2 +- contract_env/server/app.py | 23 ++++++--------- contract_env/tests/test_api.py | 44 ++++++++++++++++++++++++++++ inference.py | 8 ++++-- 7 files changed, 123 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 57fa72a..f0ca664 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ An episode is considered successful if `score ≥ 0.50`. ```bash pip install -e ".[dev]" -python -m pytest contract_env/tests/ -v # 66 tests +python -m pytest contract_env/tests/ -v # 72 tests ``` ### Run the server diff --git a/contract_env/env/environment.py b/contract_env/env/environment.py index 0298785..4dfd20f 100644 --- a/contract_env/env/environment.py +++ b/contract_env/env/environment.py @@ -15,6 +15,20 @@ class ContractEnv: + """Multi-turn contract-negotiation environment. + + Cycles through a list of :class:`NegotiationTask` objects, presenting + agents with contract clauses to analyse and improve. Each episode + consists of up to ``max_steps`` actions, and the agent receives a + reward after every ``step()``. + + Usage:: + + env = ContractEnv() + obs = env.reset() + obs, reward, done, info = env.step(Action(action_type="FLAG_RISK")) + """ + max_steps: int = 7 max_content_length: int = 50_000 # guard against oversized action content @@ -28,7 +42,7 @@ def __init__(self) -> None: @property def tasks(self) -> list[str]: - from contract_env.env.tasks import TASKS + """Return IDs of all registered negotiation tasks.""" return [task.id for task in TASKS] @property @@ -50,6 +64,19 @@ def _opponent_reply(self, action_type: str) -> Optional[str]: return self._rng.choice(responses) def reset(self, task_id: Optional[str] = None) -> Observation: + """Begin a new episode, optionally targeting a specific task. + + Args: + task_id: If given, reset to the task with this ID instead of + cycling through the task list sequentially. + + Returns: + Initial observation for the episode. + + Raises: + ValueError: If *task_id* is not ``None`` and no matching task exists. + RuntimeError: If no task could be selected (should never happen). + """ self.done = False self.current_step = 0 @@ -64,7 +91,8 @@ def reset(self, task_id: Optional[str] = None) -> Observation: self.current_task = TASKS[idx] self._reset_count += 1 - assert self.current_task is not None + if self.current_task is None: # pragma: no cover — defensive guard + raise RuntimeError("No task selected after reset") t = self.current_task @@ -83,7 +111,8 @@ def reset(self, task_id: Optional[str] = None) -> Observation: return self._make_observation() def _make_observation(self) -> Observation: - assert self.current_task is not None + if self.current_task is None: # pragma: no cover — defensive guard + raise RuntimeError("Cannot make observation: no active task. Call reset() first.") ct = self.state_data["contract_text"] return Observation( @@ -108,7 +137,19 @@ def _validate_action(self, action: Action) -> Optional[str]: return None def step(self, action: Action) -> Tuple[Observation, float, bool, dict[str, Any]]: - assert self.current_task is not None + """Execute one negotiation action and return (observation, reward, done, info). + + Args: + action: The agent's chosen action (action_type + optional content). + + Returns: + A 4-tuple of (observation, reward, done, info). + + Raises: + RuntimeError: If called before ``reset()``. + """ + if self.current_task is None: + raise RuntimeError("Cannot step: no active task. Call reset() first.") info: dict[str, Any] = {} @@ -173,6 +214,7 @@ def step(self, action: Action) -> Tuple[Observation, float, bool, dict[str, Any] return self._make_observation(), reward, self.done, info def state(self) -> dict[str, Any]: + """Return a serialisable snapshot of the current environment state.""" out = dict(self.state_data) if self.current_task is not None: @@ -185,4 +227,5 @@ def state(self) -> dict[str, Any]: return out def close(self) -> None: + """Clean up resources (no-op for this environment).""" return None \ No newline at end of file diff --git a/contract_env/env/graders.py b/contract_env/env/graders.py index 900e029..d059a4f 100644 --- a/contract_env/env/graders.py +++ b/contract_env/env/graders.py @@ -172,6 +172,19 @@ def evaluate_action( action: Action, proposed_contract_text: str, ) -> Tuple[Reward, dict[str, Any]]: + """Score an agent action using a 5-dimensional rubric. + + Dimensions (weights sum to 1.0): + - **correctness** (0.35): Risk-keyword identification or removal. + - **improvement** (0.25): Overlap with safe keywords / expected safe edit. + - **risk_alignment** (0.25): Whether the action type fits the risk level. + - **semantic_similarity** (0.10): Cosine + Jaccard similarity to expected safe edit. + - **completeness** (0.05): Required legal elements present in the rewrite. + + Returns: + (Reward, info_dict) where info_dict contains per-dimension scores and + an ``accept_blocked`` flag when ACCEPT is attempted on a still-risky contract. + """ content = (action.content or "").strip() eval_text = content if content else proposed_contract_text @@ -463,4 +476,5 @@ def grade_expert(task: NegotiationTask, contract_before: str, action: Action, pr # Count of tasks with graders NUM_GRADED_TASKS = len(GRADED_TASKS) -assert NUM_GRADED_TASKS >= 3, f"Expected at least 3 graded tasks, got {NUM_GRADED_TASKS}" \ No newline at end of file +if NUM_GRADED_TASKS < 3: + raise ValueError(f"Expected at least 3 graded tasks, got {NUM_GRADED_TASKS}") \ No newline at end of file diff --git a/contract_env/env/tasks.py b/contract_env/env/tasks.py index bedc25e..d853050 100644 --- a/contract_env/env/tasks.py +++ b/contract_env/env/tasks.py @@ -325,7 +325,7 @@ def grade(self, contract_before: str, action: "Action", proposed_contract_text: ), clause_type="intellectual_property", risk_keywords=[ - "belongs to Supplier", + "belongs to supplier", "expressly agreed otherwise", "created under this agreement", "feedback", diff --git a/contract_env/server/app.py b/contract_env/server/app.py index 1f1df99..aa88c3e 100644 --- a/contract_env/server/app.py +++ b/contract_env/server/app.py @@ -1,19 +1,24 @@ from __future__ import annotations import os -from typing import Any from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse -from pydantic import ValidationError +from pydantic import BaseModel, Field, ValidationError from contract_env.env.environment import ContractEnv from contract_env.env.graders import TASK_GRADERS, NUM_GRADED_TASKS, contract_quality_score from contract_env.env.models import Action, Observation, Reward, StepRequest from contract_env.env.tasks import TASKS + +class EvaluateQualityRequest(BaseModel): + """Request body for the /evaluate-quality endpoint.""" + contract_text: str = Field(..., min_length=1, max_length=100_000) + + _env = ContractEnv() app = FastAPI( @@ -136,12 +141,10 @@ def get_schema(): } -_MAX_EVALUATE_TEXT_LEN = 100_000 - # ── EVALUATE QUALITY ───────────────────────────────────────────────────── @app.post("/evaluate-quality") -def evaluate_quality(body: dict): +def evaluate_quality(body: EvaluateQualityRequest): """Score an arbitrary contract text against the current task. Body: {"contract_text": "..."} @@ -150,15 +153,7 @@ def evaluate_quality(body: dict): """ if _env.current_task is None: raise HTTPException(status_code=400, detail="No active task. Call /reset first.") - contract_text = body.get("contract_text", "") - if not contract_text: - raise HTTPException(status_code=422, detail="contract_text must be non-empty.") - if len(contract_text) > _MAX_EVALUATE_TEXT_LEN: - raise HTTPException( - status_code=422, - detail=f"contract_text exceeds maximum length of {_MAX_EVALUATE_TEXT_LEN}.", - ) - quality = contract_quality_score(_env.current_task, contract_text) + quality = contract_quality_score(_env.current_task, body.contract_text) return {"quality_score": round(quality, 4), "risk_score": round(1.0 - quality, 4)} diff --git a/contract_env/tests/test_api.py b/contract_env/tests/test_api.py index 11a5e78..6286313 100644 --- a/contract_env/tests/test_api.py +++ b/contract_env/tests/test_api.py @@ -55,6 +55,50 @@ def test_tasks_endpoint(self) -> None: self.assertIn("clause_type", t) self.assertIn("has_grader", t) + def test_schema_endpoint(self) -> None: + r = self.client.get("/schema") + self.assertEqual(r.status_code, 200) + data = r.json() + self.assertIn("Action", data) + self.assertIn("Observation", data) + self.assertIn("Reward", data) + + def test_root_endpoint(self) -> None: + r = self.client.get("/") + self.assertEqual(r.status_code, 200) + self.assertEqual(r.json().get("status"), "ok") + + def test_evaluate_quality_missing_field(self) -> None: + """POST /evaluate-quality with empty body should return 422.""" + self.client.post("/reset") + r = self.client.post("/evaluate-quality", json={}) + self.assertEqual(r.status_code, 422) + + def test_evaluate_quality_empty_text(self) -> None: + """POST /evaluate-quality with empty string should return 422.""" + self.client.post("/reset") + r = self.client.post("/evaluate-quality", json={"contract_text": ""}) + self.assertEqual(r.status_code, 422) + + def test_evaluate_quality_success(self) -> None: + """POST /evaluate-quality with valid text should return scores.""" + self.client.post("/reset") + r = self.client.post( + "/evaluate-quality", + json={"contract_text": "Liability capped at fees paid in preceding twelve months."}, + ) + self.assertEqual(r.status_code, 200) + data = r.json() + self.assertIn("quality_score", data) + self.assertIn("risk_score", data) + self.assertGreater(data["quality_score"], 0.0) + + def test_step_invalid_action_type_rejected(self) -> None: + """POST /step with an invalid action_type should return 422.""" + self.client.post("/reset") + r = self.client.post("/step", json={"action_type": "INVALID_ACTION"}) + self.assertEqual(r.status_code, 422) + if __name__ == "__main__": unittest.main() diff --git a/inference.py b/inference.py index abf5688..edc43de 100644 --- a/inference.py +++ b/inference.py @@ -569,8 +569,12 @@ def run_episode(env, task_id: Optional[str] = None) -> tuple[float, str]: Returns (mean_episode_score, task_id). """ - if task_id is not None and hasattr(env, "reset") and "task_id" in env.reset.__code__.co_varnames: - obs_obj = env.reset(task_id=task_id) + if task_id is not None: + try: + obs_obj = env.reset(task_id=task_id) + except TypeError: + # env.reset() doesn't accept task_id (e.g., _HTTPEnvClient) + obs_obj = env.reset() else: obs_obj = env.reset() task = env.current_task From af47336b8ca59fdc928f201bbc6814aef7aeba69 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 10:13:46 +0000 Subject: [PATCH 12/13] Harden error handling, add docstrings, HTTP timeouts, validation, and 3 new tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Security: global exception handler no longer leaks internal error details Robustness: _HTTPEnvClient now uses 30s request timeouts Robustness: opponent_responses validator rejects empty response lists Code quality: added docstrings to 10 undocumented public functions in graders.py Code quality: cleaned up dev-note comments in models.py (✅/❌ markers) Code quality: removed unused StepResponse model from models.py Code quality: added proper __init__.py to tests/ directory Test coverage: added test_step_before_reset_raises (RuntimeError check) Test coverage: added test_step_after_reset_specific_task_keeps_task Test coverage: added test_opponent_responses_non_empty_validation README: updated test count 72 → 75 Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/08a54a1a-0cc0-4519-a9c8-096c1ff8f4b8 Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- README.md | 2 +- contract_env/env/graders.py | 17 +++++++++++++- contract_env/env/models.py | 32 ++++++++++++-------------- contract_env/env/tasks.py | 5 ++++ contract_env/server/app.py | 7 +++--- contract_env/tests/__init__.py | 0 contract_env/tests/test_smoke.py | 39 ++++++++++++++++++++++++++++++++ inference.py | 11 +++++---- 8 files changed, 87 insertions(+), 26 deletions(-) create mode 100644 contract_env/tests/__init__.py diff --git a/README.md b/README.md index f0ca664..ae89862 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ An episode is considered successful if `score ≥ 0.50`. ```bash pip install -e ".[dev]" -python -m pytest contract_env/tests/ -v # 72 tests +python -m pytest contract_env/tests/ -v # 75 tests ``` ### Run the server diff --git a/contract_env/env/graders.py b/contract_env/env/graders.py index d059a4f..ff846cc 100644 --- a/contract_env/env/graders.py +++ b/contract_env/env/graders.py @@ -78,6 +78,7 @@ def _is_negated(text_lower: str, keyword_lower: str) -> bool: def _weighted_risk_hits(text: str, risk_keywords: list[str]) -> float: + """Return the fraction of *risk_keywords* found (un-negated) in *text* ∈ [0, 1].""" low = text.lower() if not risk_keywords: return 0.0 @@ -92,10 +93,12 @@ def _weighted_risk_hits(text: str, risk_keywords: list[str]) -> float: def keyword_match_score(text: str, risk_keywords: list[str]) -> float: + """Public alias for :func:`_weighted_risk_hits` — negation-aware risk score.""" return _weighted_risk_hits(text, risk_keywords) def _safe_overlap(text: str, safe_keywords: list[str], expected_safe: str) -> float: + """Score how well *text* overlaps with safe keywords and the expected safe edit ∈ [0, 1].""" if not text.strip(): return 0.0 @@ -130,11 +133,18 @@ def semantic_similarity(text: str, reference: str) -> float: def trap_unresolved(task: NegotiationTask, contract_text: str) -> bool: + """Return True if any of the task's trap markers still appear in *contract_text*.""" low = contract_text.lower() return any(m in low for m in task.trap_markers) def effective_risk_high(task: NegotiationTask, contract_text: str) -> bool: + """Return True if the contract is still effectively high-risk for grading purposes. + + A contract is "effectively high" when: + - Any trap marker remains unresolved, **or** + - The weighted risk-keyword density exceeds the task's risk-level threshold. + """ # Any task that defines explicit trap markers (HARD, HARD_PLUS, HARD_PLUS2, # EXPERT, MEDIUM_PLUS) is still "effectively high risk" as long as any trap # marker remains in the text. @@ -152,6 +162,7 @@ def effective_risk_high(task: NegotiationTask, contract_text: str) -> bool: def action_risk_alignment(action_type: str, effective_high: bool, task: NegotiationTask) -> float: + """Score how well the chosen action type matches the current risk level ∈ [0, 1].""" clause_boost = task.clause_type_weight if action_type == "ACCEPT": @@ -266,13 +277,15 @@ def evaluate_action( return reward, info -def score_action_hypothetical(task, state_data, action) -> float: +def score_action_hypothetical(task: NegotiationTask, state_data: dict, action: Action) -> float: + """Score an action without stepping the environment (read-only / dry-run).""" contract_before = state_data.get("contract_text", "") proposed = build_proposed_contract_for_step(contract_before, action) return evaluate_action(task, contract_before, action, proposed)[0].score def build_proposed_contract_for_step(contract_before: str, action: Action) -> str: + """Build the proposed contract text that would result from applying *action*.""" content = (action.content or "").strip() if action.action_type == "EDIT_CLAUSE" and content: @@ -285,6 +298,7 @@ def build_proposed_contract_for_step(contract_before: str, action: Action) -> st def observation_risk_float(task: NegotiationTask, contract_text: str) -> float: + """Compute the risk-level float for the observation, clamped to (0, 1).""" base = _weighted_risk_hits(contract_text, task.risk_keywords) # Boost risk observation when any task's trap markers remain unresolved @@ -297,6 +311,7 @@ def observation_risk_float(task: NegotiationTask, contract_text: str) -> float: def contract_quality_score(task: NegotiationTask, contract_text: str) -> float: + """Return a quality score for *contract_text* ∈ (0, 1), where 1 = fully safe.""" return 1.0 - observation_risk_float(task, contract_text) diff --git a/contract_env/env/models.py b/contract_env/env/models.py index 66318dd..e9aad2d 100644 --- a/contract_env/env/models.py +++ b/contract_env/env/models.py @@ -14,46 +14,44 @@ class Action(BaseModel): + """Agent action: one of the five negotiation moves with optional clause content.""" + action_type: ActionType content: Optional[str] = None -# ✅ FIX: simplified reward with strict bounds (0, 1) class Reward(BaseModel): + """Step reward with a score strictly between 0 and 1 (exclusive).""" + score: float = Field(gt=0.0, lt=1.0) - - @field_validator('score') + + @field_validator("score") @classmethod def validate_score(cls, v: float) -> float: if not (0.0 < v < 1.0): - raise ValueError(f'score must be strictly between 0 and 1, got {v}') + raise ValueError(f"score must be strictly between 0 and 1, got {v}") return v class Observation(BaseModel): + """Environment observation returned after each reset/step.""" + contract_text: str clause_type: str risk_level: float = Field(gt=0.0, lt=1.0) step_count: int negotiation_history: list[str] - - @field_validator('risk_level') + + @field_validator("risk_level") @classmethod def validate_risk_level(cls, v: float) -> float: if not (0.0 < v < 1.0): - raise ValueError(f'risk_level must be strictly between 0 and 1, got {v}') + raise ValueError(f"risk_level must be strictly between 0 and 1, got {v}") return v class StepRequest(BaseModel): - action_type: ActionType - content: Optional[str] = None + """Request body for the ``/step`` endpoint.""" - -# ❌ REMOVE Reward object nesting -# ✅ Use dict instead -class StepResponse(BaseModel): - observation: Observation - reward: dict[str, float] # FIXED - done: bool - info: dict[str, Any] \ No newline at end of file + action_type: ActionType + content: Optional[str] = None \ No newline at end of file diff --git a/contract_env/env/tasks.py b/contract_env/env/tasks.py index d853050..0b54bd7 100644 --- a/contract_env/env/tasks.py +++ b/contract_env/env/tasks.py @@ -73,6 +73,11 @@ def _validate_opponent_response_keys( raise ValueError( f"opponent_responses contains invalid action types: {bad}" ) + for action_type, responses in v.items(): + if not responses: + raise ValueError( + f"opponent_responses[{action_type!r}] must not be empty" + ) return v def get_grader(self) -> Callable[["NegotiationTask", str, "Action", str], "Reward"]: diff --git a/contract_env/server/app.py b/contract_env/server/app.py index aa88c3e..495066b 100644 --- a/contract_env/server/app.py +++ b/contract_env/server/app.py @@ -61,11 +61,12 @@ async def validation_handler(request: Request, exc: RequestValidationError): @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): + import logging + + logging.getLogger(__name__).exception("Unhandled error on %s %s", request.method, request.url.path) return JSONResponse( status_code=500, - content={ - "detail": str(exc), - }, + content={"detail": "Internal server error"}, ) diff --git a/contract_env/tests/__init__.py b/contract_env/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/contract_env/tests/test_smoke.py b/contract_env/tests/test_smoke.py index 7ac6464..582195d 100644 --- a/contract_env/tests/test_smoke.py +++ b/contract_env/tests/test_smoke.py @@ -238,6 +238,45 @@ def test_edit_then_accept_full_flow(self) -> None: self.assertFalse(info2.get("accept_blocked", False)) self.assertGreater(r2, 0.01) + def test_step_before_reset_raises(self) -> None: + """Calling step() without reset() should raise RuntimeError.""" + env = ContractEnv() + with self.assertRaises(RuntimeError): + env.step(Action(action_type="FLAG_RISK")) + + def test_step_after_reset_specific_task_keeps_task(self) -> None: + """After reset(task_id=...) the task should remain correct through steps.""" + env = ContractEnv() + env.reset(task_id="hard_intellectual_property") + self.assertEqual(env.current_task.id, "hard_intellectual_property") + obs, _, _, _ = env.step(Action(action_type="FLAG_RISK")) + self.assertEqual(env.current_task.id, "hard_intellectual_property") + obs, _, _, _ = env.step( + Action(action_type="EDIT_CLAUSE", content=env.current_task.expected_safe_edit) + ) + self.assertEqual(env.current_task.id, "hard_intellectual_property") + + def test_opponent_responses_non_empty_validation(self) -> None: + """Empty opponent response lists should be rejected by validation.""" + from contract_env.env.tasks import NegotiationTask + from contract_env.env.graders import grade_easy + from pydantic import ValidationError + with self.assertRaises(ValidationError): + NegotiationTask( + id="test", + name="TEST", + contract_text="test", + clause_type="liability", + risk_keywords=["test"], + safe_keywords=["test"], + expected_safe_edit="test", + risk_level="HIGH", + hidden_trap="", + opponent_responses={"FLAG_RISK": []}, + grader_func=grade_easy, + grader_name="grade_easy", + ) + if __name__ == "__main__": unittest.main() diff --git a/inference.py b/inference.py index edc43de..a88b986 100644 --- a/inference.py +++ b/inference.py @@ -499,10 +499,13 @@ def _choose( class _HTTPEnvClient: """Thin HTTP wrapper with the same interface as ContractEnv for inference.""" - def __init__(self, base_url: str) -> None: + _DEFAULT_TIMEOUT: float = 30.0 # seconds per HTTP request + + def __init__(self, base_url: str, timeout: float = _DEFAULT_TIMEOUT) -> None: import requests self.base_url = base_url.rstrip("/") self._session = requests.Session() + self._timeout = timeout self._task_idx = 0 self.current_task: Optional[NegotiationTask] = None @@ -516,14 +519,14 @@ def __exit__(self, *exc: Any) -> None: self.close() def reset(self): - resp = self._session.post(f"{self.base_url}/reset") + resp = self._session.post(f"{self.base_url}/reset", timeout=self._timeout) resp.raise_for_status() data = resp.json() obs = data["observation"] # Map to a NegotiationTask if possible (for _choose() to use) task_id = None try: - state = self._session.get(f"{self.base_url}/state").json() + state = self._session.get(f"{self.base_url}/state", timeout=self._timeout).json() task_id = state.get("task_id") except Exception: pass @@ -538,7 +541,7 @@ def step(self, action: Action): payload = {"action_type": action.action_type} if action.content: payload["content"] = action.content - resp = self._session.post(f"{self.base_url}/step", json=payload) + resp = self._session.post(f"{self.base_url}/step", json=payload, timeout=self._timeout) resp.raise_for_status() data = resp.json() obs = _DictObservation(data["observation"]) From 63a4fdafadab61045727a127f73f2786b95d398e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 11 Apr 2026 10:23:23 +0000 Subject: [PATCH 13/13] Final polish: non-root Dockerfile, /reset task_id support, logging fix, 3 new tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Dockerfile: add non-root user for container security - app.py: move logging import to module level, add /reset task_id body param - graders.py: clean stale comments, add grade_action docstring - tests: add 3 new API tests (reset with task_id, invalid task_id, evaluate before reset) - README: update test count 75 → 78 Agent-Logs-Url: https://github.com/bigturtle679/Contract-Negotiation-Environment/sessions/e910db62-1682-4ee6-9dc6-e8639ab081ea Co-authored-by: AbeerChaturvedi <171315954+AbeerChaturvedi@users.noreply.github.com> --- Dockerfile | 5 +++++ README.md | 2 +- contract_env/env/graders.py | 5 +++-- contract_env/server/app.py | 25 ++++++++++++++++++++----- contract_env/tests/test_api.py | 28 ++++++++++++++++++++++++++++ 5 files changed, 57 insertions(+), 8 deletions(-) diff --git a/Dockerfile b/Dockerfile index beccc05..6b2cce1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,6 +12,11 @@ RUN pip install --no-cache-dir --upgrade pip \ COPY . . +# Run as non-root for security +RUN useradd --create-home --shell /bin/bash appuser \ + && chown -R appuser:appuser /app +USER appuser + EXPOSE 7860 HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 \ diff --git a/README.md b/README.md index ae89862..7d27ade 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ An episode is considered successful if `score ≥ 0.50`. ```bash pip install -e ".[dev]" -python -m pytest contract_env/tests/ -v # 75 tests +python -m pytest contract_env/tests/ -v # 78 tests ``` ### Run the server diff --git a/contract_env/env/graders.py b/contract_env/env/graders.py index ff846cc..2e3728e 100644 --- a/contract_env/env/graders.py +++ b/contract_env/env/graders.py @@ -253,7 +253,7 @@ def evaluate_action( + 0.05 * completeness ) - # ✅ STRICT RANGE FIX: strictly between 0 and 1, clamped to [0.001, 0.999] + # Clamp strictly between 0 and 1 → [0.001, 0.999] score = max(0.001, min(0.999, score)) reward = Reward(score=round(score, 4)) @@ -305,7 +305,7 @@ def observation_risk_float(task: NegotiationTask, contract_text: str) -> float: if task.trap_markers and trap_unresolved(task, contract_text): base = min(1.0, base + 0.25) - # ✅ STRICT RANGE FIX: strictly between 0 and 1, clamped to [0.001, 0.999] + # Clamp strictly between 0 and 1 → [0.001, 0.999] base = min(0.999, max(0.001, base)) return round(base, 4) @@ -321,6 +321,7 @@ def grade_action( action: Action, proposed_contract_text: str, ) -> Reward: + """Convenience wrapper: grade an action and return only the Reward (drop info).""" reward, _ = evaluate_action(task, contract_before, action, proposed_contract_text) return reward diff --git a/contract_env/server/app.py b/contract_env/server/app.py index 495066b..97e77fa 100644 --- a/contract_env/server/app.py +++ b/contract_env/server/app.py @@ -1,6 +1,8 @@ from __future__ import annotations +import logging import os +from typing import Optional from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError @@ -13,6 +15,13 @@ from contract_env.env.models import Action, Observation, Reward, StepRequest from contract_env.env.tasks import TASKS +logger = logging.getLogger(__name__) + + +class ResetRequest(BaseModel): + """Optional request body for the /reset endpoint.""" + task_id: Optional[str] = Field(default=None, description="Force a specific task by ID.") + class EvaluateQualityRequest(BaseModel): """Request body for the /evaluate-quality endpoint.""" @@ -61,9 +70,7 @@ async def validation_handler(request: Request, exc: RequestValidationError): @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): - import logging - - logging.getLogger(__name__).exception("Unhandled error on %s %s", request.method, request.url.path) + logger.exception("Unhandled error on %s %s", request.method, request.url.path) return JSONResponse( status_code=500, content={"detail": "Internal server error"}, @@ -105,10 +112,18 @@ def get_state(): # ── RESET ─────────────────────────────────────────────────────────────── @app.post("/reset") -def reset(): +def reset(body: Optional[ResetRequest] = None): + """Start a new episode. + + Optionally pass ``{"task_id": "..."}`` to target a specific task; + otherwise the environment cycles through tasks sequentially. + """ try: - obs = _env.reset() + task_id = body.task_id if body else None + obs = _env.reset(task_id=task_id) return {"observation": obs.model_dump()} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/contract_env/tests/test_api.py b/contract_env/tests/test_api.py index 6286313..5890ac3 100644 --- a/contract_env/tests/test_api.py +++ b/contract_env/tests/test_api.py @@ -99,6 +99,34 @@ def test_step_invalid_action_type_rejected(self) -> None: r = self.client.post("/step", json={"action_type": "INVALID_ACTION"}) self.assertEqual(r.status_code, 422) + def test_reset_with_task_id(self) -> None: + """POST /reset with a task_id body should target that specific task.""" + r = self.client.post("/reset", json={"task_id": "expert_data_protection"}) + self.assertEqual(r.status_code, 200) + obs = r.json()["observation"] + self.assertEqual(obs["clause_type"], "data_protection") + + def test_reset_with_invalid_task_id(self) -> None: + """POST /reset with an unknown task_id should return 400.""" + r = self.client.post("/reset", json={"task_id": "nonexistent_task"}) + self.assertEqual(r.status_code, 400) + + def test_evaluate_quality_before_reset(self) -> None: + """POST /evaluate-quality before any /reset should return 400.""" + # Use a fresh app instance with a fresh env that hasn't been reset + from contract_env.server.app import _env + # Save and restore state to simulate a fresh start + old_task = _env.current_task + _env.current_task = None + try: + r = self.client.post( + "/evaluate-quality", + json={"contract_text": "Some clause text."}, + ) + self.assertEqual(r.status_code, 400) + finally: + _env.current_task = old_task + if __name__ == "__main__": unittest.main()