TCP Server in Zig - Part 4 - Multithreading
Oct 11, 2024
We finished Part 1 with a simple single-threaded server, which we could describe as:
- Create our socket
- Bind it to an address
- Put it in "server" mode (i.e. call
listen
on it)
- Accept a connection
- Application logic involving reading/writing to the socket
- Close the connection
- Goto step 4
While this approach is useful for getting familiar with various socket APIs and networking concepts, it isn't practical for most real world application. The issue is that it can only service 1 client at a time. Any additional clients will block (or fail to connect) until the currently connected client is finished and the server calls accept
again to process the next client in line.
There are a few different ways to deal with this problem, with some of these being complementary. But a common place to start is to move step 5 and 6 from our above list into their own thread:
const std = @import("std");
const net = std.net;
const posix = std.posix;
pub fn main() !void {
while (true) {
var client_address: net.Address = undefined;
var client_address_len: posix.socklen_t = @sizeOf(net.Address);
const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
std.debug.print("error accept: {}\n", .{err});
continue;
};
const thread = try std.Thread.spawn(.{}, run, .{socket, client_address});
thread.detach();
}
}
fn run(socket: posix.socket_t, address: std.net.Address) !void {
defer posix.close(socket);
const timeout = posix.timeval{ .sec = 2, .usec = 500_000 };
try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(timeout));
try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &std.mem.toBytes(timeout));
var buf: [1024]u8 = undefined;
var reader = Reader{ .pos = 0, .buf = &buf, .socket = socket };
while (true) {
const msg = try reader.readMessage();
std.debug.print("Got: {s}\n", .{msg});
}
}
Thread.spawn
is used to launch a new thread. The first parameter is a SpawnConfig
which allows us to set a stack size (defaults to 16MB). The second parameter is the function to run and the last parameter are the arguments to pass to the function. The first parameter we pass is the socket
and the second is the client_address
. As a small tweak, we can create a Client
to start encompassing our client-handling logic:
const Client = struct {
socket: posix.socket_t,
address: std.net.Address,
fn handle(self: Client) !void {
const socket = self.socket;
defer posix.close(socket);
std.debug.print("{} connected\n", .{self.address});
const timeout = posix.timeval{ .sec = 2, .usec = 500_000 };
try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(timeout));
try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &std.mem.toBytes(timeout));
var buf: [1024]u8 = undefined;
var reader = Reader{ .pos = 0, .buf = &buf, .socket = socket };
while (true) {
const msg = try reader.readMessage();
std.debug.print("Got: {s}\n", .{msg});
}
}
};
This is the same code as above, just organized a little better. To call this from our main listening thread, the change is small:
while (true) {
var client_address: net.Address = undefined;
var client_address_len: posix.socklen_t = @sizeOf(net.Address);
const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
};
const client = Client{ .socket = socket, .address = client_address };
const thread = try std.Thread.spawn(.{}, Client.handle, .{client});
thread.detach();
}
The client
variable in the above code is scoped to, and only valid within, the while
block. Care must be taken with respect to the values we pass to Thread.spawn
; the spawned thread is independent from the location where it was spawned from. If we changed the above to pass a reference to client, &client
, we'd be in undefined behavior territory. Our code would probably crash, but before that happens, it could send the wrong message to the wrong client.
spawn
returns an std.Thread
. You always want to call either join
or detach
on this. join
will block the caller until the thread exits. detach
is used to indicate that we never intend to call join
on the thread. That might seem odd, but by calling detach
we signal that we don't intend to call join
on the thread. In doing so, the threading implementation (which is platform specific) can release any data associated with the thread upon thread exit. If you don't call detach
, some state has to stick around because you might call join
at some point in the future.
If we replaced our call to detach
with join
, our implementation would behave a lot like our initial single-threaded example. Our main thread would accept the connection and spawn a new thread, but then would block on the call to join
until the newly spawned thread terminated.
Spawning a thread per connection is a simple modification to our existing code, but our current implementation will spawn as many threads as there are connections. That could be an issue if we're expecting thousands of concurrent connection. There's no one-size-fits-all rule for the number of threads a system can support. It'll depend on the hardware we're running and, critically, what those threads are doing. If the threads are doing heavy CPU work, having more threads than CPU cores could hurt performance.
One common solution to this problem is to use a ThreadPool. Let's modify our code, then look at how it works:
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
const allocator = gpa.allocator();
var pool: std.Thread.Pool = undefined;
try std.Thread.Pool.init(&pool, .{.allocator = allocator, .n_jobs = 64});
while (true) {
var client_address: net.Address = undefined;
var client_address_len: posix.socklen_t = @sizeOf(net.Address);
const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
};
const client = Client{ .socket = socket, .address = client_address };
try pool.spawn(Client.handle, .{client});
}
}
We also need to make a slight change to our Client.handle
, but let's review the above first. For the first time in this series, we need an allocator. You might find it odd that we define pool
and then pass its address into ThreadPool.init
. This is because init
needs a stable pointer to the ThreadPool
instance, and this provides the caller with the flexibility to decide how that should be achieved (the standard library doesn't like calling allocator.create(Self)
within init
functions).
Previously, our Client.handle
returned !void
, which Thread.spawn
handled. But our Thread.Pool
doesn't. We need to change handle
to return void
instead of !void
. Since I like the ergonomics of try
, I think the best option is to make handle
a wrapper around a function that returns an error:
const Client = struct {
fn handle(self: Client) void {
self._handle() catch |err| switch (err) {
error.Closed => {},
else => std.debug.print("[{any}] client handle error: {}\n", .{self.address, err}),
};
}
fn _handle(self: Client) !void {
}
};
One issue with Zig's Thread.Pool
is that it can be memory-intensive. This is because it's generic, each invocation can be given a different function to run and parameters to use. This requires creating a closure around the arguments. My understand is that Zig's ThreadPool was designed for long-running jobs, where the initial overhead was a relatively small cost compared to the overall work being done. If you're running many short-lived jobs, like processing HTTP requests, you might want to look at something more optimized for that use-case.
Finally, in the above code, we initialized our Thread.Pool
with n_jobs
set to 64. This is the number of workers our pool will have and represents the maximum number of concurrent jobs it will be able to run. This limits the number of concurrent connections our system can support to 64. One some hardware, for some workloads, you could set this number much higher.
Our 64-worker ThreadPool is considerably better than our single-threaded implementation without adding much complexity. The most significant change we made, introducing the Client
, was housecleaning unrelated to our multithreaded initiative. You can forgo the ThreadPool and call Thread.spawn
directly, but you should only do this in more controlled environments - where the number of clients is known (and relatively small) and you aren't exposing your service to the public internet.
In the next part, we'll take our first look at a nonblocking implementation, which is how we'll be able to support a much larger number of connections. However, and as we'll see, the nonblocking implementation doesn't have to be an alternative to using a ThreadPool. The two can compliment each other, so everything we've learnt here will continue to be useful.
Here's a complete working server implementation:
const std = @import("std");
const net = std.net;
const posix = std.posix;
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
const allocator = gpa.allocator();
var pool: std.Thread.Pool = undefined;
try std.Thread.Pool.init(&pool, .{.allocator = allocator, .n_jobs = 64});
const address = try std.net.Address.parseIp("127.0.0.1", 5882);
const tpe: u32 = posix.SOCK.STREAM;
const protocol = posix.IPPROTO.TCP;
const listener = try posix.socket(address.any.family, tpe, protocol);
defer posix.close(listener);
try posix.setsockopt(listener, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
try posix.bind(listener, &address.any, address.getOsSockLen());
try posix.listen(listener, 128);
while (true) {
var client_address: net.Address = undefined;
var client_address_len: posix.socklen_t = @sizeOf(net.Address);
const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
std.debug.print("error accept: {}\n", .{err});
continue;
};
const client = Client{ .socket = socket, .address = client_address };
try pool.spawn(Client.handle, .{client});
}
}
const Client = struct {
socket: posix.socket_t,
address: std.net.Address,
fn handle(self: Client) void {
defer posix.close(self.socket);
self._handle() catch |err| switch (err) {
error.Closed => {},
error.WouldBlock => {},
else => std.debug.print("[{any}] client handle error: {}\n", .{self.address, err}),
};
}
fn _handle(self: Client) !void {
const socket = self.socket;
std.debug.print("[{}] connected\n", .{self.address});
const timeout = posix.timeval{ .sec = 2, .usec = 500_000 };
try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(timeout));
try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &std.mem.toBytes(timeout));
var buf: [1024]u8 = undefined;
var reader = Reader{ .pos = 0, .buf = &buf, .socket = socket };
while (true) {
const msg = try reader.readMessage();
std.debug.print("[{}] sent: {s}\n", .{self.address, msg});
}
}
};
const Reader = struct {
buf: []u8,
pos: usize = 0,
start: usize = 0,
socket: posix.socket_t,
fn readMessage(self: *Reader) ![]u8 {
var buf = self.buf;
while (true) {
if (try self.bufferedMessage()) |msg| {
return msg;
}
const pos = self.pos;
const n = try posix.read(self.socket, buf[pos..]);
if (n == 0) {
return error.Closed;
}
self.pos = pos + n;
}
}
fn bufferedMessage(self: *Reader) !?[]u8 {
const buf = self.buf;
const pos = self.pos;
const start = self.start;
std.debug.assert(pos >= start);
const unprocessed = buf[start..pos];
if (unprocessed.len < 4) {
self.ensureSpace(4 - unprocessed.len) catch unreachable;
return null;
}
const message_len = std.mem.readInt(u32, unprocessed[0..4], .little);
const total_len = message_len + 4;
if (unprocessed.len < total_len) {
try self.ensureSpace(total_len);
return null;
}
self.start += total_len;
return unprocessed[4..total_len];
}
fn ensureSpace(self: *Reader, space: usize) error{BufferTooSmall}!void {
const buf = self.buf;
if (buf.len < space) {
return error.BufferTooSmall;
}
const start = self.start;
const spare = buf.len - start;
if (spare >= space) {
return;
}
const unprocessed = buf[start..self.pos];
std.mem.copyForwards(u8, buf[0..unprocessed.len], unprocessed);
self.start = 0;
self.pos = unprocessed.len;
}
};
Which you can test with this dummy client:
const std = @import("std");
const posix = std.posix;
pub fn main() !void {
const address = try std.net.Address.parseIp("127.0.0.1", 5882);
const tpe: u32 = posix.SOCK.STREAM;
const protocol = posix.IPPROTO.TCP;
const socket = try posix.socket(address.any.family, tpe, protocol);
defer posix.close(socket);
try posix.connect(socket, &address.any, address.getOsSockLen());
try writeMessage(socket, "Hello World");
try writeMessage(socket, "It's Over 9000!!");
}
fn writeMessage(socket: posix.socket_t, msg: []const u8) !void {
var buf: [4]u8 = undefined;
std.mem.writeInt(u32, &buf, @intCast(msg.len), .little);
var vec = [2]posix.iovec_const{
.{ .len = 4, .base = &buf },
.{ .len = msg.len, .base = msg.ptr },
};
try writeAllVectored(socket, &vec);
}
fn writeAllVectored(socket: posix.socket_t, vec: []posix.iovec_const) !void {
var i: usize = 0;
while (true) {
var n = try posix.writev(socket, vec[i..]);
while (n >= vec[i].len) {
n -= vec[i].len;
i += 1;
if (i >= vec.len) return;
}
vec[i].base += n;
vec[i].len -= n;
}
}