Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ vim.keymap.set("v", "<leader>ai", ":PiAskSelection<CR>", { desc = "Ask pi (selec
- Trims oversized context for speed instead of always sending the full file.


## API

`pi.nvim` exposes `get_cmd()` and `run()` for programmatic use. See [this gist](https://gist.github.com/nhlmg93/49c1e5ec1e1df20b5050c770840cd7b2) for a minimal `:PiSearch` example built on `run()`.

## License

MIT
58 changes: 45 additions & 13 deletions lua/pi/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ local function build_append_system_prompt(cfg)
return table.concat(prompts, "\n\n")
end

local function get_pi_cmd()
function M.get_cmd()
local cfg = config.get()
local binary = { "pi" }
if cfg.binary then
if type(cfg.binary) == "table" then
binary = vim.deepcopy(cfg.binary)
else
binary = { cfg.binary }
binary = { cfg.binary }
end
end
local cmd = vim.list_extend(binary, { "--mode", "rpc", "--no-session" })
Expand Down Expand Up @@ -173,7 +173,15 @@ local function finish_session(session, status, opts)
ui.update(session)
runner.finish(session)
else
reload_changed_file_buffers(session)
if session.on_done then
local ok, err = pcall(session.on_done, session)
if not ok then
vim.notify("pi on_done error: " .. tostring(err), vim.log.levels.ERROR)
end
end
if not session.skip_reload then
reload_changed_file_buffers(session)
end
ui.close(session)
runner.finish(session)
end
Expand All @@ -186,7 +194,15 @@ local function finish_session(session, status, opts)
log.append_session(nil, session, session.last_message, status, session.source_path)
end

local function start_session(message, build_context)
function M.run(opts)
opts = vim.deepcopy(opts or {})
local message = opts.message
local build_context_fn = opts.build_context
local bufnr = opts.bufnr or vim.api.nvim_get_current_buf()
local cmd = opts.cmd or M.get_cmd()
local skip_reload = opts.skip_reload
local on_done = opts.on_done

if active_session then
vim.notify("pi is already running, please wait", vim.log.levels.WARN)
return
Expand All @@ -197,16 +213,24 @@ local function start_session(message, build_context)
return
end

local source_bufnr = vim.api.nvim_get_current_buf()
if not build_context_fn then
build_context_fn = function()
return context.get_buffer_context(bufnr, config.get())
end
end

local source_bufnr = bufnr
local session = session_mod.new(source_bufnr)
session.file_snapshots = snapshot_loaded_file_buffers()
session.last_message = message
session.skip_reload = skip_reload
session.on_done = on_done
active_session = session
last_session = session
ui.open(session)
set_status(session, "collecting_context")

local ok, built_context = pcall(build_context)
local ok, built_context = pcall(build_context_fn)
if not ok then
finish_session(session, "error", { error = built_context })
return
Expand All @@ -219,7 +243,7 @@ local function start_session(message, build_context)

set_status(session, "starting")

local process, err = runner.start(session, get_pi_cmd(), payload, {
local process, err = runner.start(session, cmd, payload, {
on_event = function(event)
if not active_session or active_session ~= session or session.cancelled then
return
Expand Down Expand Up @@ -294,9 +318,13 @@ function M.prompt_with_buffer()

vim.ui.input({ prompt = context.format_prompt_label(bufnr, nil) }, function(input)
if input then
start_session(input, function()
return context.get_buffer_context(bufnr, config.get())
end)
M.run({
message = input,
bufnr = bufnr,
build_context = function()
return context.get_buffer_context(bufnr, config.get())
end,
})
end
end)
end
Expand All @@ -311,9 +339,13 @@ function M.prompt_with_selection()
local range = context.get_visual_selection_range()
vim.ui.input({ prompt = context.format_prompt_label(bufnr, range) }, function(input)
if input then
start_session(input, function()
return context.get_visual_context(bufnr, config.get())
end)
M.run({
message = input,
bufnr = bufnr,
build_context = function()
return context.get_visual_context(bufnr, config.get())
end,
})
end
end)
end
Expand Down
145 changes: 145 additions & 0 deletions tests/test_pi_commands.lua
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,142 @@ local function test_reloaded_buffer_can_be_written_without_changed_since_reading
MiniTest.expect.equality(last_notification(), nil)
end

local function test_run_with_custom_context_builder()
setup_test_env()
setup_buffer({ "code" }, "/test/file.lua")

local system = mock_system()
child.lua([[
require("pi").run({
message = "custom ctx",
bufnr = 0,
build_context = function() return "CUSTOM_CONTEXT" end,
})
]])
flush()

local prompt = decode_prompt(system.get_stdin())
MiniTest.expect.equality(prompt.message:match("custom ctx"), "custom ctx")
MiniTest.expect.equality(prompt.message:match("CUSTOM_CONTEXT"), "CUSTOM_CONTEXT")
end

local function test_run_with_custom_cmd()
setup_test_env()
setup_buffer({ "code" }, "/test/file.lua")

local system = mock_system()
child.lua([[
require("pi").run({
message = "custom cmd",
cmd = { "custom-binary", "--flag" },
build_context = function() return "ctx" end,
})
]])
flush()

local cmd = system.get_cmd()
MiniTest.expect.equality(cmd[1], "custom-binary")
MiniTest.expect.equality(cmd[2], "--flag")
end

local function test_get_cmd_returns_default_command()
setup_test_env()
local cmd = child.lua_get([[require("pi").get_cmd()]])

MiniTest.expect.equality(cmd[1], "pi")
MiniTest.expect.equality(has_arg(cmd, "--mode"), 2)
MiniTest.expect.equality(has_arg(cmd, "--no-session"), 4)
end

local function test_run_calls_on_done_before_success()
setup_test_env()
setup_buffer({ "code" }, "/test/file.lua")

local system = mock_system()
child.lua([[
_G.__pi_test_on_done_called = false
_G.__pi_test_on_done_session = nil
require("pi").run({
message = "test",
build_context = function() return "ctx" end,
on_done = function(session)
_G.__pi_test_on_done_called = true
_G.__pi_test_on_done_session = session and session.status or nil
end,
})
]])
system.stdout('{"type":"agent_end"}\n')
system.exit(0, 0)

MiniTest.expect.equality(child.lua_get([[_G.__pi_test_on_done_called]]), true)
MiniTest.expect.equality(child.lua_get([[_G.__pi_test_on_done_session]]), "done")
end

local function test_run_on_done_error_still_finishes_success()
setup_test_env()
setup_buffer({ "code" }, "/test/file.lua")

local system = mock_system()
child.lua([[
require("pi").run({
message = "test",
build_context = function() return "ctx" end,
on_done = function()
error("on_done boom")
end,
})
]])
system.stdout('{"type":"agent_end"}\n')
system.exit(0, 0)

MiniTest.expect.equality(child.lua_get([[require("pi").is_running()]]), false)
MiniTest.expect.equality(child.lua_get([[require("pi")._get_last_session().status]]), "done")
end

local function test_run_skip_reload_prevents_buffer_reload()
setup_test_env()
local file = child.lua_get([[vim.fn.tempname() .. ".lua"]])
write_file(file, { "from disk" })
setup_buffer({ "code" }, file)
child.lua([[vim.bo.modified = true]])
-- skip_reload bypasses the post-success reload_changed_file_buffers() gate entirely,
-- so no loaded file-backed buffers (including this one) are reloaded.

local system = mock_system()
child.lua([[
require("pi").run({
message = "test",
bufnr = vim.api.nvim_get_current_buf(),
skip_reload = true,
build_context = function() return "ctx" end,
})
]])
write_file(file, { "updated on disk" })
system.stdout('{"type":"agent_end"}\n')
system.exit(0, 0)

MiniTest.expect.equality(child.lua_get([[vim.bo.modified]]), true)
local lines = child.lua_get([[vim.api.nvim_buf_get_lines(0, 0, -1, false)]])
MiniTest.expect.equality(lines[1], "code")
end

local function test_run_build_context_error_finishes_with_error()
setup_test_env()
setup_buffer({ "code" }, "/test/file.lua")

child.lua([[
require("pi").run({
message = "test",
build_context = function() error("ctx boom") end,
})
]])
flush()

MiniTest.expect.equality(child.lua_get([[require("pi").is_running()]]), false)
MiniTest.expect.equality(child.lua_get([[require("pi")._get_last_session().status]]), "error")
MiniTest.expect.no_equality(last_notification().msg:match("ctx boom"), nil)
end

local T = MiniTest.new_set()

T["PiAsk"] = MiniTest.new_set()
Expand Down Expand Up @@ -568,4 +704,13 @@ T["Session"]["turn_end does not finish session (multi-turn tool use)"] = test_tu
T["Session"]["turn_end followed by agent_end completes"] = test_turn_end_followed_by_agent_end_completes
T["Session"]["cancel closes immediately"] = test_cancel_kills_process_and_closes_immediately

T["run API"] = MiniTest.new_set()
T["run API"]["custom context builder"] = test_run_with_custom_context_builder
T["run API"]["custom cmd"] = test_run_with_custom_cmd
T["run API"]["get_cmd returns default command"] = test_get_cmd_returns_default_command
T["run API"]["calls on_done before success"] = test_run_calls_on_done_before_success
T["run API"]["on_done error still finishes success"] = test_run_on_done_error_still_finishes_success
T["run API"]["skip_reload prevents buffer reload"] = test_run_skip_reload_prevents_buffer_reload
T["run API"]["build_context error finishes with error"] = test_run_build_context_error_finishes_with_error

return T
Loading