TCP Server in Zig - Part 5b - Poll
Oct 17, 2024
In the previous part we introduced non-blocking sockets and used them, along with the poll
system call, to maximize the efficiency of our server. Rather than having a thread-per-connection, waiting on data, a single thread can now manage multiple client connections. But this performance leap doesn't come for free: our code has gotten more complex.
At a high level, I think the idea behind evented I/O is straightforward. We ask the operating system to monitor a list of sockets, and it notifies us when those sockets are ready. We'll soon switch to alternatives to poll
which perform better, have nicer APIs and are more powerful (at the cost of more complexity), but despite polls awkwardness we've managed to create a relatively clean TCP server implementation.
Of course, we're also missing a number of important features. Our last working example crashes if too many clients try to connect, it doesn't write to the client and it doesn't implement any timeouts. We still have a lot of work to do if we want something approaching production-ready.
The most glaring bug is our server crashes if too many simultaneous clients try to connect. This happens because, on a new connection, we don't do any bound-check on the clients
and client_polls
slices. Exactly how we fix this will be app-dependent. Maybe you want to allow new connections by disconnecting old ones. Maybe you want to disallow the connection but write an error message. We'll do something simpler: only accept a connection if we have an available slot.
One way to stop accepting connections is to remove the listener's pollfd
from our polls
slice. But since we'll want to re-enable this notification once free slots become available, a better way is to modify the existing entry:
fn accept(self: *Server, listener: posix.socket_t) !void {
const available = self.client_polls.len - self.connected;
for (0..available) |_| {
} else {
self.polls[0].events = 0;
}
}
In addition to the original code which accepts connections until accept
returns error.WouldBlock
, we now also limit the number of connections to the available
space we have. Once we've reached the available space - the else
on a for
only executes when the loop naturally reaches its end and not on a break/return - we disable the notification on our listening socket. Here we disable it by setting events
to 0
. Another option is to negate the file descriptor which poll
would then ignore. This has the benefits of preserving the events
value for when we re-enable it by negating it back to a positive value.
All that's left to do is re-enable the monitor when space becomes available:
fn removeClient(self: *Server, at: usize) void {
self.polls[0].events = posix.POLL.IN;
}
Just like that, we've fixed a bad bug. More importantly, we've seen how we can modify a pollfd
value. As above, this can be used to disable and enable monitors, but it can also be used to alternate between monitoring for read-readiness and write-readiness.
When a client connects we monitor it by adding a new pollfd
entry to client_polls
and, through the events
field, we register our interest in posix.POLL.IN
. What would happen if we also registered our interest in posix.POLL.OUT
:
self.client_polls[connected] = .{
.fd = socket,
.revents = 0,
.events = posix.POLL.IN | posix.POLL.OUT,
};
This is almost certainly not what you want to do because the socket is almost always ready to be written to. If you make this change, run the code and connect a client, poll
will constantly fire because revents & posix.POLL.OUT == posix.POLL.OUT
will be true. What we need to do is only register our interest in this event only when we have something to write.
As we've said before, as common as it is for read
to read a number of bytes completely unrelated to any application-specific messages, it's equally common for write
to succeed in a single call. Despite this, correctly writing using non-blocking I/O is much more nuanced than reading. I've seen many simple implementation assume that write doesn't fail and will not do partial write. This is a dangerous assumption to make.
One challenge we have is that different application have different requirements. Are multiple threads allowed to write to a socket? Are we implementing a strict request -> response flow or can there be multiple incoming messages? Just like we potentially have to poll
and read
multiple times before getting a whole message, so too might we have to poll
and write
before sending a whole message - this means that the lifetime of those bytes might have to outlive an application's call to writeMessage
. Furthermore, in Part 3, we introduced vectored I/O (writev) as an optimization, but now that we have to make our write stateful, it's another small complexity to worry about.
To make our life easier, we're going to assume we want to implement a request -> response flow and we'll revert to using write
rather than writev
. Because write
might return error.WouldBlock
before writing our whole message, we need to add more state to our Client
:
const Client = struct {
reader: Reader,
socket: posix.socket_t,
address: std.net.Address,
to_write: []u8,
write_buf: []u8,
fn init(allocator: Allocator, socket: posix.socket_t, address: std.net.Address) !Client {
const reader = try Reader.init(allocator, 4096);
errdefer reader.deinit(allocator);
const write_buf = try allocator.alloc(u8, 4096);
errdefer allocator.free(write_buf);
return .{
.reader = reader,
.socket = socket,
.address = address,
.to_write = &.{},
.write_buf = write_buf,
};
}
fn deinit(self: *const Client, allocator: Allocator) void {
self.reader.deinit(allocator);
allocator.free(self.write_buf);
}
...
};
We've added and initialized a write_buf
field. When writeMessage
is called, we'll copy the bytes (along with the length prefix) here. to_write
is a slice of write_buf
which represents the bytes we still need to write. You'll often see implementations add a mode
to reflect whether or not the client is currently reading or writing data. For now, we'll just use to_write
- when to_write.len == 0
, it means we're reading (or waiting for) a message.
writeMessage
must change to copy the message into write_buf
:
fn writeMessage(self: *Client, msg: []const u8) !bool {
if (self.to_write.len > 0) {
return error.PendingMessage;
}
if (msg.len + 4 > self.write_buf.len) {
return error.MessageTooLarge;
}
std.mem.writeInt(u32, self.write_buf[0..4], @intCast(msg.len), .little);
const end = msg.len + 4;
@memcpy(self.write_buf[4..end], msg);
self.to_write = self.write_buf[0..end];
return self.write();
}
fn write(self: *Client) !bool {
var buf = self.to_write;
defer self.to_write = buf;
while (buf.len > 0) {
const n = posix.write(self.socket, buf) catch |err| switch (err) {
error.WouldBlock => return false,
else => return err,
};
if (n == 0) {
return error.Closed;
}
buf = buf[n..];
} else {
return true;
}
}
writeMessage
creates and stores the prefixed-length message in our new write_buf
, and then calls write
to write as much of the message as possible. For its part, write
writes as much of our unsent bytes, stored in to_write
. Crucially, it returns false
if the write is incomplete, and true
otherwise. This is used by our run
function to switch the socket between read and write mode. Here's the modified poll loop within our server's run
method. Only a handful of lines, near the end, have changed:
while (i < self.connected) {
const revents = self.client_polls[i].revents;
if (revents == 0) {
i += 1;
continue;
}
var client = &self.clients[i];
if (revents & posix.POLL.IN == posix.POLL.IN) {
while (true) {
const msg = client.readMessage() catch {
self.removeClient(i);
break;
} orelse {
i += 1;
break;
};
const written = client.writeMessage(msg) catch {
self.removeClient(i);
break;
};
if (written == false) {
self.client_polls[i].events = posix.POLL.OUT;
break;
}
}
} else if (revents & posix.POLL.OUT == posix.POLL.OUT) {
const written = client.write() catch {
self.removeClient(i);
continue;
};
if (written) {
self.client_polls[i].events = posix.POLL.IN;
}
}
}
We now attempt to write the message back to the client. If writeMessage
is able to immediately write the entire message (if it returns true
), then nothing changes - we go back to trying to read more messages from the cilent. If writeMessage
returns false
, then our message was not fully written and we switch to "write-mode", which is to say, we stop monitoring POLL.IN
and start monitoring POLL.OUT
.
In this implementation, we're either monitoring POLL.IN
or POLL.OUT
, never both. This is why we can use an else if
to check if revents
is signaling write-readiness. And, if it is, we try to write more data. Once all data is written, we can revert to monitoring POLL.IN
.
It's possible to monitor both POLL.IN
and POLL.OUT
at the same time, but as we saw, we should only monitor POLL.OUT
if we actually have something to write, else we'll get endless and pointless notifications. We could support multiple pending write message by appending new messages to write_buf
and expanding to_write
accordingly. Or, we could use an array or ArrayList of buffers - one per pending message.
As-is, our Client
has no way to reach into the corresponding pollfd
. That means that we can't expose the readMessage
function for the application to call. Only our Server's run
method can call writeMessage
because only it can handle a partial write by changing the pollfd
's event to POLL.OUT
. This is not an insurmountable problem - the client could hold a reference to the server as well as its client_polls
index. However, this becomes much easier to solve using epoll
and kqueue
, so we'll leave it until then.
Finally, because write
often succeeds in a singe call, there is an opportunity to optimize writeMessage
. We could first try our vectored write first, avoiding having to copy bytes to our write_buf
. However, writev
, like write
, might do a partial write and then return error.WouldBlock
, in which case we'd need to only copy the unwritten bytes to write_buf
. I'll leave this optimization to you.
We previously improved the organization of our code by introducing a Server
structure and extracting various behavior into their own methods. We also introduced a Client
to maintain state, such as our read and write buffers, associated with a socket. It isn't hard to imagine the Client
becoming the main point of interaction with the rest of the application. As-is, that wouldn't be possible because the client doesn't have a fixed-location, i.e. it can't safely be referenced from outside the server. Because of the way we handle removal from the clients
array, a Client
instance doesn't have a stable address.
We'll fix this by allocating clients on the heap. To accommodate this change, we'll add a MemoryPool to our server and change the type of value our clients
array holds from Client
to *Client
:
const Server = struct {
client_pool: std.heap.MemoryPool(Client),
clients: []*Client,
};
Our Server's init
and deinit
needs to be adjusted:
fn init(allocator: Allocator, max: usize) !Server {
const polls = try allocator.alloc(posix.pollfd, max + 1);
errdefer allocator.free(polls);
const clients = try allocator.alloc(*Client, max);
errdefer allocator.free(clients);
return .{
.polls = polls,
.clients = clients,
.client_polls = polls[1..],
.connected = 0,
.allocator = allocator,
.client_pool = std.heap.MemoryPool(Client).init(allocator),
};
}
fn deinit(self: *Server) void {
self.allocator.free(self.polls);
self.allocator.free(self.clients);
self.client_pool.deinit();
}
If you aren't familiar with Zig's MemoryPool
, it's a specialized ArenaAllocator that can create a single type. It maintains a free-list of previously destroyed values for re-use, so subsequent calls to create
can be very cheap.
We need to make three final changes. First, we need to change our accept
method to use client_pool
to create an instance:
const client = try self.client_pool.create();
errdefer self.client_pool.destroy(client);
client.* = Client.init(self.allocator, socket, address) catch |err| {
posix.close(socket);
log.err("failed to initialize client: {}", .{err});
return;
};
Then, for cleanup, when removeClient
is called, we need to destroy the client:
fn removeClient(self: *Server, at: usize) void {
var client = self.clients[at];
defer self.client_pool.destroy(client);
}
Finally, in our run
method, we were previously getting a reference to the Client
stored in our array: var client = &self.clients[i];
. We no longer have to dereference the array value since the value is already a pointer. Thus, the code becomes: var client = self.clients[i];
(notice the & is removed).
Allowing clients to be referenced is our first step in making Client
a first-class citizen in our system. Next we'll look at implementing a read timeout, and we'll now be able to safely reference a Client
.
In Part 1 we called setsocketopt
with the SO.RCVTIMEO
and SO.SNDTIMEO
options to set a timeout on subsequent read
and write
operations. I wish I could tell you that you can use the same mechanism and, on timeout, poll
will notify you. Unfortunately, that isn't the case. As far as I know, there isn't a built-in way to hook read/write timeouts with evented I/O. What we do have, is the ability to pass a timeout to poll
itself. So far we've been passing a timeout of -1, which means poll
block until there's at least one event.
When poll
is given a timeout, as milliseconds, it'll return after the timeout expires even if no socket is ready. We need a way to find the client which is going to timeout next, and set poll
's timeout based on that. This is something that we'll need to do before every call to poll
, so it seems like it'll be prohibitively expensive - looping through every client to find the one closest to timing out. But, there's a time and memory efficient data structure that's perfect for solving this problem: a doubly linked list.
Say we want to enforce an idle timeout of 5 minutes. If a client doesn't send a message within 5 minutes of connecting or within 5 minutes of their last message, they'll get disconnected. Initially we have no clients, so we have an empty linked list:
head -> null null <- tail
At this point, we can set a timeout of -1 (infinity). When the first client connects, we set its read_timeout
field to now + 60
and append it to linked list. In the name of readability, I'm going to use a timestamp to display the absolute timeout of a client:
head -> c1[to=13:00] <-tail
Any subsequent client that connects will have a timeout further in the future than this client. If three more clients connect, we set their read_timeout
field and can append them to our list:
head -> c1[to=13:00] <-> c2[to=13:02] <-> c3[to=13:02] <-> c4[to=13:05] <-tail
Put differently, the oldest connection is the one that'll timeout soonest. Thus, by always appending new connections to the end of our list, we can traverse the list to find timed-out clients and stop iterating as soon as we find a client with a timeout in the future. This client will be the next to time-out. This also holds true after a message is received. Let's say that both c1 and c3 send us a message. All we need to do is move them to the end of our list:
head -> c2[to=13:02] <-> c4[to=13:05] <-> c1[to=13:07] <-> c3[to=13:07] <-tail
You might be wondering: what if you want to have two separate timeouts. For example, you might want a short timeout for initial messages (maybe as part of an authentication flow) and then a much longer timeout for all subsequent messages. As with most problems in computer science, the solution is: add more linked lists! For each distinct timeout value, we need a distinct list. We can figure the next timeout value to pass to poll by taking the minimal next timeout value of all lists.
Most of this code has little to do with network programming - we're mostly just moving linked list nodes around. You can see the full implementation at the end. The way we enforce the timeout and get the timeout value to pass to poll
is worth reviewing though:
while (true) {
const next_timeout = self.enforceTimeout();
_ = try posix.poll(self.polls[0..self.connected + 1], next_timeout);
}
We no longer pass -1
as the timeout to poll
. The new enforceTimeout
, which is always called before we poll, not only disconnects timed-out clients, but it also returns the time, in millisecond, that our next client will timeout at. This is the the maximum amount of time we can block on poll
, and so it is our timeout:
fn enforceTimeout(self: *Server) i32 {
const now = std.time.milliTimestamp();
var node = self.read_timeout_list.first;
while (node) |n| {
const client = n.data;
const diff = client.read_timeout - now;
if (diff > 0) {
return @intCast(diff);
}
posix.shutdown(client.socket, .recv) catch {};
node = n.next;
} else {
return -1;
}
}
The read_timeout
field given to each client is the absolute time that the client will timeout. It is set on connection and updated after each message is received, extending the timeout. Because our list keeps clients ordered by timeout, we can iterate through the list and disconnect any client timed-out clients. As soon as we find a client with a future timeout, we can return and use the amount of time until this client times-out as the timeout to pass to poll
.
The poll
system call, while far from perfect, is a wonderful introduction to evented I/O. It teaches us the need to manage per-connection state (i.e. read and write buffers) so that we can respond to the OS' notification about the readiness of monitored sockets. With blocking I/O, the stream oriented nature of TCP is easy to gloss over - we can read in a loop until we have a message and write in a loop until we've written a message. With non-blocking sockets, it's something we have to put a lot more thought and care into. We still "loop" until our message is fully read or written, but that loop happens over a much wider expanse of code.
In the next part we'll look at epoll
, a more powerful and better performing Linux-specific version of poll
. Almost everything that we've learnt so far will be directly applicable to epoll
as well as kqueue
(our subject after epoll
). One of the most annoying part of our implementation so far has been the index-sharing sync between polls
and client_polls
. In fairness to poll
, using an AutoHashMap
might have made our life easier. Still, I'm glad to say that both epoll
and kqueue
will allow us to clean up that mess!
const std = @import("std");
const net = std.net;
const posix = std.posix;
const Allocator = std.mem.Allocator;
const log = std.log.scoped(.tcp_demo);
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
const allocator = gpa.allocator();
var server = try Server.init(allocator, 4096);
defer server.deinit();
const address = try std.net.Address.parseIp("127.0.0.1", 5882);
try server.run(address);
std.debug.print("STOPPED\n", .{});
}
const READ_TIMEOUT_MS = 60_000;
const ClientList = std.DoublyLinkedList(*Client);
const ClientNode = ClientList.Node;
const Server = struct {
allocator: Allocator,
connected: usize,
polls: []posix.pollfd,
client_pool: std.heap.MemoryPool(Client),
clients: []*Client,
client_polls: []posix.pollfd,
read_timeout_list: ClientList,
client_node_pool: std.heap.MemoryPool(ClientList.Node),
fn init(allocator: Allocator, max: usize) !Server {
const polls = try allocator.alloc(posix.pollfd, max + 1);
errdefer allocator.free(polls);
const clients = try allocator.alloc(*Client, max);
errdefer allocator.free(clients);
return .{
.polls = polls,
.clients = clients,
.client_polls = polls[1..],
.connected = 0,
.allocator = allocator,
.read_timeout_list = .{},
.client_pool = std.heap.MemoryPool(Client).init(allocator),
.client_node_pool = std.heap.MemoryPool(ClientNode).init(allocator),
};
}
fn deinit(self: *Server) void {
self.allocator.free(self.polls);
self.allocator.free(self.clients);
self.client_pool.deinit();
self.client_node_pool.deinit();
}
fn run(self: *Server, address: std.net.Address) !void {
const tpe: u32 = posix.SOCK.STREAM | posix.SOCK.NONBLOCK;
const protocol = posix.IPPROTO.TCP;
const listener = try posix.socket(address.any.family, tpe, protocol);
defer posix.close(listener);
try posix.setsockopt(listener, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try posix.bind(listener, &address.any, address.getOsSockLen());
try posix.listen(listener, 128);
self.polls[0] = .{
.fd = listener,
.revents = 0,
.events = posix.POLL.IN,
};
var read_timeout_list = &self.read_timeout_list;
while (true) {
const next_timeout = self.enforceTimeout();
_ = try posix.poll(self.polls[0..self.connected + 1], next_timeout);
if (self.polls[0].revents != 0) {
self.accept(listener) catch |err| log.err("failed to accept: {}", .{err});
}
var i: usize = 0;
while (i < self.connected) {
const revents = self.client_polls[i].revents;
if (revents == 0) {
i += 1;
continue;
}
var client = self.clients[i];
if (revents & posix.POLL.IN == posix.POLL.IN) {
while (true) {
const msg = client.readMessage() catch {
self.removeClient(i);
break;
} orelse {
i += 1;
break;
};
client.read_timeout = std.time.milliTimestamp() + READ_TIMEOUT_MS;
read_timeout_list.remove(client.read_timeout_node);
read_timeout_list.append(client.read_timeout_node);
const written = client.writeMessage(msg) catch {
self.removeClient(i);
break;
};
if (written == false) {
self.client_polls[i].events = posix.POLL.OUT;
break;
}
}
} else if (revents & posix.POLL.OUT == posix.POLL.OUT) {
const written = client.write() catch {
self.removeClient(i);
continue;
};
if (written) {
self.client_polls[i].events = posix.POLL.IN;
}
}
}
}
}
fn enforceTimeout(self: *Server) i32 {
const now = std.time.milliTimestamp();
var node = self.read_timeout_list.first;
while (node) |n| {
const client = n.data;
const diff = client.read_timeout - now;
if (diff > 0) {
return @intCast(diff);
}
posix.shutdown(client.socket, .recv) catch {};
node = n.next;
} else {
return -1;
}
}
fn accept(self: *Server, listener: posix.socket_t) !void {
const space = self.client_polls.len - self.connected;
for (0..space) |_| {
var address: net.Address = undefined;
var address_len: posix.socklen_t = @sizeOf(net.Address);
const socket = posix.accept(listener, &address.any, &address_len, posix.SOCK.NONBLOCK) catch |err| switch (err) {
error.WouldBlock => return,
else => return err,
};
const client = try self.client_pool.create();
errdefer self.client_pool.destroy(client);
client.* = Client.init(self.allocator, socket, address) catch |err| {
posix.close(socket);
log.err("failed to initialize client: {}", .{err});
return;
};
client.read_timeout = std.time.milliTimestamp() + READ_TIMEOUT_MS;
client.read_timeout_node = try self.client_node_pool.create();
errdefer self.client_node_pool.destroy(client.read_timeout_node);
client.read_timeout_node.* = .{
.next = null,
.prev = null,
.data = client,
};
self.read_timeout_list.append(client.read_timeout_node);
const connected = self.connected;
self.clients[connected] = client;
self.client_polls[connected] = .{
.fd = socket,
.revents = 0,
.events = posix.POLL.IN,
};
self.connected = connected + 1;
} else {
self.polls[0].events = 0;
}
}
fn removeClient(self: *Server, at: usize) void {
var client = self.clients[at];
defer {
posix.close(client.socket);
self.client_node_pool.destroy(client.read_timeout_node);
client.deinit(self.allocator);
self.client_pool.destroy(client);
}
const last_index = self.connected - 1;
self.clients[at] = self.clients[last_index];
self.client_polls[at] = self.client_polls[last_index];
self.connected = last_index;
self.polls[0].events = posix.POLL.IN;
self.read_timeout_list.remove(client.read_timeout_node);
}
};
const Client = struct {
socket: posix.socket_t,
address: std.net.Address,
reader: Reader,
to_write: []u8,
write_buf: []u8,
read_timeout: i64,
read_timeout_node: *ClientNode,
fn init(allocator: Allocator, socket: posix.socket_t, address: std.net.Address) !Client {
const reader = try Reader.init(allocator, 4096);
errdefer reader.deinit(allocator);
const write_buf = try allocator.alloc(u8, 4096);
errdefer allocator.free(write_buf);
return .{
.reader = reader,
.socket = socket,
.address = address,
.to_write = &.{},
.write_buf = write_buf,
.read_timeout = 0,
.read_timeout_node = undefined,
};
}
fn deinit(self: *const Client, allocator: Allocator) void {
self.reader.deinit(allocator);
allocator.free(self.write_buf);
}
fn readMessage(self: *Client) !?[]const u8 {
return self.reader.readMessage(self.socket) catch |err| switch (err) {
error.WouldBlock => return null,
else => return err,
};
}
fn writeMessage(self: *Client, msg: []const u8) !bool {
if (self.to_write.len > 0) {
return error.PendingMessage;
}
if (msg.len + 4 > self.write_buf.len) {
return error.MessageTooLarge;
}
std.mem.writeInt(u32, self.write_buf[0..4], @intCast(msg.len), .little);
const end = msg.len + 4;
@memcpy(self.write_buf[4..end], msg);
self.to_write = self.write_buf[0..end];
return self.write();
}
fn write(self: *Client) !bool {
var buf = self.to_write;
defer self.to_write = buf;
while (buf.len > 0) {
const n = posix.write(self.socket, buf) catch |err| switch (err) {
error.WouldBlock => return false,
else => return err,
};
if (n == 0) {
return error.Closed;
}
buf = buf[n..];
} else {
return true;
}
}
};
const Reader = struct {
buf: []u8,
pos: usize = 0,
start: usize = 0,
fn init(allocator: Allocator, size: usize) !Reader {
const buf = try allocator.alloc(u8, size);
return .{
.pos = 0,
.start = 0,
.buf = buf,
};
}
fn deinit(self: *const Reader, allocator: Allocator) void {
allocator.free(self.buf);
}
fn readMessage(self: *Reader, socket: posix.socket_t) ![]u8 {
var buf = self.buf;
while (true) {
if (try self.bufferedMessage()) |msg| {
return msg;
}
const pos = self.pos;
const n = try posix.read(socket, buf[pos..]);
if (n == 0) {
return error.Closed;
}
self.pos = pos + n;
}
}
fn bufferedMessage(self: *Reader) !?[]u8 {
const buf = self.buf;
const pos = self.pos;
const start = self.start;
std.debug.assert(pos >= start);
const unprocessed = buf[start..pos];
if (unprocessed.len < 4) {
self.ensureSpace(4 - unprocessed.len) catch unreachable;
return null;
}
const message_len = std.mem.readInt(u32, unprocessed[0..4], .little);
const total_len = message_len + 4;
if (unprocessed.len < total_len) {
try self.ensureSpace(total_len);
return null;
}
self.start += total_len;
return unprocessed[4..total_len];
}
fn ensureSpace(self: *Reader, space: usize) error{BufferTooSmall}!void {
const buf = self.buf;
if (buf.len < space) {
return error.BufferTooSmall;
}
const start = self.start;
const spare = buf.len - start;
if (spare >= space) {
return;
}
const unprocessed = buf[start..self.pos];
std.mem.copyForwards(u8, buf[0..unprocessed.len], unprocessed);
self.start = 0;
self.pos = unprocessed.len;
}
};