From 7caa02c2f26670a20d68b2f446b6b0fb7fb0e71a Mon Sep 17 00:00:00 2001 From: Tomas Slusny Date: Thu, 16 Apr 2026 10:28:37 +0200 Subject: [PATCH] feat(trusted-tools): auto-execute trusted tool calls Add `trusted_tools` config option to allow automatic execution of trusted tool calls without manual approval. By default, all tool calls require approval, but users can now trust specific tools, groups, or all tools. Update README and type annotations to document the new behavior. Trusted tools are determined by function definition, group, or name. Closes #1534 --- README.md | 176 ++++++++++++++++++--------- lua/CopilotChat/config.lua | 2 + lua/CopilotChat/config/functions.lua | 1 + lua/CopilotChat/config/mappings.lua | 23 ---- lua/CopilotChat/init.lua | 137 +++++++++++++++++++-- lua/CopilotChat/prompts.lua | 144 ++++++++++++++-------- lua/CopilotChat/ui/chat.lua | 23 ++-- lua/CopilotChat/ui/overlay.lua | 10 +- 8 files changed, 359 insertions(+), 157 deletions(-) diff --git a/README.md b/README.md index 9a54c75e..b37e2cfe 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ https://github.com/user-attachments/assets/8cad5643-63b2-4641-a5c4-68bc313f20e6 CopilotChat.nvim brings GitHub Copilot Chat capabilities directly into Neovim with a focus on transparency and user control. - 🤖 **Multiple AI Models** - GitHub Copilot (including GPT-4o, Gemini 2.5 Pro, Claude 4 Sonnet, Claude 3.7 Sonnet, Claude 3.5 Sonnet, o3-mini, o4-mini) + custom providers (Ollama, Mistral.ai). The exact list of available models depends on your [GitHub Copilot settings](https://github.com/settings/copilot/features) and the models provided by GitHub's API. -- 🔧 **Tool Calling** - LLM can call workspace functions (file reading, git operations, search) with your explicit approval +- 🔧 **Tool Calling** - LLM can call workspace functions (file reading, git operations, search) with manual approval or automatic execution for trusted tools - 🔒 **Privacy First** - Only shares what you explicitly request - no background data collection - 📝 **Interactive Chat** - Interactive UI with completion, diffs, and quickfix integration - 🎯 **Smart Prompts** - Composable templates and sticky prompts for consistent context @@ -92,7 +92,7 @@ EOF # Core Concepts - **Resources** (`#`) - Add specific content (files, git diffs, URLs) to your prompt -- **Tools** (`@`) - Give LLM access to functions it can call with your approval +- **Tools** (`@`) - Give LLM access to functions it can call during the chat, with manual approval by default - **Sticky Prompts** (`> `) - Persist context across single chat session - **Models** (`$`) - Specify which AI model to use for the chat - **Prompts** (`/PromptName`) - Use predefined prompt templates for common tasks @@ -114,7 +114,15 @@ EOF > You are a helpful coding assistant ``` -When you use `@copilot`, the LLM can call functions like `bash`, `edit`, `file`, `glob`, `grep`, `gitdiff` etc. You'll see the proposed function call and can approve/reject it before execution. +When you use `@copilot`, the LLM can call functions from the `copilot` group such as `bash`, `edit`, `file`, `glob`, `grep`, and `gitdiff`. + +- By default, proposed tool calls wait for your approval. +- You can configure `trusted_tools` to automatically run specific tools or groups. +- Resources added with `#...` are resolved immediately and shared as context. +- Tool call results are sent back to the model as plain output, while manual resources keep their `##` references in chat. + +> [!WARNING] +> `trusted_tools = true` allows the model to run every enabled tool without asking. Only use it if you fully trust the tool set and workspace. # Usage @@ -136,21 +144,20 @@ When you use `@copilot`, the LLM can call functions like `bash`, `edit`, `file`, ## Chat Key Mappings -| Insert | Normal | Action | -| ------- | ------- | ------------------------------------------ | -| `` | - | Trigger/accept completion menu for tokens | -| `` | `q` | Close the chat window | -| `` | `` | Reset and clear the chat window | -| `` | `` | Submit the current prompt | -| - | `grr` | Toggle sticky prompt for line under cursor | -| `` | `` | Accept nearest diff | -| - | `gj` | Jump to section of nearest diff | -| - | `gqa` | Add all answers from chat to quickfix list | -| - | `gqd` | Add all diffs from chat to quickfix list | -| - | `gy` | Yank nearest diff to register | -| - | `gd` | Show diff between source and nearest diff | -| - | `gc` | Show info about current chat | -| - | `gh` | Show help message | +| Insert | Normal | Action | +| ------- | ------- | ----------------------------------------- | +| `` | - | Trigger/accept completion menu for tokens | +| `` | `q` | Close the chat window | +| `` | `` | Reset and clear the chat window | +| `` | `` | Submit the current prompt | +| `` | `` | Accept nearest diff | +| - | `gj` | Jump to section of nearest diff | +| - | `gqa` | Add all answers from chat to quickfix | +| - | `gqd` | Add all diffs from chat to quickfix | +| - | `gy` | Yank nearest diff to register | +| - | `gd` | Show diff between source and nearest diff | +| - | `gc` | Show info about current chat | +| - | `gh` | Show help message | > [!WARNING] > Some plugins (e.g. `copilot.vim`) may also map common keys like `` in insert mode. @@ -167,23 +174,24 @@ When you use `@copilot`, the LLM can call functions like `bash`, `edit`, `file`, All predefined functions belong to the `copilot` group. -| Function | Type | Description | Example Usage | -| ----------- | -------- | ------------------------------------------------------ | -------------------- | -| `bash` | Tool | Executes a bash command and returns output | `@copilot` only | -| `buffer` | Resource | Retrieves content from buffer(s) with diagnostics | `#buffer:active` | -| `clipboard` | Resource | Provides access to system clipboard content | `#clipboard` | -| `edit` | Tool | Applies a unified diff to a file | `@copilot` only | -| `file` | Resource | Reads content from a specified file path | `#file:path/to/file` | -| `gitdiff` | Resource | Retrieves git diff information | `#gitdiff:staged` | -| `glob` | Resource | Lists filenames matching a pattern in workspace | `#glob:**/*.lua` | -| `grep` | Resource | Searches for a pattern across files in workspace | `#grep:TODO` | -| `selection` | Resource | Includes the current visual selection with diagnostics | `#selection` | -| `url` | Resource | Fetches content from a specified URL | `#url:https://...` | +| Function | Manual `#...` | Description | Example Usage | +| ----------- | ------------- | ------------------------------------------------------ | -------------------- | +| `bash` | No | Executes a bash command and returns output | `@copilot` | +| `buffer` | Yes | Retrieves content from buffer(s) with diagnostics | `#buffer:active` | +| `clipboard` | Yes | Provides access to system clipboard content | `#clipboard` | +| `edit` | No | Applies a unified diff to a file | `@copilot` | +| `file` | Yes | Reads content from a specified file path | `#file:path/to/file` | +| `gitdiff` | Yes | Retrieves git diff information | `#gitdiff:staged` | +| `glob` | Yes | Lists filenames matching a pattern in workspace | `#glob:**/*.lua` | +| `grep` | Yes | Searches for a pattern across files in workspace | `#grep:TODO` | +| `selection` | Yes | Includes the current visual selection with diagnostics | `#selection` | +| `url` | Yes | Fetches content from a specified URL | `#url:https://...` | + +`#...` resolves a function immediately and adds its output as chat context. -**Type Legend:** +`@copilot` shares the enabled functions with the model so it can choose when to call them. -- **Resource**: Can be used manually via `#function` syntax -- **Tool**: Can only be called by LLM via `@copilot` (for safety/complexity reasons) +Only `bash` and `edit` are tool-only. The rest can be used both as manual resources and as callable tools. ## Predefined Prompts @@ -209,6 +217,7 @@ Most users only need to configure a few options: { model = 'gpt-4.1', -- AI model to use temperature = 0.1, -- Lower = focused, higher = creative + trusted_tools = nil, -- Require approval for all tool calls window = { layout = 'vertical', -- 'vertical', 'horizontal', 'float' width = 0.5, -- 50% of screen width @@ -241,12 +250,14 @@ Most users only need to configure a few options: } ``` +`window.layout` also supports `'replace'` to reuse the current window. + ## Buffer Behavior ```lua -- Auto-command to customize chat buffer behavior vim.api.nvim_create_autocmd('BufEnter', { - pattern = 'copilot-*', + pattern = 'copilot-chat', callback = function() vim.opt_local.relativenumber = false vim.opt_local.number = false @@ -278,6 +289,7 @@ Types of copilot highlights: - `CopilotChatModel` - Model highlight in chat buffer (e.g. `$gpt-4.1`) - `CopilotChatUri` - URI highlight in chat buffer (e.g. `##https://...`) - `CopilotChatAnnotation` - Annotation highlight in chat buffer (file headers, tool call headers, tool call body) +- `CopilotChatAnnotationHeader` - Annotation header highlight in chat buffer ## Prompts @@ -304,14 +316,44 @@ Define your own prompts in the configuration: ## Functions +Use `trusted_tools` to control which tool calls are executed automatically: + +```lua +{ + trusted_tools = nil, -- default: require approval for all tool calls + + -- trust all functions in a group + -- trusted_tools = 'copilot', + + -- trust specific functions by name or groups by name + -- trusted_tools = { 'file', 'glob', 'grep' }, + + -- trust every enabled tool call + -- trusted_tools = true, +} +``` + +A tool is trusted when any of these match: + +- Its function definition sets `trusted = true` +- Its function name appears in `trusted_tools` +- Its function group appears in `trusted_tools` +- `trusted_tools = true` + +For most setups, trusting a few read-only functions such as `file`, `glob`, or `grep` is safer than trusting everything. + +> [!WARNING] +> Trusted tools run without asking for confirmation. Be especially careful with tools like `bash` and `edit`, which can change your workspace. + Define your own functions in the configuration with input handling and schema: ```lua { functions = { birthday = { - description = "Retrieves birthday information for a person", - uri = "birthday://{name}", + description = 'Retrieves birthday information for a person', + uri = 'birthday://{name}', + trusted = false, schema = { type = 'object', required = { 'name' }, @@ -329,14 +371,16 @@ Define your own functions in the configuration with input handling and schema: uri = 'birthday://' .. input.name, mimetype = 'text/plain', data = input.name .. ' birthday info', - } + }, } - end - } + end, + }, } } ``` +If a function has a `uri`, it can be used manually with `#birthday:Alice`. Functions without a `uri` are tool-only and can only be called by the model. + ## Providers Add custom AI providers: @@ -345,9 +389,9 @@ Add custom AI providers: { providers = { my_provider = { - get_url = function(opts) return "https://api.example.com/chat" end, - get_headers = function() return { ["Authorization"] = "Bearer " .. api_key } end, - get_models = function() return { { id = "gpt-4.1", name = "GPT-4.1 model" } } end, + get_url = function(opts) return 'https://api.example.com/chat' end, + get_headers = function() return { ['Authorization'] = 'Bearer ' .. api_key } end, + get_models = function() return { { id = 'gpt-4.1', name = 'GPT-4.1 model' } } end, prepare_input = require('CopilotChat.config.providers').copilot.prepare_input, prepare_output = require('CopilotChat.config.providers').copilot.prepare_output, } @@ -363,7 +407,7 @@ Add custom AI providers: disabled?: boolean, -- Optional: Extra info about the provider displayed in info panel - get_info?(): string[] + get_info?(headers: table): string[] -- Optional: Get extra request headers with optional expiration time get_headers?(): table, number?, @@ -379,20 +423,23 @@ Add custom AI providers: -- Optional: Get available models get_models?(headers: table): table, + + -- Optional: Resolve a user-facing model id to a provider model id + resolve_model?(headers: table, model: string): string, } ``` **Built-in providers:** - `copilot` - GitHub Copilot (default) -- `github_models` - GitHub Marketplace models (disabled by default) +- `github_models` - GitHub Models (disabled by default) # API Reference ## Core ```lua -local chat = require("CopilotChat") +local chat = require('CopilotChat') -- Basic Chat Functions chat.ask(prompt, config) -- Ask a question with optional config @@ -422,7 +469,7 @@ chat.log_level(level) -- Set log level (debug, info, etc.) You can also access the chat window UI methods through the `chat.chat` object: ```lua -local window = require("CopilotChat").chat +local window = require('CopilotChat').chat -- Chat UI State window:visible() -- Check if chat window is visible @@ -441,8 +488,8 @@ window:start() -- Start writing to chat window window:finish() -- Finish writing to chat window -- Source Management -window.get_source() -- Get the current source buffer and window -window.set_source(winnr) -- Set the source window +window:get_source() -- Get the current source buffer and window +window:set_source(winnr) -- Set the source window -- Navigation window:follow() -- Move cursor to end of chat content @@ -455,10 +502,11 @@ window:overlay(opts) -- Show overlay with specified options ## Prompt parser ```lua -local parser = require("CopilotChat.prompts") +local parser = require('CopilotChat.prompts') parser.resolve_prompt() -- Resolve prompt references -parser.resolve_tools() -- Resolve tools that are available for automatic use by LLM +parser.resolve_tools() -- Resolve tools shared with the model via @... +parser.resolve_functions() -- Resolve manual function/resource references via #... parser.resolve_model() -- Resolve model from prompt (WARN: async, requires plenary.async.run) ``` @@ -466,22 +514,26 @@ parser.resolve_model() -- Resolve model from prompt (WARN: async, requi ```lua -- Open chat, ask a question and handle response -require("CopilotChat").open() -require("CopilotChat").ask("#buffer Explain this code", { +require('CopilotChat').open() +require('CopilotChat').ask('#buffer Explain this code', { callback = function(response) - vim.notify("Got response: " .. response:sub(1, 50) .. "...") - return response + vim.notify('Got response: ' .. vim.trim(response.content):sub(1, 50) .. '...') end, }) -- Save and load chat history -require("CopilotChat").save("my_debugging_session") -require("CopilotChat").load("my_debugging_session") +require('CopilotChat').save('my_debugging_session') +require('CopilotChat').load('my_debugging_session') -- Use custom sticky and model -require("CopilotChat").ask("How can I optimize this?", { - model = "gpt-4.1", - sticky = {"#buffer", "#gitdiff:staged"} +require('CopilotChat').ask('How can I optimize this?', { + model = 'gpt-4.1', + sticky = { '#buffer', '#gitdiff:staged' }, +}) + +-- Automatically trust a small read-only tool set +require('CopilotChat').setup({ + trusted_tools = { 'file', 'glob', 'grep' }, }) ``` @@ -512,6 +564,12 @@ To run tests: make test ``` +To run the same formatting check as CI: + +```bash +stylua --check . +``` + ## Contributing 1. Fork the repository diff --git a/lua/CopilotChat/config.lua b/lua/CopilotChat/config.lua index 1e0adda3..a54a66f3 100644 --- a/lua/CopilotChat/config.lua +++ b/lua/CopilotChat/config.lua @@ -19,6 +19,7 @@ ---@field tools string|table|nil ---@field resources string|table|nil ---@field sticky string|table|nil +---@field trusted_tools boolean|string|table|nil ---@field diff 'block'|'unified'? ---@field language string? ---@field temperature number? @@ -64,6 +65,7 @@ return { tools = nil, -- Default tool or array of tools (or groups) to share with LLM (can be specified manually in prompt via @). resources = 'selection', -- Default resources to share with LLM (can be specified manually in prompt via #). sticky = nil, -- Default sticky prompt or array of sticky prompts to use at start of every new chat (can be specified manually in prompt via >). + trusted_tools = nil, -- Trust tool calls from specific functions or groups, or all trusted tools when true (e.g., {'buffer', 'file'} or 'copilot'). diff = 'block', -- Default diff format to use, 'block' or 'unified'. language = 'English', -- Default language to use for answers diff --git a/lua/CopilotChat/config/functions.lua b/lua/CopilotChat/config/functions.lua index 991759dd..0ec22b10 100644 --- a/lua/CopilotChat/config/functions.lua +++ b/lua/CopilotChat/config/functions.lua @@ -44,6 +44,7 @@ end ---@field description string? ---@field schema table? ---@field group string? +---@field trusted boolean? ---@field uri string? ---@field resolve fun(input: table, source: CopilotChat.ui.chat.Source):CopilotChat.client.Resource[] diff --git a/lua/CopilotChat/config/mappings.lua b/lua/CopilotChat/config/mappings.lua index a72529f9..c3028f8a 100644 --- a/lua/CopilotChat/config/mappings.lua +++ b/lua/CopilotChat/config/mappings.lua @@ -290,7 +290,6 @@ return { async.run(function() local config, prompt = prompts.resolve_prompt(message.content) local system_prompt = config.system_prompt - local resolved_resources = prompts.resolve_functions(prompt, config) local selected_tools = prompts.resolve_tools(prompt, config) local selected_model = prompts.resolve_model(prompt, config) local infos = client:info() @@ -357,28 +356,6 @@ return { table.insert(lines, '') end - if not utils.empty(resolved_resources) then - table.insert(lines, '**Resources**') - table.insert(lines, '') - end - - for _, resource in ipairs(resolved_resources) do - local resource_lines = vim.split(resource.data, '\n') - local preview = vim.list_slice(resource_lines, 1, math.min(10, #resource_lines)) - local header = string.format('**%s** (%s lines)', resource.name or resource.uri, #resource_lines) - if #resource_lines > 10 then - header = header .. ' (truncated)' - end - - table.insert(lines, header) - table.insert(lines, '```' .. files.mimetype_to_filetype(resource.mimetype)) - for _, line in ipairs(preview) do - table.insert(lines, line) - end - table.insert(lines, '```') - table.insert(lines, '') - end - chat:overlay({ text = vim.trim(table.concat(lines, '\n')) .. '\n', }) diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index b5052845..f6a2cc3b 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -2,6 +2,7 @@ local async = require('plenary.async') local log = require('plenary.log') local client = require('CopilotChat.client') local constants = require('CopilotChat.constants') +local functions = require('CopilotChat.functions') local prompts = require('CopilotChat.prompts') local select = require('CopilotChat.select') local utils = require('CopilotChat.utils') @@ -30,6 +31,37 @@ local M = setmetatable({}, { end, }) +---@param config CopilotChat.config.Shared +---@param tool_name string +---@return boolean +local function is_trusted_tool(config, tool_name) + local tool_spec = config.functions[tool_name] + if not tool_spec then + return false + end + + if tool_spec.trusted then + return true + end + + local trusted_tools = config.trusted_tools + if trusted_tools == true then + return true + end + + for _, trusted_pattern in ipairs(utils.to_table(trusted_tools)) do + if tool_name == trusted_pattern then + return true + end + + if tool_spec.group == trusted_pattern then + return true + end + end + + return false +end + --- Process sticky values from prompt and config --- Extracts stickies from prompt, adds config-based stickies, stores them, returns clean prompt ---@param prompt string @@ -116,7 +148,7 @@ end --- Finish writing to chat buffer. ---@param start_of_chat boolean? -local function finish(start_of_chat) +local function finish(start_of_chat, remaining_tool_calls) if start_of_chat then local sticky = {} if M.config.sticky then @@ -128,8 +160,11 @@ local function finish(start_of_chat) end local prompt_content = '' - local assistant_message = M.chat:get_message(constants.ROLE.ASSISTANT) - local tool_calls = assistant_message and assistant_message.tool_calls or {} + local tool_calls = remaining_tool_calls + if not tool_calls then + local assistant_message = M.chat:get_message(constants.ROLE.ASSISTANT) + tool_calls = assistant_message and assistant_message.tool_calls or {} + end local current_sticky = M.chat:get_sticky() if not utils.empty(current_sticky) then @@ -430,16 +465,9 @@ function M.ask(prompt, config) config, prompt = prompts.resolve_prompt(prompt, config) local system_prompt = config.system_prompt or '' local selected_tools, prompt = prompts.resolve_tools(prompt, config) - local resolved_resources, resolved_tools, resolved_stickies, prompt = prompts.resolve_functions(prompt, config) + local resolved_resources, resolved_tools, prompt = prompts.resolve_functions(prompt, config) local selected_model, prompt = prompts.resolve_model(prompt, config) - -- Store resolved stickies to chat - local current_sticky = M.chat:get_sticky() - for _, sticky in ipairs(resolved_stickies) do - table.insert(current_sticky, sticky) - end - M.chat:set_sticky(current_sticky) - prompt = vim.trim(prompt) if not config.headless then @@ -547,6 +575,93 @@ function M.ask(prompt, config) M.chat.token_count = token_count M.chat.token_max_count = token_max_count + -- Execute trusted tool calls automatically + if response.tool_calls and #response.tool_calls > 0 then + local trusted_tool_calls = {} + local untrusted_tool_calls = {} + + for _, tool_call in ipairs(response.tool_calls) do + if is_trusted_tool(config, tool_call.name) then + table.insert(trusted_tool_calls, tool_call) + else + table.insert(untrusted_tool_calls, tool_call) + end + end + + if #trusted_tool_calls > 0 then + async.run(handle_error(config, function() + local trusted_tool_results = {} + local source = M.chat:get_source() + + for _, tool_call in ipairs(trusted_tool_calls) do + local input = {} + if not utils.empty(tool_call.arguments) then + input = utils.json_decode(tool_call.arguments) + end + + local ok, output = prompts.execute_tool_call(tool_call.name, input, config, source) + local result = prompts.format_tool_output(ok, output) + + table.insert(trusted_tool_results, { + id = tool_call.id, + result = result, + }) + end + + if not utils.empty(trusted_tool_results) then + utils.schedule_main() + for _, tool in ipairs(trusted_tool_results) do + M.chat:add_message({ + id = tool.id, + role = constants.ROLE.TOOL, + tool_call_id = tool.id, + content = '\n' .. tool.result .. '\n', + }) + end + + if #untrusted_tool_calls > 0 then + finish(nil, untrusted_tool_calls) + else + local continue_response = client:ask({ + headless = config.headless, + history = M.chat:get_messages(), + resources = resolved_resources, + tools = selected_tools, + system_prompt = system_prompt, + model = selected_model, + temperature = config.temperature, + on_progress = vim.schedule_wrap(function(message) + if not config.headless then + M.chat:add_message(message) + end + end), + }) + + if continue_response then + local continue_message = continue_response.message + continue_message.content = vim.trim(continue_message.content) + if utils.empty(continue_message.content) then + continue_message.content = '' + else + continue_message.content = '\n' .. continue_message.content .. '\n' + end + + utils.schedule_main() + M.chat:add_message(continue_message, true) + M.chat.token_count = continue_response.token_count + M.chat.token_max_count = continue_response.token_max_count + end + + finish() + end + else + finish() + end + end)) + return + end + end + finish() end end)) diff --git a/lua/CopilotChat/prompts.lua b/lua/CopilotChat/prompts.lua index 461a0770..f5c76525 100644 --- a/lua/CopilotChat/prompts.lua +++ b/lua/CopilotChat/prompts.lua @@ -91,10 +91,67 @@ function M.resolve_tools(prompt, config) return enabled_tools:values(), prompt end +--- Execute a tool call and return the raw output. +---@param name string Tool name +---@param input table|string Input arguments +---@param config CopilotChat.config.Shared +---@param source CopilotChat.client.Source +---@return boolean ok +---@return any output +---@async +function M.execute_tool_call(name, input, config, source) + local tool = config.functions[name] + if not tool or not tool.resolve then + return false, 'Tool not found: ' .. name + end + + local schema = nil + for _, t in ipairs(functions.parse_tools(config.functions)) do + if t.name == name then + schema = t.schema + break + end + end + + local ok, output + if config.stop_on_function_failure then + output = tool.resolve(functions.parse_input(input, schema), source) + ok = true + else + ok, output = pcall(tool.resolve, functions.parse_input(input, schema), source) + end + + return ok, output +end + +--- Format tool output as plain text. +---@param ok boolean +---@param output any +---@return string +function M.format_tool_output(ok, output) + local result = '' + if not ok then + result = utils.make_string(output) + elseif type(output) ~= 'table' then + result = utils.make_string(output) + else + for _, content in ipairs(output) do + if content then + local data = content.data or content.uri + if data then + result = result .. (utils.empty(result) and '' or '\n') .. data + end + end + end + end + + return result +end + --- Call and resolve function calls from the prompt. ---@param prompt string? ---@param config CopilotChat.config.Shared? ----@return table, table, table, string +---@return table, table, string ---@async function M.resolve_functions(prompt, config) config, prompt = M.resolve_prompt(prompt, config) @@ -102,11 +159,6 @@ function M.resolve_functions(prompt, config) local chat = require('CopilotChat').chat local source = chat:get_source() - local tools = {} - for _, tool in ipairs(functions.parse_tools(config.functions)) do - tools[tool.name] = tool - end - if config.resources then local resources = utils.to_table(config.resources) local lines = utils.split_lines(prompt) @@ -119,7 +171,6 @@ function M.resolve_functions(prompt, config) local resolved_resources = {} local resolved_tools = {} - local resolved_stickies = {} local tool_calls = {} utils.schedule_main() @@ -199,58 +250,51 @@ function M.resolve_functions(prompt, config) return nil end - local schema = tools[name] and tools[name].schema or nil - local ok, output - if config.stop_on_function_failure then - output = tool.resolve(functions.parse_input(input, schema), source) - ok = true - else - ok, output = pcall(tool.resolve, functions.parse_input(input, schema), source) + local ok, output = M.execute_tool_call(name, input, config, source) + + if tool_id then + table.insert(resolved_tools, { + id = tool_id, + result = M.format_tool_output(ok, output), + }) + + return '' end - local result = '' if not ok then - result = utils.make_string(output) - else - for _, content in ipairs(output) do - if content then - local content_out = nil - if content.uri then - if - not vim.tbl_contains(resolved_resources, function(resource) - return resource.uri == content.uri - end, { predicate = true }) - then - content_out = '##' .. content.uri - table.insert(resolved_resources, content) - end - - if tool_id then - table.insert(resolved_stickies, '##' .. content.uri) - end - else - content_out = content.data + return utils.make_string(output) + end + + if type(output) ~= 'table' then + return utils.make_string(output) + end + + local result = '' + for _, content in ipairs(output) do + if content then + local content_out = nil + if content.uri then + if + not vim.tbl_contains(resolved_resources, function(resource) + return resource.uri == content.uri + end, { predicate = true }) + then + content_out = '##' .. content.uri + table.insert(resolved_resources, content) end + else + content_out = content.data + end - if content_out then - if not utils.empty(result) then - result = result .. '\n' - end - result = result .. content_out + if content_out then + if not utils.empty(result) then + result = result .. '\n' end + result = result .. content_out end end end - if tool_id then - table.insert(resolved_tools, { - id = tool_id, - result = result, - }) - - return '' - end - return result end @@ -266,7 +310,7 @@ function M.resolve_functions(prompt, config) end end - return resolved_resources, resolved_tools, resolved_stickies, prompt + return resolved_resources, resolved_tools, prompt end --- Resolve the final prompt and config from prompt template. diff --git a/lua/CopilotChat/ui/chat.lua b/lua/CopilotChat/ui/chat.lua index 0f3b4871..f20d08b3 100644 --- a/lua/CopilotChat/ui/chat.lua +++ b/lua/CopilotChat/ui/chat.lua @@ -99,9 +99,9 @@ end ---@field content string ---@class CopilotChat.ui.chat.Section ----@field start_line number ----@field end_line number ----@field blocks table +---@field start_line integer +---@field end_line integer +---@field blocks CopilotChat.ui.chat.Block[] ---@class CopilotChat.ui.chat.Message : CopilotChat.client.Message ---@field id string? @@ -113,7 +113,7 @@ end --- @field cwd fun():string ---@class CopilotChat.ui.chat.Chat : CopilotChat.ui.overlay.Overlay ----@field winnr number? +---@field winnr integer? ---@field config CopilotChat.config.Shared ---@field token_count number? ---@field token_max_count number? @@ -125,7 +125,7 @@ end ---@field private chat_overlay CopilotChat.ui.overlay.Overlay ---@field private last_changedtick number? ---@field private source CopilotChat.ui.chat.Source ----@field private sticky table +---@field private sticky string[] local Chat = class(function(self, config, on_buf_create) Overlay.init(self, 'copilot-chat', utils.key_to_info('show_help', config.mappings.show_help), on_buf_create) @@ -243,7 +243,7 @@ function Chat:get_block(role, cursor) end --- Get list of all chat messages ----@return table +---@return CopilotChat.ui.chat.Message[] function Chat:get_messages() self:parse() return self.messages:values() @@ -269,7 +269,12 @@ function Chat:get_message(role, cursor) for _, message in ipairs(messages) do local section = message.section local matches_role = not role or message.role == role - if matches_role and section.start_line <= cursor_line and section.start_line > max_line_below_cursor then + if + matches_role + and section + and section.start_line <= cursor_line + and section.start_line > max_line_below_cursor + then max_line_below_cursor = section.start_line closest_message = message end @@ -288,13 +293,13 @@ function Chat:get_message(role, cursor) end --- Get the current sticky array. ----@return table +---@return string[] function Chat:get_sticky() return self.sticky end --- Set the sticky array. ----@param sticky table +---@param sticky string[] function Chat:set_sticky(sticky) self.sticky = sticky end diff --git a/lua/CopilotChat/ui/overlay.lua b/lua/CopilotChat/ui/overlay.lua index 32157685..ace646c4 100644 --- a/lua/CopilotChat/ui/overlay.lua +++ b/lua/CopilotChat/ui/overlay.lua @@ -2,7 +2,7 @@ local utils = require('CopilotChat.utils') local class = require('CopilotChat.utils.class') ---@class CopilotChat.ui.overlay.Overlay : Class ----@field bufnr number? +---@field bufnr integer? ---@field protected name string ---@field protected help string ---@field private cursor integer[]? @@ -23,11 +23,11 @@ end) --- Show the overlay buffer ---@param text string ----@param winnr number +---@param winnr integer ---@param filetype? string ---@param syntax string? ----@param on_show? fun(bufnr: number) ----@param on_hide? fun(bufnr: number) +---@param on_show? fun(bufnr: integer) +---@param on_hide? fun(bufnr: integer) function Overlay:show(text, winnr, filetype, syntax, on_show, on_hide) if not text or text == '' then return @@ -75,7 +75,7 @@ function Overlay:delete() end --- Create the overlay buffer ----@return number +---@return integer ---@protected function Overlay:create() local bufnr = vim.api.nvim_create_buf(false, true)