Using SIMD to Tell if the High Bit is Set
Jan 21, 2025
One of the first Zig-related blog posts I wrote was an overview of SIMD with Zig. I recently needed to revisit this topic when enhancing my smtp client library. Specifically, SMTP mostly expects printable ASCII characters. Almost all other characters, including UTF-8 text, must be encoded.
I found the various SMTP and MIME RFCs confusing. So I settled on a simple approach: if the high bit of a character is set, I'll base64 encode the value. The simple approach to do detect this is:
fn isHighBitSet(input: []const u8) bool {
for (input) |c| {
if (c > 127) {
return true;
}
}
return false;
}
But with SIMD, we can do this check on multiple bytes at a time. The first thing we have to do is get the ideal size for SIMD operations for our CPU:
if (comptime std.simd.suggestVectorLength(u8)) |vector_len| {
}
As you can tell, suggestVectorLength
returns an optional value: some platforms don't support SIMD. On my computer, this returns 16
, which mean that I can process 16 bytes at a time. We can extend our skeleton:
if (comptime std.simd.suggestVectorLength(u8)) |vector_len| {
var remaining = value;
while (remaining.len > vector_len) {
const chunk: @Vector(vector_len, u8) = remaining[0..vector_len].*;
remaining = remaining[vector_len..];
}
}
Above we're breaking our input into vector_len
chunks. The @Vector
builtin returns a type (in Zig, by convention, upper-case functions return types). To see if our chunk
has a high bit set, we use @reduce
:
if (@reduce(.Max, chunk) > 127) {
return true;
}
Like @Vector
, @reduce
is one of a handful of SIMD-specific builtins. Its job is to take an std.builtin.ReduceOp
and a vector input (.Max
and chunk
) and return a scalar value. The possible operations depend on the vector type. For example, std.builtin.ReduceOp.And
is only valid for a vector of booleans. With Max
, we're asking @reduce
to return the higher value in the provided vector.
As I said, on my computer, the code will process 16 bytes of data at a time, but our input might not be perfectly divisible by 16. Our while
loop exits when remaining.len > vector_len
; if our input was 35 bytes long, we'd process 2 chunks (2 * 16) and be left with 3 bytes. These last 3 bytes still need to be checked:
fn isHighBitSet(input: []const u8) bool {
var remaining = value;
if (comptime std.simd.suggestVectorLength(u8)) |vector_len| {
while (remaining.len > vector_len) {
const chunk: @Vector(vector_len, u8) = remaining[0..vector_len].*;
if (@reduce(.Max, chunk) > 127) {
return true;
}
remaining = remaining[vector_len..];
}
}
for (remaining) |c| {
if (c > 127) {
return true;
}
}
return false;
}
We've made another subtle change: we moved remaining
to the outer scope. Our code not only handles chunks that aren't perfectly divisible by vector_len
, it also handles cases where suggestVectorLength
return null
.
Finally, more recent versions of Zig have introduced different backends. Not all of these necessarily support SIMD operations. So, for completeness, we need one more check:
const backend_supports_vectors = switch (@import("builtin").zig_backend) {
.stage2_llvm, .stage2_c => true,
else => false,
};
fn isHighBitSet(input: []const u8) bool {
var remaining = value;
if (comptime backend_supports_vectors) {
if (comptime std.simd.suggestVectorLength(u8)) |vector_len| {
while (remaining.len > vector_len) {
const chunk: @Vector(vector_len, u8) = remaining[0..vector_len].*;
if (@reduce(.Max, chunk) > 127) {
return true;
}
remaining = remaining[vector_len..];
}
}
}
for (remaining) |c| {
if (c > 127) {
return true;
}
}
return false;
}
Hopefully this is something that will get cleaned up, but note that it's only really necessary if you're going to use a different backend (I'm using it because the code is in a library, and I don't know how users of my library will compile it).
For very short strings, the SIMD version is the same as the linear version, so there's no performance difference. But for a string of "a" ** 100
(which requires scanning the entire string), the SIMD version is ~2x faster and for a string of "a" ** 1000
, it's ~8x faster.