8354433: Assert in AbstractRBTree::visit_range_in_order(const K& from, const K& to, F f) is wrong

Reviewed-by: jsjolen, aboldtch
This commit is contained in:
Casper Norrbin 2025-05-08 16:21:14 +00:00 committed by Johan Sjölen
parent 7f3191a630
commit 1e8927dded
3 changed files with 58 additions and 41 deletions

View File

@ -37,10 +37,14 @@
// - an int < 0 when a < b
// - an int == 0 when a == b
// - an int > 0 when a > b
// A second static function `cmp(const IntrusiveRBNode* a, const IntrusiveRBNode* b)`
// used for `verify_self` and other extra validation can optionally be provided. This should return:
// - true if a < b
// - false otherwise
// Additional static functions used for extra validation can optionally be provided:
// `cmp(K a, K b)` which returns:
// - an int < 0 when a < b
// - an int == 0 when a == b
// - an int > 0 when a > b
// `cmp(const IntrusiveRBNode* a, const IntrusiveRBNode* b)` which returns:
// - true if a < b
// - false otherwise
// K needs to be of a type that is trivially destructible.
// K needs to be stored by the user and is not stored inside the tree.
// Nodes are address stable and will not change during its lifetime.
@ -48,7 +52,7 @@
// A red-black tree is constructed with four template parameters:
// K is the key type stored in the tree nodes.
// V is the value type stored in the tree nodes.
// COMPARATOR must have one of the static functions `cmp(K a, K b)` or `cmp(K a, const RBNode<K, V>* b)` which returns:
// COMPARATOR must have a static function `cmp(K a, K b)` which returns:
// - an int < 0 when a < b
// - an int == 0 when a == b
// - an int > 0 when a > b
@ -198,20 +202,20 @@ private:
struct has_cmp_type<CMP, RET, ARG1, ARG2, decltype(static_cast<RET(*)(ARG1, ARG2)>(CMP::cmp), void())> : std::true_type {};
template <typename CMP>
static constexpr bool IsKeyComparator = has_cmp_type<CMP, int, K, K>::value;
static constexpr bool HasKeyComparator = has_cmp_type<CMP, int, K, K>::value;
template <typename CMP>
static constexpr bool IsNodeComparator = has_cmp_type<CMP, int, K, const NodeType*>::value;
static constexpr bool HasNodeComparator = has_cmp_type<CMP, int, K, const NodeType*>::value;
template <typename CMP>
static constexpr bool HasNodeVerifier = has_cmp_type<CMP, bool, const NodeType*, const NodeType*>::value;
template <typename CMP = COMPARATOR, ENABLE_IF(IsKeyComparator<CMP>)>
template <typename CMP = COMPARATOR, ENABLE_IF(HasKeyComparator<CMP> && !HasNodeComparator<CMP>)>
int cmp(const K& a, const NodeType* b) const {
return COMPARATOR::cmp(a, b->key());
}
template <typename CMP = COMPARATOR, ENABLE_IF(IsNodeComparator<CMP>)>
template <typename CMP = COMPARATOR, ENABLE_IF(HasNodeComparator<CMP>)>
int cmp(const K& a, const NodeType* b) const {
return COMPARATOR::cmp(a, b);
}
@ -226,24 +230,13 @@ private:
return COMPARATOR::cmp(a, b);
}
template <typename CMP = COMPARATOR, ENABLE_IF(IsKeyComparator<CMP>)>
void assert_leq(const K& a, const NodeType* b) const {
assert(COMPARATOR::cmp(a, b->key()) <= 0, "key not <= node");
}
// Cannot assert if no key comparator exist.
template <typename CMP = COMPARATOR, ENABLE_IF(!HasKeyComparator<CMP>)>
void assert_key_leq(K a, K b) const {}
template <typename CMP = COMPARATOR, ENABLE_IF(IsNodeComparator<CMP>)>
void assert_leq(const K& a, const NodeType* b) const {
assert(COMPARATOR::cmp(a, b) <= 0, "key not <= node");
}
template <typename CMP = COMPARATOR, ENABLE_IF(IsKeyComparator<CMP>)>
void assert_geq(const K& a, const NodeType* b) const {
assert(COMPARATOR::cmp(a, b->key()) >= 0, "key not >= node");
}
template <typename CMP = COMPARATOR, ENABLE_IF(IsNodeComparator<CMP>)>
void assert_geq(const K& a, const NodeType* b) const {
assert(COMPARATOR::cmp(a, b) >= 0, "key not >= node");
template <typename CMP = COMPARATOR, ENABLE_IF(HasKeyComparator<CMP>)>
void assert_key_leq(K a, K b) const {
assert(COMPARATOR::cmp(a, b) <= 0, "key a must be less or equal to key b");
}
// True if node is black (nil nodes count as black)
@ -272,7 +265,7 @@ public:
AbstractRBTree() : _num_nodes(0), _root(nullptr) DEBUG_ONLY(COMMA _expected_visited(false)) {
static_assert(std::is_trivially_destructible<K>::value, "key type must be trivially destructable");
static_assert(IsKeyComparator<COMPARATOR> || IsNodeComparator<COMPARATOR>,
static_assert(HasKeyComparator<COMPARATOR> || HasNodeComparator<COMPARATOR>,
"comparator must be of correct type");
}
@ -425,12 +418,12 @@ public:
verify_self([](const NodeType* a, const NodeType* b){ return COMPARATOR::cmp(a, b);});
}
template <typename CMP = COMPARATOR, ENABLE_IF(IsKeyComparator<CMP> && !HasNodeVerifier<CMP>)>
template <typename CMP = COMPARATOR, ENABLE_IF(HasKeyComparator<CMP> && !HasNodeVerifier<CMP>)>
void verify_self() const {
verify_self([](const NodeType* a, const NodeType* b){ return COMPARATOR::cmp(a->key(), b->key()) < 0; });
}
template <typename CMP = COMPARATOR, ENABLE_IF(IsNodeComparator<CMP> && !HasNodeVerifier<CMP>)>
template <typename CMP = COMPARATOR, ENABLE_IF(HasNodeComparator<CMP> && !HasKeyComparator<CMP> && !HasNodeVerifier<CMP>)>
void verify_self() const {
verify_self([](const NodeType*, const NodeType*){ return true;});
}

View File

@ -550,19 +550,19 @@ inline void AbstractRBTree<K, NodeType, COMPARATOR>::replace_at_cursor(NodeType*
new_node->_parent = old_node->_parent;
if (new_node->is_left_child()) {
assert(cmp((const NodeType*)new_node, (const NodeType*)new_node->_parent), "new node not < parent");
assert(cmp(static_cast<const NodeType*>(new_node), static_cast<const NodeType*>(new_node->parent())), "new node not < parent");
} else if (new_node->is_right_child()) {
assert(cmp((const NodeType*)new_node->_parent, (const NodeType*)new_node->_right), "new node not > parent");
assert(cmp(static_cast<const NodeType*>(new_node->parent()), static_cast<const NodeType*>(new_node)), "new node not > parent");
}
new_node->_left = old_node->_left;
new_node->_right = old_node->_right;
if (new_node->_left != nullptr) {
assert(cmp((const NodeType*)new_node->_left, (const NodeType*)new_node), "left child not < new node");
assert(cmp(static_cast<const NodeType*>(new_node->_left), static_cast<const NodeType*>(new_node)), "left child not < new node");
new_node->_left->set_parent(new_node);
}
if (new_node->_right != nullptr) {
assert(cmp((const NodeType*)new_node, (const NodeType*)new_node->_right), "right child not > new node");
assert(cmp(static_cast<const NodeType*>(new_node), static_cast<const NodeType*>(new_node->_right)), "right child not > new node");
new_node->_right->set_parent(new_node);
}
@ -606,6 +606,7 @@ inline void AbstractRBTree<K, NodeType, COMPARATOR>::visit_in_order(F f) const {
template <typename K, typename NodeType, typename COMPARATOR>
template <typename F>
inline void AbstractRBTree<K, NodeType, COMPARATOR>::visit_range_in_order(const K& from, const K& to, F f) const {
assert_key_leq(from, to);
if (_root == nullptr) {
return;
}
@ -615,13 +616,6 @@ inline void AbstractRBTree<K, NodeType, COMPARATOR>::visit_range_in_order(const
const NodeType* start = cursor_start.found() ? cursor_start.node() : next(cursor_start).node();
const NodeType* end = next(cursor_end).node();
if (start != nullptr) {
assert_leq(from, start);
assert_geq(to, start);
} else {
assert(end == nullptr, "end node found but not start node");
}
while (start != end) {
f(start);
start = start->next();

View File

@ -33,10 +33,16 @@
class RBTreeTest : public testing::Test {
public:
using RBTreeIntNode = RBNode<int, int>;
struct Cmp {
static int cmp(int a, int b) {
return a - b;
}
static bool cmp(const RBTreeIntNode* a, const RBTreeIntNode* b) {
return a->key() < b->key();
}
};
struct CmpInverse {
@ -73,7 +79,6 @@ struct ArrayAllocator {
};
using RBTreeInt = RBTreeCHeap<int, int, Cmp, mtTest>;
using RBTreeIntNode = RBNode<int, int>;
using IntrusiveTreeNode = IntrusiveRBNode;
struct IntrusiveHolder {
@ -93,6 +98,10 @@ struct ArrayAllocator {
return a - IntrusiveHolder::cast_to_self(b)->key;
}
static int cmp(int a, int b) {
return a - b;
}
// true if a < b
static bool cmp(const IntrusiveTreeNode* a, const IntrusiveTreeNode* b) {
return (IntrusiveHolder::cast_to_self(a)->key -
@ -300,6 +309,23 @@ public:
}
}
void test_visit_outside_range() {
RBTreeInt rbtree;
using Node = RBTreeIntNode;
rbtree.upsert(2, 0);
rbtree.upsert(5, 0);
constexpr int test_cases[9][2] = {{0, 0}, {0, 1}, {1, 1}, {3, 3}, {3, 4},
{4, 4}, {6, 6}, {6, 7}, {7, 7}};
for (const int (&test_case)[2] : test_cases) {
rbtree.visit_range_in_order(test_case[0], test_case[1], [&](const Node* x) {
FAIL() << "Range should not visit nodes";
});
}
}
void test_closest_leq() {
using Node = RBTreeIntNode;
{
@ -802,6 +828,10 @@ TEST_VM_F(RBTreeTest, TestVisitors) {
this->test_visitors();
}
TEST_VM_F(RBTreeTest, TestVisitOutsideRange) {
this->test_visit_outside_range();
}
TEST_VM_F(RBTreeTest, TestClosestLeq) {
this->test_closest_leq();
}