home

TCP Server in Zig - Part 4 - Multithreading

Oct 11, 2024

We finished Part 1 with a simple single-threaded server, which we could describe as:

  1. Create our socket
  2. Bind it to an address
  3. Put it in "server" mode (i.e. call listen on it)
  4. Accept a connection
  5. Application logic involving reading/writing to the socket
  6. Close the connection
  7. Goto step 4

While this approach is useful for getting familiar with various socket APIs and networking concepts, it isn't practical for most real world application. The issue is that it can only service 1 client at a time. Any additional clients will block (or fail to connect) until the currently connected client is finished and the server calls accept again to process the next client in line.

There are a few different ways to deal with this problem, with some of these being complementary. But a common place to start is to move step 5 and 6 from our above list into their own thread:

const std = @import("std");
const net = std.net;
const posix = std.posix;

pub fn main() !void {
    // All of the same existing code from before, to create, bind and listen on
    // the socket. A complete working example is given at the end of this post.

    while (true) {
        var client_address: net.Address = undefined;
        var client_address_len: posix.socklen_t = @sizeOf(net.Address);
        const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
            // Rare that this happens, but in later parts we'll
            // see examples where it does.
            std.debug.print("error accept: {}\n", .{err});
            continue;
        };
        const thread = try std.Thread.spawn(.{}, run, .{socket, client_address});
        thread.detach();
    }
}

fn run(socket: posix.socket_t, address: std.net.Address) !void {
    defer posix.close(socket);

    const timeout = posix.timeval{ .sec = 2, .usec = 500_000 };
    try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(timeout));
    try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &std.mem.toBytes(timeout));

    var buf: [1024]u8 = undefined;
    var reader = Reader{ .pos = 0, .buf = &buf, .socket = socket };

    while (true) {
        const msg = try reader.readMessage();
        std.debug.print("Got: {s}\n", .{msg});
    }
}

Thread.spawn is used to launch a new thread. The first parameter is a SpawnConfig which allows us to set a stack size (defaults to 16MB). The second parameter is the function to run and the last parameter are the arguments to pass to the function. The first parameter we pass is the socket and the second is the client_address. As a small tweak, we can create a Client to start encompassing our client-handling logic:

const Client = struct {
    socket: posix.socket_t,
    address: std.net.Address,

    fn handle(self: Client) !void {
        const socket = self.socket;

        defer posix.close(socket);
        std.debug.print("{} connected\n", .{self.address});

        const timeout = posix.timeval{ .sec = 2, .usec = 500_000 };
        try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(timeout));
        try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &std.mem.toBytes(timeout));

        var buf: [1024]u8 = undefined;
        var reader = Reader{ .pos = 0, .buf = &buf, .socket = socket };

        while (true) {
            const msg = try reader.readMessage();
            std.debug.print("Got: {s}\n", .{msg});
        }
    }

};

This is the same code as above, just organized a little better. To call this from our main listening thread, the change is small:

while (true) {
    var client_address: net.Address = undefined;
    var client_address_len: posix.socklen_t = @sizeOf(net.Address);
    const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
        //  same code as before
    };

    const client = Client{ .socket = socket, .address = client_address };
    const thread = try std.Thread.spawn(.{}, Client.handle, .{client});
    thread.detach();
}

The client variable in the above code is scoped to, and only valid within, the while block. Care must be taken with respect to the values we pass to Thread.spawn; the spawned thread is independent from the location where it was spawned from. If we changed the above to pass a reference to client, &client, we'd be in undefined behavior territory. Our code would probably crash, but before that happens, it could send the wrong message to the wrong client.

spawn returns an std.Thread. You always want to call either join or detach on this. join will block the caller until the thread exits. detach is used to indicate that we never intend to call join on the thread. That might seem odd, but by calling detach we signal that we don't intend to call join on the thread. In doing so, the threading implementation (which is platform specific) can release any data associated with the thread upon thread exit. If you don't call detach, some state has to stick around because you might call join at some point in the future.

If we replaced our call to detach with join, our implementation would behave a lot like our initial single-threaded example. Our main thread would accept the connection and spawn a new thread, but then would block on the call to join until the newly spawned thread terminated.

Spawning a thread per connection is a simple modification to our existing code, but our current implementation will spawn as many threads as there are connections. That could be an issue if we're expecting thousands of concurrent connection. There's no one-size-fits-all rule for the number of threads a system can support. It'll depend on the hardware we're running and, critically, what those threads are doing. If the threads are doing heavy CPU work, having more threads than CPU cores could hurt performance.

One common solution to this problem is to use a ThreadPool. Let's modify our code, then look at how it works:

pub fn main() !void {
    // our thread pool needs an allocator
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    const allocator = gpa.allocator();

    var pool: std.Thread.Pool = undefined;
    try std.Thread.Pool.init(&pool, .{.allocator = allocator, .n_jobs = 64});

    // ...
    // all the same socket setup stuff as before
    // removed to keep this snippet more readable..
    // ...

    while (true) {
        var client_address: net.Address = undefined;
        var client_address_len: posix.socklen_t = @sizeOf(net.Address);
        const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
            //  same code as before
        };

        const client = Client{ .socket = socket, .address = client_address };
        try pool.spawn(Client.handle, .{client});
    }
}

We also need to make a slight change to our Client.handle, but let's review the above first. For the first time in this series, we need an allocator. You might find it odd that we define pool and then pass its address into ThreadPool.init. This is because init needs a stable pointer to the ThreadPool instance, and this provides the caller with the flexibility to decide how that should be achieved (the standard library doesn't like calling allocator.create(Self) within init functions).

Previously, our Client.handle returned !void, which Thread.spawn handled. But our Thread.Pool doesn't. We need to change handle to return void instead of !void. Since I like the ergonomics of try, I think the best option is to make handle a wrapper around a function that returns an error:

const Client = struct {
    // ...

    fn handle(self: Client) void {
        self._handle() catch |err| switch (err) {
            error.Closed => {},
            else => std.debug.print("[{any}] client handle error: {}\n", .{self.address, err}),
        };
    }

    fn _handle(self: Client) !void {
        // same handle as before
    }
};

One issue with Zig's Thread.Pool is that it can be memory-intensive. This is because it's generic, each invocation can be given a different function to run and parameters to use. This requires creating a closure around the arguments. My understand is that Zig's ThreadPool was designed for long-running jobs, where the initial overhead was a relatively small cost compared to the overall work being done. If you're running many short-lived jobs, like processing HTTP requests, you might want to look at something more optimized for that use-case.

Finally, in the above code, we initialized our Thread.Pool with n_jobs set to 64. This is the number of workers our pool will have and represents the maximum number of concurrent jobs it will be able to run. This limits the number of concurrent connections our system can support to 64. One some hardware, for some workloads, you could set this number much higher.

Our 64-worker ThreadPool is considerably better than our single-threaded implementation without adding much complexity. The most significant change we made, introducing the Client, was housecleaning unrelated to our multithreaded initiative. You can forgo the ThreadPool and call Thread.spawn directly, but you should only do this in more controlled environments - where the number of clients is known (and relatively small) and you aren't exposing your service to the public internet.

In the next part, we'll take our first look at a nonblocking implementation, which is how we'll be able to support a much larger number of connections. However, and as we'll see, the nonblocking implementation doesn't have to be an alternative to using a ThreadPool. The two can compliment each other, so everything we've learnt here will continue to be useful.

Here's a complete working server implementation:

// Multithreaded server with a ThreadPool
const std = @import("std");
const net = std.net;
const posix = std.posix;

pub fn main() !void {
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    const allocator = gpa.allocator();

    var pool: std.Thread.Pool = undefined;
    try std.Thread.Pool.init(&pool, .{.allocator = allocator, .n_jobs = 64});

    const address = try std.net.Address.parseIp("127.0.0.1", 5882);

    const tpe: u32 = posix.SOCK.STREAM;
    const protocol = posix.IPPROTO.TCP;
    const listener = try posix.socket(address.any.family, tpe, protocol);
    defer posix.close(listener);

    try posix.setsockopt(listener, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
    try posix.bind(listener, &address.any, address.getOsSockLen());
    try posix.listen(listener, 128);

    while (true) {
        var client_address: net.Address = undefined;
        var client_address_len: posix.socklen_t = @sizeOf(net.Address);
        const socket = posix.accept(listener, &client_address.any, &client_address_len, 0) catch |err| {
            std.debug.print("error accept: {}\n", .{err});
            continue;
        };

        const client = Client{ .socket = socket, .address = client_address };
        try pool.spawn(Client.handle, .{client});
    }
}

const Client = struct {
    socket: posix.socket_t,
    address: std.net.Address,

    fn handle(self: Client) void {
        defer posix.close(self.socket);
        self._handle() catch |err| switch (err) {
            error.Closed => {},
            error.WouldBlock => {}, // read or write timeout
            else => std.debug.print("[{any}] client handle error: {}\n", .{self.address, err}),
        };
    }

    fn _handle(self: Client) !void {
        const socket = self.socket;
        std.debug.print("[{}] connected\n", .{self.address});

        const timeout = posix.timeval{ .sec = 2, .usec = 500_000 };
        try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &std.mem.toBytes(timeout));
        try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &std.mem.toBytes(timeout));

        var buf: [1024]u8 = undefined;
        var reader = Reader{ .pos = 0, .buf = &buf, .socket = socket };

        while (true) {
            const msg = try reader.readMessage();
            std.debug.print("[{}] sent: {s}\n", .{self.address, msg});
        }
    }
};

const Reader = struct {
    buf: []u8,
    pos: usize = 0,
    start: usize = 0,
    socket: posix.socket_t,

    fn readMessage(self: *Reader) ![]u8 {
        var buf = self.buf;

        while (true) {
            if (try self.bufferedMessage()) |msg| {
                return msg;
            }
            const pos = self.pos;
            const n = try posix.read(self.socket, buf[pos..]);
            if (n == 0) {
                return error.Closed;
            }
            self.pos = pos + n;
        }
    }

    fn bufferedMessage(self: *Reader) !?[]u8 {
        const buf = self.buf;
        const pos = self.pos;
        const start = self.start;

        std.debug.assert(pos >= start);
        const unprocessed = buf[start..pos];
        if (unprocessed.len < 4) {
            self.ensureSpace(4 - unprocessed.len) catch unreachable;
            return null;
        }

        const message_len = std.mem.readInt(u32, unprocessed[0..4], .little);

        // the length of our message + the length of our prefix
        const total_len = message_len + 4;

        if (unprocessed.len < total_len) {
            try self.ensureSpace(total_len);
            return null;
        }

        self.start += total_len;
        return unprocessed[4..total_len];
    }

    fn ensureSpace(self: *Reader, space: usize) error{BufferTooSmall}!void {
        const buf = self.buf;
        if (buf.len < space) {
            return error.BufferTooSmall;
        }

        const start = self.start;
        const spare = buf.len - start;
        if (spare >= space) {
            return;
        }

        const unprocessed = buf[start..self.pos];
        std.mem.copyForwards(u8, buf[0..unprocessed.len], unprocessed);
        self.start = 0;
        self.pos = unprocessed.len;
    }
};

Which you can test with this dummy client:

// Test client
const std = @import("std");
const posix = std.posix;

pub fn main() !void {
    const address = try std.net.Address.parseIp("127.0.0.1", 5882);

    const tpe: u32 = posix.SOCK.STREAM;
    const protocol = posix.IPPROTO.TCP;
    const socket = try posix.socket(address.any.family, tpe, protocol);
    defer posix.close(socket);

    try posix.connect(socket, &address.any, address.getOsSockLen());
    try writeMessage(socket, "Hello World");
    try writeMessage(socket, "It's Over 9000!!");
}

fn writeMessage(socket: posix.socket_t, msg: []const u8) !void {
    var buf: [4]u8 = undefined;
    std.mem.writeInt(u32, &buf, @intCast(msg.len), .little);

    var vec = [2]posix.iovec_const{
      .{ .len = 4, .base = &buf },
      .{ .len = msg.len, .base = msg.ptr },
    };
    try writeAllVectored(socket, &vec);
}

fn writeAllVectored(socket: posix.socket_t, vec: []posix.iovec_const) !void {
    var i: usize = 0;
    while (true) {
        var n = try posix.writev(socket, vec[i..]);
        while (n >= vec[i].len) {
            n -= vec[i].len;
            i += 1;
            if (i >= vec.len) return;
        }
        vec[i].base += n;
        vec[i].len -= n;
    }
}