SIMD with Zig
May 02, 2023
To find the index of the first instance of a character within a body of [ASCII] text, you might write something like:
fn indexOf(haystack: []const u8, needle: u8) ?usize {
for (haystack, 0..) |c, i| {
if (c == needle) return i;
}
return null;
}
Or use the std.mem.indexOfScalar
function from the standard library, which is essentially the same implementation. This implementation loops through the input and checks each character, one by one, to see if it's equal to our target.
With SIMD, we can leverage CPU instructions to check multiple characters of our input in parallel. Let's do that in Zig.
To keep this simple for now, let's pretend that our input is always exactly 8 characters long (we'll look at dynamic input lengths after, but 8 characters means we can illustrate the full content of our vectors).
As an example, say we have "Hello Jo" and we want the first index of "o" (which is 4). Our first step is to create a vector (think of it as an array) of 8 elements, each containing the value "o":
const vector_len = 8;
const vector_needles: @Vector(vector_len, u8) = @splat(@as(u8, 'o'));
Our @as
cast is a very Zig-specicific thing. 'o'
on its own is a comptime_int
. If we try to use that, we'll get an error message: expected integer, float, bool, or pointer for the vector element type. So we coerce the type to a u8
. The more important part of the above code is the @splat
builtin, which creates a vector with each element containing the specified value. The length of the created vector is inferred from the type we're assigning to, which above is @Vector(vector_len, u8)
If we were to print vector_needles
, we'd see 'o' eight times (the ASCII value of 'o' is 111):
{ 111, 111, 111, 111, 111, 111, 111, 111 }
The next step is to take our input and also convert that into a vector. Because we've said that, for the time being, our input will be limited to 8 characters, this is easy:
const haystack = "Hello Jo";
const vector_haystack: @Vector(vector_len, u8) = haystack.*;
The @Vector
builtin returns a type, and because our haystack
is 8 characters long, the same as our vector_len
we can store our full haystack into this new vector. If we were to print vector_haystack
, we'd get:
{ 72, 101, 108, 108, 111, 32, 74, 111 }
We now have two vectors, each containing eight u8 (byte) values. Our first SIMD operation will be to compare the two using the equality operator (==
):
const matches = vector_haystack == vector_needles;
This line of code is powerful largely because it's simple. Vectors can be subjected to any arithmetic or bitwise operations.
This results in a new vector of type @Vector(vector_len, bool)
, and its content will be:
{ false, false, false, false, true, false, false, true }
If you look at this closely, you'll note that we're getting close to our answer. The first true
occurs at index 4, which is the correct answer. We've compared our haystack with our needle and gotten the result for each index in parallel. But we still need to extract the index from the above. We could loop through matches
, but then we'd be back at iterating and comparing one value at a time.
A quick solution to the above problem is to use std.simd.firstTrue
, which will give us 4
:
const vector_len = 8;
const vector_needles: @Vector(vector_len, u8) = @splat(@as(u8, 'o'));
const haystack = "Hello Jo";
const vector_haystack: @Vector(vector_len, u8) = haystack.*;
const matches = vector_haystack == vector_needles;
const index = std.simd.firstTrue(matches);
But how does firstTrue
work? Let's build our own simple implementation:
The first thing that we'll do is check if we have any matches. This is possibly a step that we can avoid, but it'll help us get the ball rolling:
if (!@reduce(.Or, matches)) {
return null;
}
The @reduce
builtin takes a vector, applies the operation, and returns a scalar (a single value). Here we're applying the std.builtin.ReduceOp.Or
operation on our matches. If any of the values are true
, this will return true
. If all values are false
, this will return false
. Other possible operations are: .And
, .Or
, .Xor
, .Min
, .Max
, .Add
and .Mul
(some operations are only available for some types of vectors, for example .Add
doesn't make sense for a vector of booleans).
If we get past this check, we know that we have at least 1 match (i.e. a true
) in our vector. How do we get its index? Admittedly, it's a bit more complicated than you might think. We'll need to use the @select
builtin. @select
is a bit like a parallel if statement, but if we were to write it using normal code, it might look like:
fn select(comptime T: type, pred: [8]bool, a: [8]T, b: [8]T) [8]T {
var out: [8]T = undefined;
for (pred, 0..) |p, i| {
out[i] = if (p) a[i] else b[i]
}
return out;
}
It always takes a vector of booleans, and based on the true/false values within this vector, it will select values from either input source a
(when true
) or input source b
(when false
).
How does @select
help us? First we'll create two new vectors. These will act as our input source a
and input source b
. The first is a vector that contains the index of each index. Huh? Let's let the code explain:
const indexes = std.simd.iota(u8, vector_len);
If we print it out, we get:
{ 0, 1, 2, 3, 4, 5, 6, 7 }
It's ok if it isn't immediately obvious how that helps. Our next vector will be even weirder:
const nulls: @Vector(vector_len, u8) = @splat(@as(u8, 255));
We've seen @splat
already, so we know the output to this is:
{ 255, 255, 255, 255, 255, 255, 255, 255 }
We now have 3 vectors, and we know a little bit about @select
. Let's look at those 3 vectors side by side, and see if any pattern emerges:
{ false, false, false, false, true, false, false, true }
{ 0, 1, 2, 3, 4, 5, 6, 7 }
{ 255, 255, 255, 255, 255, 255, 255, 255 }
I still don't really see any pattern here. But what if pass this into @select
:
const result = @select(u8, matches, indexes, nulls);
This results in:
{ 255, 255, 255, 255, 4, 255, 255, 7 }
We'll revisit @select
in a second, but I see something I want to keep exploring: the smallest value in this vector is what we're after. We want to extract the 4
and we already know that @reduce
is our tool for turning a vector into a scalar. We also briefly mentioned that one of the operations is .Min
. Putting that together, we end up with:
const index = @reduce(.Min, result);
Which returns 4
, our final and correct answer.
There's a few things we still want to explore, but first let's go back and look at @select
. We had a vector of booleans and we wanted to get the first index which was true
. The first thing we did was create another vector that just contains the indexes. On its own, that didn't help. What we needed was a vector with indexes when our match was true, and some invalid value when our match was false. Only this way could we use @reduce
(as it would, "ignore" invalid values from non-matches). For our invalid value, we used 255
. We couldn't use null
(vectors operate on numbers and booleans only) and we couldn't use a negative number since we relied on the .Min
operator to extract our first (valid) index. Since we're limiting ourselves to a vector_len
of 8, we could have used any value > 8. If our vector was larger, say, 512, we could simply use 513 or any other larger value (we'd have to change the type of our nulls vector from u8 to u16 or something, but it would all still work fine).
@select
gave us a vector that we could @reduce
to get our desired result. But it wasn't immediately obvious, to me at least, how to get there. When I think about "finding the index of the first true", I think about iterating one value at a time and discarding any non-matches. This isn't a suitable mindset when working with vectors. We couldn't discard individual elements - we had to transform them in a way that would let us use a reduce operation.
Vector Length
There are two details we need to figure out. The first is: what happens when our input isn't the same size as our vector length? This is easy to solve: we process the input one vector_len
at a time.
fn firstIndexOf(haystack: []const u8, needle: u8) ?usize {
const vector_len = 8;
const vector_needles: @Vector(vector_len, u8) = @splat(@as(u8, needle));
const indexes = std.simd.iota(u8, vector_len);
const nulls: @Vector(vector_len, u8) = @splat(@as(u8, 255));
var pos: usize = 0;
var left = haystack.len;
while (left > 0) {
if (left < vector_len) {
return std.mem.indexOfScalarPos(u8, haystack, pos, needle);
}
const h: @Vector(vector_len, u8) = haystack[pos..][0..vector_len].*;
const matches = h == vector_needles;
if (@reduce(.Or, matches)) {
const result = @select(u8, matches, indexes, nulls);
return @reduce(.Min, result) + pos;
}
pos += vector_len;
left -= vector_len;
}
return null;
}
Every iteration of the loop examines vector_len
characters at a time. If there's no match within one iteration, we move to the next vector_len
. It's possible that our input is less than vector_len
, either originally, or eventually as we iterate through it. When this happens, we fallback to using the standard linearly search.
The second question we must answer is: what vector length should we use? So far, we've been using 8
, but would larger vectors run faster? One option is to use std.simd.suggestVectorSize
which returns a ?usize
. This function attempts to return the maximum supported vector length of the system, or null if SIMD operations aren't supported. But using the largest possible vector length won't always be ideal. You need to consider your input length. Using a vector length of 512 with inputs that are small (say 128-2048) will result in the fallback to std.mem.indexOfScalar
being executed often. One solution to this problem is to use different vector lengths based on the input length. On the flip side, if you're dealing with large inputs (megabytes and beyond), a large vector size would be ideal as you'll spend comparatively little time dealing with small parts of the input.
Conclusion
SIMD requires a different way to think about variables and how they are processed. It's also an optimization technique that's up to the developer; there's no compiler or runtime help. You need to benchmark, test and tweak in order to figure out what works best. Benchmarking is particularly important because, unless you're dealing with large data or very hot code, there's a good chance that effort won't yield measurable benefits.
Having said that, Zig exposes a pretty concise API: a few builtins (@splat
, @select
, @Vector
and @reduce
) along with functions in std.simd
. Once you understand these functions and how they can work together, experimentation is straightforward.