Skip to content

Commit 2de0ff7

Browse files
authored
chore: bump luassert to the latest master (#377)
1 parent 46e8bb9 commit 2de0ff7

10 files changed

Lines changed: 204 additions & 71 deletions

File tree

lua/luassert/assert.lua

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
local s = require 'say'
22
local astate = require 'luassert.state'
33
local util = require 'luassert.util'
4-
local unpack = require 'luassert.compatibility'.unpack
4+
local unpack = util.unpack
55
local obj -- the returned module table
66
local level_mt = {}
77

@@ -39,8 +39,7 @@ local __state_meta = {
3939
end
4040
end
4141

42-
local arguments = {...}
43-
arguments.n = select('#', ...) -- add argument count for trailing nils
42+
local arguments = util.make_arglist(...)
4443
local val, retargs = assertion.callback(self, arguments, util.errorlevel())
4544

4645
if not val == self.mod then
@@ -57,8 +56,7 @@ local __state_meta = {
5756
end
5857
return ...
5958
else
60-
local arguments = {...}
61-
arguments.n = select('#', ...)
59+
local arguments = util.make_arglist(...)
6260
self.tokens = {}
6361

6462
for _, key in ipairs(keys) do
@@ -135,25 +133,25 @@ obj = {
135133
set_parameter = function(self, name, value)
136134
astate.set_parameter(name, value)
137135
end,
138-
136+
139137
get_parameter = function(self, name)
140138
return astate.get_parameter(name)
141-
end,
142-
139+
end,
140+
143141
add_spy = function(self, spy)
144142
astate.add_spy(spy)
145143
end,
146-
144+
147145
snapshot = function(self)
148146
return astate.snapshot()
149147
end,
150-
148+
151149
level = function(self, level)
152150
return setmetatable({
153151
level = level
154152
}, level_mt)
155153
end,
156-
154+
157155
-- returns the level if a level-value, otherwise nil
158156
get_level = function(self, level)
159157
if getmetatable(level) ~= level_mt then

lua/luassert/assertions.lua

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,15 @@ local function same(state, arguments, level)
142142
end
143143

144144
local function truthy(state, arguments, level)
145+
local argcnt = arguments.n
146+
assert(argcnt > 0, s("assertion.internal.argtolittle", { "truthy", 1, tostring(argcnt) }), level)
145147
set_failure_message(state, arguments[2])
146148
return arguments[1] ~= false and arguments[1] ~= nil
147149
end
148150

149151
local function falsy(state, arguments, level)
152+
local argcnt = arguments.n
153+
assert(argcnt > 0, s("assertion.internal.argtolittle", { "falsy", 1, tostring(argcnt) }), level)
150154
return not truthy(state, arguments, level)
151155
end
152156

lua/luassert/compatibility.lua

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
-- no longer needed, only for backward compatibility
2+
local unpack = require ("luassert.util").unpack
3+
14
return {
2-
unpack = table.unpack or unpack,
5+
unpack = function(...)
6+
print(debug.traceback("WARN: calling deprecated function 'luassert.compatibility.unpack' use 'luassert.util.unpack' instead"))
7+
return unpack(...)
8+
end
39
}

lua/luassert/formatters/init.lua

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
-- module will not return anything, only register formatters with the main assert engine
22
local assert = require('luassert.assert')
3+
local match = require('luassert.match')
4+
local util = require('luassert.util')
35

46
local colors = setmetatable({
57
none = function(c) return c end
@@ -201,6 +203,35 @@ local function fmt_thread(arg)
201203
end
202204
end
203205

206+
local function fmt_matcher(arg)
207+
if not match.is_matcher(arg) then
208+
return
209+
end
210+
local not_inverted = {
211+
[true] = "is.",
212+
[false] = "no.",
213+
}
214+
local args = {}
215+
for idx = 1, arg.arguments.n do
216+
table.insert(args, assert:format({ arg.arguments[idx], n = 1, })[1])
217+
end
218+
return string.format("(matcher) %s%s(%s)",
219+
not_inverted[arg.mod],
220+
tostring(arg.name),
221+
table.concat(args, ", "))
222+
end
223+
224+
local function fmt_arglist(arglist)
225+
if not util.is_arglist(arglist) then
226+
return
227+
end
228+
local formatted_vals = {}
229+
for idx = 1, arglist.n do
230+
table.insert(formatted_vals, assert:format({ arglist[idx], n = 1, })[1])
231+
end
232+
return "(values list) (" .. table.concat(formatted_vals, ", ") .. ")"
233+
end
234+
204235
assert:add_formatter(fmt_string)
205236
assert:add_formatter(fmt_number)
206237
assert:add_formatter(fmt_boolean)
@@ -209,6 +240,8 @@ assert:add_formatter(fmt_table)
209240
assert:add_formatter(fmt_function)
210241
assert:add_formatter(fmt_userdata)
211242
assert:add_formatter(fmt_thread)
243+
assert:add_formatter(fmt_matcher)
244+
assert:add_formatter(fmt_arglist)
212245
-- Set default table display depth for table formatter
213246
assert:set_parameter("TableFormatLevel", 3)
214247
assert:set_parameter("TableFormatShowRecursion", false)

lua/luassert/match.lua

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,16 @@ local state_mt = {
2525
end
2626
end
2727

28-
local arguments = {...}
29-
arguments.n = select('#', ...) -- add argument count for trailing nils
28+
local arguments = util.make_arglist(...)
3029
local matches = matcher.callback(self, arguments, util.errorlevel())
3130
return setmetatable({
3231
name = matcher.name,
3332
mod = self.mod,
3433
callback = matches,
34+
arguments = arguments,
3535
}, matcher_mt)
3636
else
37-
local arguments = {...}
38-
arguments.n = select('#', ...) -- add argument count for trailing nils
37+
local arguments = util.make_arglist(...)
3938

4039
for _, key in ipairs(keys) do
4140
if namespace.modifier[key] then

lua/luassert/matchers/core.lua

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ local function matches(state, arguments, level)
6060
local pattern = arguments[1]
6161
local init = arguments[2]
6262
local plain = arguments[3]
63-
local stringtype = "string or object convertible to a string"
6463
assert(type(pattern) == "string", s("assertion.internal.badargtype", { 1, "matches", "string", type(arguments[1]) }), level)
6564
assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { 2, "matches", "number", type(arguments[2]) }), level)
6665

lua/luassert/spy.lua

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@ local util = require('luassert.util')
55
-- Spy metatable
66
local spy_mt = {
77
__call = function(self, ...)
8-
local arguments = {...}
9-
arguments.n = select('#',...) -- add argument count for trailing nils
8+
local arguments = util.make_arglist(...)
109
table.insert(self.calls, util.copyargs(arguments))
1110
local function get_returns(...)
12-
local returnvals = {...}
13-
returnvals.n = select('#',...) -- add argument count for trailing nils
11+
local returnvals = util.make_arglist(...)
1412
table.insert(self.returnvals, util.copyargs(returnvals))
1513
return ...
1614
end
@@ -59,11 +57,27 @@ spy = {
5957
end,
6058

6159
called_with = function(self, args)
62-
return util.matchargs(self.calls, args) ~= nil
60+
local last_arglist = nil
61+
if #self.calls > 0 then
62+
last_arglist = self.calls[#self.calls].vals
63+
end
64+
local matching_arglists = util.matchargs(self.calls, args)
65+
if matching_arglists ~= nil then
66+
return true, matching_arglists.vals
67+
end
68+
return false, last_arglist
6369
end,
6470

6571
returned_with = function(self, args)
66-
return util.matchargs(self.returnvals, args) ~= nil
72+
local last_returnvallist = nil
73+
if #self.returnvals > 0 then
74+
last_returnvallist = self.returnvals[#self.returnvals].vals
75+
end
76+
local matching_returnvallists = util.matchargs(self.returnvals, args)
77+
if matching_returnvallists ~= nil then
78+
return true, matching_returnvallists.vals
79+
end
80+
return false, last_returnvallist
6781
end
6882
}, spy_mt)
6983
assert:add_spy(s) -- register with the current state
@@ -96,7 +110,12 @@ local function returned_with(state, arguments, level)
96110
local level = (level or 1) + 1
97111
local payload = rawget(state, "payload")
98112
if payload and payload.returned_with then
99-
return state.payload:returned_with(arguments)
113+
local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments)
114+
local expected_returnvallist = util.shallowcopy(arguments)
115+
util.cleararglist(arguments)
116+
util.tinsert(arguments, 1, matching_or_last_returnvallist)
117+
util.tinsert(arguments, 2, expected_returnvallist)
118+
return assertion_holds
100119
else
101120
error("'returned_with' must be chained after 'spy(aspy)'", level)
102121
end
@@ -106,7 +125,12 @@ local function called_with(state, arguments, level)
106125
local level = (level or 1) + 1
107126
local payload = rawget(state, "payload")
108127
if payload and payload.called_with then
109-
return state.payload:called_with(arguments)
128+
local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments)
129+
local expected_arglist = util.shallowcopy(arguments)
130+
util.cleararglist(arguments)
131+
util.tinsert(arguments, 1, matching_or_last_arglist)
132+
util.tinsert(arguments, 2, expected_arglist)
133+
return assertion_holds
110134
else
111135
error("'called_with' must be chained after 'spy(aspy)'", level)
112136
end

lua/luassert/state.lua

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ state.revert = function(self)
3030
end
3131
end
3232
if getmetatable(self) ~= state_mt then error("Value provided is not a valid snapshot", 2) end
33-
33+
3434
if self.next then
3535
self.next:revert()
3636
end
@@ -52,7 +52,6 @@ end
5252
-- Creates a new snapshot.
5353
-- @return snapshot table
5454
state.snapshot = function()
55-
local s = current
5655
local new = setmetatable ({
5756
formatters = {},
5857
parameters = {},

lua/luassert/stub.lua

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
local assert = require 'luassert.assert'
33
local spy = require 'luassert.spy'
44
local util = require 'luassert.util'
5-
local unpack = require 'luassert.compatibility'.unpack
5+
local unpack = util.unpack
6+
local pack = util.pack
67

78
local stub = {}
89

@@ -12,22 +13,20 @@ function stub.new(object, key, ...)
1213
object = {}
1314
key = ""
1415
end
15-
local return_values_count = select("#", ...)
16-
local return_values = {...}
16+
local return_values = pack(...)
1717
assert(type(object) == "table" and key ~= nil, "stub.new(): Can only create stub on a table key, call with 2 params; table, key", util.errorlevel())
1818
assert(object[key] == nil or util.callable(object[key]), "stub.new(): The element for which to create a stub must either be callable, or be nil", util.errorlevel())
1919
local old_elem = object[key] -- keep existing element (might be nil!)
2020

21-
local fn = (return_values_count == 1 and util.callable(return_values[1]) and return_values[1])
21+
local fn = (return_values.n == 1 and util.callable(return_values[1]) and return_values[1])
2222
local defaultfunc = fn or function()
23-
return unpack(return_values, 1, return_values_count)
23+
return unpack(return_values)
2424
end
2525
local oncalls = {}
2626
local callbacks = {}
2727
local stubfunc = function(...)
28-
local args = {...}
29-
args.n = select('#', ...)
30-
local match = util.matchargs(oncalls, args)
28+
local args = util.make_arglist(...)
29+
local match = util.matchoncalls(oncalls, args)
3130
if match then
3231
return callbacks[match](...)
3332
end
@@ -48,10 +47,9 @@ function stub.new(object, key, ...)
4847
end
4948

5049
s.returns = function(...)
51-
local return_args = {...}
52-
local n = select('#', ...)
50+
local return_args = pack(...)
5351
defaultfunc = function()
54-
return unpack(return_args, 1, n)
52+
return unpack(return_args)
5553
end
5654
return s
5755
end
@@ -69,16 +67,14 @@ function stub.new(object, key, ...)
6967
}
7068

7169
s.on_call_with = function(...)
72-
local match_args = {...}
73-
match_args.n = select('#', ...)
70+
local match_args = util.make_arglist(...)
7471
match_args = util.copyargs(match_args)
7572
return {
7673
returns = function(...)
77-
local return_args = {...}
78-
local n = select('#', ...)
74+
local return_args = pack(...)
7975
table.insert(oncalls, match_args)
8076
callbacks[match_args] = function()
81-
return unpack(return_args, 1, n)
77+
return unpack(return_args)
8278
end
8379
return s
8480
end,

0 commit comments

Comments
 (0)