From a4fb614f3519c53081f22b43dcb73e65e17f7659 Mon Sep 17 00:00:00 2001 From: Michael Brown Date: Mon, 27 May 2019 12:51:44 +0100 Subject: [PATCH] Fix __ge__ and __gt__ comparison methods Quoting the Python documentation for sets: "The subset and equality comparisons do not generalize to a total ordering function. For example, any two nonempty disjoint sets are not equal and are not subsets of each other, so all of the following return False: ab." It is therefore not possible to define __ge__ as the inverse of __lt__ (and similarly __gt__ as the inverse of __le__), since this will give false positive results. Fix by implementing __ge__ and __gt__ using equivalent logic to that used in __le__ and __lt__. Signed-off-by: Michael Brown --- lib/orderedset/_orderedset.pyx | 18 ++++++++++-------- tests/test_orderedset.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/lib/orderedset/_orderedset.pyx b/lib/orderedset/_orderedset.pyx index 85a5507..e46e934 100644 --- a/lib/orderedset/_orderedset.pyx +++ b/lib/orderedset/_orderedset.pyx @@ -499,13 +499,15 @@ class OrderedSet(_OrderedSet, MutableSet): return NotImplemented def __ge__(self, other): - ret = self < other - if ret is NotImplemented: - return ret - return not ret + if isinstance(other, Set): + return len(self) >= len(other) and set(self) >= set(other) + elif isinstance(other, list): + return len(self) >= len(other) and list(self) >= list(other) + return NotImplemented def __gt__(self, other): - ret = self <= other - if ret is NotImplemented: - return ret - return not ret + if isinstance(other, Set): + return len(self) > len(other) and set(self) > set(other) + elif isinstance(other, list): + return len(self) > len(other) and list(self) > list(other) + return NotImplemented diff --git a/tests/test_orderedset.py b/tests/test_orderedset.py index f8f3a49..fc99be2 100755 --- a/tests/test_orderedset.py +++ b/tests/test_orderedset.py @@ -371,6 +371,21 @@ def test_ordering(self): self.assertGreater(oset1, set(oset3)) self.assertGreater(oset1, list(oset3)) + oset4 = OrderedSet(self.lst[1:]) + + self.assertFalse(oset3 < oset4) + self.assertFalse(oset3 < set(oset4)) + self.assertFalse(oset3 < list(oset4)) + self.assertFalse(oset3 >= oset4) + self.assertFalse(oset3 >= set(oset4)) + self.assertFalse(oset3 >= list(oset4)) + self.assertFalse(oset3 < oset4) + self.assertFalse(oset3 < set(oset4)) + self.assertFalse(oset3 < list(oset4)) + self.assertFalse(oset3 >= oset4) + self.assertFalse(oset3 >= set(oset4)) + self.assertFalse(oset3 >= list(oset4)) + if __name__ == '__main__': unittest.main()