Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions forklet/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import asyncio
import hashlib
from pathlib import Path
from typing import List, Optional, Dict, Set, Tuple, Callable, Any
from dataclasses import dataclass
Expand All @@ -15,6 +16,7 @@
ProgressInfo,
DownloadStatus,
GitHubFile,
VerificationMethod,
)
from ..services import GitHubAPIService, DownloadService
from .filter import FilterEngine
Expand Down Expand Up @@ -348,6 +350,42 @@ async def _download_single_file(
self.progress_tracker.update_file_progress(bytes_written, file.path)
self.progress_tracker.complete_file()

# Verify integrity if requested
if (
request.verify_integrity
and request.verification_method != VerificationMethod.NONE
):
verified = False
try:
if request.verification_method == VerificationMethod.GIT_BLOB_SHA1:
# Verify using Git blob SHA1
verified = await self._verify_git_blob_sha1(
target_path, file.sha
)
elif request.verification_method == VerificationMethod.SIZE:
# Verify using file size
actual_size = await asyncio.to_thread(
lambda: target_path.stat().st_size
)
verified = actual_size == file.size
# Add other methods as needed
except Exception as e:
logger.warning(
f"Integrity verification failed for {file.path}: {e}"
)
verified = False

if verified:
self.progress_tracker.verified_files.append(file.path)
logger.debug(f"Integrity verified for {file.path}")
else:
self.progress_tracker.verification_failures[file.path] = (
"Integrity verification failed"
)
logger.warning(f"Integrity verification failed for {file.path}")
# Treat verification failure as a failure? For now, we'll still return the bytes but track it.
# Optionally we could delete the file and return None to trigger retry.

logger.debug(f"Downloaded {file.path} ({bytes_written} bytes)")
return bytes_written

Expand All @@ -356,6 +394,34 @@ async def _download_single_file(
self.progress_tracker.add_failed_file(file.path, str(e))
raise

async def _verify_git_blob_sha1(self, file_path: Path, expected_sha: str) -> bool:
"""
Verify a file's SHA-1 hash matches the expected Git blob SHA-1.

Git blob SHA-1 is computed as: SHA1("blob " + <size> + "\0" + <content>)

Args:
file_path: Path to the file to verify
expected_sha: Expected SHA-1 hash

Returns:
True if verification passes, False otherwise
"""
try:
# Read file content
content = await asyncio.to_thread(lambda: file_path.read_bytes())

# Create Git blob header: "blob <size>\0"
header = f"blob {len(content)}\0".encode("utf-8")

# Calculate SHA1 of header + content
sha1_hash = hashlib.sha1(header + content).hexdigest()

return sha1_hash == expected_sha
except Exception as e:
logger.debug(f"Git blob SHA1 verification failed for {file_path}: {e}")
return False

# Delegate methods to state controller for external control
def cancel(self) -> Optional[DownloadResult]:
"""
Expand Down
31 changes: 31 additions & 0 deletions forklet/core/progress_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,42 @@ class ProgressTracker:
# File tracking sets
_completed_files: Set[str] = field(default_factory=set)
_failed_files: Dict[str, str] = field(default_factory=dict)
_verified_files: Set[str] = field(default_factory=set)
_verification_failures: Dict[str, str] = field(default_factory=dict)
_skipped_count: int = 0

# Matched files for reporting (populated by orchestrator)
matched_files: List[str] = field(default_factory=list)

def add_verified_file(self, file_path: str) -> None:
"""Add a successfully verified file to tracking."""
self._verified_files.add(file_path)

def add_verification_failure(self, file_path: str, error: str) -> None:
"""Add a verification failure to tracking."""
self._verification_failures[file_path] = error

def get_verification_results(self) -> tuple[List[str], Dict[str, str]]:
"""
Get verification results.

Returns:
Tuple of (verified_files, verification_failures)
"""
return list(self._verified_files), dict(self._verification_failures)

def reset(self) -> None:
"""Reset all tracking state."""
self.progress = ProgressInfo(
total_files=0, downloaded_files=0, total_bytes=0, downloaded_bytes=0
)
self._completed_files.clear()
self._failed_files.clear()
self._verified_files.clear()
self._verification_failures.clear()
self._skipped_count = 0
self.matched_files.clear()

def update_file_progress(
self, bytes_downloaded: int, current_file: Optional[str] = None
) -> None:
Expand Down
2 changes: 2 additions & 0 deletions forklet/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .download import (
DownloadStrategy,
DownloadStatus,
VerificationMethod,
FilterCriteria,
DownloadRequest,
FileDownloadInfo,
Expand All @@ -32,6 +33,7 @@
# Download models
"DownloadStrategy",
"DownloadStatus",
"VerificationMethod",
"FilterCriteria",
"DownloadRequest",
"FileDownloadInfo",
Expand Down
18 changes: 18 additions & 0 deletions forklet/models/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ class DownloadStatus(Enum):
PAUSED = "paused"


class VerificationMethod(Enum):
"""Methods for verifying file integrity."""

NONE = "none"
GIT_BLOB_SHA1 = "git_blob_sha1"
SIZE = "size"


@dataclass
class FilterCriteria:
"""Flexible filtering criteria for repository content."""
Expand Down Expand Up @@ -108,6 +116,10 @@ class DownloadRequest:
timeout: int = 300
stream_threshold: int = 10 * 1024 * 1024 # 10 MB default

# Integrity options
verify_integrity: bool = False
verification_method: VerificationMethod = VerificationMethod.GIT_BLOB_SHA1

# Authentication
token: Optional[str] = None

Expand All @@ -131,6 +143,8 @@ def __post_init__(self) -> None:
raise ValueError("timeout must be positive")
if self.stream_threshold < 0:
raise ValueError("stream_threshold must be non-negative")
if self.stream_threshold < 0:
raise ValueError("stream_threshold must be non-negative")


@dataclass
Expand Down Expand Up @@ -206,6 +220,10 @@ class DownloadResult:
# Matched file paths (populated by orchestrator for verbose reporting)
matched_files: List[str] = field(default_factory=list)

# Integrity verification results
verified_files: List[str] = field(default_factory=list)
verification_failures: Dict[str, str] = field(default_factory=dict)

# Metadata
started_at: datetime = field(default_factory=datetime.now)
completed_at: Optional[datetime] = None
Expand Down
Loading