Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions squire/agents/di.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from squire.agents.comment_generation.di import CommentGenerationDependencyInjector
from squire.agents.file_analysis.di import FileAnalysisDependencyInjector
from squire.agents.file_review.di import FileReviewDependencyInjector
from squire.agents.preview_rule.di import PreviewRuleDependencyInjector
from squire.agents.utils import LLMInvoker
from squire.config import Config

Expand All @@ -26,3 +27,4 @@ def configure(self, binder):
binder.install(FileReviewDependencyInjector(self.config))
binder.install(CommentGenerationDependencyInjector(self.config))
binder.install(ChatAgentDependencyInjector(self.config))
binder.install(PreviewRuleDependencyInjector(self.config))
32 changes: 32 additions & 0 deletions squire/agents/models/violation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel as BaseModelV2


class Violation(BaseModel):
Expand Down Expand Up @@ -30,3 +31,34 @@ class Violation(BaseModel):

class Config:
arbitrary_types_allowed = True


class ViolationV2(BaseModelV2):
description: str = Field(..., description="Description of the violation")
reason: str = Field(
...,
description="Identify your reasoning regarding how the code violates the identified standard",
)
standard_id: int = Field(..., description="ID of the violated standard")
hunk_id: str = Field(description="ID of the hunk this violation is associated with")
start_line: int = Field(
..., description="Starting line number from the hunk where the violation starts"
)
end_line: int = Field(
..., description="Starting line number from the hunk where the violation ends"
)
severity: int = Field(
...,
ge=1,
le=10,
description="How severe is the issue stemming from this violation? 1 being the least severe and 10 being breaking change",
)
confidence: int = Field(
...,
ge=1,
le=10,
description="How confident are you that this is a valid issue that must be resolved? 1 being the least confident and 10 being extremely confident",
)

class Config:
arbitrary_types_allowed = True
26 changes: 26 additions & 0 deletions squire/agents/preview_rule/di.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from injector import Module, provider, singleton

from squire.agents.utils import ContextManager, LLMInvoker
from squire.config import Config

from .nodes import RulePreviewNode


class PreviewRuleDependencyInjector(Module):
def __init__(self, config: Config):
self.config = config

def configure(self, binder):
binder.bind(Config, to=self.config, scope=singleton)
binder.bind(
RulePreviewNode,
to=self.provide_rule_preview_node,
scope=singleton,
)

@provider
@singleton
def provide_rule_preview_node(
self, llm_invoker: LLMInvoker, context_manager: ContextManager
) -> RulePreviewNode:
return RulePreviewNode(llm_invoker, context_manager)
5 changes: 5 additions & 0 deletions squire/agents/preview_rule/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .rule_preview_node import RulePreviewNode

__all__ = [
"RulePreviewNode",
]
103 changes: 103 additions & 0 deletions squire/agents/preview_rule/nodes/rule_preview_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Dict, List, Type

from langchain_core.messages import BaseMessage
from langchain_core.prompts import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import BaseModel, Field

from squire.agents.preview_rule.state import RulePreviewState
from squire.monitoring import logger

from ...models import Violation
from ...utils.formatters import format_file_diff, format_rule
from ...utils.single_inference_node import SingleInferenceNode


class ViolationList(BaseModel):
violations: List[Violation] = Field(
default_factory=list, description="List of identified violations"
)


class RulePreviewNode(SingleInferenceNode[RulePreviewState, ViolationList]):
NAME = "RulePreviewNode"

@property
def response_type(self) -> Type[ViolationList]:
return ViolationList

@property
def inference_temperature(self) -> float:
return 1.0

def _build_system_prompt(
self, state: RulePreviewState
) -> SystemMessagePromptTemplate:
return SystemMessagePromptTemplate.from_template(
"""
As a senior software developer, identify issues in proposed code changes (code diff) based on provided standards and best practices.
Guidelines:
1. Only identify violations introduced by the proposed changes.
2. Verify violations are accurate and relevant in the broader context.
3. Provide detailed reasoning for each violation.
4. If unsure, skip reporting the violation.
5. Consider ongoing work when evaluating changes.
6. Be conservative with severity ratings (e.g., unused imports generally below 5).
7. Focus on logic and code-related changes, not static content.
8. Avoid assumptions and stating the obvious.
9. Always verify violations in the full file context and based on tool calls.
""" # noqa
)

def _build_node_prompt(self, state: RulePreviewState) -> HumanMessagePromptTemplate:
return HumanMessagePromptTemplate.from_template(
"""
Analyze the following code diff against the provided rule:

Rule:
{rule}

File: {file_name}
Code diff:
```
{code_diff}
```

Identify violations introduced by the proposed changes. For each violation:
1. Explain how it violates the standard.
2. Assign a severity score (1-10) based on the violation's impact.
3. Consider the broader context, including any available commit messages.
""" # noqa
)

async def _build_prompt_variables(self, state: RulePreviewState) -> Dict[str, str]:
rule = format_rule(state["standard"])
file_name = state["file_diff"].path
code_diff = format_file_diff(state["file_diff"])

return {
"rule": rule,
"file_name": file_name,
"code_diff": code_diff,
}

async def post_inference_hook(self, state: RulePreviewState, result):
logger.info(f"{self.NAME}: Identified {len(result.violations)} violations")

for violation in result.violations:
if violation.confidence >= 8 and violation.severity >= 5:
state["violations"].append(violation)
else:
logger.info(
f"{self.NAME}: Skipping violation with low confidence or severity: {violation}"
)

return state

async def _update_historical_context(
self, state: RulePreviewState, node_prompt: BaseMessage, result
):
"""We don't want to update the historical context with all the violations however we should inform the LLM of what it was looking for."""
self.context_manager.add_to_historical_context(state, node_prompt)
138 changes: 138 additions & 0 deletions squire/agents/preview_rule/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import List, Optional

from nanoid import generate
from pydantic import BaseModel
from unidiff import PatchedFile, PatchSet

from squire.agents.file_review.nodes.identify_violations_node import (
IdentifyViolationsNode,
)
from squire.agents.models.violation import ViolationV2
from squire.agents.preview_rule.state import RulePreviewState
from squire.database.postgres.models.actor import ActorModel
from squire.database.postgres.models.framework.standard_model import StandardCreateDTO
from squire.database.postgres.repositories import (
OrganizationRepository,
PullRequestRepository,
)
from squire.database.postgres.repositories.data_source_repository import (
DataSourceRepository,
)
from squire.di import injector
from squire.scm.github.github_pull_request_service import GitHubPullRequestService

from .nodes import RulePreviewNode


class Hunk(BaseModel):
source_start: int
source_length: int
target_start: int
target_length: int
section_header: str
added: int
removed: int
lines: List[str]


class SerializableFileDiff(BaseModel):
source_file: str
target_file: str
is_rename: bool
is_new_file: bool
is_deleted_file: bool
hunks: List[Hunk]


class IdentifiedViolation(BaseModel):
file_diff: SerializableFileDiff
violations: Optional[List[ViolationV2]] = None

class Config:
arbitrary_types_allowed = True


class RulePreviewService:
def __init__(self):
self.identify_violations_node: IdentifyViolationsNode = injector.get(
IdentifyViolationsNode
)
self.pull_request_respository: PullRequestRepository = injector.get(
PullRequestRepository
)
self.organization_repository: OrganizationRepository = injector.get(
OrganizationRepository
)
self.data_source_repository: DataSourceRepository = injector.get(
DataSourceRepository
)
self.github_pull_request_service: GitHubPullRequestService = injector.get(
GitHubPullRequestService
)
self.rule_preview_node: RulePreviewNode = injector.get(RulePreviewNode)

async def ainvoke(
self, org_public_id: str, pull_request_url: str, standard: StandardCreateDTO
) -> List[IdentifiedViolation]:
org: ActorModel = await self.organization_repository.get(
org_public_id, use_public_id=True
)
if not org:
raise Exception(f"Organization with public id {org_public_id} not found")
data_source = await self.data_source_repository.get_github_installation(org.id)

diff = await self.github_pull_request_service.get_diff_by_pull_url(
pull_request_url,
data_source.ext_id,
)

results = []
parsed_diff = PatchSet(diff)
for file_diff in parsed_diff:
llm_result = await self._identify_violations(standard, file_diff)
serialized_file_diff = SerializableFileDiff(
source_file=file_diff.source_file,
target_file=file_diff.target_file,
is_rename=file_diff.is_rename,
is_new_file=file_diff.is_added_file,
is_deleted_file=file_diff.is_removed_file,
hunks=[
Hunk(
source_start=hunk.source_start,
source_length=hunk.source_length,
target_start=hunk.target_start,
target_length=hunk.target_length,
section_header=hunk.section_header,
added=hunk.added,
removed=hunk.removed,
lines=[line.value for line in hunk],
)
for hunk in file_diff
],
)
result = IdentifiedViolation(
file_diff=serialized_file_diff,
violations=[
ViolationV2(**violation.dict())
for violation in llm_result["violations"]
],
)
results.append(result)

return results

async def _identify_violations(
self,
standard: StandardCreateDTO,
file_diff: PatchedFile,
) -> RulePreviewState:
state: RulePreviewState = {
"uuid": generate(),
"historical_context": [],
"chat_history_id": None,
"standard": standard.to_model(),
"violations": [],
"file_diff": file_diff,
}

return await self.rule_preview_node.ainvoke(state)
13 changes: 13 additions & 0 deletions squire/agents/preview_rule/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import List

from unidiff.patch import PatchedFile

from squire.agents.models.state import BaseContextualState
from squire.agents.models.violation import Violation
from squire.database.postgres.models import StandardModel


class RulePreviewState(BaseContextualState):
standard: StandardModel
file_diff: PatchedFile
violations: List[Violation]
8 changes: 8 additions & 0 deletions squire/agents/utils/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unidiff import Hunk, PatchedFile

from squire.database.postgres.models import FrameworkModel
from squire.database.postgres.models.framework.standard_model import StandardModel


def format_file_diff(patched_file: PatchedFile) -> str:
Expand Down Expand Up @@ -83,3 +84,10 @@ def format_standards(framework: FrameworkModel) -> str:
)

return "\n\n".join(formatted_standards)


def format_rule(rule: StandardModel) -> str:
result = f"rule_id: {rule.id}\n"
result += f"Key: {rule.key}\n"
result += f"Description: {rule.prompt}\n"
return result
6 changes: 6 additions & 0 deletions squire/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
router as organization_framework_router,
)
from squire.api.controllers.ping_controller import router as ping_router
from squire.api.controllers.preview_rule_controller import router as preview_rule_router
from squire.api.controllers.repository import router as repository_router
from squire.api.controllers.repository_framework import (
router as repository_framework_router,
Expand Down Expand Up @@ -58,6 +59,11 @@
prefix="/organization/{org_public_id}/repository/{repo_public_id}/framework",
tags=["organization", "repository", "framework"],
)
api_router.include_router(
preview_rule_router,
prefix="/organization/{org_public_id}",
tags=["organization", "rule"],
)

app.include_router(api_router)
app.add_exception_handler(404, NotFoundHandler.handle_404)
Loading