wg: use stream instead of seqpacket

To support OS X and Windows, we have to. Ugh.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2016-07-22 21:07:12 +02:00
parent ec890556e4
commit d6b3bc6948
1 changed files with 31 additions and 18 deletions

View File

@ -20,6 +20,7 @@
#include <dirent.h> #include <dirent.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/types.h> #include <sys/types.h>
#include <sys/poll.h>
#include <sys/ioctl.h> #include <sys/ioctl.h>
#include <sys/types.h> #include <sys/types.h>
#include <sys/stat.h> #include <sys/stat.h>
@ -99,7 +100,7 @@ static int userspace_interface_fd(const char *interface)
if (!S_ISSOCK(sbuf.st_mode)) if (!S_ISSOCK(sbuf.st_mode))
goto out; goto out;
ret = fd = socket(AF_UNIX, SOCK_SEQPACKET, 0); ret = fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (ret < 0) if (ret < 0)
goto out; goto out;
@ -172,10 +173,10 @@ static int userspace_set_device(struct wgdevice *dev)
ret = -EBADMSG; ret = -EBADMSG;
if (!len) if (!len)
goto out; goto out;
ret = send(fd, dev, len, 0); ret = write(fd, dev, len);
if (ret < 0) if (ret < 0)
goto out; goto out;
ret = recv(fd, &ret_code, sizeof(ret_code), 0); ret = read(fd, &ret_code, sizeof(ret_code));
if (ret < 0) if (ret < 0)
goto out; goto out;
ret = ret_code; ret = ret_code;
@ -187,50 +188,62 @@ out:
static int userspace_get_device(struct wgdevice **dev, const char *interface) static int userspace_get_device(struct wgdevice **dev, const char *interface)
{ {
#ifdef __linux__ struct pollfd pollfd = { .events = POLLIN };
ssize_t len;
#else
int len; int len;
#endif char byte = 0;
size_t i;
struct wgpeer *peer;
ssize_t ret; ssize_t ret;
int fd = userspace_interface_fd(interface); int fd = userspace_interface_fd(interface);
if (fd < 0) if (fd < 0)
return fd; return fd;
*dev = NULL; *dev = NULL;
ret = send(fd, NULL, 0, 0); ret = write(fd, &byte, sizeof(byte));
if (ret < 0) if (ret < 0)
goto out; goto out;
#ifdef __linux__ pollfd.fd = fd;
ret = len = recv(fd, NULL, 0, MSG_PEEK | MSG_TRUNC); if (poll(&pollfd, 1, -1) < 0)
if (len < 0)
goto out; goto out;
#else ret = -ECONNABORTED;
ret = recv(fd, &ret, 1, MSG_PEEK); if (!(pollfd.revents & POLLIN))
if (ret < 0)
goto out; goto out;
ret = ioctl(fd, FIONREAD, &len); ret = ioctl(fd, FIONREAD, &len);
if (ret < 0) { if (ret < 0) {
ret = -errno; ret = -errno;
goto out; goto out;
} }
#endif
ret = -EBADMSG; ret = -EBADMSG;
if ((size_t)len < sizeof(struct wgdevice)) if ((size_t)len < sizeof(struct wgdevice))
goto out; goto out;
ret = -ENOMEM; ret = -ENOMEM;
*dev = calloc(len, 1); *dev = malloc(len);
if (!*dev) if (!*dev)
goto out; goto out;
ret = recv(fd, *dev, len, 0); ret = read(fd, *dev, len);
if (ret < 0) if (ret < 0)
goto out; goto out;
if (ret != len) {
ret = -EBADMSG;
goto out;
}
ret = -EBADMSG;
for_each_wgpeer(*dev, peer, i) {
if ((uint8_t *)peer + sizeof(struct wgpeer) > (uint8_t *)*dev + len)
goto out;
if ((uint8_t *)peer + sizeof(struct wgpeer) + sizeof(struct wgipmask) * peer->num_ipmasks > (uint8_t *)*dev + len)
goto out;
}
ret = 0; ret = 0;
out: out:
if (*dev && ret) if (*dev && ret) {
free(*dev); free(*dev);
*dev = NULL;
}
close(fd); close(fd);
errno = -ret; errno = -ret;
return ret; return ret;