wg: ipc: read from socket incrementally

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2017-01-10 04:50:42 +01:00
parent e92e0dca14
commit 12904a1095
2 changed files with 50 additions and 44 deletions

View File

@ -33,7 +33,7 @@ endif
CFLAGS ?= -O3
CFLAGS += -std=gnu11
CFLAGS += -pedantic -Wall -Wextra
CFLAGS += -Wall -Wextra
CFLAGS += -MMD -MP
CFLAGS += -DRUNSTATEDIR="\"$(RUNSTATEDIR)\""
LDLIBS += -lresolv

View File

@ -18,7 +18,6 @@
#include <unistd.h>
#include <time.h>
#include <dirent.h>
#include <poll.h>
#include <signal.h>
#include <sys/socket.h>
#include <sys/types.h>
@ -41,7 +40,7 @@ struct inflatable_buffer {
size_t pos;
};
#define max(a, b) (a > b ? a : b)
#define max(a, b) ((a) > (b) ? (a) : (b))
static int add_next_to_inflatable_buffer(struct inflatable_buffer *buffer)
{
@ -190,68 +189,75 @@ out:
return (int)ret;
}
#define READ_BYTES(bytes) ({ \
void *__p; \
size_t __bytes = (bytes); \
if (bytes_left < __bytes) { \
offset = p - buffer; \
bytes_left += buffer_size; \
buffer_size *= 2; \
ret = -ENOMEM; \
p = realloc(buffer, buffer_size); \
if (!p) \
goto out; \
buffer = p; \
p += offset; \
} \
bytes_left -= __bytes; \
ret = read(fd, p, __bytes); \
if (ret < 0) \
goto out; \
if ((size_t)ret != __bytes) { \
ret = -EBADMSG; \
goto out; \
} \
__p = p; \
p += __bytes; \
__p; \
})
static int userspace_get_device(struct wgdevice **dev, const char *interface)
{
struct pollfd pollfd = { .events = POLLIN };
int len;
char byte = 0;
size_t i;
struct wgpeer *peer;
unsigned int len = 0, i;
size_t buffer_size, bytes_left;
ssize_t ret;
ptrdiff_t offset;
uint8_t *buffer = NULL, *p, byte = 0;
int fd = userspace_interface_fd(interface);
if (fd < 0)
return fd;
*dev = NULL;
ret = write(fd, &byte, sizeof(byte));
if (ret < 0)
goto out;
pollfd.fd = fd;
if (poll(&pollfd, 1, -1) < 0)
goto out;
ret = -ECONNABORTED;
if (!(pollfd.revents & POLLIN))
goto out;
ret = ioctl(fd, FIONREAD, &len);
if (ret < 0) {
ret = -errno;
if (ret != sizeof(byte)) {
ret = -EBADMSG;
goto out;
}
ret = -EBADMSG;
if ((size_t)len < sizeof(struct wgdevice))
goto out;
ioctl(fd, FIONREAD, &len);
bytes_left = buffer_size = max(len, sizeof(struct wgdevice) + sizeof(struct wgpeer) + sizeof(struct wgipmask));
p = buffer = malloc(buffer_size);
ret = -ENOMEM;
*dev = malloc(len);
if (!*dev)
if (!buffer)
goto out;
ret = read(fd, *dev, len);
if (ret < 0)
goto out;
if (ret != len) {
ret = -EBADMSG;
goto out;
}
ret = -EBADMSG;
for_each_wgpeer(*dev, peer, i) {
if ((uint8_t *)peer + sizeof(struct wgpeer) > (uint8_t *)*dev + len)
goto out;
if ((uint8_t *)peer + sizeof(struct wgpeer) + sizeof(struct wgipmask) * peer->num_ipmasks > (uint8_t *)*dev + len)
goto out;
}
len = ((struct wgdevice *)READ_BYTES(sizeof(struct wgdevice)))->num_peers;
for (i = 0; i < len; ++i)
READ_BYTES(sizeof(struct wgipmask) * ((struct wgpeer *)READ_BYTES(sizeof(struct wgpeer)))->num_ipmasks);
ret = 0;
out:
if (*dev && ret) {
free(*dev);
*dev = NULL;
if (buffer && ret) {
free(buffer);
buffer = NULL;
}
*dev = (struct wgdevice *)buffer;
close(fd);
errno = -ret;
return ret;
}
#undef READ_BYTES
#ifdef __linux__
static int parse_linkinfo(const struct nlattr *attr, void *data)