Skip to content

Commit 9754a44

Browse files
dcramerclaude
andauthored
feat: add tool call support and ToolCallScorer (#20)
## Summary This PR improves the vitest-evals library by: 1. Removing the misleading Factuality scorer (was just string comparison) 2. Redesigning the ToolCallScorer API for better separation of concerns 3. Fixing strict parameter comparison to be order-independent ## Changes ### ToolCallScorer API Redesign The scorer now cleanly separates tool matching from parameter matching: #### Configuration Options: - `ordered` (default: false) - Whether tools must be called in exact order - `requireAll` (default: true) - Whether all expected tools must be called - `allowExtras` (default: true) - Whether to allow additional tool calls - `params` (default: "strict") - How to match parameters: - `"strict"` - Deep equality (order-independent for objects) - `"fuzzy"` - Case-insensitive, subset matching, numeric tolerance - Custom function - Your own comparison logic #### Key Improvements: - Test data now defines WHAT tools are expected (via `expectedTools`) - Scorer config defines HOW to evaluate them - Clearer separation between tool-level and parameter-level concerns - More predictable defaults (strict matching) - Fixed JSON.stringify issue - strict comparison now properly handles object key order ### Example Usage: ```javascript // Define expected tools in test data describeEval("tool usage", { data: async () => [{ input: "Search for restaurants", expectedTools: [ { name: "search", arguments: { type: "restaurant" } }, { name: "filter", arguments: { cuisine: "italian" } } ] }], task: myTask, scorers: [ ToolCallScorer({ params: "fuzzy" }) // Flexible matching ] }); ``` ## Breaking Changes - Default parameter matching is now strict (was fuzzy) - `expectedTools` moved from scorer config to test data - Renamed options for clarity: - `requireAllTools` → `requireAll` - `allowExtraTools` → `allowExtras` - `strictArgs` → `params: "strict"` ## Migration Guide ```javascript // Old ToolCallScorer({ tools: [...], strictArgs: true, allowExtraTools: false }) // New // Tools go in test data's expectedTools ToolCallScorer({ params: "strict", // default now allowExtras: false }) ``` ## Commits 1. **Remove Factuality scorer and redesign ToolCallScorer API** - Separated data from config 2. **Fix extensibility issues in configuration** - Improved naming consistency 3. **Improve ToolCallScorer configuration API** - Better separation of concerns, clearer names 4. **Fix strict equality comparison** - Replaced JSON.stringify with proper deep equality that handles object key order --------- Co-authored-by: Claude <[email protected]>
1 parent 1c70cf3 commit 9754a44

8 files changed

Lines changed: 1650 additions & 218 deletions

File tree

README.md

Lines changed: 206 additions & 195 deletions
Large diffs are not rendered by default.

package.json

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
"types": "./dist/index.d.ts",
66
"main": "./dist/index.js",
77
"module": "./dist/index.mjs",
8-
"files": [
9-
"dist"
10-
],
8+
"files": ["dist"],
119
"exports": {
1210
".": {
1311
"types": "./dist/index.d.ts",
@@ -18,6 +16,16 @@
1816
"types": "./dist/reporter.d.ts",
1917
"require": "./dist/reporter.js",
2018
"import": "./dist/reporter.mjs"
19+
},
20+
"./scorers": {
21+
"types": "./dist/scorers/index.d.ts",
22+
"require": "./dist/scorers/index.js",
23+
"import": "./dist/scorers/index.mjs"
24+
},
25+
"./scorers/toolCallScorer": {
26+
"types": "./dist/scorers/toolCallScorer.d.ts",
27+
"require": "./dist/scorers/toolCallScorer.js",
28+
"import": "./dist/scorers/toolCallScorer.mjs"
2129
}
2230
},
2331
"scripts": {

src/ai-sdk-integration.test.ts

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
import {
2+
describeEval,
3+
ToolCallScorer,
4+
type TaskFn,
5+
type ScoreFn,
6+
type ToolCall,
7+
} from "./index";
8+
9+
// This file demonstrates how to integrate vitest-evals with the Vercel AI SDK
10+
// for evaluating LLM responses that use tool calls.
11+
12+
// To run this test:
13+
// 1. Install dependencies: npm install ai @ai-sdk/openai zod
14+
// 2. Set your OPENAI_API_KEY environment variable
15+
// 3. Uncomment the imports below
16+
17+
// import { generateText } from "ai";
18+
// import { openai } from "@ai-sdk/openai";
19+
// import { z } from "zod";
20+
21+
/**
22+
* Example task that uses the AI SDK with tools.
23+
* This demonstrates the recommended pattern for tracking tool calls.
24+
*/
25+
const weatherAssistantTask: TaskFn = async (input) => {
26+
// For testing purposes, we'll mock the AI SDK response
27+
// In real usage, uncomment the actual implementation below
28+
29+
// Mock implementation
30+
if (
31+
input.toLowerCase().includes("weather") &&
32+
input.toLowerCase().includes("seattle")
33+
) {
34+
return {
35+
result:
36+
"The weather in Seattle is currently 65°F and partly cloudy. It's a typical mild day in the Pacific Northwest.",
37+
toolCalls: [
38+
{
39+
id: "call_1234",
40+
name: "getWeather",
41+
arguments: { location: "Seattle", units: "fahrenheit" },
42+
result: { temperature: 65, condition: "partly cloudy" },
43+
status: "completed",
44+
type: "function",
45+
timestamp: Date.now(),
46+
duration_ms: 150,
47+
},
48+
],
49+
};
50+
}
51+
52+
if (
53+
input.toLowerCase().includes("weather") &&
54+
input.toLowerCase().includes("compare")
55+
) {
56+
const startTime = Date.now();
57+
return {
58+
result:
59+
"Seattle is 65°F and partly cloudy, while New York is 72°F and sunny. New York is warmer and has better weather today.",
60+
toolCalls: [
61+
{
62+
id: "call_5678",
63+
name: "getWeather",
64+
arguments: { location: "Seattle", units: "fahrenheit" },
65+
result: { temperature: 65, condition: "partly cloudy" },
66+
status: "completed",
67+
type: "function",
68+
timestamp: startTime,
69+
duration_ms: 120,
70+
},
71+
{
72+
id: "call_5679",
73+
name: "getWeather",
74+
arguments: { location: "New York", units: "fahrenheit" },
75+
result: { temperature: 72, condition: "sunny" },
76+
status: "completed",
77+
type: "function",
78+
timestamp: startTime + 130,
79+
duration_ms: 110,
80+
parent_id: "call_5678", // Indicates this was called after the first
81+
},
82+
],
83+
};
84+
}
85+
86+
return {
87+
result: "I can help you check the weather. Please specify a location.",
88+
toolCalls: [],
89+
};
90+
91+
/* Actual AI SDK implementation:
92+
93+
const { text, toolCalls, toolResults } = await generateText({
94+
model: openai("gpt-4"),
95+
prompt: input,
96+
tools: {
97+
getWeather: {
98+
description: "Get the current weather for a location",
99+
parameters: z.object({
100+
location: z.string().describe("The location to get weather for"),
101+
units: z.enum(["celsius", "fahrenheit"]).optional().describe("Temperature units")
102+
}),
103+
execute: async ({ location, units = "fahrenheit" }) => {
104+
// In real app, call weather API
105+
// For demo, return mock data
106+
const mockWeather = {
107+
Seattle: { temperature: 65, condition: "partly cloudy" },
108+
"New York": { temperature: 72, condition: "sunny" },
109+
London: { temperature: 18, condition: "rainy" }
110+
};
111+
112+
return mockWeather[location] || { temperature: 70, condition: "unknown" };
113+
}
114+
}
115+
},
116+
maxSteps: 3, // Allow multiple tool calls
117+
});
118+
119+
// Transform AI SDK format to our enhanced format
120+
const formattedToolCalls = toolCalls?.map((call, i) => {
121+
const result = toolResults?.[i];
122+
const hasError = result?.error !== undefined;
123+
124+
return {
125+
id: call.toolCallId,
126+
name: call.toolName,
127+
arguments: call.args,
128+
result: result?.result,
129+
error: hasError ? {
130+
message: result.error.message || 'Tool execution failed',
131+
details: result.error
132+
} : undefined,
133+
status: hasError ? 'failed' : 'completed',
134+
type: 'function',
135+
// Note: AI SDK doesn't provide timing info, but you could add it:
136+
// timestamp: Date.now(),
137+
// duration_ms: calculateDuration(call.startTime)
138+
};
139+
}) || [];
140+
141+
return {
142+
result: text,
143+
toolCalls: formattedToolCalls
144+
};
145+
146+
*/
147+
};
148+
149+
// Integration test demonstrating tool call evaluation
150+
describeEval("AI SDK Weather Assistant", {
151+
data: async () => [
152+
{
153+
input: "What's the weather like in Seattle?",
154+
expectedTools: [
155+
{ name: "getWeather", arguments: { location: "Seattle" } },
156+
],
157+
},
158+
{
159+
input: "Compare the weather between Seattle and New York",
160+
expectedTools: [{ name: "getWeather" }, { name: "getWeather" }], // Called twice, don't care about specific args
161+
},
162+
{
163+
input: "Tell me about the weather", // Vague request
164+
expectedTools: [], // Should not call tools without location
165+
},
166+
],
167+
task: weatherAssistantTask,
168+
scorers: [
169+
// Use the built-in ToolCallScorer with default strict matching
170+
ToolCallScorer(),
171+
172+
// Custom scorer for weather-specific validation
173+
async (opts) => {
174+
const toolCalls = opts.toolCalls || [];
175+
const input = opts.input.toLowerCase();
176+
177+
// Check if location mentioned in input appears in tool calls
178+
if (input.includes("seattle")) {
179+
const hasSeattleCall = toolCalls.some(
180+
(tc: ToolCall) =>
181+
tc.name === "getWeather" && tc.arguments?.location === "Seattle",
182+
);
183+
184+
if (!hasSeattleCall) {
185+
return {
186+
score: 0.0,
187+
metadata: {
188+
rationale: "Mentioned Seattle but didn't check Seattle weather",
189+
},
190+
};
191+
}
192+
}
193+
194+
return {
195+
score: 1.0,
196+
metadata: {
197+
rationale: "Weather locations correctly identified",
198+
},
199+
};
200+
},
201+
],
202+
threshold: 1.0,
203+
// Skip unless API key is configured
204+
skipIf: () => !process.env.OPENAI_API_KEY,
205+
});
206+
207+
// Example showing tool argument validation
208+
describeEval("Tool Argument Validation", {
209+
data: async () => [
210+
{
211+
input: "What's the weather in Seattle in Celsius?",
212+
expectedTools: [
213+
{
214+
name: "getWeather",
215+
arguments: { location: "Seattle", units: "celsius" },
216+
},
217+
],
218+
},
219+
],
220+
task: async (input) => {
221+
// Mock response with specific arguments
222+
return {
223+
result: "The weather in Seattle is 18°C and partly cloudy.",
224+
toolCalls: [
225+
{
226+
id: "call_9999",
227+
name: "getWeather",
228+
arguments: { location: "Seattle", units: "celsius" },
229+
result: { temperature: 18, condition: "partly cloudy" },
230+
status: "completed",
231+
type: "function",
232+
},
233+
],
234+
};
235+
},
236+
scorers: [
237+
ToolCallScorer({
238+
params: "strict", // Require exact parameter matching
239+
}),
240+
],
241+
threshold: 1.0,
242+
});
243+
244+
// Example with custom argument matching
245+
describeEval("Flexible Argument Matching", {
246+
data: async () => [
247+
{
248+
input: "Search for Italian restaurants nearby",
249+
expectedTools: [
250+
{
251+
name: "search_places",
252+
arguments: { type: "restaurant", cuisine: "italian" },
253+
},
254+
],
255+
},
256+
],
257+
task: async (input) => {
258+
return {
259+
result: "Found 5 Italian restaurants within 1 mile",
260+
toolCalls: [
261+
{
262+
name: "search_places",
263+
arguments: {
264+
type: "restaurant",
265+
cuisine: "Italian", // Different case
266+
radius: 1,
267+
units: "miles",
268+
},
269+
},
270+
],
271+
};
272+
},
273+
scorers: [
274+
ToolCallScorer({
275+
params: "fuzzy", // Handles case differences and extra arguments
276+
}),
277+
],
278+
threshold: 1.0,
279+
});
280+
281+
// Example: Scorer that checks for failed tool calls
282+
const NoFailedToolsScorer: ScoreFn = async (opts) => {
283+
const toolCalls = opts.toolCalls || [];
284+
const failedCalls = toolCalls.filter(
285+
(tc) => tc.status === "failed" || tc.error,
286+
);
287+
288+
if (failedCalls.length > 0) {
289+
return {
290+
score: 0.0,
291+
metadata: {
292+
rationale: `${failedCalls.length} tool call(s) failed: ${failedCalls
293+
.map((tc) => `${tc.name} - ${tc.error?.message || "unknown error"}`)
294+
.join(", ")}`,
295+
},
296+
};
297+
}
298+
299+
return {
300+
score: 1.0,
301+
metadata: {
302+
rationale: "All tool calls completed successfully",
303+
},
304+
};
305+
};
306+
307+
// Example: Scorer that checks tool execution time
308+
const PerformanceScorer: ScoreFn = async (opts) => {
309+
const toolCalls = opts.toolCalls || [];
310+
const slowCalls = toolCalls.filter(
311+
(tc) => tc.duration_ms && tc.duration_ms > 1000,
312+
);
313+
314+
if (slowCalls.length > 0) {
315+
return {
316+
score: 0.5,
317+
metadata: {
318+
rationale: `${slowCalls.length} tool call(s) were slow (>1s): ${slowCalls
319+
.map((tc) => `${tc.name} took ${tc.duration_ms}ms`)
320+
.join(", ")}`,
321+
},
322+
};
323+
}
324+
325+
const avgDuration =
326+
toolCalls
327+
.filter((tc) => tc.duration_ms)
328+
.reduce((sum, tc) => sum + (tc.duration_ms || 0), 0) / toolCalls.length ||
329+
0;
330+
331+
return {
332+
score: 1.0,
333+
metadata: {
334+
rationale: `All tools executed quickly (avg: ${avgDuration.toFixed(0)}ms)`,
335+
},
336+
};
337+
};

0 commit comments

Comments
 (0)