From 3b9d9cf42a596eefff93214a21f497dd4f0234db Mon Sep 17 00:00:00 2001 From: Nathan Helmig Date: Sun, 10 May 2026 13:53:49 -0500 Subject: [PATCH] Expose run API for downstream workflows --- README.md | 4 + lua/pi/init.lua | 58 +++++++++++---- tests/test_pi_commands.lua | 145 +++++++++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 32a4e69..79384ff 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,10 @@ vim.keymap.set("v", "ai", ":PiAskSelection", { desc = "Ask pi (selec - Trims oversized context for speed instead of always sending the full file. +## API + +`pi.nvim` exposes `get_cmd()` and `run()` for programmatic use. See [this gist](https://gist.github.com/nhlmg93/49c1e5ec1e1df20b5050c770840cd7b2) for a minimal `:PiSearch` example built on `run()`. + ## License MIT diff --git a/lua/pi/init.lua b/lua/pi/init.lua index c65d2bf..c8bb3de 100644 --- a/lua/pi/init.lua +++ b/lua/pi/init.lua @@ -33,14 +33,14 @@ local function build_append_system_prompt(cfg) return table.concat(prompts, "\n\n") end -local function get_pi_cmd() +function M.get_cmd() local cfg = config.get() local binary = { "pi" } if cfg.binary then if type(cfg.binary) == "table" then binary = vim.deepcopy(cfg.binary) else - binary = { cfg.binary } + binary = { cfg.binary } end end local cmd = vim.list_extend(binary, { "--mode", "rpc", "--no-session" }) @@ -173,7 +173,15 @@ local function finish_session(session, status, opts) ui.update(session) runner.finish(session) else - reload_changed_file_buffers(session) + if session.on_done then + local ok, err = pcall(session.on_done, session) + if not ok then + vim.notify("pi on_done error: " .. tostring(err), vim.log.levels.ERROR) + end + end + if not session.skip_reload then + reload_changed_file_buffers(session) + end ui.close(session) runner.finish(session) end @@ -186,7 +194,15 @@ local function finish_session(session, status, opts) log.append_session(nil, session, session.last_message, status, session.source_path) end -local function start_session(message, build_context) +function M.run(opts) + opts = vim.deepcopy(opts or {}) + local message = opts.message + local build_context_fn = opts.build_context + local bufnr = opts.bufnr or vim.api.nvim_get_current_buf() + local cmd = opts.cmd or M.get_cmd() + local skip_reload = opts.skip_reload + local on_done = opts.on_done + if active_session then vim.notify("pi is already running, please wait", vim.log.levels.WARN) return @@ -197,16 +213,24 @@ local function start_session(message, build_context) return end - local source_bufnr = vim.api.nvim_get_current_buf() + if not build_context_fn then + build_context_fn = function() + return context.get_buffer_context(bufnr, config.get()) + end + end + + local source_bufnr = bufnr local session = session_mod.new(source_bufnr) session.file_snapshots = snapshot_loaded_file_buffers() session.last_message = message + session.skip_reload = skip_reload + session.on_done = on_done active_session = session last_session = session ui.open(session) set_status(session, "collecting_context") - local ok, built_context = pcall(build_context) + local ok, built_context = pcall(build_context_fn) if not ok then finish_session(session, "error", { error = built_context }) return @@ -219,7 +243,7 @@ local function start_session(message, build_context) set_status(session, "starting") - local process, err = runner.start(session, get_pi_cmd(), payload, { + local process, err = runner.start(session, cmd, payload, { on_event = function(event) if not active_session or active_session ~= session or session.cancelled then return @@ -294,9 +318,13 @@ function M.prompt_with_buffer() vim.ui.input({ prompt = context.format_prompt_label(bufnr, nil) }, function(input) if input then - start_session(input, function() - return context.get_buffer_context(bufnr, config.get()) - end) + M.run({ + message = input, + bufnr = bufnr, + build_context = function() + return context.get_buffer_context(bufnr, config.get()) + end, + }) end end) end @@ -311,9 +339,13 @@ function M.prompt_with_selection() local range = context.get_visual_selection_range() vim.ui.input({ prompt = context.format_prompt_label(bufnr, range) }, function(input) if input then - start_session(input, function() - return context.get_visual_context(bufnr, config.get()) - end) + M.run({ + message = input, + bufnr = bufnr, + build_context = function() + return context.get_visual_context(bufnr, config.get()) + end, + }) end end) end diff --git a/tests/test_pi_commands.lua b/tests/test_pi_commands.lua index 448915f..9188df2 100644 --- a/tests/test_pi_commands.lua +++ b/tests/test_pi_commands.lua @@ -537,6 +537,142 @@ local function test_reloaded_buffer_can_be_written_without_changed_since_reading MiniTest.expect.equality(last_notification(), nil) end +local function test_run_with_custom_context_builder() + setup_test_env() + setup_buffer({ "code" }, "/test/file.lua") + + local system = mock_system() + child.lua([[ + require("pi").run({ + message = "custom ctx", + bufnr = 0, + build_context = function() return "CUSTOM_CONTEXT" end, + }) + ]]) + flush() + + local prompt = decode_prompt(system.get_stdin()) + MiniTest.expect.equality(prompt.message:match("custom ctx"), "custom ctx") + MiniTest.expect.equality(prompt.message:match("CUSTOM_CONTEXT"), "CUSTOM_CONTEXT") +end + +local function test_run_with_custom_cmd() + setup_test_env() + setup_buffer({ "code" }, "/test/file.lua") + + local system = mock_system() + child.lua([[ + require("pi").run({ + message = "custom cmd", + cmd = { "custom-binary", "--flag" }, + build_context = function() return "ctx" end, + }) + ]]) + flush() + + local cmd = system.get_cmd() + MiniTest.expect.equality(cmd[1], "custom-binary") + MiniTest.expect.equality(cmd[2], "--flag") +end + +local function test_get_cmd_returns_default_command() + setup_test_env() + local cmd = child.lua_get([[require("pi").get_cmd()]]) + + MiniTest.expect.equality(cmd[1], "pi") + MiniTest.expect.equality(has_arg(cmd, "--mode"), 2) + MiniTest.expect.equality(has_arg(cmd, "--no-session"), 4) +end + +local function test_run_calls_on_done_before_success() + setup_test_env() + setup_buffer({ "code" }, "/test/file.lua") + + local system = mock_system() + child.lua([[ + _G.__pi_test_on_done_called = false + _G.__pi_test_on_done_session = nil + require("pi").run({ + message = "test", + build_context = function() return "ctx" end, + on_done = function(session) + _G.__pi_test_on_done_called = true + _G.__pi_test_on_done_session = session and session.status or nil + end, + }) + ]]) + system.stdout('{"type":"agent_end"}\n') + system.exit(0, 0) + + MiniTest.expect.equality(child.lua_get([[_G.__pi_test_on_done_called]]), true) + MiniTest.expect.equality(child.lua_get([[_G.__pi_test_on_done_session]]), "done") +end + +local function test_run_on_done_error_still_finishes_success() + setup_test_env() + setup_buffer({ "code" }, "/test/file.lua") + + local system = mock_system() + child.lua([[ + require("pi").run({ + message = "test", + build_context = function() return "ctx" end, + on_done = function() + error("on_done boom") + end, + }) + ]]) + system.stdout('{"type":"agent_end"}\n') + system.exit(0, 0) + + MiniTest.expect.equality(child.lua_get([[require("pi").is_running()]]), false) + MiniTest.expect.equality(child.lua_get([[require("pi")._get_last_session().status]]), "done") +end + +local function test_run_skip_reload_prevents_buffer_reload() + setup_test_env() + local file = child.lua_get([[vim.fn.tempname() .. ".lua"]]) + write_file(file, { "from disk" }) + setup_buffer({ "code" }, file) + child.lua([[vim.bo.modified = true]]) + -- skip_reload bypasses the post-success reload_changed_file_buffers() gate entirely, + -- so no loaded file-backed buffers (including this one) are reloaded. + + local system = mock_system() + child.lua([[ + require("pi").run({ + message = "test", + bufnr = vim.api.nvim_get_current_buf(), + skip_reload = true, + build_context = function() return "ctx" end, + }) + ]]) + write_file(file, { "updated on disk" }) + system.stdout('{"type":"agent_end"}\n') + system.exit(0, 0) + + MiniTest.expect.equality(child.lua_get([[vim.bo.modified]]), true) + local lines = child.lua_get([[vim.api.nvim_buf_get_lines(0, 0, -1, false)]]) + MiniTest.expect.equality(lines[1], "code") +end + +local function test_run_build_context_error_finishes_with_error() + setup_test_env() + setup_buffer({ "code" }, "/test/file.lua") + + child.lua([[ + require("pi").run({ + message = "test", + build_context = function() error("ctx boom") end, + }) + ]]) + flush() + + MiniTest.expect.equality(child.lua_get([[require("pi").is_running()]]), false) + MiniTest.expect.equality(child.lua_get([[require("pi")._get_last_session().status]]), "error") + MiniTest.expect.no_equality(last_notification().msg:match("ctx boom"), nil) +end + local T = MiniTest.new_set() T["PiAsk"] = MiniTest.new_set() @@ -568,4 +704,13 @@ T["Session"]["turn_end does not finish session (multi-turn tool use)"] = test_tu T["Session"]["turn_end followed by agent_end completes"] = test_turn_end_followed_by_agent_end_completes T["Session"]["cancel closes immediately"] = test_cancel_kills_process_and_closes_immediately +T["run API"] = MiniTest.new_set() +T["run API"]["custom context builder"] = test_run_with_custom_context_builder +T["run API"]["custom cmd"] = test_run_with_custom_cmd +T["run API"]["get_cmd returns default command"] = test_get_cmd_returns_default_command +T["run API"]["calls on_done before success"] = test_run_calls_on_done_before_success +T["run API"]["on_done error still finishes success"] = test_run_on_done_error_still_finishes_success +T["run API"]["skip_reload prevents buffer reload"] = test_run_skip_reload_prevents_buffer_reload +T["run API"]["build_context error finishes with error"] = test_run_build_context_error_finishes_with_error + return T