Skip to content

Commit 86b1ff1

Browse files
fix(chat): watchers use the same diff format as in patch (olimorris#1655)
1 parent a6e226c commit 86b1ff1

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

lua/codecompanion/strategies/chat/watchers.lua

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ end
105105
---@param old_content table
106106
---@param new_content table
107107
---@return string
108-
local function format_changes_as_diff(old_content, new_content)
108+
function Watchers:format_changes_as_diff(old_content, new_content)
109109
-- Convert line arrays to strings for vim.diff
110110
local old_str = table.concat(old_content, "\n") .. "\n"
111111
local new_str = table.concat(new_content, "\n") .. "\n"
@@ -116,7 +116,9 @@ local function format_changes_as_diff(old_content, new_content)
116116
algorithm = "myers",
117117
})
118118
if diff_result and diff_result ~= "" then
119-
return fmt("```diff\n%s```", diff_result)
119+
-- replace line numbers in diff to keep a common format
120+
diff_result = diff_result:gsub("^@@ .+ @@\n", "@@\n")
121+
return fmt("```\n%s```", diff_result)
120122
end
121123

122124
return ""
@@ -132,7 +134,7 @@ function Watchers:check_for_changes(chat)
132134
if has_changed and old_content then
133135
local filename = vim.fn.fnamemodify(api.nvim_buf_get_name(ref.bufnr), ":.")
134136
local current_content = api.nvim_buf_get_lines(ref.bufnr, 0, -1, false)
135-
local diff_content = format_changes_as_diff(old_content, current_content)
137+
local diff_content = self:format_changes_as_diff(old_content, current_content)
136138

137139
if diff_content ~= "" then
138140
local delta = fmt("The file `%s`, has been modified. Here are the changes:\n%s", filename, diff_content)

tests/strategies/chat/test_watcher.lua

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,4 +310,12 @@ T["Watchers"]["doesn't watch invalid buffers"] = function()
310310
h.eq(old_content, nil)
311311
end
312312

313+
T["Watchers"]["format_changes_as_diff returns correct unified diff"] = function()
314+
local old = { "one", "two", "three" }
315+
local new = { "one", "TWO", "three", "four" }
316+
local watcher = Watcher.new()
317+
local diff = watcher:format_changes_as_diff(old, new)
318+
h.eq(diff, "```\n@@\n one\n-two\n+TWO\n three\n+four\n```")
319+
end
320+
313321
return T

0 commit comments

Comments
 (0)