Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions examples/playwright_page_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Example: use a Playwright Page with the Stagehand Python SDK.

What this demonstrates:
- Start a Stagehand session (remote Stagehand API / Browserbase browser)
- Attach Playwright to the same browser via CDP (`cdp_url`)
- Pass the Playwright `page` into `session.observe/act/extract` so Stagehand
auto-detects the correct `frame_id` for that page.

Environment variables required:
- MODEL_API_KEY
- BROWSERBASE_API_KEY
- BROWSERBASE_PROJECT_ID

Optional:
- STAGEHAND_BASE_URL (defaults to https://api.stagehand.browserbase.com)
"""

from __future__ import annotations

import os
import sys
from typing import Optional

from stagehand import Stagehand


def main() -> None:
model_api_key = os.environ.get("MODEL_API_KEY")
if not model_api_key:
sys.exit("Set the MODEL_API_KEY environment variable to run this example.")

bb_api_key = os.environ.get("BROWSERBASE_API_KEY")
bb_project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
if not bb_api_key or not bb_project_id:
sys.exit(
"Set BROWSERBASE_API_KEY and BROWSERBASE_PROJECT_ID to run this example."
)

try:
from playwright.sync_api import sync_playwright # type: ignore[import-not-found]
except Exception:
sys.exit(
"Playwright is not installed. Install it with:\n"
" uv pip install playwright\n"
"and ensure browsers are installed (e.g. `playwright install chromium`)."
)

session_id: Optional[str] = None

with Stagehand(
server="remote",
browserbase_api_key=bb_api_key,
browserbase_project_id=bb_project_id,
model_api_key=model_api_key,
) as client:
print("⏳ Starting Stagehand session...")
session = client.sessions.create(
model_name="openai/gpt-5-nano",
browser={"type": "browserbase"},
)
session_id = session.id

cdp_url = session.data.cdp_url
if not cdp_url:
sys.exit(
"No cdp_url returned from the API for this session; cannot attach Playwright."
)

print(f"✅ Session started: {session_id}")
print("🔌 Connecting Playwright to the same browser over CDP...")

with sync_playwright() as p:
# Attach to the same browser session Stagehand is controlling.
browser = p.chromium.connect_over_cdp(cdp_url)
try:
# Reuse an existing context/page if present; otherwise create one.
context = browser.contexts[0] if browser.contexts else browser.new_context()
page = context.pages[0] if context.pages else context.new_page()

page.goto("https://example.com", wait_until="domcontentloaded")

print("👀 Stagehand.observe(page=...) ...")
actions = session.observe(
instruction="Find the most relevant click target on this page",
page=page,
)
print(f"Observed {len(actions.data.result)} actions")

print("🧠 Stagehand.extract(page=...) ...")
extracted = session.extract(
instruction="Extract the page title and the primary heading (h1) text",
schema={
"type": "object",
"properties": {
"title": {"type": "string"},
"h1": {"type": "string"},
},
"required": ["title", "h1"],
"additionalProperties": False,
},
page=page,
)
print("Extracted:", extracted.data.result)

print("🖱️ Stagehand.act(page=...) ...")
_ = session.act(
input="Click the 'More information' link",
page=page,
)
print("Done.")
finally:
browser.close()


if __name__ == "__main__":
main()

134 changes: 122 additions & 12 deletions src/stagehand/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, Union, cast
from datetime import datetime
from typing_extensions import Unpack, Literal
import inspect
from typing_extensions import Unpack, Literal, Protocol

import httpx

Expand All @@ -23,11 +24,110 @@
from .types.session_extract_response import SessionExtractResponse
from .types.session_observe_response import SessionObserveResponse
from .types.session_navigate_response import SessionNavigateResponse
from ._exceptions import StagehandError

if TYPE_CHECKING:
from ._client import Stagehand, AsyncStagehand


class _PlaywrightCDPSession(Protocol):
def send(self, method: str, params: Any = ...) -> Any: # noqa: ANN401
...


class _PlaywrightContext(Protocol):
def new_cdp_session(self, page: Any) -> Any: # noqa: ANN401
...


def _extract_frame_id_from_playwright_page(page: Any) -> str:
context = getattr(page, "context", None)
if context is None:
raise StagehandError("page must be a Playwright Page with a .context attribute")

if callable(context):
context = context()

new_cdp_session = getattr(context, "new_cdp_session", None)
if not callable(new_cdp_session):
raise StagehandError(
"page must be a Playwright Page; expected page.context.new_cdp_session(...) to exist"
)

pw_context = cast(_PlaywrightContext, context)
cdp = pw_context.new_cdp_session(page)
if inspect.isawaitable(cdp):
raise StagehandError(
"Expected a synchronous Playwright Page, but received an async CDP session; use AsyncSession methods"
)

send = getattr(cdp, "send", None)
if not callable(send):
raise StagehandError("Playwright CDP session missing .send(...) method")

pw_cdp = cast(_PlaywrightCDPSession, cdp)
result = pw_cdp.send("Page.getFrameTree")
if inspect.isawaitable(result):
raise StagehandError(
"Expected a synchronous Playwright Page, but received an async CDP session; use AsyncSession methods"
)

try:
return result["frameTree"]["frame"]["id"]
except Exception as e: # noqa: BLE001
raise StagehandError("Failed to extract frame id from Playwright CDP Page.getFrameTree response") from e


async def _extract_frame_id_from_playwright_page_async(page: Any) -> str:
context = getattr(page, "context", None)
if context is None:
raise StagehandError("page must be a Playwright Page with a .context attribute")

if callable(context):
context = context()

new_cdp_session = getattr(context, "new_cdp_session", None)
if not callable(new_cdp_session):
raise StagehandError(
"page must be a Playwright Page; expected page.context.new_cdp_session(...) to exist"
)

pw_context = cast(_PlaywrightContext, context)
cdp = pw_context.new_cdp_session(page)
if inspect.isawaitable(cdp):
cdp = await cdp

send = getattr(cdp, "send", None)
if not callable(send):
raise StagehandError("Playwright CDP session missing .send(...) method")

pw_cdp = cast(_PlaywrightCDPSession, cdp)
result = pw_cdp.send("Page.getFrameTree")
if inspect.isawaitable(result):
result = await result

try:
return result["frameTree"]["frame"]["id"]
except Exception as e: # noqa: BLE001
raise StagehandError("Failed to extract frame id from Playwright CDP Page.getFrameTree response") from e


def _maybe_inject_frame_id(params: dict[str, Any], page: Any | None) -> dict[str, Any]:
if page is None:
return params
if "frame_id" in params:
return params
return {**params, "frame_id": _extract_frame_id_from_playwright_page(page)}


async def _maybe_inject_frame_id_async(params: dict[str, Any], page: Any | None) -> dict[str, Any]:
if page is None:
return params
if "frame_id" in params:
return params
return {**params, "frame_id": await _extract_frame_id_from_playwright_page_async(page)}


class Session(SessionStartResponse):
"""A Stagehand session bound to a specific `session_id`."""

Expand All @@ -41,6 +141,7 @@ def __init__(self, client: Stagehand, id: str, data: SessionStartResponseData, s
def navigate(
self,
*,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
Expand All @@ -53,12 +154,13 @@ def navigate(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**params,
**_maybe_inject_frame_id(dict(params), page),
)

def act(
self,
*,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
Expand All @@ -71,12 +173,13 @@ def act(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**params,
**_maybe_inject_frame_id(dict(params), page),
)

def observe(
self,
*,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
Expand All @@ -89,12 +192,13 @@ def observe(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**params,
**_maybe_inject_frame_id(dict(params), page),
)

def extract(
self,
*,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
Expand All @@ -107,12 +211,13 @@ def extract(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**params,
**_maybe_inject_frame_id(dict(params), page),
)

def execute(
self,
*,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
Expand All @@ -125,7 +230,7 @@ def execute(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**params,
**_maybe_inject_frame_id(dict(params), page),
)

def end(
Expand Down Expand Up @@ -161,6 +266,7 @@ def __init__(self, client: AsyncStagehand, id: str, data: SessionStartResponseDa
async def navigate(
self,
*,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
Expand All @@ -173,12 +279,13 @@ async def navigate(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**params,
**(await _maybe_inject_frame_id_async(dict(params), page)),
)

async def act(
self,
*,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
Expand All @@ -191,12 +298,13 @@ async def act(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**params,
**(await _maybe_inject_frame_id_async(dict(params), page)),
)

async def observe(
self,
*,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
Expand All @@ -209,12 +317,13 @@ async def observe(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**params,
**(await _maybe_inject_frame_id_async(dict(params), page)),
)

async def extract(
self,
*,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
Expand All @@ -227,12 +336,13 @@ async def extract(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**params,
**(await _maybe_inject_frame_id_async(dict(params), page)),
)

async def execute(
self,
*,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
Expand All @@ -245,7 +355,7 @@ async def execute(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**params,
**(await _maybe_inject_frame_id_async(dict(params), page)),
)

async def end(
Expand Down
Loading
Loading