Skip to content

Commit f96a39b

Browse files
kkharjiConni2461
andauthored
feat(benchmark): add util to benchmark functions (#228)
* feat(benchmark): init * update interface and output * make it work with a single function * apply conni patch and update interface co-authored-by: Conni2461 <[email protected]> * enhance co-authored-by: Conni2461 <[email protected]> Co-authored-by: Conni2461 <[email protected]>
1 parent 6c80b83 commit f96a39b

3 files changed

Lines changed: 216 additions & 0 deletions

File tree

lua/plenary/benchmark/init.lua

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
local B = {}
2+
local stat = require "plenary.benchmark.stat"
3+
4+
local get_stats = function(results)
5+
local ret = {}
6+
7+
ret.max, ret.min = stat.maxmin(results)
8+
ret.mean = stat.mean(results)
9+
ret.median = stat.median(results)
10+
ret.std = stat.std_dev(results)
11+
12+
return ret
13+
end
14+
15+
local get_output = function(index, res, runs)
16+
-- divine with a sutable one / 1e3, 1e6, 1e9
17+
local time_types = { "ns", "μs", "ms" }
18+
19+
local get_leading = function(time)
20+
time = math.floor(time)
21+
local count = 0
22+
repeat
23+
time = math.floor(time / 10)
24+
count = count + 1
25+
until time <= 0
26+
return count
27+
end
28+
29+
local get_best_fmt = function(time)
30+
for _, v in ipairs(time_types) do
31+
if math.abs(time) < 1000.0 then
32+
return string.format("%s%3.1f %s", string.rep(" ", 3 - get_leading(time)), time, v)
33+
end
34+
time = time / 1000.0
35+
end
36+
return string.format("%.1f %s", time, "s")
37+
end
38+
39+
return string.format(
40+
"Benchmark #%d: '%s'\n Time(mean ± σ): %s ± %s\n Range(min … max): %s … %s %d runs\n",
41+
index,
42+
res.name,
43+
get_best_fmt(res.stats.mean),
44+
get_best_fmt(res.stats.std),
45+
get_best_fmt(res.stats.min),
46+
get_best_fmt(res.stats.max),
47+
runs
48+
)
49+
end
50+
51+
local get_summary = function(res)
52+
if #res == 1 then
53+
return ""
54+
end
55+
56+
local fastest_mean = math.huge
57+
local fastest_index = 1
58+
for i, benchmark in ipairs(res) do
59+
if benchmark.stats.mean < fastest_mean then
60+
fastest_mean = benchmark.stats.mean
61+
fastest_index = i
62+
end
63+
end
64+
65+
if fastest_mean == math.huge then
66+
return ""
67+
end
68+
69+
local output = {}
70+
local fastest = res[fastest_index].stats
71+
for i, benchmark in ipairs(res) do
72+
if i ~= fastest_index then
73+
local result = benchmark.stats
74+
local ratio = result.mean / fastest.mean
75+
76+
-- // https://en.wikipedia.org/wiki/Propagation_of_uncertainty#Example_formulas
77+
-- // Covariance asssumed to be 0, i.e. variables are assumed to be independent
78+
local ratio_std = ratio
79+
* math.sqrt(math.pow(result.std / result.mean, 2) + math.pow(fastest.std / fastest.mean, 2))
80+
81+
table.insert(output, string.format(" %.1f ± %.1f times faster than '%s'\n", ratio, ratio_std, benchmark.name))
82+
end
83+
end
84+
85+
return string.format("Summary\n '%s' ran\n%s", res[fastest_index].name, table.concat(output, ""))
86+
end
87+
88+
---@class benchmark_run_opts
89+
---@field warmup number @number of initial runs before starting to track time.
90+
---@field runs number @number of runs to make
91+
---@field fun table<array<string, function>> @functions to execute
92+
93+
---Benchmark a function
94+
---@param name string @benchmark name
95+
---@param opts benchmark_run_opts
96+
local bench = function(name, opts)
97+
vim.validate {
98+
opts = { opts, "table" },
99+
fun = { opts.fun, "table" },
100+
}
101+
opts.warmup = vim.F.if_nil(opts.warmup, 3)
102+
opts.runs = vim.F.if_nil(opts.runs, 5)
103+
104+
opts.fun = type(opts.fun) == "function" and { opts.fun } or opts.fun
105+
local output = { string.format("Benchmark Group: '%s' -----------------------\n", name) }
106+
local res = {}
107+
for i, fun in ipairs(opts.fun) do
108+
res[i] = { name = fun[1], results = {} }
109+
for _ = 1, opts.warmup do
110+
fun[2]()
111+
end
112+
for j = 1, opts.runs do
113+
local start = vim.loop.hrtime()
114+
fun[2]()
115+
res[i].results[j] = vim.loop.hrtime() - start
116+
end
117+
res[i].stats = get_stats(res[i].results)
118+
table.insert(output, get_output(i, res[i], opts.runs))
119+
end
120+
121+
print(string.format("%s\n%s", table.concat(output, ""), get_summary(res)))
122+
123+
return res
124+
end
125+
126+
return bench

lua/plenary/benchmark/stat.lua

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
local stat = {}
2+
3+
---Calculate mean
4+
---@param t number[] @double
5+
---@return number @double
6+
stat.mean = function(t)
7+
local sum = 0
8+
local count = 0
9+
10+
for _, v in pairs(t) do
11+
if type(v) == "number" then
12+
sum = sum + v
13+
count = count + 1
14+
end
15+
end
16+
17+
return (sum / count)
18+
end
19+
20+
-- Get the median of a table.
21+
---@param t number[]
22+
---@return number
23+
stat.median = function(t)
24+
local temp = {}
25+
26+
-- deep copy table so that when we sort it, the original is unchanged
27+
-- also weed out any non numbers
28+
for _, v in pairs(t) do
29+
if type(v) == "number" then
30+
table.insert(temp, v)
31+
end
32+
end
33+
34+
table.sort(temp)
35+
36+
-- If we have an even number of table elements or odd.
37+
if math.fmod(#temp, 2) == 0 then
38+
-- return mean value of middle two elements
39+
return (temp[#temp / 2] + temp[(#temp / 2) + 1]) / 2
40+
else
41+
-- return middle element
42+
return temp[math.ceil(#temp / 2)]
43+
end
44+
end
45+
46+
--- Get the standard deviation of a table
47+
---@param t number[]
48+
stat.std_dev = function(t)
49+
local m, vm, result
50+
local sum = 0
51+
local count = 0
52+
53+
m = stat.mean(t)
54+
55+
for _, v in pairs(t) do
56+
if type(v) == "number" then
57+
vm = v - m
58+
sum = sum + (vm * vm)
59+
count = count + 1
60+
end
61+
end
62+
63+
result = math.sqrt(sum / (count - 1))
64+
65+
return result
66+
end
67+
68+
---Get the max and min for a table
69+
---@param t number[]
70+
---@return number
71+
---@return number
72+
stat.maxmin = function(t)
73+
local max = -math.huge
74+
local min = math.huge
75+
76+
for _, v in pairs(t) do
77+
if type(v) == "number" then
78+
max = math.max(max, v)
79+
min = math.min(min, v)
80+
end
81+
end
82+
83+
return max, min
84+
end
85+
86+
return stat

lua/plenary/test_harness.lua

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ function harness.test_directory(directory, opts)
6161
if res.border_win_id then
6262
vim.api.nvim_win_set_option(res.border_win_id, "winhl", "Normal:Normal")
6363
end
64+
65+
if res.bufnr then
66+
vim.api.nvim_buf_set_option(res.bufnr, "filetype", "PlenaryTestPopup")
67+
end
6468
vim.cmd "mode"
6569
end
6670

0 commit comments

Comments
 (0)