From 5e7c25c155e3c3e9c77060c6b9ce9c0d65513efb Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 13 May 2026 16:40:01 -0400 Subject: [PATCH] [INITIAL] Add ghstack pull command with merge conflict resolution Implement `ghstack pull` as a new subcommand that pulls remote updates for a ghstack PR into the local working tree. The command resolves the PR from HEAD's commit message or an explicit PR argument, fetches the remote head/orig refs, and merges remote changes with local changes. Key behaviors: - Fast-forward: if the local commit is an ancestor of the remote orig, simply check out the remote orig directly. - Clean merge: use `git merge-tree --write-tree` to produce a merged tree without touching the worktree, then create a new orig commit with the merged tree and updated ghstack-source-id. - Conflict: fall back to `git merge-recursive` to materialize conflicts in the worktree, save state to `.git/GHSTACK_PULL`, and instruct the user to resolve and run `ghstack pull --continue`. - Continue: `--continue` validates that conflicts are resolved and all changes are staged, then writes the merged tree and checks out the result. The merge strategy works by finding the remote head commit whose tree matches the local ghstack-source-id (establishing the common ancestor), creating an imputed head commit from the local tree, and merging that against the remote head. Includes three test scenarios: basic non-conflicting pull, conflicting pull with manual resolution via --continue, and explicit PR argument with fast-forward. [ghstack-poisoned] --- src/ghstack/cli.py | 29 +++ src/ghstack/pull.py | 269 +++++++++++++++++++++ src/ghstack/test_prelude.py | 14 ++ test/pull/basic.py.test | 26 ++ test/pull/conflict.py.test | 32 +++ test/pull/explicit_pr_fast_forward.py.test | 14 ++ 6 files changed, 384 insertions(+) create mode 100644 src/ghstack/pull.py create mode 100644 test/pull/basic.py.test create mode 100644 test/pull/conflict.py.test create mode 100644 test/pull/explicit_pr_fast_forward.py.test diff --git a/src/ghstack/cli.py b/src/ghstack/cli.py index c762cc2..e2c196a 100644 --- a/src/ghstack/cli.py +++ b/src/ghstack/cli.py @@ -15,6 +15,7 @@ import ghstack.land import ghstack.log import ghstack.logs +import ghstack.pull import ghstack.rage import ghstack.status import ghstack.submit @@ -282,6 +283,34 @@ def checkout(same_base: bool, pull_request: str) -> None: ) +@main.command("pull") +@click.option( + "--continue", + "continue_", + is_flag=True, + help="Finish a ghstack pull after resolving conflicts", +) +@click.argument("pull_request", metavar="PR", required=False) +def pull(continue_: bool, pull_request: Optional[str]) -> None: + """ + Pull remote updates for a ghstack PR + """ + with cli_context(request_github_token=False) as (shell, config, github): + run_async( + run_with_github( + github, + ghstack.pull.main( + pull_request=pull_request, + github=github, + sh=shell, + remote_name=config.remote_name, + github_url=config.github_url, + continue_=continue_, + ), + ) + ) + + @main.command("cherry-pick") @click.option( "--stack", diff --git a/src/ghstack/pull.py b/src/ghstack/pull.py new file mode 100644 index 0000000..4c370d4 --- /dev/null +++ b/src/ghstack/pull.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 + +import asyncio +import json +import os +import re +from typing import Any, Dict, List, Optional, Tuple + +import ghstack.checkout +import ghstack.diff +import ghstack.github +import ghstack.github_utils +import ghstack.shell +import ghstack.submit + + +async def _run_git_for_status( + sh: ghstack.shell.Shell, args: List[str] +) -> Tuple[int, str]: + ghstack.shell.log_command(["git", *args]) + proc = await asyncio.create_subprocess_exec( + "git", + *args, + cwd=sh.cwd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + out, _ = await proc.communicate() + assert proc.returncode is not None + return proc.returncode, out.decode(errors="backslashreplace") + + +async def _resolve_params( + *, + pull_request: Optional[str], + github_url: str, + sh: ghstack.shell.Shell, + remote_name: str, +) -> ghstack.github_utils.GitHubPullRequestParams: + if pull_request is not None: + return await ghstack.github_utils.parse_pull_request( + pull_request, sh=sh, remote_name=remote_name + ) + + commit_msg = await sh.agit("log", "-1", "--format=%B", "HEAD") + pr = ghstack.diff.PullRequestResolved.search(commit_msg, github_url) + if pr is None: + raise RuntimeError( + "HEAD commit is not associated with a ghstack pull request " + "(no Pull-Request trailer found). Check out the commit for the " + "PR you want to pull, or pass the PR explicitly." + ) + return { + "github_url": pr.github_url, + "owner": pr.owner, + "name": pr.repo, + "number": pr.number, + } + + +def _replace_source_id(commit_msg: str, source_id: str) -> str: + line = f"ghstack-source-id: {source_id}\n" + if ghstack.submit.RE_GHSTACK_SOURCE_ID.search(commit_msg) is None: + return commit_msg.rstrip() + "\n" + line + return ghstack.submit.RE_GHSTACK_SOURCE_ID.sub(line, commit_msg) + + +async def _state_path(sh: ghstack.shell.Shell) -> str: + path = await sh.agit("rev-parse", "--git-path", "GHSTACK_PULL") + return path if os.path.isabs(path) else sh.abspath(path) + + +async def _read_state(sh: ghstack.shell.Shell) -> Dict[str, Any]: + path = await _state_path(sh) + if not os.path.exists(path): + raise RuntimeError("No ghstack pull conflict in progress.") + with open(path, encoding="utf-8") as f: + return json.load(f) + + +async def _write_state(sh: ghstack.shell.Shell, state: Dict[str, Any]) -> None: + path = await _state_path(sh) + with open(path, "w", encoding="utf-8") as f: + json.dump(state, f) + f.write("\n") + + +async def _clear_state(sh: ghstack.shell.Shell) -> None: + path = await _state_path(sh) + try: + os.unlink(path) + except FileNotFoundError: + pass + + +async def _find_head_with_tree( + sh: ghstack.shell.Shell, *, remote_head: str, tree: str +) -> str: + log = await sh.agit("log", "--first-parent", "--format=%H %T", remote_head) + for line in log.splitlines(): + commit, commit_tree = line.split() + if commit_tree == tree: + return commit + raise RuntimeError( + "Could not find the previously checked out ghstack head commit. " + "The local ghstack-source-id does not appear in the remote head history." + ) + + +async def _is_worktree_clean(sh: ghstack.shell.Shell) -> bool: + return bool( + await sh.agit("diff", "--quiet", exitcode=True) + and await sh.agit("diff", "--cached", "--quiet", exitcode=True) + ) + + +async def _finish_pull(sh: ghstack.shell.Shell, state: Dict[str, Any]) -> None: + unmerged = await sh.agit("ls-files", "-u") + if unmerged: + raise RuntimeError( + "There are still unresolved merge conflicts. Resolve them and run " + "`ghstack pull --continue` again." + ) + if not await sh.agit("diff", "--quiet", exitcode=True): + raise RuntimeError( + "There are unstaged changes. Stage the resolved files with `git add`, " + "then run `ghstack pull --continue` again." + ) + + merged_tree = await sh.agit("write-tree") + pulled_commit_msg = _replace_source_id( + state["commit_msg"], state["remote_source_id"] + ) + pulled_orig = await sh.agit( + "commit-tree", + "-p", + state["parent"], + merged_tree, + input=pulled_commit_msg, + env={ + "GIT_AUTHOR_NAME": state["author_name"], + "GIT_AUTHOR_EMAIL": state["author_email"], + }, + ) + await sh.agit("checkout", pulled_orig) + await _clear_state(sh) + + +async def main( + github: ghstack.github.GitHubEndpoint, + sh: ghstack.shell.Shell, + remote_name: str, + github_url: str, + pull_request: Optional[str] = None, + continue_: bool = False, +) -> None: + if continue_: + await _finish_pull(sh, await _read_state(sh)) + return + + params = await _resolve_params( + pull_request=pull_request, + github_url=github_url, + sh=sh, + remote_name=remote_name, + ) + head_ref = await github.get_head_ref(**params) + orig_ref = re.sub(r"/head$", "/orig", head_ref) + if orig_ref == head_ref: + raise RuntimeError(f"The ref {head_ref} doesn't look like a ghstack reference") + + await ghstack.checkout._fetch_refs( + sh, remote_name=remote_name, refs=[head_ref, orig_ref] + ) + remote_head = f"{remote_name}/{head_ref}" + remote_orig = f"{remote_name}/{orig_ref}" + + if await sh.agit("merge-base", "--is-ancestor", "HEAD", remote_orig, exitcode=True): + await sh.agit("checkout", remote_orig) + await _clear_state(sh) + return + + state_path = await _state_path(sh) + if os.path.exists(state_path): + raise RuntimeError( + "A ghstack pull conflict is already in progress. Resolve it and run " + "`ghstack pull --continue`." + ) + + if not await _is_worktree_clean(sh): + raise RuntimeError( + "Working tree has uncommitted changes; commit or stash them first." + ) + + local_commit_msg = await sh.agit("log", "-1", "--format=%B", "HEAD") + m_local_source_id = ghstack.submit.RE_GHSTACK_SOURCE_ID.search(local_commit_msg) + if m_local_source_id is None: + raise RuntimeError( + "HEAD has no ghstack-source-id trailer, so ghstack cannot determine " + "which remote head version your local changes are based on." + ) + local_source_id = m_local_source_id.group(1) + + old_head = await _find_head_with_tree( + sh, remote_head=remote_head, tree=local_source_id + ) + local_tree = await sh.agit("rev-parse", "HEAD^{tree}") + local_imputed_head = await sh.agit( + "commit-tree", + "-p", + old_head, + local_tree, + input="Local changes for ghstack pull\n\n[ghstack-poisoned]\n", + ) + + returncode, merge_tree_output = await _run_git_for_status( + sh, + ["merge-tree", "--write-tree", "--messages", remote_head, local_imputed_head], + ) + merged_tree = merge_tree_output.splitlines()[0] if returncode == 0 else None + + remote_orig_commit_msg = await sh.agit("log", "-1", "--format=%B", remote_orig) + m_remote_source_id = ghstack.submit.RE_GHSTACK_SOURCE_ID.search( + remote_orig_commit_msg + ) + remote_source_id = ( + m_remote_source_id.group(1) + if m_remote_source_id is not None + else await sh.agit("rev-parse", f"{remote_orig}^{{tree}}") + ) + remote_orig_parent = await sh.agit("rev-parse", f"{remote_orig}^") + + author_name = await sh.agit("log", "-1", "--format=%an", "HEAD") + author_email = await sh.agit("log", "-1", "--format=%ae", "HEAD") + state = { + "parent": remote_orig_parent, + "remote_source_id": remote_source_id, + "commit_msg": local_commit_msg, + "author_name": author_name, + "author_email": author_email, + } + + if returncode != 0: + await _write_state(sh, state) + recursive_returncode, recursive_output = await _run_git_for_status( + sh, ["merge-recursive", old_head, "--", local_imputed_head, remote_head] + ) + if recursive_returncode == 0: + await _finish_pull(sh, state) + return + raise RuntimeError( + "Automatic ghstack pull merge failed. Resolve the conflicts, then run " + "`ghstack pull --continue`.\n" + recursive_output + ) + + pulled_commit_msg = _replace_source_id(local_commit_msg, remote_source_id) + assert merged_tree is not None + pulled_orig = await sh.agit( + "commit-tree", + "-p", + remote_orig_parent, + merged_tree, + input=pulled_commit_msg, + env={ + "GIT_AUTHOR_NAME": author_name, + "GIT_AUTHOR_EMAIL": author_email, + }, + ) + await sh.agit("checkout", pulled_orig) diff --git a/src/ghstack/test_prelude.py b/src/ghstack/test_prelude.py index 96bf1a5..69ea201 100644 --- a/src/ghstack/test_prelude.py +++ b/src/ghstack/test_prelude.py @@ -32,6 +32,7 @@ import ghstack.github_utils import ghstack.land import ghstack.log +import ghstack.pull import ghstack.shell import ghstack.submit import ghstack.sync @@ -49,6 +50,7 @@ "gh_cherry_pick", "gh_checkout", "gh_log", + "gh_pull", "gh_sync", "GitCommitHash", "checkout", @@ -307,6 +309,18 @@ async def gh_log(pull_request: Optional[str] = None, args: Sequence[str] = ()) - ) +async def gh_pull(pull_request: Optional[str] = None, continue_: bool = False) -> None: + self = CTX + return await ghstack.pull.main( + github=self.github, + sh=self.sh, + remote_name="origin", + github_url="github.com", + pull_request=pull_request, + continue_=continue_, + ) + + async def gh_sync() -> GitCommitHash: self = CTX return await ghstack.sync.main( diff --git a/test/pull/basic.py.test b/test/pull/basic.py.test new file mode 100644 index 0000000..e224981 --- /dev/null +++ b/test/pull/basic.py.test @@ -0,0 +1,26 @@ +from ghstack.test_prelude import * + +await init_test() + +await commit("A") +(A,) = await gh_submit("Initial") +old_orig = A.orig + +await write_file_and_add("remote.txt", "remote change") +await git("commit", "--amend", "--no-edit") +await gh_submit("Remote update") + +await checkout(old_orig) +await write_file_and_add("local.txt", "local change") +await git("commit", "--amend", "--no-edit") + +await gh_pull() + +assert_eq(await git("show", "HEAD:remote.txt"), "remote change") +assert_eq(await git("show", "HEAD:local.txt"), "local change") + +# The pulled commit records that it is based on the latest remote orig, so a +# normal submit should not need --force. +await gh_submit("Local update") + +ok() diff --git a/test/pull/conflict.py.test b/test/pull/conflict.py.test new file mode 100644 index 0000000..4d144da --- /dev/null +++ b/test/pull/conflict.py.test @@ -0,0 +1,32 @@ +from ghstack.test_prelude import * + +await init_test() + +await commit("A") +(A,) = await gh_submit("Initial") +old_orig = A.orig + +await write_file_and_add("A.txt", "remote change") +await git("commit", "--amend", "--no-edit") +await gh_submit("Remote update") + +await checkout(old_orig) +await write_file_and_add("A.txt", "local change") +await git("commit", "--amend", "--no-edit") + +await assert_raises(RuntimeError, gh_pull) + +status = await git("status", "--porcelain") +assert "UU A.txt" in status +contents = await git("show", ":2:A.txt") +assert_eq(contents, "local change") +contents = await git("show", ":3:A.txt") +assert_eq(contents, "remote change") + +await write_file_and_add("A.txt", "resolved change") +await gh_pull(continue_=True) + +assert_eq(await git("show", "HEAD:A.txt"), "resolved change") +await gh_submit("Resolved update") + +ok() diff --git a/test/pull/explicit_pr_fast_forward.py.test b/test/pull/explicit_pr_fast_forward.py.test new file mode 100644 index 0000000..d2bc08a --- /dev/null +++ b/test/pull/explicit_pr_fast_forward.py.test @@ -0,0 +1,14 @@ +from ghstack.test_prelude import * + +await init_test() + +await commit("A") +(A,) = await gh_submit("Initial") + +await git("checkout", "main") +await gh_pull(f"https://github.com/pytorch/pytorch/pull/{A.number}") + +current_log = await git("log", "--oneline", "-n", "1") +assert "Commit A" in current_log + +ok()