Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
}
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
if isEqualler(actualType) {
res := reflect.ValueOf(actual).MethodByName(equalMethod).Call([]reflect.Value{expectedValue})
return res[0].Bool()
}

// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
}
Expand Down
15 changes: 15 additions & 0 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,22 @@ func TestObjectsAreEqual(t *testing.T) {
if ObjectsAreEqualValues(nil, 0) {
t.Fail()
}
}

func TestObjectsAreEqualValuesByEqualMethod(t *testing.T) {
f1a := Foo{1, 1}
f1b := Foo{1, 2}
f2 := Foo{2, 3}

if !ObjectsAreEqualValues(f1a, f1a) {
t.Fail()
}
if !ObjectsAreEqualValues(f1a, f1b) {
t.Fail()
}
if ObjectsAreEqualValues(f1a, f2) {
t.Fail()
}
}

func TestImplements(t *testing.T) {
Expand Down
59 changes: 59 additions & 0 deletions assert/isequaller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package assert

import (
"reflect"
"sync"
)

const equalMethod = "Equal"

var equallerCacheMu sync.RWMutex
var equallerCache map[reflect.Type]bool

func init() {
equallerCache = make(map[reflect.Type]bool, 0)
}

func isEqualler(t reflect.Type) bool {
isEqualler, cached := isEquallerCached(t)
if !cached {
isEqualler = determineIsEqualler(t)
setIsEquallerCached(t, isEqualler)
}

return isEqualler
}

func determineIsEqualler(t reflect.Type) bool {
equalMethod, hasEqualMethod := t.MethodByName(equalMethod)
if hasEqualMethod {
// should have only 1 return value which should be a bool
// and should have exactly 2 arguments (pointer method so first is self)
// of which the 2nd argument should also be of its own type
if equalMethod.Type.NumOut() != 1 || equalMethod.Type.Out(0).Kind() != reflect.Bool {
return false
} else if equalMethod.Type.NumIn() != 2 || !t.ConvertibleTo(equalMethod.Type.In(1)) {
return false
} else {
return true
}
}

return false
}

func isEquallerCached(t reflect.Type) (bool, bool) {
equallerCacheMu.RLock()
defer equallerCacheMu.RUnlock()

isEqualler, cached := equallerCache[t]

return isEqualler, cached
}

func setIsEquallerCached(t reflect.Type, isEqualler bool) {
equallerCacheMu.Lock()
defer equallerCacheMu.Unlock()

equallerCache[t] = isEqualler
}
138 changes: 138 additions & 0 deletions assert/isequaller_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package assert

import (
"reflect"
"testing"
)

type Foo struct {
id int
ignore int
}

func (f Foo) Equal(other Foo) bool {
return other.id == f.id
}

func TestDetermineIsEqualler(t *testing.T) {
fooT := reflect.TypeOf(Foo{})
fooTIsEqualler := determineIsEqualler(fooT)
if !fooTIsEqualler {
t.Errorf("Foo should be isEqualler")
}

fooPtrT := reflect.TypeOf(&Foo{})
fooPtrTIsEqualler := determineIsEqualler(fooPtrT)
if fooPtrTIsEqualler {
t.Errorf("*Foo should not be isEqualler")
}
}

type FooNoEq struct {
id int
}

func TestDetermineIsEquallerNoMethod(t *testing.T) {
fooT := reflect.TypeOf(FooNoEq{})
fooTIsEqualler := determineIsEqualler(fooT)
if fooTIsEqualler {
t.Errorf("FooNoEq doesn't have Equal method, shouldn't be isEqualler")
}
}

type FooFunkyEqReturn struct {
id int
}

func (f FooFunkyEqReturn) Equal(other FooFunkyEqReturn) (bool, bool) {
return other.id == f.id, true
}

func TestDetermineIsEquallerFunkyEqReturn(t *testing.T) {
fooT := reflect.TypeOf(FooFunkyEqReturn{})
fooTIsEqualler := determineIsEqualler(fooT)
if fooTIsEqualler {
t.Errorf("FooFunkyEqReturn has a weird return value for Equal, shouldn't be isEqualler")
}
}

type FooFunkyEqArg struct {
id int
}

func (f FooFunkyEqArg) Equal(other Foo) (bool, bool) {
return other.id == f.id, true
}

func TestDetermineIsEquallerFunkyEqArg(t *testing.T) {
fooT := reflect.TypeOf(FooFunkyEqArg{})
fooTIsEqualler := determineIsEqualler(fooT)
if fooTIsEqualler {
t.Errorf("FooFunkyEqArg has a weird argument value for Equal, shouldn't be isEqualler")
}
}

func TestIsEquallerCache(t *testing.T) {
fooT := reflect.TypeOf(Foo{})
fooPtrT := reflect.TypeOf(&Foo{})

// reset cache
equallerCache = make(map[reflect.Type]bool, 0)

if _, isCached := isEquallerCached(fooT); isCached {
t.Errorf("Foo shouldn't be cached yet")
}
if _, isCached := isEquallerCached(fooPtrT); isCached {
t.Errorf("*Foo shouldn't be cached yet")
}

setIsEquallerCached(fooT, true)

if isEqualler, isCached := isEquallerCached(fooT); !isCached && !isEqualler {
t.Errorf("Foo should be cached and true")
}
if _, isCached := isEquallerCached(fooPtrT); isCached {
t.Errorf("*Foo shouldn't be cached yet")
}

setIsEquallerCached(fooPtrT, false)

if isEqualler, isCached := isEquallerCached(fooT); !isCached && !isEqualler {
t.Errorf("Foo should be cached and true")
}
if isEqualler, isCached := isEquallerCached(fooPtrT); isCached && isEqualler {
t.Errorf("*Foo should be cached and false")
}

}

// the tests for determineIsEqualler should cover most cases, here we just test we are using the cache
func TestIsEqualler(t *testing.T) {
fooT := reflect.TypeOf(Foo{})
fooPtrT := reflect.TypeOf(&Foo{})

// reset cache
equallerCache = make(map[reflect.Type]bool, 0)

setIsEquallerCached(fooT, true)

if !isEqualler(fooT) {
t.Errorf("Foo should be cached and true")
}

setIsEquallerCached(fooPtrT, true)

if !isEqualler(fooPtrT) {
t.Errorf("*Foo should be cached and true")
}

setIsEquallerCached(fooPtrT, false)

if !isEqualler(fooT) {
t.Errorf("Foo should be cached and true")
}
if isEqualler(fooPtrT) {
t.Errorf("*Foo should be cached and false")
}

}