Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions higher_order.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,32 +101,44 @@ func init() {
higherOrderForms[tryIndexFormName] = formTryIndex
}

// iterItems evaluates collExpr and converts the result to a []any for
// itemSeq is an indexable view over a list-shaped collection. Typed
// slices stay behind a reflect.Value and box elements lazily in at(),
// so forms never materialize an intermediate []any just to iterate.
type itemSeq struct {
items []any // set when the collection already is a []any
rv reflect.Value // used otherwise (typed slices and arrays)
n int
}

func (s itemSeq) at(i int) any {
if s.items != nil {
return s.items[i]
}
return s.rv.Index(i).Interface()
}

// iterItems evaluates collExpr and wraps the result in an itemSeq for
// predicate iteration. nil is treated as an empty list so
// `map(nil, it)` / `filter(nil, it > 0)` return empty without error.
// Maps and other non-list shapes return a user-friendly error naming
// the form, so users do not have to guess which argument was wrong.
func (p *Program) iterItems(ctx context.Context, name string, collExpr ast.Expr, env any, depth int) ([]any, error) {
func (p *Program) iterItems(ctx context.Context, name string, collExpr ast.Expr, env any, depth int) (itemSeq, error) {
coll, err := p.eval(ctx, collExpr, env, depth)
if err != nil {
return nil, err
return itemSeq{}, err
}
if coll == nil {
return nil, nil
return itemSeq{}, nil
}
if s, ok := coll.([]any); ok {
return s, nil
return itemSeq{items: s, n: len(s)}, nil
}
rv := reflect.ValueOf(coll)
switch rv.Kind() {
case reflect.Slice, reflect.Array:
out := make([]any, rv.Len())
for i := 0; i < rv.Len(); i++ {
out[i] = rv.Index(i).Interface()
}
return out, nil
return itemSeq{rv: rv, n: rv.Len()}, nil
}
return nil, fmt.Errorf("%w: %s expects a list as its first argument, got %T",
return itemSeq{}, fmt.Errorf("%w: %s expects a list as its first argument, got %T",
ErrEvaluate, name, coll)
}

Expand All @@ -152,14 +164,15 @@ func checkFormArity(name string, got int) error {
func (p *Program) forEach(
ctx context.Context,
name string,
items []any,
items itemSeq,
predicate ast.Expr,
env any,
depth int,
body func(item any, result any) (stop bool, err error),
) error {
scope := &itEnv{parent: env}
for i, item := range items {
for i := 0; i < items.n; i++ {
item := items.at(i)
scope.it = item
scope.index = int64(i)
v, err := p.eval(ctx, predicate, scope, depth)
Expand Down Expand Up @@ -361,7 +374,7 @@ func formMap(p *Program, ctx context.Context, n *ast.CallExpr, env any, depth in
if err != nil {
return nil, err
}
out := make([]any, 0, len(items))
out := make([]any, 0, items.n)
err = p.forEach(ctx, "map", items, n.Args[1], env, depth, func(_ any, v any) (bool, error) {
out = append(out, v)
return false, nil
Expand All @@ -380,7 +393,7 @@ func formFilter(p *Program, ctx context.Context, n *ast.CallExpr, env any, depth
if err != nil {
return nil, err
}
out := make([]any, 0, len(items))
out := make([]any, 0, items.n)
err = p.forEach(ctx, "filter", items, n.Args[1], env, depth, func(item any, v any) (bool, error) {
if isTruthy(v) {
out = append(out, item)
Expand Down
119 changes: 119 additions & 0 deletions index_chain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package expr

import (
"strings"
"testing"

"github.com/deepnoodle-ai/expr/internal/require"
)

// These tests pin the reflect-space index-chain fast path
// (evalIndexChainRV) to the general path's behavior. Struct envs route
// ident-rooted chains like Grid[1][2] or Inner.Vals[0] through the fast
// path, so values and error messages must match what indexValue and
// selectField produce for the same shapes.

type indexChainEnv struct {
Grid [2][3]int
Items []string
Meta map[string]int
Mixed map[string]any
Inner struct {
Vals []int
}
Text string
}

func newIndexChainEnv() indexChainEnv {
env := indexChainEnv{
Grid: [2][3]int{{1, 2, 3}, {4, 5, 6}},
Items: []string{"a", "b", "c"},
Meta: map[string]int{"x": 7},
Mixed: map[string]any{"list": []any{int64(10), int64(20)}},
Text: "héllo",
}
env.Inner.Vals = []int{42, 43}
return env
}

func TestIndexChainStructEnv(t *testing.T) {
env := newIndexChainEnv()
cases := []struct {
src string
want any
}{
{src: `Grid[1][2]`, want: 6},
{src: `Grid[0][0]`, want: 1},
{src: `Items[1]`, want: "b"},
{src: `Meta["x"]`, want: 7},
{src: `Mixed["list"][1]`, want: int64(20)},
{src: `Inner.Vals[0]`, want: 42},
{src: `Text[1]`, want: "é"},
{src: `Items[Grid[0][0]]`, want: "b"},
}
for _, tc := range cases {
t.Run(tc.src, func(t *testing.T) {
got, err := evalExpr(t.Context(), tc.src, env)
require.NoError(t, err)
require.Equal(t, tc.want, got)
})
}
}

func TestIndexChainStructEnvErrors(t *testing.T) {
env := newIndexChainEnv()
cases := []struct {
src string
wantErr string
}{
{src: `Items[5]`, wantErr: "index 5 out of range [0, 3)"},
{src: `Grid[0][9]`, wantErr: "index 9 out of range [0, 3)"},
{src: `Items[-1]`, wantErr: "index -1 out of range"},
// Non-map[string]any maps format the missing key with %v,
// matching indexValue's general map branch.
{src: `Meta["nope"]`, wantErr: `key nope not found`},
{src: `Mixed[0]`, wantErr: "map index must be string, got int64"},
{src: `Grid[0]["x"]`, wantErr: "index must be integer"},
{src: `Meta["x"][0]`, wantErr: "cannot index int"},
{src: `Inner.Nope[0]`, wantErr: `field "Nope" not found`},
{src: `Text[99]`, wantErr: "index 99 out of range [0, 5)"},
}
for _, tc := range cases {
t.Run(tc.src, func(t *testing.T) {
_, err := evalExpr(t.Context(), tc.src, env)
require.Error(t, err)
require.ErrorIs(t, err, ErrEvaluate)
if !strings.Contains(err.Error(), tc.wantErr) {
t.Fatalf("error %q does not contain %q", err.Error(), tc.wantErr)
}
})
}
}

// The same expressions evaluated against a map env take the general
// path (single hops) or enter the chain with a boxed root (multi hop);
// results must agree with the struct-env fast path.
func TestIndexChainMapEnvParity(t *testing.T) {
structEnv := newIndexChainEnv()
mapEnv := map[string]any{
"Grid": structEnv.Grid,
"Items": structEnv.Items,
"Meta": structEnv.Meta,
"Mixed": structEnv.Mixed,
"Inner": structEnv.Inner,
"Text": structEnv.Text,
}
srcs := []string{
`Grid[1][2]`, `Items[1]`, `Meta["x"]`, `Mixed["list"][1]`,
`Inner.Vals[0]`, `Text[1]`, `Items[Grid[0][0]]`,
}
for _, src := range srcs {
t.Run(src, func(t *testing.T) {
fromStruct, err := evalExpr(t.Context(), src, structEnv)
require.NoError(t, err)
fromMap, err := evalExpr(t.Context(), src, mapEnv)
require.NoError(t, err)
require.Equal(t, fromStruct, fromMap)
})
}
}
3 changes: 1 addition & 2 deletions prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ func prepareFunc(name string, fn any) (*preparedFunc, error) {
p.paramTypes[i] = ft.In(p.paramOff + i)
}
if p.numOut == 2 {
errType := reflect.TypeOf((*error)(nil)).Elem()
if !ft.Out(1).Implements(errType) {
if !ft.Out(1).Implements(errValType) {
return nil, fmt.Errorf("function %q: second return must be error, got %v", name, ft.Out(1))
}
p.hasErrRet = true
Expand Down
Loading
Loading