diff --git a/squire/agents/di.py b/squire/agents/di.py index 2cdf0655..de8c41d2 100644 --- a/squire/agents/di.py +++ b/squire/agents/di.py @@ -7,6 +7,7 @@ from squire.agents.comment_generation.di import CommentGenerationDependencyInjector from squire.agents.file_analysis.di import FileAnalysisDependencyInjector from squire.agents.file_review.di import FileReviewDependencyInjector +from squire.agents.preview_rule.di import PreviewRuleDependencyInjector from squire.agents.utils import LLMInvoker from squire.config import Config @@ -26,3 +27,4 @@ def configure(self, binder): binder.install(FileReviewDependencyInjector(self.config)) binder.install(CommentGenerationDependencyInjector(self.config)) binder.install(ChatAgentDependencyInjector(self.config)) + binder.install(PreviewRuleDependencyInjector(self.config)) diff --git a/squire/agents/models/violation.py b/squire/agents/models/violation.py index d0f02618..c31bf547 100644 --- a/squire/agents/models/violation.py +++ b/squire/agents/models/violation.py @@ -1,4 +1,5 @@ from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel as BaseModelV2 class Violation(BaseModel): @@ -30,3 +31,34 @@ class Violation(BaseModel): class Config: arbitrary_types_allowed = True + + +class ViolationV2(BaseModelV2): + description: str = Field(..., description="Description of the violation") + reason: str = Field( + ..., + description="Identify your reasoning regarding how the code violates the identified standard", + ) + standard_id: int = Field(..., description="ID of the violated standard") + hunk_id: str = Field(description="ID of the hunk this violation is associated with") + start_line: int = Field( + ..., description="Starting line number from the hunk where the violation starts" + ) + end_line: int = Field( + ..., description="Starting line number from the hunk where the violation ends" + ) + severity: int = Field( + ..., + ge=1, + le=10, + description="How severe is the issue stemming from this violation? 1 being the least severe and 10 being breaking change", + ) + confidence: int = Field( + ..., + ge=1, + le=10, + description="How confident are you that this is a valid issue that must be resolved? 1 being the least confident and 10 being extremely confident", + ) + + class Config: + arbitrary_types_allowed = True diff --git a/squire/agents/preview_rule/di.py b/squire/agents/preview_rule/di.py new file mode 100644 index 00000000..dac46f76 --- /dev/null +++ b/squire/agents/preview_rule/di.py @@ -0,0 +1,26 @@ +from injector import Module, provider, singleton + +from squire.agents.utils import ContextManager, LLMInvoker +from squire.config import Config + +from .nodes import RulePreviewNode + + +class PreviewRuleDependencyInjector(Module): + def __init__(self, config: Config): + self.config = config + + def configure(self, binder): + binder.bind(Config, to=self.config, scope=singleton) + binder.bind( + RulePreviewNode, + to=self.provide_rule_preview_node, + scope=singleton, + ) + + @provider + @singleton + def provide_rule_preview_node( + self, llm_invoker: LLMInvoker, context_manager: ContextManager + ) -> RulePreviewNode: + return RulePreviewNode(llm_invoker, context_manager) diff --git a/squire/agents/preview_rule/nodes/__init__.py b/squire/agents/preview_rule/nodes/__init__.py new file mode 100644 index 00000000..48b09dc7 --- /dev/null +++ b/squire/agents/preview_rule/nodes/__init__.py @@ -0,0 +1,5 @@ +from .rule_preview_node import RulePreviewNode + +__all__ = [ + "RulePreviewNode", +] diff --git a/squire/agents/preview_rule/nodes/rule_preview_node.py b/squire/agents/preview_rule/nodes/rule_preview_node.py new file mode 100644 index 00000000..e02b840d --- /dev/null +++ b/squire/agents/preview_rule/nodes/rule_preview_node.py @@ -0,0 +1,103 @@ +from typing import Dict, List, Type + +from langchain_core.messages import BaseMessage +from langchain_core.prompts import ( + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain_core.pydantic_v1 import BaseModel, Field + +from squire.agents.preview_rule.state import RulePreviewState +from squire.monitoring import logger + +from ...models import Violation +from ...utils.formatters import format_file_diff, format_rule +from ...utils.single_inference_node import SingleInferenceNode + + +class ViolationList(BaseModel): + violations: List[Violation] = Field( + default_factory=list, description="List of identified violations" + ) + + +class RulePreviewNode(SingleInferenceNode[RulePreviewState, ViolationList]): + NAME = "RulePreviewNode" + + @property + def response_type(self) -> Type[ViolationList]: + return ViolationList + + @property + def inference_temperature(self) -> float: + return 1.0 + + def _build_system_prompt( + self, state: RulePreviewState + ) -> SystemMessagePromptTemplate: + return SystemMessagePromptTemplate.from_template( + """ + As a senior software developer, identify issues in proposed code changes (code diff) based on provided standards and best practices. + Guidelines: + 1. Only identify violations introduced by the proposed changes. + 2. Verify violations are accurate and relevant in the broader context. + 3. Provide detailed reasoning for each violation. + 4. If unsure, skip reporting the violation. + 5. Consider ongoing work when evaluating changes. + 6. Be conservative with severity ratings (e.g., unused imports generally below 5). + 7. Focus on logic and code-related changes, not static content. + 8. Avoid assumptions and stating the obvious. + 9. Always verify violations in the full file context and based on tool calls. + """ # noqa + ) + + def _build_node_prompt(self, state: RulePreviewState) -> HumanMessagePromptTemplate: + return HumanMessagePromptTemplate.from_template( + """ + Analyze the following code diff against the provided rule: + + Rule: + {rule} + + File: {file_name} + Code diff: + ``` + {code_diff} + ``` + + Identify violations introduced by the proposed changes. For each violation: + 1. Explain how it violates the standard. + 2. Assign a severity score (1-10) based on the violation's impact. + 3. Consider the broader context, including any available commit messages. + """ # noqa + ) + + async def _build_prompt_variables(self, state: RulePreviewState) -> Dict[str, str]: + rule = format_rule(state["standard"]) + file_name = state["file_diff"].path + code_diff = format_file_diff(state["file_diff"]) + + return { + "rule": rule, + "file_name": file_name, + "code_diff": code_diff, + } + + async def post_inference_hook(self, state: RulePreviewState, result): + logger.info(f"{self.NAME}: Identified {len(result.violations)} violations") + + for violation in result.violations: + if violation.confidence >= 8 and violation.severity >= 5: + state["violations"].append(violation) + else: + logger.info( + f"{self.NAME}: Skipping violation with low confidence or severity: {violation}" + ) + + return state + + async def _update_historical_context( + self, state: RulePreviewState, node_prompt: BaseMessage, result + ): + """We don't want to update the historical context with all the violations however we should inform the LLM of what it was looking for.""" + self.context_manager.add_to_historical_context(state, node_prompt) diff --git a/squire/agents/preview_rule/service.py b/squire/agents/preview_rule/service.py new file mode 100644 index 00000000..c3621a9f --- /dev/null +++ b/squire/agents/preview_rule/service.py @@ -0,0 +1,138 @@ +from typing import List, Optional + +from nanoid import generate +from pydantic import BaseModel +from unidiff import PatchedFile, PatchSet + +from squire.agents.file_review.nodes.identify_violations_node import ( + IdentifyViolationsNode, +) +from squire.agents.models.violation import ViolationV2 +from squire.agents.preview_rule.state import RulePreviewState +from squire.database.postgres.models.actor import ActorModel +from squire.database.postgres.models.framework.standard_model import StandardCreateDTO +from squire.database.postgres.repositories import ( + OrganizationRepository, + PullRequestRepository, +) +from squire.database.postgres.repositories.data_source_repository import ( + DataSourceRepository, +) +from squire.di import injector +from squire.scm.github.github_pull_request_service import GitHubPullRequestService + +from .nodes import RulePreviewNode + + +class Hunk(BaseModel): + source_start: int + source_length: int + target_start: int + target_length: int + section_header: str + added: int + removed: int + lines: List[str] + + +class SerializableFileDiff(BaseModel): + source_file: str + target_file: str + is_rename: bool + is_new_file: bool + is_deleted_file: bool + hunks: List[Hunk] + + +class IdentifiedViolation(BaseModel): + file_diff: SerializableFileDiff + violations: Optional[List[ViolationV2]] = None + + class Config: + arbitrary_types_allowed = True + + +class RulePreviewService: + def __init__(self): + self.identify_violations_node: IdentifyViolationsNode = injector.get( + IdentifyViolationsNode + ) + self.pull_request_respository: PullRequestRepository = injector.get( + PullRequestRepository + ) + self.organization_repository: OrganizationRepository = injector.get( + OrganizationRepository + ) + self.data_source_repository: DataSourceRepository = injector.get( + DataSourceRepository + ) + self.github_pull_request_service: GitHubPullRequestService = injector.get( + GitHubPullRequestService + ) + self.rule_preview_node: RulePreviewNode = injector.get(RulePreviewNode) + + async def ainvoke( + self, org_public_id: str, pull_request_url: str, standard: StandardCreateDTO + ) -> List[IdentifiedViolation]: + org: ActorModel = await self.organization_repository.get( + org_public_id, use_public_id=True + ) + if not org: + raise Exception(f"Organization with public id {org_public_id} not found") + data_source = await self.data_source_repository.get_github_installation(org.id) + + diff = await self.github_pull_request_service.get_diff_by_pull_url( + pull_request_url, + data_source.ext_id, + ) + + results = [] + parsed_diff = PatchSet(diff) + for file_diff in parsed_diff: + llm_result = await self._identify_violations(standard, file_diff) + serialized_file_diff = SerializableFileDiff( + source_file=file_diff.source_file, + target_file=file_diff.target_file, + is_rename=file_diff.is_rename, + is_new_file=file_diff.is_added_file, + is_deleted_file=file_diff.is_removed_file, + hunks=[ + Hunk( + source_start=hunk.source_start, + source_length=hunk.source_length, + target_start=hunk.target_start, + target_length=hunk.target_length, + section_header=hunk.section_header, + added=hunk.added, + removed=hunk.removed, + lines=[line.value for line in hunk], + ) + for hunk in file_diff + ], + ) + result = IdentifiedViolation( + file_diff=serialized_file_diff, + violations=[ + ViolationV2(**violation.dict()) + for violation in llm_result["violations"] + ], + ) + results.append(result) + + return results + + async def _identify_violations( + self, + standard: StandardCreateDTO, + file_diff: PatchedFile, + ) -> RulePreviewState: + state: RulePreviewState = { + "uuid": generate(), + "historical_context": [], + "chat_history_id": None, + "standard": standard.to_model(), + "violations": [], + "file_diff": file_diff, + } + + return await self.rule_preview_node.ainvoke(state) diff --git a/squire/agents/preview_rule/state.py b/squire/agents/preview_rule/state.py new file mode 100644 index 00000000..e69303b9 --- /dev/null +++ b/squire/agents/preview_rule/state.py @@ -0,0 +1,13 @@ +from typing import List + +from unidiff.patch import PatchedFile + +from squire.agents.models.state import BaseContextualState +from squire.agents.models.violation import Violation +from squire.database.postgres.models import StandardModel + + +class RulePreviewState(BaseContextualState): + standard: StandardModel + file_diff: PatchedFile + violations: List[Violation] diff --git a/squire/agents/utils/formatters.py b/squire/agents/utils/formatters.py index 86bc0739..1425f6a8 100644 --- a/squire/agents/utils/formatters.py +++ b/squire/agents/utils/formatters.py @@ -3,6 +3,7 @@ from unidiff import Hunk, PatchedFile from squire.database.postgres.models import FrameworkModel +from squire.database.postgres.models.framework.standard_model import StandardModel def format_file_diff(patched_file: PatchedFile) -> str: @@ -83,3 +84,10 @@ def format_standards(framework: FrameworkModel) -> str: ) return "\n\n".join(formatted_standards) + + +def format_rule(rule: StandardModel) -> str: + result = f"rule_id: {rule.id}\n" + result += f"Key: {rule.key}\n" + result += f"Description: {rule.prompt}\n" + return result diff --git a/squire/api/api.py b/squire/api/api.py index d278d434..aac477ee 100644 --- a/squire/api/api.py +++ b/squire/api/api.py @@ -10,6 +10,7 @@ router as organization_framework_router, ) from squire.api.controllers.ping_controller import router as ping_router +from squire.api.controllers.preview_rule_controller import router as preview_rule_router from squire.api.controllers.repository import router as repository_router from squire.api.controllers.repository_framework import ( router as repository_framework_router, @@ -58,6 +59,11 @@ prefix="/organization/{org_public_id}/repository/{repo_public_id}/framework", tags=["organization", "repository", "framework"], ) +api_router.include_router( + preview_rule_router, + prefix="/organization/{org_public_id}", + tags=["organization", "rule"], +) app.include_router(api_router) app.add_exception_handler(404, NotFoundHandler.handle_404) diff --git a/squire/api/controllers/preview_rule_controller.py b/squire/api/controllers/preview_rule_controller.py new file mode 100644 index 00000000..14ff05d7 --- /dev/null +++ b/squire/api/controllers/preview_rule_controller.py @@ -0,0 +1,44 @@ +from typing import Annotated + +from fastapi import APIRouter, HTTPException +from fastapi.params import Body, Depends, Path +from pydantic import BaseModel + +from squire.agents.preview_rule.service import RulePreviewService +from squire.api.base import BaseAPI +from squire.api.dependencies import org_dependencies +from squire.database.postgres.models.framework.standard_model import StandardCreateDTO +from squire.monitoring.logging import logger + + +class PreviewRuleRequest(BaseModel): + standard: StandardCreateDTO + pull_request_url: str + + +router = APIRouter() + + +class PreviewRuleController(BaseAPI): + async def preview_rule(self, org_public_id: str, request: PreviewRuleRequest): + try: + service = RulePreviewService() + response = await service.ainvoke( + org_public_id, request.pull_request_url, request.standard + ) + return response + except Exception as e: + logger.error("Error previewing rule", exc_info=True) + raise HTTPException(status_code=400, detail=str(e)) from e + + +controller = PreviewRuleController() + + +@router.post("/preview-rule", dependencies=[Depends(org_dependencies.is_org_member)]) +@controller.handle_request +async def review_code( + org_public_id: Annotated[str, Path()], + request: Annotated[PreviewRuleRequest, Body()], +): + return await controller.preview_rule(org_public_id, request) diff --git a/squire/database/postgres/di.py b/squire/database/postgres/di.py index 0013ba41..23e79ddb 100644 --- a/squire/database/postgres/di.py +++ b/squire/database/postgres/di.py @@ -16,9 +16,11 @@ from squire.database.postgres.models.repository import RepositoryModel from squire.database.postgres.models.review import ReviewModel from squire.database.postgres.repositories import ( + DataSourceRepository, FrameworkRepository, OrganizationRepository, OrganizationUserRepository, + PullRequestRepository, RepositoryRepository, StandardRepository, ) @@ -117,6 +119,16 @@ def configure(self, binder): to=self.provide_organization_user_repository, scope=singleton, ) + binder.bind( + PullRequestRepository, + to=self.provide_pull_request_repository, + scope=singleton, + ) + binder.bind( + DataSourceRepository, + to=self.provide_data_source_repository, + scope=singleton, + ) @provider def provide_async_engine(self) -> AsyncEngine: @@ -219,3 +231,15 @@ def provide_organization_user_repository( self, session_factory: sessionmaker ) -> OrganizationUserRepository: return OrganizationUserRepository(session_factory) + + @provider + def provide_pull_request_repository( + self, session_factory: sessionmaker + ) -> PullRequestRepository: + return PullRequestRepository(session_factory) + + @provider + def provide_data_source_repository( + self, session_factory: sessionmaker + ) -> DataSourceRepository: + return DataSourceRepository(session_factory) diff --git a/squire/database/postgres/models/framework/standard_model.py b/squire/database/postgres/models/framework/standard_model.py index 3b572c3b..e3706615 100644 --- a/squire/database/postgres/models/framework/standard_model.py +++ b/squire/database/postgres/models/framework/standard_model.py @@ -56,18 +56,15 @@ class Config: class StandardCreateDTO(StandardBaseDTO): - public_id: str - def to_model(self) -> StandardModel: return StandardModel(**self.model_dump()) class StandardPublicDTO(StandardBaseDTO): public_id: str - key: Optional[str] = None - title: Optional[str] = None - prompt: Optional[str] = None - framework_id: Optional[int] = None + + def to_model(self) -> StandardModel: + return StandardModel(**self.model_dump()) class StandardUpdateDTO(BaseModel): @@ -88,7 +85,7 @@ class StandardDTO(StandardBaseDTO): created_at: datetime updated_at: datetime - class Config: + class Config(StandardBaseDTO.Config): arbitrary_types_allowed = True from_attributes = True orm_mode = True diff --git a/squire/database/postgres/models/pull_request.py b/squire/database/postgres/models/pull_request.py index 386d5462..2f2a47d9 100644 --- a/squire/database/postgres/models/pull_request.py +++ b/squire/database/postgres/models/pull_request.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from sqlalchemy import TIMESTAMP, BigInteger, Column from sqlalchemy import Enum as SQLAlchemyEnum -from sqlalchemy import Integer, Text +from sqlalchemy import ForeignKey, Integer, Text from .base import Base @@ -25,7 +25,7 @@ class PullRequestModel(Base): id = Column(Integer, primary_key=True, autoincrement=True) ext_pull_id = Column("extPullId", BigInteger) ext_pull_number = Column("extPullNumber", Integer) - repository_id = Column("repositoryId", Integer) + repository_id = Column("repositoryId", Integer, ForeignKey("repository.id")) state = Column( SQLAlchemyEnum( PullRequestState, diff --git a/squire/database/postgres/repositories/__init__.py b/squire/database/postgres/repositories/__init__.py index 5a45f345..7eb9674b 100644 --- a/squire/database/postgres/repositories/__init__.py +++ b/squire/database/postgres/repositories/__init__.py @@ -1,9 +1,11 @@ from .base_postgres_repository import BasePostgresRepository +from .data_source_repository import DataSourceRepository from .framework_repository import FrameworkRepository from .organization_repository import OrganizationRepository -from .standard_repository import StandardRepository from .organization_user_repository import OrganizationUserRepository +from .pull_request_respository import PullRequestRepository from .repository_repository import RepositoryRepository +from .standard_repository import StandardRepository __all__ = [ "BasePostgresRepository", @@ -12,4 +14,6 @@ "StandardRepository", "RepositoryRepository", "OrganizationUserRepository", + "PullRequestRepository", + "DataSourceRepository", ] diff --git a/squire/database/postgres/repositories/data_source_repository.py b/squire/database/postgres/repositories/data_source_repository.py new file mode 100644 index 00000000..e50f40c8 --- /dev/null +++ b/squire/database/postgres/repositories/data_source_repository.py @@ -0,0 +1,30 @@ +from typing import Optional + +from sqlalchemy import Select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.session import sessionmaker + +from squire.database.postgres.models import DataSourceModel +from squire.database.postgres.repositories.base_postgres_repository import ( + BasePostgresRepository, + provide_session, +) +from squire.database.postgres.services.data_source_postgres_service import ( + DataSourceType, +) + + +class DataSourceRepository(BasePostgresRepository): + def __init__(self, session_maker: sessionmaker): + super().__init__(session_maker, DataSourceModel) + + @provide_session + async def get_github_installation( + self, session: AsyncSession, organization_id: int + ) -> Optional[DataSourceModel]: + query = Select(DataSourceModel).where( + DataSourceModel.organization_id == organization_id, + DataSourceModel.type == DataSourceType.GITHUB.value, + ) + result = await session.execute(query) + return result.scalar_one_or_none() diff --git a/squire/database/postgres/repositories/pull_request_respository.py b/squire/database/postgres/repositories/pull_request_respository.py new file mode 100644 index 00000000..fdcb3fde --- /dev/null +++ b/squire/database/postgres/repositories/pull_request_respository.py @@ -0,0 +1,39 @@ +from typing import List, override + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.session import sessionmaker + +from squire.database.postgres.models import RepositoryModel +from squire.database.postgres.models.pull_request import PullRequestModel +from squire.database.postgres.repositories.base_postgres_repository import ( + BasePostgresRepository, + provide_session, +) + + +class PullRequestRepository(BasePostgresRepository): + def __init__(self, session_maker: sessionmaker): + super().__init__(session_maker, PullRequestModel) + + @override + @provide_session + async def get_many( + self, + session: AsyncSession, + repo_public_id: str, + limit: int = 5, + offset: int = 0, + ) -> List[PullRequestModel]: + query = ( + select(PullRequestModel) + .join( + RepositoryModel, + RepositoryModel.id == PullRequestModel.repository_id, + ) + .filter(RepositoryModel.public_id == repo_public_id) + .limit(limit) + .offset(offset) + ) + result = await session.execute(query) + return result.scalars().all() diff --git a/squire/database/postgres/services/repository_postgres_service.py b/squire/database/postgres/services/repository_postgres_service.py index ff0d81c1..bd6e8384 100644 --- a/squire/database/postgres/services/repository_postgres_service.py +++ b/squire/database/postgres/services/repository_postgres_service.py @@ -3,7 +3,7 @@ from sqlalchemy import select from squire.monitoring import logger -from squire.utils.url_utils import extract_owner_and_name_from_url +from squire.utils.url_utils import extract_owner_and_name_from_git_url from ..models import RepositoryModel from .base_postgres_service import BasePostgresService @@ -37,7 +37,7 @@ async def get_by_ext_id(self, ext_id: int) -> Optional[RepositoryModel]: return result.scalar_one_or_none() async def get_by_url(self, remote_url: str) -> Optional[RepositoryModel]: - extraction = extract_owner_and_name_from_url(remote_url) + extraction = extract_owner_and_name_from_git_url(remote_url) if extraction is None: logger.error("Could not extract owner, name from URL: %s", remote_url) return None diff --git a/squire/pipeline/stages/sync/create_models.py b/squire/pipeline/stages/sync/create_models.py index 9e4f62af..8ddd1bbc 100644 --- a/squire/pipeline/stages/sync/create_models.py +++ b/squire/pipeline/stages/sync/create_models.py @@ -8,7 +8,7 @@ from squire.pipeline.stages.clone_repository_stage import CloneRepositoryStageOutput from squire.pipeline.stages.pipeline_stage import PipelineStage from squire.utils.git_utils import get_commit_sha -from squire.utils.url_utils import extract_owner_and_name_from_url +from squire.utils.url_utils import extract_owner_and_name_from_git_url class RecordIndexStageOutput(CloneRepositoryStageOutput): @@ -36,7 +36,7 @@ async def execute( self, input_data: CloneRepositoryStageOutput, **kwargs ) -> RecordIndexStageOutput: try: - extraction = extract_owner_and_name_from_url( + extraction = extract_owner_and_name_from_git_url( input_data.cloned_repo.remote_url ) if not extraction: diff --git a/squire/scm/github/github_pull_request_service.py b/squire/scm/github/github_pull_request_service.py index 507097ba..ced51207 100644 --- a/squire/scm/github/github_pull_request_service.py +++ b/squire/scm/github/github_pull_request_service.py @@ -16,6 +16,7 @@ RepositoryPostgresService, ) from squire.monitoring import logger +from squire.utils.url_utils import extract_repo_info_from_pull_url from ..git_pull_request_service import GitPullRequestService from .github_app_service import GitHubAppService @@ -55,7 +56,7 @@ def __init__( self.pr_postgres_service = pr_postgres_service self.organization_postgres_service = organization_postgres_service self.datasource_postgres_service = datasource_postgres_service - self.github_app_service = github_app_service + self.github_app_service: GitHubAppService = github_app_service async def get_diff(self, repository_id: int, pr_id: int) -> str: try: @@ -99,6 +100,59 @@ async def get_diff(self, repository_id: int, pr_id: int) -> str: logger.error("Unexpected error in get_diff: %s", e) raise GitCommandError(f"Unexpected error: {e}") from e + async def get_diff_by_pull_url( + self, + pull_url: str, + ext_installation_id: Optional[str] = None, + ): + repo_info = extract_repo_info_from_pull_url(pull_url) + logger.debug(f"Info: {repo_info}") + if not repo_info: + raise ValueError("Invalid pull request URL") + owner, repo_name, ext_pull_number = repo_info + full_repo_name = f"{owner}/{repo_name}" + is_public = self.github_app_service.is_repository_public(full_repo_name) + if is_public: + github_client = self.github_app_service.get_authed_integration() + auth_token = github_client.auth.token + endpoint_url = f"{pull_url}.diff" + else: + if not ext_installation_id: + raise ValueError("Private repository requires installation_id") + github_client = self.github_app_service.get_authenticated_client( + int(ext_installation_id) + ) + authed_integration = self.github_app_service.get_authed_integration() + auth_token = authed_integration.get_access_token( + int(ext_installation_id) + ).token + github_repo = github_client.get_repo(full_repo_name) + github_pr = github_repo.get_pull(int(ext_pull_number)) + endpoint_url = github_pr.url + + # Get the diff using authenticated headers + headers = { + "Authorization": f"Token {auth_token}", + "Accept": "application/vnd.github.diff", + "X-GitHub-Api-Version": "2022-11-28", + } + try: + response = requests.get(endpoint_url, headers=headers, timeout=10) + logger.debug(f"Response: {response.text}") + + response.raise_for_status() # Raise an exception for HTTP errors + + return response.text + except GitCommandError as e: + logger.error("GitCommandError: %s", e) + raise GitCommandError(f"Error getting diff: {e}") from e + except requests.exceptions.RequestException as e: + logger.error("HTTP error in get_diff: %s", e) + raise GitCommandError(f"HTTP error: {e}") from e + except Exception as e: + logger.error("Unexpected error in get_diff: %s", e) + raise GitCommandError(f"Unexpected error: {e}") from e + async def get_pull( self, ext_installation_id: str, diff --git a/squire/utils/url_utils.py b/squire/utils/url_utils.py index 21ff8d71..a74d6579 100644 --- a/squire/utils/url_utils.py +++ b/squire/utils/url_utils.py @@ -19,7 +19,7 @@ def is_valid_git_url(url: str) -> bool: return bool(http_pattern.match(url) or ssh_pattern.match(url)) -def extract_owner_and_name_from_url(url: str) -> Optional[tuple]: +def extract_owner_and_name_from_git_url(url: str) -> Optional[tuple]: """ Extract the owner and repository name from a Git remote URL. @@ -42,6 +42,47 @@ def extract_owner_and_name_from_url(url: str) -> Optional[tuple]: return None +def extract_repo_info_from_pull_url(url: str) -> Optional[tuple]: + """ + Extract the owner and repository name from any GitHub URL which is nested under a repository. + + Args: + url (str): The GitHub URL. + + Returns: + Optional[tuple]: A tuple containing the owner and repository name, or None if the URL is invalid. + """ + if not is_valid_pull_url(url): + return None + # Pattern for extracting owner and repo name from various GitHub URL formats + github_pattern = re.compile( + r"(?:https?://)?(?:www\.)?github\.com/([^/]+)/([^/]+)/pull/(\d+)$" + ) + + match = github_pattern.match(url) + if match: + return match.group(1), match.group(2), match.group(3) + return None + + +def is_valid_pull_url(url: str) -> bool: + """ + Validate if the given URL is a valid GitHub pull request URL. + + Args: + url (str): The URL to validate. + + Returns: + bool: True if the URL is a valid GitHub pull request URL, False otherwise. + """ + # Pattern for matching GitHub pull request URLs + pull_pattern = re.compile( + r"(?:https?://)?(?:www\.)?github\.com/([^/]+)/([^/]+)/pull/(\d+)$" + ) + + return bool(pull_pattern.match(url)) + + def construct_clone_url(name: str, owner: str, protocol: str = "https") -> str: if protocol == "https": return f"https://github.com/{owner}/{name}.git"