-
-
Notifications
You must be signed in to change notification settings - Fork 406
Expand file tree
/
Copy pathchange_adapter.lua
More file actions
264 lines (226 loc) · 7.39 KB
/
change_adapter.lua
File metadata and controls
264 lines (226 loc) · 7.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
local config = require("codecompanion.config")
local log = require("codecompanion.utils.log")
local utils = require("codecompanion.utils")
local M = {}
---Create options for vim.ui.select with formatting
---@param prompt string The prompt to display
---@param conditional string The item to mark as current
---@return table
local function select_opts(prompt, conditional)
return {
prompt = prompt,
kind = "codecompanion.nvim",
format_item = function(item)
if conditional == item then
return "* " .. item
end
return " " .. item
end,
}
end
---Get list of available adapters
---@param current_adapter string The currently selected adapter
---@return table List of adapter names with current adapter first
function M.get_adapters_list(current_adapter)
local adapters =
vim.tbl_deep_extend("force", {}, vim.deepcopy(config.adapters.acp), vim.deepcopy(config.adapters.http))
local adapters_list = vim
.iter(adapters)
:filter(function(adapter)
-- Clear out the acp and http keys
return adapter ~= "opts" and adapter ~= "acp" and adapter ~= "http" and adapter ~= current_adapter
end)
:map(function(adapter, _)
return adapter
end)
:totable()
table.sort(adapters_list)
table.insert(adapters_list, 1, current_adapter)
return adapters_list
end
---Get list of available models for an adapter
---@param adapter CodeCompanion.HTTPAdapter
---@return table|nil
function M.list_http_models(adapter)
local models = adapter.schema.model.choices
-- Check if we should show model choices or just the default
local show_choices = config.adapters
and config.adapters.http
and config.adapters.http.opts
and config.adapters.http.opts.show_model_choices
if not show_choices then
models = { adapter.schema.model.default }
end
if type(models) == "function" then
-- When user explicitly wants to change models, force token creation
models = models(adapter, { async = false })
end
if not models or vim.tbl_count(models) < 2 then
return nil
end
local current_model_id = adapter.schema.model.default
if type(current_model_id) == "function" then
current_model_id = current_model_id(adapter)
end
local current_model = nil
for _, model_str in ipairs(models) do
if model_str == current_model_id then
current_model = model_str
break
end
end
if not current_model and models[current_model_id] then
current_model = models[current_model_id]
-- If it's a table without an id, create one
if type(current_model) == "table" and not current_model.id then
current_model.id = current_model_id
end
end
local models_list = vim
.iter(models)
:map(function(key, value)
if type(key) == "string" and value == nil then
-- `models` is already a list
return key
end
if type(value) == "table" and not value.id then
value.id = key
end
return value
end)
:filter(function(model)
local model_id = type(model) == "table" and model.id or model
return model_id ~= current_model_id
end)
:totable()
table.sort(models_list, function(a, b)
local id_a = type(a) == "table" and (a.formatted_name or a.id) or a
local id_b = type(b) == "table" and (b.formatted_name or b.id) or b
return id_a < id_b
end)
if current_model then
table.insert(models_list, 1, current_model)
end
return models_list
end
---List available models for an ACP adapter
---@param connection CodeCompanion.ACP.Connection
---@return table|nil
function M.list_acp_models(connection)
local show_choices = config.adapters
and config.adapters.acp
and config.adapters.acp.opts
and config.adapters.acp.opts.show_model_choices
if not show_choices then
return nil
end
local models = connection:get_models()
if not models or vim.tbl_count(models.availableModels) < 2 then
return nil
end
return models
end
---Update the system prompt after adapter change
---@param chat CodeCompanion.Chat
function M.update_system_prompt(chat)
local system_prompt = config.interactions.chat.opts.system_prompt
if type(system_prompt) == "function" then
if chat.messages[1] and chat.messages[1].role == "system" then
chat.messages[1].content = system_prompt(chat:make_system_prompt_context())
end
end
end
---Handle model selection for HTTP adapters
---@param chat CodeCompanion.Chat
---@return nil
function M.select_model(chat)
local adapter_type = chat.adapter.type
local current_model = nil
local models_list = nil
if adapter_type == "http" then
---@diagnostic disable-next-line: param-type-mismatch
models_list = M.list_http_models(chat.adapter)
if not models_list then
return log:debug("No models to select for the HTTP adapter")
end
current_model = models_list[1]
end
if adapter_type == "acp" then
---@diagnostic disable-next-line: param-type-mismatch
local acp_models = M.list_acp_models(chat.acp_connection)
models_list = acp_models and acp_models.availableModels or nil
if not acp_models or not models_list then
return log:debug("No models to select for the ACP adapter")
end
current_model = acp_models.currentModelId
end
if not models_list then
return
end
local function get_model_id(model)
return type(model) == "table" and model.id or model.modelId or model
end
local current_id = get_model_id(current_model)
local opts = {
prompt = "Select Model",
kind = "codecompanion.nvim",
format_item = function(model)
local model_id = get_model_id(model)
local display
if type(model) == "table" then
if adapter_type == "http" then
display = model.description or model.formatted_name or model.id or "Unknown"
elseif adapter_type == "acp" then
display = model.name
if model.description then
display = string.format("%s - %s", display, model.description)
end
end
else
display = model
end
-- Mark the current model
if model_id == current_id then
return "* " .. display
end
return " " .. display
end,
}
vim.ui.select(models_list, opts, function(selected_model)
if not selected_model then
return
end
local model_id = get_model_id(selected_model)
chat:change_model({ model = model_id })
end)
end
---Main callback for the change_adapter keymap
---@param chat CodeCompanion.Chat
---@return nil
function M.callback(chat)
if config.display.chat.show_settings then
return utils.notify("Adapter can't be changed when `display.chat.show_settings = true`", vim.log.levels.WARN)
end
local current_adapter = chat.adapter.name
local adapters_list = M.get_adapters_list(current_adapter)
vim.ui.select(adapters_list, select_opts("Select Adapter", current_adapter), function(selected_adapter)
if not selected_adapter then
return
end
local function on_adapter_ready()
-- Only force a system prompt update if the user isn't ignoring it. This
-- occurs when a user has initiated a chat from the prompt library
if not chat.opts.ignore_system_prompt then
M.update_system_prompt(chat)
end
return M.select_model(chat)
end
if current_adapter ~= selected_adapter then
chat.acp_connection = nil -- Ensure we reset this
chat:change_adapter(selected_adapter, on_adapter_ready)
else
return on_adapter_ready()
end
end)
end
return M