Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.git
.github
.gitignore
.venv
venv
__pycache__
*.py[cod]
*.egg-info
.pytest_cache
.mypy_cache
.cursor
uv.lock
hf_create.py
verify_graders.py
contract_env/tests
contract_env/scripts
contract_env/.dockerignore
contract_env/docker-compose.yml
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ USER appuser

EXPOSE 7860

HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 \
HEALTHCHECK --interval=30s --timeout=15s --start-period=10s --retries=3 \
CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:7860/health', timeout=3)"

CMD ["uvicorn", "contract_env.server.app:app", "--host", "0.0.0.0", "--port", "7860"]
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,10 @@ python inference.py --benchmark --mode api

| Variable | Required | Default | Description |
|----------|----------|---------|-------------|
| `HF_TOKEN` | Yes | — | HuggingFace / LLM API key |
| `HF_TOKEN` | Yes | — | HuggingFace / LLM API key (falls back to `API_KEY` if unset) |
| `API_BASE_URL` | No | `https://router.huggingface.co/v1` | LLM API endpoint |
| `MODEL_NAME` | No | `Qwen/Qwen2.5-72B-Instruct` | Model identifier |
| `LOCAL_IMAGE_NAME` | No | `contract-negotiation-env` | Docker image name for `from_docker_image()` client usage |
| `BENCHMARK` | No | `contract_negotiation` | Benchmark name in [START] log line |
| `ENV_SERVER_URL` | No | `http://localhost:7860` | Docker server URL (for `--mode api`) |
| `PORT` | No | `7860` | Server port |
Expand Down
2 changes: 1 addition & 1 deletion contract_env/env/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def step(self, action: Action) -> Tuple[Observation, float, bool, dict[str, Any]
info: dict[str, Any] = {}

if self.done:
return self._make_observation(), 0.0, True, {"error": "already_done"}
return self._make_observation(), 0.001, True, {"error": "already_done"}

err = self._validate_action(action)
if err:
Expand Down
6 changes: 3 additions & 3 deletions contract_env/env/graders.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def _is_negated(text_lower: str, keyword_lower: str) -> bool:
if pos == -1:
break
found_any = True
# Check the 60-character window before the match for negation cues
window_start = max(0, pos - 60)
# Check the 150-character window before the match for negation cues
window_start = max(0, pos - 150)
preceding = text_lower[window_start:pos]
if not any(neg in preceding for neg in _NEGATION_PREFIXES):
all_negated = False
Expand Down Expand Up @@ -239,7 +239,7 @@ def evaluate_action(
# Completeness bonus: reward rewrites that include required legal elements
# defined on the task (if any).
completeness = 0.0
required_elems: list[str] = getattr(task, "required_elements", [])
required_elems = task.required_elements
if action.action_type in ("EDIT_CLAUSE", "PROPOSE_COUNTER") and content and required_elems:
completeness = clause_completeness_score(content, required_elems)

Expand Down
18 changes: 11 additions & 7 deletions contract_env/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import os
from typing import Optional
from typing import Any, Optional

from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
Expand Down Expand Up @@ -41,7 +41,7 @@ class EvaluateQualityRequest(BaseModel):

app.add_middleware(
CORSMiddleware,
allow_origins=os.getenv("CORS_ORIGINS", "*").split(","),
allow_origins=[o.strip() for o in os.getenv("CORS_ORIGINS", "*").split(",")],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_state():

# ── RESET ───────────────────────────────────────────────────────────────
@app.post("/reset")
def reset(body: Optional[ResetRequest] = None):
def reset(body: Optional[ResetRequest] = None) -> dict[str, Any]:
"""Start a new episode.

Optionally pass ``{"task_id": "..."}`` to target a specific task;
Expand All @@ -124,13 +124,14 @@ def reset(body: Optional[ResetRequest] = None):
return {"observation": obs.model_dump()}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
except Exception:
logger.exception("Unexpected error during /reset")
raise HTTPException(status_code=500, detail="Internal server error")


# ── STEP ────────────────────────────────────────────────────────────────
@app.post("/step")
def step(req: StepRequest):
def step(req: StepRequest) -> dict[str, Any]:
try:
action = Action(action_type=req.action_type, content=req.content)
obs, reward, done, info = _env.step(action)
Expand All @@ -144,6 +145,9 @@ def step(req: StepRequest):

except ValidationError as e:
raise HTTPException(status_code=422, detail=e.errors())
except Exception:
logger.exception("Unexpected error during /step")
raise HTTPException(status_code=500, detail="Internal server error")


# ── SCHEMA ──────────────────────────────────────────────────────────────
Expand All @@ -160,7 +164,7 @@ def get_schema():

# ── EVALUATE QUALITY ─────────────────────────────────────────────────────
@app.post("/evaluate-quality")
def evaluate_quality(body: EvaluateQualityRequest):
def evaluate_quality(body: EvaluateQualityRequest) -> dict[str, float]:
"""Score an arbitrary contract text against the current task.

Body: {"contract_text": "..."}
Expand Down
10 changes: 10 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
services:
contract-env:
build: .
image: contract-negotiation-env:latest
ports:
- "7860:7860"
environment:
HOST: "0.0.0.0"
PORT: "7860"
HF_TOKEN: "${HF_TOKEN:-}"
69 changes: 43 additions & 26 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""
Inference Script — Contract Negotiation Environment
=====================================================
LLM-driven agent that analyses contract clauses, identifies legal risks,
and proposes safer alternatives through multi-turn negotiation.

MANDATORY environment variables:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
Inference Script Example
===================================
MANDATORY
- Before submitting, ensure the following variables are defined in your
environment configuration:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
LOCAL_IMAGE_NAME The name of the local Docker image to use for the
environment if you are using from_docker_image() method.

- Defaults are set only for API_BASE_URL and MODEL_NAME
(and should reflect your active inference setup):
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")

STDOUT FORMAT (strictly followed):
[START] task=<task_id> env=<benchmark> model=<model_name>
Expand All @@ -26,7 +32,7 @@
import os
import random
import re
from typing import Any, Optional
from typing import Any, Optional, get_args

try:
from dotenv import load_dotenv
Expand All @@ -40,9 +46,10 @@ def load_dotenv(*_a: Any, **_kw: Any) -> None: # type: ignore[misc]
from contract_env.env.graders import (
effective_risk_high,
keyword_match_score,
observation_risk_float,
trap_unresolved,
)
from contract_env.env.models import Action
from contract_env.env.models import Action, ActionType
from contract_env.env.tasks import TASKS, NegotiationTask

log = logging.getLogger(__name__)
Expand All @@ -53,6 +60,7 @@ def load_dotenv(*_a: Any, **_kw: Any) -> None: # type: ignore[misc]
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
BENCHMARK = os.getenv("BENCHMARK", "contract_negotiation")
ENV_SERVER_URL = os.getenv("ENV_SERVER_URL", "http://localhost:7860")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "contract-negotiation-env")
MAX_STEPS = 10
SUCCESS_SCORE_THRESHOLD = 0.5
HISTORY_WINDOW = 8 # How many recent history entries to show the LLM
Expand Down Expand Up @@ -139,7 +147,7 @@ def _llm_chat(
if attempt == _MAX_RETRIES:
raise
log.warning("[DEBUG] LLM call attempt %d failed: %s", attempt + 1, exc)
return ""
raise RuntimeError("LLM call failed after all retries") # unreachable; satisfies type-checker


def _parse_llm_json(text: str) -> Optional[dict]:
Expand Down Expand Up @@ -273,7 +281,7 @@ def _build_rewrite_prompt(
]


_VALID_ACTIONS = {"FLAG_RISK", "EDIT_CLAUSE", "ACCEPT", "REJECT", "PROPOSE_COUNTER"}
_VALID_ACTIONS = set(get_args(ActionType))

# Optimal action sequences per intent level for rule-based fallback.
# MODERATE tasks front-load PROPOSE_COUNTER since it's the ideal action for
Expand Down Expand Up @@ -455,7 +463,6 @@ def _choose(

# ── 4d. Smart ACCEPT gate: only accept when quality actually improved ─
if action_type == "ACCEPT":
from contract_env.env.graders import observation_risk_float
current_risk = observation_risk_float(task, contract_text)
original_risk = observation_risk_float(task, task.contract_text)
# Block acceptance if the contract hasn't improved meaningfully
Expand Down Expand Up @@ -518,20 +525,25 @@ def __enter__(self):
def __exit__(self, *exc: Any) -> None:
self.close()

def reset(self):
resp = self._session.post(f"{self.base_url}/reset", timeout=self._timeout)
def reset(self, task_id: Optional[str] = None):
body: dict[str, Any] = {}
if task_id is not None:
body["task_id"] = task_id
resp = self._session.post(
f"{self.base_url}/reset", json=body or None, timeout=self._timeout,
)
resp.raise_for_status()
data = resp.json()
obs = data["observation"]
# Map to a NegotiationTask if possible (for _choose() to use)
task_id = None
resolved_task_id = None
try:
state = self._session.get(f"{self.base_url}/state", timeout=self._timeout).json()
task_id = state.get("task_id")
resolved_task_id = state.get("task_id")
except Exception:
pass
if task_id:
self.current_task = next((t for t in TASKS if t.id == task_id), None)
if resolved_task_id:
self.current_task = next((t for t in TASKS if t.id == resolved_task_id), None)
if self.current_task is None:
self.current_task = TASKS[self._task_idx % len(TASKS)]
self._task_idx += 1
Expand Down Expand Up @@ -568,16 +580,12 @@ def run_episode(env, task_id: Optional[str] = None) -> tuple[float, str]:

Args:
env: Environment instance (ContractEnv or _HTTPEnvClient).
task_id: If given, reset to this specific task (local mode only).
task_id: If given, reset to this specific task.

Returns (mean_episode_score, task_id).
"""
if task_id is not None:
try:
obs_obj = env.reset(task_id=task_id)
except TypeError:
# env.reset() doesn't accept task_id (e.g., _HTTPEnvClient)
obs_obj = env.reset()
obs_obj = env.reset(task_id=task_id)
else:
obs_obj = env.reset()
task = env.current_task
Expand Down Expand Up @@ -690,6 +698,15 @@ def main() -> None:
env = ContractEnv()
print("[CONFIG] mode=local", flush=True)

try:
_run_episodes(env, args)
finally:
if hasattr(env, "close"):
env.close()


def _run_episodes(env, args) -> None:
"""Execute episode loop, retries, and print summary."""
episodes_to_run = len(TASKS) if args.benchmark else args.episodes

total_score = 0.0
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ dependencies = [
"fastapi>=0.115.0",
"uvicorn[standard]>=0.32.0",
"openai>=1.50.0",
"openenv>=0.1.13"
"openenv>=0.1.13",
"requests>=2.31.0"
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ fastapi>=0.115.0
uvicorn[standard]>=0.32.0
openai>=1.50.0
openenv>=0.1.13
requests>=2.31.0
Empty file added server/__init__.py
Empty file.
7 changes: 7 additions & 0 deletions server/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Re-export the FastAPI app for multi-mode deployment compatibility.

The openenv validator expects ``server/app.py`` at the repository root.
The canonical implementation lives in ``contract_env.server.app``.
"""

from contract_env.server.app import app # noqa: F401
Loading