diff --git a/lua/opencode/services/messaging.lua b/lua/opencode/services/messaging.lua index 23c5c9fa..4bd433b5 100644 --- a/lua/opencode/services/messaging.lua +++ b/lua/opencode/services/messaging.lua @@ -83,6 +83,7 @@ M.send_message = Promise.async(function(prompt, opts) update_sent_message_count(-1) session_runtime.cancel():await() end) + :await() end) ---@param prompt string diff --git a/tests/unit/services_messaging_spec.lua b/tests/unit/services_messaging_spec.lua index cd8975f0..234398f7 100644 --- a/tests/unit/services_messaging_spec.lua +++ b/tests/unit/services_messaging_spec.lua @@ -78,7 +78,6 @@ describe('opencode.services.messaging', function() local count_before = state.user_message_count['sess1'] or 0 local count_during = nil - local count_after = nil local orig = state.api_client.create_message state.api_client.create_message = function(_, sid, params) @@ -90,12 +89,9 @@ describe('opencode.services.messaging', function() }) end - messaging.send_message('hello world') + messaging.send_message('hello world'):wait() - vim.wait(50, function() - count_after = state.user_message_count['sess1'] or 0 - return count_after == 0 - end) + local count_after = state.user_message_count['sess1'] or 0 assert.equal(0, count_before) assert.equal(1, count_during) @@ -111,7 +107,6 @@ describe('opencode.services.messaging', function() local count_before = state.user_message_count['sess1'] or 0 local count_during = nil - local count_after = nil local orig = state.api_client.create_message state.api_client.create_message = function(_, sid, params) @@ -120,14 +115,11 @@ describe('opencode.services.messaging', function() end local orig_cancel = session_runtime.cancel - stub(session_runtime, 'cancel') + stub(session_runtime, 'cancel').returns(Promise.new():resolve(nil)) - messaging.send_message('hello world') + messaging.send_message('hello world'):wait() - vim.wait(50, function() - count_after = state.user_message_count['sess1'] or 0 - return count_after == 0 - end) + local count_after = state.user_message_count['sess1'] or 0 assert.equal(0, count_before) assert.equal(1, count_during)