home

TCP Server in Zig - Part 3 - Minimizing Writes & Reads

Oct 08, 2024

Before we look at making our server multi-threaded, and then move to polling, there are two optimization techniques worth exploring. You might think that we should finalize our code before applying optimizations, but I think optimizations in general can teach us things to look out for / consider, and it's particularly true in both of these cases.

In the previous parts, we made use of numerous system calls to setup our server and then communicate with the client. System calls (aka, syscalls) is how our program asks the operating system to do something, like writing bytes to a socket. There's overhead to making syscalls, so it's something we want to keep an eye on and, if possible, minimize. This overhead is small (100-200ns plus some trashing of various caches, from what I could find), therefore it isn't a concern for infrequent calls like the ones we used to setup our server - socket, setsockopt, bind and listen. read and write are a different story though: they're often in a loop and, for a server, the number of calls to read and write will grow with the number of active connection.

In Part 2, when we added a header to our messages, we made two calls writeAll, one to write the header and the other to write the message:

fn writeMessage(socket: posix.socket_t, msg: []const u8) !void {
    var buf: [4]u8 = undefined;
    std.mem.writeInt(u32, &buf, @intCast(msg.len), .little);
    try writeAll(socket, &buf);
    try writeAll(socket, msg);
}

Our writeAll function calls posix.write until the buffer is empty:

fn writeAll(socket: posix.socket_t, msg: []const u8) !void {
    var pos: usize = 0;
    while (pos < msg.len) {
        const written = try posix.write(socket, msg[pos..]);
        if (written == 0) {
            return error.Closed;
        }
        pos += written;
    }
}

The loop is necessary because posix.write can write anywhere from 1 to msg.len bytes (or zero if the connection is closed). However, under normal conditions, where the connection is stable and the sender isn't overwhelming the receiver and where our messages fit in the OS' send buffer, write will often complete in a single operation. So, while we absolutely do need the loop, we can also say that, in general, our writeMessage function will result in two calls to posix.write, which is a wrapper around the write(2) system call.

There's one obvious way to reduce our two calls to writeAll, which would hopefully reduce the number of calls to write: we can concatenate our prefix and message:

fn writeMessage(allocator: Allocator, socket: posix.socket_t, msg: []const u8) !void {
    var buf = try allocator.alloc(u8, 4 + msg.len);
    defer allocator.free(buf);

    std.mem.writeInt(u32, buf[0..4], @intCast(msg.len), .little);
    @memcpy(buf[4..], msg);
    try writeAll(socket, buf);
}

But this also has considerable overhead: we need to allocate a larger buffer and copy the message. Since we're looking to optimize this code because it is presumably in a hot path, can we do better?

Ideally, we'd like to make a single system call with our two distinct buffers (the header prefix and the message). This is exactly what writev does - the "v" stands for vector. It's part of a family of operation known as vectored I/O or scatter/gather I/O (because we're gathering data from multiple buffers). To leverage writev, we can rewrite our writeMessage and writeAll functions like so:

fn writeMessage(socket: posix.socket_t, msg: []const u8) !void {
    var buf: [4]u8 = undefined;
    std.mem.writeInt(u32, &buf, @intCast(msg.len), .little);

    var vec = [2]posix.iovec_const{
      .{ .len = 4, .base = &buf },
      .{ .len = msg.len, .base = msg.ptr },
    };

    try writeAllVectored(socket, &vec);
}

fn writeAllVectored(socket: posix.socket_t, vec: []posix.iovec_const) !void {
    var i: usize = 0;
    while (true) {
        var n = try posix.writev(socket, vec[i..]);
        while (n >= vec[i].len) {
            n -= vec[i].len;
            i += 1;
            if (i >= vec.len) return;
        }
        vec[i].base += n;
        vec[i].len -= n;
    }
}

Rather than having to concatenate the two buffers together, we can create an iovec_const which references each buffer. Because the Zig structure closely copies the C structure, we can't use our slices and must instead provide the pointer to the data and the length of the data (ideally, when Zig development starts to focus on the standard library, these types of things will get cleaned up). Our call to posix.write has been replaced with a call to posix.writev, but this function isn't magical; we still need to loop until all bytes are written. The arithmetic for a partial write is a bit more complicated, since we need to advance across two values: the number of vectors we have and the number of bytes in each vector. It's code we only have to write once though.

What was previously always at least two syscalls is now potentially (and likely) one. In some cases, that won't really matter, but keep two things in mind. First, we're specifically talking about optimization code known to be in a critical path. Second writev isn't limited to two buffers.

There's also a readv function which does the opposite: reading data into multiple buffers. I've personally never found a use for it. I suspect that it's a combination of being generally less useful than writev, lack of imagination on my part, and the type of systems I've worked on.

At the other end of our communication, readMessage requires a minimum of two calls to read, at least one call to fill our 4 byte header buffer, and then at least one more to fill our message buffer. Reading from a socket is generally more chaotic than writing to it. Reading is the end of our data's journey and can be delayed because of network latency or dropped packet. Furthermore, clients can be buggy or even malicious. All this to say that there's no "best" way to handle incoming data.

Despite this, one common approach is to have a buffer-per-connection and to read as much data as possible. This has the same worst case behavior as doing multiple explicit reads, but the best case behavior isn't just that we read both the prefix and the message in a single read, but that we also read all of, or at least part of, any subsequent messages.

For this to work, we need to maintain state between calls in order to handle any extra bytes a previous read got. We need to know where in our buffer new data needs to be written to and where in our buffer the next message starts. To this end, I like to build a Reader to encapsulate this:

const Reader = struct {
    // This is what we'll read into and where we'll look for a complete message
    buf: []u8,

    // This is where in buf that we're read up to, any subsequent reads need
    // to start from here
    pos: usize = 0,

    // This is where our next message starts at
    start: usize = 0,

    // The socket to read from
    socket: posix.socket_t,

    fn readMessage(self: *Reader) ![]u8 {
        var buf = self.buf;

        // loop until we've read a message, or the connection was closed
        while (true) {

            // Check if we already have a message in our buffer
            if (try self.bufferedMessage()) |msg| {
                return msg;
            }

            // read data from the socket, we need to read this into buf from
            // the end of where we have data (aka, self.pos)
            const pos = self.pos;
            const n = try posix.read(self.socket, buf[pos..]);
            if (n == 0) {
                return error.Closed;
            }

            self.pos = pos + n;
        }
    }

    // Checks if there's a full message in self.buf already.
    // If there isn't, checks that we have enough spare space in self.buf for
    // the next message.
    fn bufferedMessage(self: *Reader) !?[]u8 {
        const buf = self.buf;
        // position up to where we have valid data
        const pos = self.pos;

        // position where the next message start
        const start = self.start;

        // pos - start represents bytes that we've read from the socket
        // but that we haven't yet returned as a "message" - possibly because
        // its incomplete.
        std.debug.assert(pos >= start);
        const unprocessed = buf[start..pos];

        if (unprocessed.len < 4) {
            // We always need at least 4 bytes of data (the length prefix)
            self.ensureSpace(4 - unprocessed.len) catch unreachable;
            return null;
        }

        // The length of the message
        const message_len = std.mem.readInt(u32, unprocessed[0..4], .little);

        // the length of our message + the length of our prefix
        const total_len = message_len + 4;

        if (unprocessed.len < total_len) {
            // We know the length of the message, but we don't have all the
            // bytes yet.
            try self.ensureSpace(total_len);
            return null;
        }

        // Position start at the start of the next message. We might not have
        // any data for this next message, but we know that it'll start where
        // our last message ended.
        self.start += total_len;
        return unprocessed[4..total_len];
    }

    // We want to make sure we have enough spare space in our buffer. This can
    // mean two things:
    //   1 - If we know that length of the next message, we need to make sure
    //       that our buffer is large enough for that message. If our buffer
    //       isn't large enough, we return an error (as an alternative, we could
    //       do something else, like dynamically allocate memory or pull a large
    //       buffer froma buffer pool).
    //   2 - At any point that we need to read more data, we need to make sure
    //       that our "spare" space (self.buf.len - self.start) is large enough
    //       for the required data. If it isn't, we need shift our buffer around
    //       and move whatever unprocessed data we have back to the start.
    fn ensureSpace(self: *Reader, space: usize) error{BufferTooSmall}!void {
        const buf = self.buf;
        if (buf.len < space) {
            // Even if we compacted our buffer (moving any unprocessed data back
            // to the start), we wouldn't have enough space for this message in
            // our buffer. Alternatively: dynamically allocate or pull a large
            // buffer from a buffer pool.
            return error.BufferTooSmall;
        }

        const start = self.start;
        const spare = buf.len - start;
        if (spare >= space) {
            // We have enough spare space in our buffer, nothing to do.
            return;
        }

        // At this point, we know that our buffer is larger enough for the data
        // we want to read, but we don't have enough spare space. We need to
        // "compact" our buffer, moving any unprocessed data back to the start
        // of the buffer.
        const unprocessed = buf[start..self.pos];
        std.mem.copyForwards(u8, buf[0..unprocessed.len], unprocessed);
        self.start = 0;
        self.pos = unprocessed.len;
    }
};

The trick to the above code is to "compact" the buffer when we have no more space for the next message, moving whatever data we've read (but haven't processed) back to the start. For example imagine that after a single read our buffer looks like:

5,0,0,0,'H','e','l','l','o', 6,0,0,_,_,_,_,_,_,_,_,_,_

readMessage would return "Hello", and the state of our Reader would look like:

5,0,0,0,'H','e','l','l','o', 6,0,0,_,_,_,_,_,_,_,_,_,_
                        start^     ^pos

For our next read, we only have 3 unprocessed bytes (pos - start, or {6, 0, 0}). In this case, we need at least 1 more byte in order to get the next messages length. So we make sure that we have enough room in our buffer for that additional byte, which we do - in fact we have enough room for 10 more bytes. We return null from bufferedMessage and read more data from the socket, hoping that we won't just read the missing length-byte, but the whole next message too. After our next readMessage, our reader would look like:

5,0,0,0,'H','e','l','l','o', 6,0,0,0,'W','o','r','l','d','!',9,0,_
                                                        start^   ^pos

The call to readMessage will return the next message "World!". But notice that we no longer have enough space in our buffer for the next message - we're missing two bytes for our length, but only have 1 spare byte. In this case, the next time we call readMessage, the inner call to bufferedMessage will result in our buffer being compacted:

⌄start
9,0,0,0,'H','e','l','l','o', 6,0,0,0,'W','o','r','l','d','!',9,0,_
    ^pos

For the sake of accuracy our buf is still filled with data from the previous reads, which is why you still see the "Hello" and "World!" messages, but we've copied the unprocessed data (just the 9, 0 in this case) to the start and adjusted start and pos. After this compaction, our buffer has enough space to read our header prefix.

To keep this implementation relatively simple, if a message is too big to fit in our buffer, we return an error. That might be reasonable if you know the maximum possible message length and you're willing to allocate that much space per-connection. As an alternative, you could opt to dynamically allocate larger buffers as needed or have a large buffer pool.

Also, this stateful approach demands a reader per connection and is therefore more resource intensive. If you have many clients, but relatively little traffic, it might not be efficient to give each connection its own buffer - if you have a 100K clients each with a 4K buffer, that's 400MB of memory.

As a final precaution, in an attempt to minimize the possibility of a Denial of Service attack, many services minimize the amount of resources they'll allocate until a client can be authenticated. When a new connection is established, you might want to allocate a much smaller buffer, or use a special buffer pool, to read an initial "authentication" message. Once the legitimacy of the client is verified, you can assign a full Reader with a dedicated buffer. Of course, this approach isn't applicable to all cases. But the internet is a hostile place; you should always be mindful of the resources you're willing to commit just because someone has established a TCP connection to your server.

System calls aren't evil and they aren't even particularly slow, rather they're a fundamental part of network programming. But we can still be mindful about their usage. In some cases, it's a matter of knowing and using the right system call for the job, as with writev vs write. In other cases, it's about organizing our code differently, as with our Reader.

As our implementation progresses, we'll start to rely on more advanced operating system features, such as Linux's epoll and BSD's kqueue. To use these, we'll introduce new system calls, some of which could end up in your critical path. In fact, kqueue and epoll differ with respect to the number of system calls required - with kqueue allowing for designs that require fewer system calls. Hopefully, given what we've explored above, you'll be in a better position to notice and consider these types of differences.