From 89d00405f97115e9322c0f7502427cae004010ec Mon Sep 17 00:00:00 2001 From: Ruben de Vries Date: Mon, 28 Jun 2021 13:37:33 +0200 Subject: [PATCH] when using EqualValues for 2 objects which are of the same type and implement an `(T) Equal(T) bool` method, use the method to compare. --- assert/assertions.go | 5 ++ assert/assertions_test.go | 15 +++++ assert/isequaller.go | 59 ++++++++++++++++ assert/isequaller_test.go | 138 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 217 insertions(+) create mode 100644 assert/isequaller.go create mode 100644 assert/isequaller_test.go diff --git a/assert/assertions.go b/assert/assertions.go index bcac4401f..64ffab020 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -88,6 +88,11 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool { } expectedValue := reflect.ValueOf(expected) if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { + if isEqualler(actualType) { + res := reflect.ValueOf(actual).MethodByName(equalMethod).Call([]reflect.Value{expectedValue}) + return res[0].Bool() + } + // Attempt comparison after type conversion return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) } diff --git a/assert/assertions_test.go b/assert/assertions_test.go index f32362b8e..d1faa99cd 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -145,7 +145,22 @@ func TestObjectsAreEqual(t *testing.T) { if ObjectsAreEqualValues(nil, 0) { t.Fail() } +} + +func TestObjectsAreEqualValuesByEqualMethod(t *testing.T) { + f1a := Foo{1, 1} + f1b := Foo{1, 2} + f2 := Foo{2, 3} + if !ObjectsAreEqualValues(f1a, f1a) { + t.Fail() + } + if !ObjectsAreEqualValues(f1a, f1b) { + t.Fail() + } + if ObjectsAreEqualValues(f1a, f2) { + t.Fail() + } } func TestImplements(t *testing.T) { diff --git a/assert/isequaller.go b/assert/isequaller.go new file mode 100644 index 000000000..48c582686 --- /dev/null +++ b/assert/isequaller.go @@ -0,0 +1,59 @@ +package assert + +import ( + "reflect" + "sync" +) + +const equalMethod = "Equal" + +var equallerCacheMu sync.RWMutex +var equallerCache map[reflect.Type]bool + +func init() { + equallerCache = make(map[reflect.Type]bool, 0) +} + +func isEqualler(t reflect.Type) bool { + isEqualler, cached := isEquallerCached(t) + if !cached { + isEqualler = determineIsEqualler(t) + setIsEquallerCached(t, isEqualler) + } + + return isEqualler +} + +func determineIsEqualler(t reflect.Type) bool { + equalMethod, hasEqualMethod := t.MethodByName(equalMethod) + if hasEqualMethod { + // should have only 1 return value which should be a bool + // and should have exactly 2 arguments (pointer method so first is self) + // of which the 2nd argument should also be of its own type + if equalMethod.Type.NumOut() != 1 || equalMethod.Type.Out(0).Kind() != reflect.Bool { + return false + } else if equalMethod.Type.NumIn() != 2 || !t.ConvertibleTo(equalMethod.Type.In(1)) { + return false + } else { + return true + } + } + + return false +} + +func isEquallerCached(t reflect.Type) (bool, bool) { + equallerCacheMu.RLock() + defer equallerCacheMu.RUnlock() + + isEqualler, cached := equallerCache[t] + + return isEqualler, cached +} + +func setIsEquallerCached(t reflect.Type, isEqualler bool) { + equallerCacheMu.Lock() + defer equallerCacheMu.Unlock() + + equallerCache[t] = isEqualler +} diff --git a/assert/isequaller_test.go b/assert/isequaller_test.go new file mode 100644 index 000000000..5ae5a0904 --- /dev/null +++ b/assert/isequaller_test.go @@ -0,0 +1,138 @@ +package assert + +import ( + "reflect" + "testing" +) + +type Foo struct { + id int + ignore int +} + +func (f Foo) Equal(other Foo) bool { + return other.id == f.id +} + +func TestDetermineIsEqualler(t *testing.T) { + fooT := reflect.TypeOf(Foo{}) + fooTIsEqualler := determineIsEqualler(fooT) + if !fooTIsEqualler { + t.Errorf("Foo should be isEqualler") + } + + fooPtrT := reflect.TypeOf(&Foo{}) + fooPtrTIsEqualler := determineIsEqualler(fooPtrT) + if fooPtrTIsEqualler { + t.Errorf("*Foo should not be isEqualler") + } +} + +type FooNoEq struct { + id int +} + +func TestDetermineIsEquallerNoMethod(t *testing.T) { + fooT := reflect.TypeOf(FooNoEq{}) + fooTIsEqualler := determineIsEqualler(fooT) + if fooTIsEqualler { + t.Errorf("FooNoEq doesn't have Equal method, shouldn't be isEqualler") + } +} + +type FooFunkyEqReturn struct { + id int +} + +func (f FooFunkyEqReturn) Equal(other FooFunkyEqReturn) (bool, bool) { + return other.id == f.id, true +} + +func TestDetermineIsEquallerFunkyEqReturn(t *testing.T) { + fooT := reflect.TypeOf(FooFunkyEqReturn{}) + fooTIsEqualler := determineIsEqualler(fooT) + if fooTIsEqualler { + t.Errorf("FooFunkyEqReturn has a weird return value for Equal, shouldn't be isEqualler") + } +} + +type FooFunkyEqArg struct { + id int +} + +func (f FooFunkyEqArg) Equal(other Foo) (bool, bool) { + return other.id == f.id, true +} + +func TestDetermineIsEquallerFunkyEqArg(t *testing.T) { + fooT := reflect.TypeOf(FooFunkyEqArg{}) + fooTIsEqualler := determineIsEqualler(fooT) + if fooTIsEqualler { + t.Errorf("FooFunkyEqArg has a weird argument value for Equal, shouldn't be isEqualler") + } +} + +func TestIsEquallerCache(t *testing.T) { + fooT := reflect.TypeOf(Foo{}) + fooPtrT := reflect.TypeOf(&Foo{}) + + // reset cache + equallerCache = make(map[reflect.Type]bool, 0) + + if _, isCached := isEquallerCached(fooT); isCached { + t.Errorf("Foo shouldn't be cached yet") + } + if _, isCached := isEquallerCached(fooPtrT); isCached { + t.Errorf("*Foo shouldn't be cached yet") + } + + setIsEquallerCached(fooT, true) + + if isEqualler, isCached := isEquallerCached(fooT); !isCached && !isEqualler { + t.Errorf("Foo should be cached and true") + } + if _, isCached := isEquallerCached(fooPtrT); isCached { + t.Errorf("*Foo shouldn't be cached yet") + } + + setIsEquallerCached(fooPtrT, false) + + if isEqualler, isCached := isEquallerCached(fooT); !isCached && !isEqualler { + t.Errorf("Foo should be cached and true") + } + if isEqualler, isCached := isEquallerCached(fooPtrT); isCached && isEqualler { + t.Errorf("*Foo should be cached and false") + } + +} + +// the tests for determineIsEqualler should cover most cases, here we just test we are using the cache +func TestIsEqualler(t *testing.T) { + fooT := reflect.TypeOf(Foo{}) + fooPtrT := reflect.TypeOf(&Foo{}) + + // reset cache + equallerCache = make(map[reflect.Type]bool, 0) + + setIsEquallerCached(fooT, true) + + if !isEqualler(fooT) { + t.Errorf("Foo should be cached and true") + } + + setIsEquallerCached(fooPtrT, true) + + if !isEqualler(fooPtrT) { + t.Errorf("*Foo should be cached and true") + } + + setIsEquallerCached(fooPtrT, false) + + if !isEqualler(fooT) { + t.Errorf("Foo should be cached and true") + } + if isEqualler(fooPtrT) { + t.Errorf("*Foo should be cached and false") + } + +}