Skip to content
Open
285 changes: 257 additions & 28 deletions src/rotator_library/providers/copilot_auth_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
import time
import asyncio
import logging
import re
from pathlib import Path
from ..utils.paths import get_oauth_dir
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import will fail - paths.py doesn't exist in base branch

The file src/rotator_library/utils/paths.py does not exist in the base branch feature/copilot-provider. This import will fail with ModuleNotFoundError at runtime.

The paths.py module exists in main branch (added in commit 467f294), but the target branch doesn't have it yet. Either:

  1. Merge/rebase with main to get paths.py first
  2. Revert to using Path.cwd() / "oauth_creds" directly (like other files in this branch do)
Suggested change
from ..utils.paths import get_oauth_dir
# from ..utils.paths import get_oauth_dir # TODO: add after merging with main
Prompt To Fix With AI
This is a comment left during a code review.
Path: src/rotator_library/providers/copilot_auth_base.py
Line: 21

Comment:
Import will fail - `paths.py` doesn't exist in base branch

The file `src/rotator_library/utils/paths.py` does not exist in the base branch `feature/copilot-provider`. This import will fail with `ModuleNotFoundError` at runtime.

The `paths.py` module exists in `main` branch (added in commit 467f294), but the target branch doesn't have it yet. Either:
1. Merge/rebase with `main` to get `paths.py` first
2. Revert to using `Path.cwd() / "oauth_creds"` directly (like other files in this branch do)

```suggestion
# from ..utils.paths import get_oauth_dir  # TODO: add after merging with main
```

How can I resolve this? If you propose a fix, please make it concise.

from typing import Dict, Any, Optional, Union
import tempfile
import shutil
from dataclasses import dataclass, field
from glob import glob

import httpx
from rich.console import Console
Expand All @@ -32,23 +36,26 @@
console = Console()


@dataclass
class CopilotCredentialSetupResult:
"""Standardized result for Copilot credential setup operations."""

success: bool
file_path: Optional[str] = None
email: Optional[str] = None
tier: Optional[str] = None
project_id: Optional[str] = None
is_update: bool = False
error: Optional[str] = None
credentials: Optional[Dict[str, Any]] = field(default=None, repr=False)


class CopilotAuthBase:
"""
GitHub Copilot OAuth2 authentication using Device Flow.

This provider uses GitHub's Device Authorization Grant flow, which is
more suitable for CLI applications than the web-based Authorization Code flow.

Key differences from GoogleOAuthBase:
- Uses GitHub Device Flow (polls for authorization)
- Two-token system: GitHub OAuth token + Copilot API token
- Copilot API tokens expire quickly (~30 min) and need frequent refresh

Subclasses may override:
- ENV_PREFIX: Prefix for environment variables (default: "COPILOT")
- REFRESH_EXPIRY_BUFFER_SECONDS: Time buffer before token expiry

Supports both github.com and GitHub Enterprise deployments.
"""

# GitHub Copilot OAuth Client ID (from VS Code Copilot extension)
Expand Down Expand Up @@ -273,17 +280,20 @@ def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
return True

async def _refresh_copilot_token(
self, path: str, creds: Dict[str, Any], force: bool = False
self, path: Optional[str], creds: Dict[str, Any], force: bool = False
) -> Dict[str, Any]:
"""
Refresh the Copilot API token using the GitHub OAuth token.

The GitHub OAuth token (refresh_token) is long-lived.
The Copilot API token (access_token) expires after ~30 minutes.
"""
async with await self._get_lock(path):
lock_key = path or f"in-memory://copilot/{id(creds)}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using id(creds) for an in-memory lock key is clever, but be aware that it might lead to redundant locks if the creds dictionary is copied or recreated (e.g., during serialization/deserialization cycles). If the credentials have a unique identifier like an email, that might be a more stable key.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point — id() can be recycled if the dict gets garbage collected and recreated. In practice the in-memory path only triggers for env-based credentials which are long-lived singletons, but using the email from _proxy_metadata would be more robust. Open to changing if preferred.

display_name = Path(path).name if path else "in-memory credential"

async with await self._get_lock(lock_key):
# Skip if token is still valid (unless forced)
cached_creds = self._credentials_cache.get(path, creds)
cached_creds = self._credentials_cache.get(lock_key, creds)
if not force and not self._is_token_expired(cached_creds):
return cached_creds

Expand All @@ -302,7 +312,7 @@ async def _refresh_copilot_token(
urls = self._get_urls(domain)

lib_logger.debug(
f"Refreshing {self.ENV_PREFIX} Copilot API token for '{Path(path).name}' (forced: {force})..."
f"Refreshing {self.ENV_PREFIX} Copilot API token for '{display_name}' (forced: {force})..."
)

async with httpx.AsyncClient() as client:
Expand All @@ -319,10 +329,14 @@ async def _refresh_copilot_token(

if response.status_code == 401:
lib_logger.warning(
f"GitHub token invalid for '{Path(path).name}' (HTTP 401). "
f"GitHub token invalid for '{display_name}' (HTTP 401). "
f"Token may have been revoked. Starting re-authentication..."
)
return await self.initialize_token(path)
if path:
return await self.initialize_token(path)
raise ValueError(
"GitHub token invalid for in-memory credential and cannot re-auth without a file path"
)

response.raise_for_status()
token_data = response.json()
Expand All @@ -338,9 +352,12 @@ async def _refresh_copilot_token(
creds["_proxy_metadata"] = {}
creds["_proxy_metadata"]["last_check_timestamp"] = time.time()

await self._save_credentials(path, creds)
if path:
await self._save_credentials(path, creds)
else:
self._credentials_cache[lock_key] = creds
lib_logger.debug(
f"Successfully refreshed {self.ENV_PREFIX} Copilot API token for '{Path(path).name}'."
f"Successfully refreshed {self.ENV_PREFIX} Copilot API token for '{display_name}'."
)
return creds

Expand Down Expand Up @@ -396,9 +413,13 @@ async def initialize_token(
)

try:
creds = (
await self._load_credentials(creds_or_path) if path else creds_or_path
)
if path:
creds: Dict[str, Any] = await self._load_credentials(path)
elif isinstance(creds_or_path, dict):
creds = creds_or_path
else:
raise ValueError("Invalid credential input for Copilot initialization")

needs_auth = False
reason = ""

Expand Down Expand Up @@ -545,10 +566,13 @@ async def initialize_token(
)
if user_response.is_success:
user_info = user_response.json()
resolved_identity = (
user_info.get("email")
or user_info.get("login")
or "unknown"
)
new_creds["_proxy_metadata"]["email"] = (
user_info.get(
"email", user_info.get("login", "unknown")
)
resolved_identity
)
except Exception as e:
lib_logger.warning(f"Failed to fetch user info: {e}")
Expand Down Expand Up @@ -591,7 +615,12 @@ async def get_user_info(
) -> Dict[str, Any]:
"""Get user info from cached metadata or API."""
path = creds_or_path if isinstance(creds_or_path, str) else None
creds = await self._load_credentials(creds_or_path) if path else creds_or_path
if path:
creds: Dict[str, Any] = await self._load_credentials(path)
elif isinstance(creds_or_path, dict):
creds = creds_or_path
else:
return {"email": "unknown"}

if creds.get("_proxy_metadata", {}).get("email"):
return {"email": creds["_proxy_metadata"]["email"]}
Expand All @@ -615,8 +644,10 @@ async def get_user_info(
)
if response.is_success:
user_info = response.json()
email = user_info.get(
"email", user_info.get("login", "unknown")
email = (
user_info.get("email")
or user_info.get("login")
or "unknown"
)
creds["_proxy_metadata"] = {
"email": email,
Expand All @@ -629,3 +660,201 @@ async def get_user_info(
lib_logger.warning(f"Failed to fetch user info: {e}")

return {"email": "unknown"}

def _get_oauth_base_dir(self) -> Path:
"""Return the OAuth credentials base directory."""
return get_oauth_dir()

def _get_provider_file_prefix(self) -> str:
"""Return file prefix for Copilot credential files."""
return "copilot"

def _find_existing_credential_by_email(
self, email: str, base_dir: Optional[Path] = None
) -> Optional[Path]:
"""Find existing credential file by email for deduplication."""
if base_dir is None:
base_dir = self._get_oauth_base_dir()

for cred in self.list_credentials(base_dir):
if cred.get("email", "").lower() == email.lower():
return Path(cred["file_path"])
return None

def _get_next_credential_number(self, base_dir: Optional[Path] = None) -> int:
"""Get next available credential number."""
if base_dir is None:
base_dir = self._get_oauth_base_dir()

prefix = self._get_provider_file_prefix()
pattern = str(base_dir / f"{prefix}_oauth_*.json")

existing_numbers = []
for cred_file in glob(pattern):
match = re.search(r"_oauth_(\d+)\.json$", cred_file)
if match:
existing_numbers.append(int(match.group(1)))

if not existing_numbers:
return 1
return max(existing_numbers) + 1

def _build_credential_path(
self, base_dir: Optional[Path] = None, number: Optional[int] = None
) -> Path:
"""Build path for a new Copilot credential file."""
if base_dir is None:
base_dir = self._get_oauth_base_dir()

if number is None:
number = self._get_next_credential_number(base_dir)

prefix = self._get_provider_file_prefix()
return base_dir / f"{prefix}_oauth_{number}.json"

async def setup_credential(
self, base_dir: Optional[Path] = None
) -> CopilotCredentialSetupResult:
"""Complete credential setup flow: OAuth -> save."""
if base_dir is None:
base_dir = self._get_oauth_base_dir()

base_dir.mkdir(parents=True, exist_ok=True)

try:
temp_creds = {"_proxy_metadata": {"display_name": "new Copilot credential"}}
new_creds = await self.initialize_token(temp_creds)

user_info = await self.get_user_info(new_creds)
email = user_info.get("email")
if not email:
return CopilotCredentialSetupResult(
success=False, error="Could not retrieve email from OAuth response"
)
Comment on lines +730 to +733
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check will never trigger - get_user_info() always returns a non-empty email

get_user_info() always returns {"email": <value>} where value is never falsy (minimum is "unknown"). This check if not email: will never be True.

If email fetch fails, the credential will be created with email "unknown", which could cause deduplication issues if multiple failed setups occur.

Suggested change
if not email:
return CopilotCredentialSetupResult(
success=False, error="Could not retrieve email from OAuth response"
)
if not email or email == "unknown":
return CopilotCredentialSetupResult(
success=False, error="Could not retrieve email from OAuth response"
)
Prompt To Fix With AI
This is a comment left during a code review.
Path: src/rotator_library/providers/copilot_auth_base.py
Line: 730-733

Comment:
Check will never trigger - `get_user_info()` always returns a non-empty email

`get_user_info()` always returns `{"email": <value>}` where value is never falsy (minimum is `"unknown"`). This check `if not email:` will never be True.

If email fetch fails, the credential will be created with email `"unknown"`, which could cause deduplication issues if multiple failed setups occur.

```suggestion
            if not email or email == "unknown":
                return CopilotCredentialSetupResult(
                    success=False, error="Could not retrieve email from OAuth response"
                )
```

How can I resolve this? If you propose a fix, please make it concise.


existing_path = self._find_existing_credential_by_email(email, base_dir)
is_update = existing_path is not None

if is_update:
file_path = existing_path
lib_logger.info(
f"Found existing credential for {email}, updating {file_path.name}"
)
else:
file_path = self._build_credential_path(base_dir)
lib_logger.info(
f"Creating new credential for {email} at {file_path.name}"
)

await self._save_credentials(str(file_path), new_creds)

return CopilotCredentialSetupResult(
success=True,
file_path=str(file_path),
email=email,
is_update=is_update,
credentials=new_creds,
)

except Exception as e:
lib_logger.error(f"Copilot credential setup failed: {e}")
return CopilotCredentialSetupResult(success=False, error=str(e))

def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> list[str]:
"""Generate .env file lines for a Copilot credential."""
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
prefix = f"COPILOT_{cred_number}"

lines = [
f"# COPILOT Credential #{cred_number} for: {email}",
f"# Exported from: copilot_oauth_{cred_number}.json",
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
"",
f"{prefix}_GITHUB_TOKEN={creds.get('refresh_token', '')}",
f"{prefix}_ENTERPRISE_URL={creds.get('enterprise_url', '')}",
f"{prefix}_EMAIL={email}",
]

return lines

def export_credential_to_env(
self, credential_path: str, output_dir: Optional[Path] = None
) -> Optional[str]:
"""Export a Copilot credential file to .env format."""
try:
cred_path = Path(credential_path)
with open(cred_path, "r") as f:
creds = json.load(f)

email = creds.get("_proxy_metadata", {}).get("email", "unknown")
match = re.search(r"_oauth_(\d+)\.json$", cred_path.name)
cred_number = int(match.group(1)) if match else 1

if output_dir is None:
output_dir = cred_path.parent

safe_email = email.replace("@", "_at_").replace(".", "_")
env_filename = f"copilot_{cred_number}_{safe_email}.env"
env_path = output_dir / env_filename

with open(env_path, "w") as f:
f.write("\n".join(self.build_env_lines(creds, cred_number)))

return str(env_path)
except Exception as e:
lib_logger.error(f"Failed to export Copilot credential: {e}")
return None

def list_credentials(self, base_dir: Optional[Path] = None) -> list[Dict[str, Any]]:
"""List all Copilot credential files."""
if base_dir is None:
base_dir = self._get_oauth_base_dir()

prefix = self._get_provider_file_prefix()
pattern = str(base_dir / f"{prefix}_oauth_*.json")

credentials = []
for cred_file in sorted(glob(pattern)):
try:
with open(cred_file, "r") as f:
creds = json.load(f)

metadata = creds.get("_proxy_metadata", {})
match = re.search(r"_oauth_(\d+)\.json$", cred_file)
number = int(match.group(1)) if match else 0

credentials.append(
{
"file_path": cred_file,
"email": metadata.get("email", "unknown"),
"number": number,
}
)
except Exception as e:
lib_logger.debug(f"Could not read credential file {cred_file}: {e}")

return credentials

def delete_credential(self, credential_path: str) -> bool:
"""Delete a Copilot credential file."""
try:
cred_path = Path(credential_path)
prefix = self._get_provider_file_prefix()

if not cred_path.name.startswith(f"{prefix}_oauth_"):
lib_logger.error(
f"File {cred_path.name} does not appear to be a Copilot credential"
)
return False

if not cred_path.exists():
lib_logger.warning(f"Credential file does not exist: {credential_path}")
return False

self._credentials_cache.pop(credential_path, None)
cred_path.unlink()
lib_logger.info(f"Deleted credential file: {credential_path}")
return True
except Exception as e:
lib_logger.error(f"Failed to delete Copilot credential: {e}")
return False
Loading