Update: Arvid Norlander has gone through the trouble of refactoring this code into a crate and publishing it. Thank you, Arvid!

Rust’s BTreeMap and corresponding BTreeSet are excellent, B-tree-based sorted map and key types. The map implements the ergonomic entry API, more flexible than other map APIs, made possible by the borrow checker. They are implemented with the more performant but more gnarly B-Tree data structure, rather than the more common AVL trees or red-black trees. All in all, they are an excellent piece of engineering, and an excellent standard library feature.

But they aren’t perfect, as I learned recently when I had a very specific operation that I needed to perform on one. I scanned the method lists diligently, trying to find the one I needed, but it was not there. range was close, but not quite there, and so I would simply have to implement the operation by hand. range is defined based on a start key (where, at our option, it includes keys that are greater than or equal to that key, or strictly greater than that key) and an end key (where the keys in the range are either less than or equal, or strictly less than that key).

Here is an example of the use of range:

let set = {
    let mut set = BTreeSet::new();

    set.insert("ABC");
    set.insert("DEF");
    set.insert("DEG");
    set.insert("HIJ");
    set.insert("KLM");
    set.insert("NOP");

    set
};

for elem in set.range("DEF".."N") {
    println!("{elem}");
}

It outputs starting with "DEF", continuing in order through the set, but not including "NOP", as that is greater than "N" (lexigraphically and therefore according to &str’s Ord instance). If "N" were in the set, it would not be printed, as .. is exclusive on the right side. ..= would include it.

Maps and sets: A brief aside

This discussion only concerns the keys of a map. For simplicity’s sake, throughout the discussion, I’ll be using BTreeSet, a wrapper around BTreeMap for when there are just keys (that are still unique and sorted) and no values. Internally, it contains a BTreeMap with the zero-sized struct SetValZST as its value type.

The Problem#

But that isn’t the exact operation I needed. I needed all of the keys (which were also String) that started with a certain prefix. So, if the set was as in the example above, and the prefix was "DE", this operation would give me "DEF',"DEG". As you can see from the example, and as is easy to prove in general, when the keys are sorted, all the keys starting with a prefix form a contiguous range. But it is not a range that can be expressed with the range operation.

It’s close, tantalizingly close. Due to the definition of Ord on String, our prefix-based range starts with the first key that is greater than or equal to the prefix, as strings starting with a prefix always compare greater to or equal to the prefix. This side of the range is therefore expressable with the range operation.

It’s the other side that causes the problem. We don’t have a key where all the keys in the prefix are less than that key. We know that once we hit a key string that doesn’t start with the prefix, it must be greater than all the keys that do, as must all subsequent ones, but we cannot express this bound easily in terms of the prefix. We would need an element that is either the greatest possible key that starts with that prefix, or else the least possible key that does not.

There is a lot of efficiency to be gained by taking advantage of the fact that the range we want is contiguous, which is why the range method exists. But there is no operation that covers this scenario, because of the narrowness of how the range operation is defined.

On the one hand, this is frustrating. We are so close to being able to do this straight-forwardly with the provided operations. It also seems like it would be more performant to determine the bounds of that range by doing a tree search, rather than trying to implement this operation by hand. Without this operation being available, we seem doomed to slowness.

On the other hand, it’s understandable. The key type of a map is only really expected to implement the Ord trait, and nothing about Ord has anything to do with prefixes. Creating ranges with range was allowed, but based on inclusive and exclusive bounds, which is to say, purely based on ordering of opaque elements. Evaluating a prefix as a range, on the other hand (or even merely proving that the keys forming a prefix do indeed constitute a contiguous range) would be outside of the scope of the operations represented by the Ord trait.

So I needed a way of getting keys that start with a specific prefix. So what did I do? I simply coded a manual form of the operation, looping starting from the beginning of the range, and checking each iteration whether we’d left the range yet:

for key in keys.range::<String, _>((Bound::Included(prefix), Bound::Unbounded)) {
    if !key.starts_with(prefix) {
        break; // We've gone past the end of the range
    }
    // ... Actually do something with the key
}

This seemed reasonable enough. My colleagues asked me to put in a comment to clarify that, since the map was sorted, all the items with a prefix would be contiguous, and therefore break was correct and not continue. It worked, and was performant enough for my purposes in writing the code, but perhaps not as much as ideally could be achieved. I couldn’t help but wonder if it could be made a little more performant if it were part of the standard library, if we had insight into and ability to access the inner structure of how a BTreeSet is laid out. Obviously, in such a case the code would also be more concise, and (more importantly) obviously correct, without need for a comment.

The performance considerations, if present, however, would be minimal. Looping through a BTreeSet is a reasonable operation, and I took advantage of the fact that my range was contiguous to stop once we’d gone past the last item. At best, explicit library support for prefixes would simply detect this condition slightly sooner, further up in the tree, without having to actually find the node with the offending item.

The next bit of code I wrote was for a closely related operation: dropping values outside of the prefix. What I wrote seemed like it definitely would be substantially less performant than a specially coded operation from the standard library would be. It certainly was harder to prove correct:

fn prefixed(mut set: BTreeSet<String>, prefix: &str) -> BTreeSet<String> {
    let mut set = set.split_off(prefix);

    let not_in_prefix = (&set).iter().find(|s| !s.starts_with(prefix));
    let not_in_prefix = not_in_prefix.map(|s| s.to_owned());
    if let Some(not_in_prefix) = not_in_prefix {
        set.split_off(&not_in_prefix);
    }

    set
}

This uses two calls to split_off, which like range needs a concrete T, a concrete String, to serve as a comparison-point for where to split. And it is certainly less performant than a dedicated method would have been, as it also uses a call to find to find a concrete String for the end of the range, which constitutes an additional loop through all the strings in the range.

Questions#

This raised two questions in my mind:

  1. Is there a way to convert a prefix into a range that can be used with range and split_off? More concretely, is there a way to construct a String such that it is the least possible String that is still greater than all the possible strings that start with our prefix, but less than or equal to all strings that do not? Would doing so in fact improve performance?

  2. How hard would it be to add this feature to the standard library, both for iterating and for splitting the set?

In this blog post, we will focus on the first question. The second question is reserved for a future blog post.

Testing prefixed#

The prefixed function needs the optimization more than the loop, so we’ll focus on that in our discussion. And as we’re discussing an optimization of the prefixed function, and as it is in any case a gnarly function, we will want to write some unit tests for it.

Here’s one example:

#[test]
fn it_works() {
    let set = {
        let mut set = BTreeSet::new();
        set.insert("Hi".to_string());
        set.insert("Hey".to_string());
        set.insert("Hello".to_string());
        set.insert("heyyy".to_string());
        set.insert("".to_string());
        set.insert("H".to_string());
        set
    };
    let set = prefixed(set, "H");
    assert_eq!(set.len(), 4);
    assert!(!set.contains("heyyy"));
}

This probably isn’t enough. Additional unit tests will be left as an exercise to the reader.

Constructing an upper bound#

So, let us return to our example. In our example, the prefix was "DE". As discussed, the lower bound is easy: Everything that starts with a "DE" is greater than or equal to "DE". Strings outside of the range to the left will not:

println!("{}", "DD" >= "DE");       // Prints "false"
println!("{}", "DE" >= "DE");       // Prints "true"
println!("{}", "DEF" >= "DE");      // Prints "true"
println!("{}", "DEG" >= "DE");      // Prints "true"
println!("{}", "DF" >= "DE");       // Still prints "true" -- need something
println!("{}", "NOP" >= "DE");      // Still prints "true" -- need something

The upper bound is also easy enough, actually – we just need to increment the last character. Anything that starts with a "DE" will also compare strictly less to "DF":

println!("{}", "DE" < "DF");        // Prints "true"
println!("{}", "DEF" < "DF");       // Prints "true"
println!("{}", "DEG" < "DF");       // Prints "true"
println!("{}", "DF" < "DF");        // Prints "false"
println!("{}", "NOP" < "DF");       // Prints "false"

This seems easy enough to handle. We just need to write a function that increments the last character in a string, something with this signature:

fn upper_bound_from_prefix(prefix: &str) -> String;

Incrementing the last character in a string seems like it’s just a matter of incrementing the last byte, so let’s see what that looks like:

fn upper_bound_from_prefix(prefix: &str) -> String {
    let mut prefix = prefix.to_string();
    unsafe {
        // SAFETY: It is not. ☹️. XXX
        let prefix_bytes = prefix.as_bytes_mut();
        prefix_bytes[prefix_bytes.len() - 1] += 1;
    }
    prefix
}

Well, that’s not good. It passes the unit test I wrote, but that’s because we need to write more unit tests. Unfortunately, like many programmers before us, we have forgotten about UTF-8. Rust requires all its strings to be stored as valid UTF-8 as a safety invariant. Fortunately, because we’re using Rust, we notice that we’re violating this invariant when an operation we have to invoke is marked as unsafe.

In order to capture this failure, we would have to write a unit test where the prefix ends in a multi-byte Unicode character. Unfortunately, because this is a safety issue, the test might not even fail (but it might be worth doing as an exercise anyway).

That isn’t even to mention the possibility that the prefix is empty, which would result in a panic in this code!

So, how can we get the last character of a string? get allows us to do substrings with byte indexes, but returns None if it is not a valid substring. We can loop backwards until we find an index that works for the split, and we can return an option in case the string is empty:

fn upper_bound_from_prefix(prefix: &str) -> Option<String> {
    for i in (0 .. prefix.len()).rev() {
        if let Some(last_char_str) = prefix.get(i..) {
            let rest_of_prefix = {
                debug_assert!(prefix.is_char_boundary(i));
                prefix[0..i]
            };

            // ???
        }
    }

    None
}

But that gives us two strs, and we want to increment a char. So we have to extract the singular char from the last_char_str, which we know to have exactly one char in it. Looking over the operations of str, we have only one real option:

let last_char = last_char_str
    .chars()
    .next()
    .expect("last_char_str will contain exactly one char");

Walking Through chars#

But once we do have a char, we cannot simply do + 1 on it. This operation isn’t defined on a char. And before you say that we should convert it to u32 and back, you should know that the operation is left undefined on char for a reason. chars are supposed to remain valid Unicode code points.

So, we must do something else that will skip over invalid code points. There is no obvious operation in char that will do it, but if we look in the “Trait Implementations” section, we find something that looks potentially relevant: Step. And looking at char’s implementation of Step, we see the exact function we want:

fn forward_checked(start: char, count: usize) -> Option<char> {
    let start = start as u32;
    let mut res = Step::forward_checked(start, count)?;
    if start < 0xD800 && 0xD800 <= res {
        res = Step::forward_checked(res, 0x800)?;
    }
    if res <= char::MAX as u32 {
        // SAFETY: res is a valid unicode scalar
        // (below 0x110000 and not in 0xD800..0xE000)
        Some(unsafe { char::from_u32_unchecked(res) })
    } else {
        None
    }
}

Unfortunately, this gives us an Option. Why? Well, you can see that from the code: What if last_char is the highest possible Unicode code point, 0x10FFFF, also known as char::MAX? We’re going to procrastinate handling this (admittedly rare) situation, and panic for now. Spoiler: Fortunately, there is a solution, which we will discuss later.

This is a great example of why Rust is great. Because this operation is defined to return an Option, we have to explicitly say what we’re doing in case it returns None. We don’t even have to have a unit test for 0x10FFFF code-points in our prefix to realize that we have to cover this case (although now would be a great time to write one).

Also unfortunately, we can’t directly call forward_checked … not if we want to use stable Rust, in any case. It’s marked as a nightly-only “unstable API.” Fortunately, however, we can access it indirectly, through the Range API. Some rooting around in the standard library reveals that nth, on an iterator on a closed range, calls forward_checked, yielding :

let last_char_incr = (last_char ..= char::MAX)
    .nth(1)
    .expect("XXX fixme: can't handle highest possible codepoint");

This actually works, with the caveat of handling char::MAX set aside. All my unit tests except my 0x10FFFF one pass. Altogether, here is the state of things: We have a prefixed function that uses this to call split_off with an appropriate value, without iterating through all the strings in range in the set:

fn upper_bound_from_prefix(prefix: &str) -> Option<String> {
    for i in (0..prefix.len()).rev() {
        if let Some(last_char_str) = prefix.get(i..) {
            let rest_of_prefix = {
                debug_assert!(prefix.is_char_boundary(i));
                &prefix[0..i]
            };

            let last_char = last_char_str
                .chars()
                .next()
                .expect("last_char_str will contain exactly one char");
            let last_char_incr = (last_char..)
                .nth(1)
                .expect("XXX fixme used highest possible codepoint");

            let new_string = format!("{rest_of_prefix}{last_char_incr}");

            return Some(new_string);
        }
    }

    None
}

pub fn prefixed(mut set: BTreeSet<String>, prefix: &str) -> BTreeSet<String> {
    let mut set = set.split_off(prefix);

    if let Some(not_in_prefix) = upper_bound_from_prefix(prefix) {
        set.split_off(&not_in_prefix);
    }

    set
}

Cleaning Up the Edge Case#

OK, now that we’ve got something that (kind of) works, it’s time to do some clean-up.

So, first, of course, we should address the XXX fixme, the 0x10FFFF case. So what do we do in that case? Well, if we use X to stand in for this “highest code point character”, we can reason about it a little.

Let’s say the prefix is "deX". In order for something to be out of the range of the prefix, it can’t start with "deY", as there is no 'Y' character greater than 'X'. So, it would have to differ on the previous character. It would have to start with "df" or greater.

So, if our prefix ends with this special character, we can simply drop it, and move one character back, and increment that character instead. Strangely enough, that just means going through our for loop again (and no, I did not plan this). See, if we keep going backwards to find another character to increment, we’ll get the previous character. Our way of extracting characters from the suffix works even if there’s more than one character in the second substring – it’ll just get the first character, which is exactly what we want.

So we can actually write:

let Some(last_char_incr) = (last_char ..= char::MAX).nth(1) else {
    continue;
};

Adding some comments to explain, and adjusting existing code to no longer lie to the reader (last_char_str might now contain more than one character) we get this:

fn upper_bound_from_prefix(prefix: &str) -> Option<String> {
    for i in (0..prefix.len()).rev() {
        if let Some(last_char_str) = prefix.get(i..) {
            let rest_of_prefix = {
                debug_assert!(prefix.is_char_boundary(i));
                &prefix[0..i]
            };

            let last_char = last_char_str
                .chars()
                .next()
                .expect("last_char_str will contain at least one char");
            let Some(last_char_incr) = (last_char ..= char::MAX).nth(1) else {
                // Last character is highest possible code point.
                // Go to second-to-last character instead.
                continue;
            };
            
            let new_string = format!("{rest_of_prefix}{last_char_incr}");

            return Some(new_string);
        }
    }

    None
}

If our string contains only copies of this highest possible code point, this returns None, which is appropriate because there will be no strings greater than the strings prefixed with these characters, just like there’s nothing that comes after names that start with “Z” in alphabetical order, nor anything that comes after names that start with “Zz”.

Note that if we want to save the other sets that are created by split_off, we can. We can easily modify this function to return all three sets: The set of keys that come lexigraphically before the prefix, the set that starts with the prefix, and the set of keys that come after the keys that start with the prefix.

Performance#

This code certainly hasn’t been optimized to the fullest extent possible. In such a case, we probably would want to do some more extreme optimizations, like working with Vec<u8> rather than Strings, and check if they were valid UTF-8 only at the point when it is necessary (if it in fact is necessary for our application). Or, alternatively, we might want to fork the standard library’s BTree implementation and actually add this operation. Both of these are gnarly, but if the absolute best possible performance was truly our goal, they would both be in scope.

But I am reserving that for a future blog post. Detailed profiling of different implementations of this operation would require that level of optimization to be fully interesting and is therefore also reserved for a future blog post. Instead, here, I will walk through some informal reasoning about the performance of this new implementation of prefixed, and whether it is also useful for iteration rather than splitting off a new set.

So, let’s do some back-of-the-envelope reckoning. In creating this upper bound, we had to reconstruct the prefix string, which costs us an allocation as well as a string copy. In exchange, we saved an extra call for find, which might have had to loop over many, many strings that start with this prefix. We can expect this implementation of prefixed to be more performant, therefore, in situations where there are many strings that start with the prefix (and the prefix is not pathologically long).

For iterating over the range, however, we would be making an allocation, and only potentially saving us some walking through the tree. Given that allocations are expensive (and potentially also involve some amount of walking around memory), it’s probably not going to be worth it unless the tree is extremely large.

A Warning unto the Test-Shy#

In an earlier draft of this post, I had the following code to increment a char rather than what I wrote above:

(last_char ..).nth(1)

This seems like it should work, in spite of having no upper bound. It stands to reason that char::MAX would, in such a case, serve as an implicit upper bound. It does still return an Option<char>, and when would None happen if not in such a situation?

But fortunately, I had a test case:

#[test]
fn maxicode() {
    let set = {
        let mut set = BTreeSet::new();
        set.insert("Hi".to_string());
        set.insert("Hey".to_string());
        set.insert("Hello".to_string());
        set.insert("heyyy".to_string());
        set.insert("H\u{10FFFF}eyyy".to_string());
        set.insert("H\u{10FFFF}".to_string());
        set.insert("I".to_string());
        set.insert("".to_string());
        set.insert("H".to_string());
        set
    };
    let set = prefixed(set, "H\u{10FFFF}");
    assert_eq!(set.len(), 2);
    assert!(!set.contains("I"));
}

This test case, in that earlier code, actually panicked! It turns out that in the case of an open-ended range like (last_char ..), which results in a value of the type RangeFrom, it is simply assumed that going forward is possible. Instead of calling forward_checked, its nth method calls forward:

#[inline]
fn nth(&mut self, n: usize) -> Option<A> {
    let plus_n = Step::forward(self.start.clone(), n);
    self.start = Step::forward(plus_n.clone(), 1);
    Some(plus_n)
}

And in forward, every None is converted into a panic:

fn forward(start: Self, count: usize) -> Self {
    Step::forward_checked(start, count).expect("overflow in `Step::forward`")
}

Conclusion#

I hope you enjoyed this walk-through. You can find the final version of prefixed and two test cases here.

Please let me know what you think of this format in the comments. Also let me know if you have any follow-up topics you want me to explore, or other problems you would want walk-throughs of.

And, of course, please feel free to provide corrections and even nit-picks!