diff --git a/src/Makefile b/src/Makefile index fee7951..6502c3d 100644 --- a/src/Makefile +++ b/src/Makefile @@ -33,7 +33,7 @@ endif CFLAGS ?= -O3 CFLAGS += -std=gnu11 -CFLAGS += -pedantic -Wall -Wextra +CFLAGS += -Wall -Wextra CFLAGS += -MMD -MP CFLAGS += -DRUNSTATEDIR="\"$(RUNSTATEDIR)\"" LDLIBS += -lresolv diff --git a/src/ipc.c b/src/ipc.c index 6237961..05609b4 100644 --- a/src/ipc.c +++ b/src/ipc.c @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -41,7 +40,7 @@ struct inflatable_buffer { size_t pos; }; -#define max(a, b) (a > b ? a : b) +#define max(a, b) ((a) > (b) ? (a) : (b)) static int add_next_to_inflatable_buffer(struct inflatable_buffer *buffer) { @@ -190,68 +189,75 @@ out: return (int)ret; } +#define READ_BYTES(bytes) ({ \ + void *__p; \ + size_t __bytes = (bytes); \ + if (bytes_left < __bytes) { \ + offset = p - buffer; \ + bytes_left += buffer_size; \ + buffer_size *= 2; \ + ret = -ENOMEM; \ + p = realloc(buffer, buffer_size); \ + if (!p) \ + goto out; \ + buffer = p; \ + p += offset; \ + } \ + bytes_left -= __bytes; \ + ret = read(fd, p, __bytes); \ + if (ret < 0) \ + goto out; \ + if ((size_t)ret != __bytes) { \ + ret = -EBADMSG; \ + goto out; \ + } \ + __p = p; \ + p += __bytes; \ + __p; \ +}) static int userspace_get_device(struct wgdevice **dev, const char *interface) { - struct pollfd pollfd = { .events = POLLIN }; - int len; - char byte = 0; - size_t i; - struct wgpeer *peer; + unsigned int len = 0, i; + size_t buffer_size, bytes_left; ssize_t ret; + ptrdiff_t offset; + uint8_t *buffer = NULL, *p, byte = 0; + int fd = userspace_interface_fd(interface); if (fd < 0) return fd; - *dev = NULL; + ret = write(fd, &byte, sizeof(byte)); if (ret < 0) goto out; - - pollfd.fd = fd; - if (poll(&pollfd, 1, -1) < 0) - goto out; - ret = -ECONNABORTED; - if (!(pollfd.revents & POLLIN)) - goto out; - - ret = ioctl(fd, FIONREAD, &len); - if (ret < 0) { - ret = -errno; - goto out; - } - ret = -EBADMSG; - if ((size_t)len < sizeof(struct wgdevice)) - goto out; - - ret = -ENOMEM; - *dev = malloc(len); - if (!*dev) - goto out; - - ret = read(fd, *dev, len); - if (ret < 0) - goto out; - if (ret != len) { + if (ret != sizeof(byte)) { ret = -EBADMSG; goto out; } - ret = -EBADMSG; - for_each_wgpeer(*dev, peer, i) { - if ((uint8_t *)peer + sizeof(struct wgpeer) > (uint8_t *)*dev + len) - goto out; - if ((uint8_t *)peer + sizeof(struct wgpeer) + sizeof(struct wgipmask) * peer->num_ipmasks > (uint8_t *)*dev + len) + ioctl(fd, FIONREAD, &len); + bytes_left = buffer_size = max(len, sizeof(struct wgdevice) + sizeof(struct wgpeer) + sizeof(struct wgipmask)); + p = buffer = malloc(buffer_size); + ret = -ENOMEM; + if (!buffer) goto out; - } + + len = ((struct wgdevice *)READ_BYTES(sizeof(struct wgdevice)))->num_peers; + for (i = 0; i < len; ++i) + READ_BYTES(sizeof(struct wgipmask) * ((struct wgpeer *)READ_BYTES(sizeof(struct wgpeer)))->num_ipmasks); ret = 0; out: - if (*dev && ret) { - free(*dev); - *dev = NULL; + if (buffer && ret) { + free(buffer); + buffer = NULL; } + *dev = (struct wgdevice *)buffer; close(fd); errno = -ret; return ret; + } +#undef READ_BYTES #ifdef __linux__ static int parse_linkinfo(const struct nlattr *attr, void *data)