diff --git a/ex2/Makefile b/ex2/Makefile index 7303076..721b049 100644 --- a/ex2/Makefile +++ b/ex2/Makefile @@ -1,4 +1,5 @@ CFLAGS += -Wall -Werror +LDLIBS += -ltls -lssl -lcrypto all: echo client diff --git a/ex2/client.c b/ex2/client.c index 4bd750d..140102e 100644 --- a/ex2/client.c +++ b/ex2/client.c @@ -34,6 +34,7 @@ #include #include #include +#include #include @@ -56,15 +57,17 @@ struct server { int state; unsigned char *readptr, *writeptr, *nextptr; unsigned char buf[BUFLEN]; + struct tls *ctx; }; static struct server server; static void -server_init(struct server *server) +server_init(struct server *server, struct tls *ctx) { server->readptr = server->writeptr = server->nextptr = server->buf; server->state = STATE_NONE; + server->ctx = ctx; } static ssize_t @@ -134,6 +137,13 @@ server_put(struct server *server, const unsigned char *inbuf, size_t inlen) static void closeconn (struct pollfd *pfd) { + int i; + + do { + i = tls_close(server.ctx); + } while (i == TLS_WANT_POLLIN || i == TLS_WANT_POLLOUT); + tls_free(server.ctx); + close(pfd->fd); pfd->fd = -1; pfd->revents = 0; @@ -166,8 +176,16 @@ handle_server(struct pollfd *pfd, struct server *server) if (server->state == STATE_READING) { ssize_t w = 0; ssize_t written = 0; - len = read(pfd->fd, buf, sizeof(buf)); - if (len > 0) { + len = tls_read(server->ctx, buf, sizeof(buf)); + if (len == TLS_WANT_POLLIN) + pfd->events = POLLIN | POLLHUP; + else if (len == TLS_WANT_POLLOUT) + pfd->events = POLLOUT | POLLHUP; + else if (len < 0) + err(1, "tls_write: %s", tls_error(server->ctx)); + else if (len == 0) + closeconn(pfd); + else { do { w = write(STDOUT_FILENO, buf, len); if (w == -1) { @@ -182,26 +200,21 @@ handle_server(struct pollfd *pfd, struct server *server) pfd->events = POLLHUP; } } - else if (len == 0) - closeconn(pfd); - else - pfd->events = POLLIN | POLLHUP; } else if (server->state == STATE_WRITING) { - ssize_t w = 0; - ssize_t written = 0; - do { - len = server_get(server, buf, sizeof(buf)); - w = write(pfd->fd, buf, len); - if (w == -1) { - if (errno != EINTR) - closeconn(pfd); - } - else { - written += w; - server_consume(server, w); - } - } while (written < len); - if (pfd->fd > 0) { + ssize_t ret = 0; + len = server_get(server, buf, sizeof(buf)); + if (len) { + ret = tls_write(server->ctx, buf, len); + if (ret == TLS_WANT_POLLIN) + pfd->events = POLLIN | POLLHUP; + else if (ret == TLS_WANT_POLLOUT) + pfd->events = POLLOUT | POLLHUP; + else if (ret < 0) + err(1, "tls_write: %s", tls_error(server->ctx)); + else + server_consume(server, ret); + } + if (ret == len) { server->state = STATE_READING; pfd->events = POLLIN | POLLHUP; } @@ -210,7 +223,8 @@ handle_server(struct pollfd *pfd, struct server *server) } int main(int argc, char **argv) { - + struct tls_config *tls_cfg = NULL; + struct tls *tls_ctx = NULL; struct addrinfo hints, *res; int serverfd, error; struct pollfd pollfd; @@ -227,14 +241,31 @@ int main(int argc, char **argv) { usage(); } + /* now set up TLS */ + if (tls_init() == -1) + errx(1, "unable to initialize TLS"); + if ((tls_cfg = tls_config_new()) == NULL) + errx(1, "unable to allocate TLS config"); + if (tls_config_set_ca_file(tls_cfg, "../CA/root.pem") == -1) + errx(1, "unable to set root CA file"); + if ((serverfd = socket(AF_INET, SOCK_STREAM, 0)) == -1) err(1, "socket failed"); if (connect(serverfd, res->ai_addr, res->ai_addrlen) == -1) err(1, "connect failed"); + if ((tls_ctx = tls_client()) == NULL) + errx(1, "tls client creation failed"); + if (tls_configure(tls_ctx, tls_cfg) == -1) + errx(1, "tls configuration failed (%s)", + tls_error(tls_ctx)); + if (tls_connect_socket(tls_ctx, serverfd, "localhost") == -1) + errx(1, "tls connection failed (%s)", + tls_error(tls_ctx)); + newconn(&pollfd, serverfd, 0); - server_init(&server); + server_init(&server, tls_ctx); while(1) { if (server.state == STATE_NONE) { diff --git a/ex2/echo.c b/ex2/echo.c index 7a1a0a2..11d7256 100644 --- a/ex2/echo.c +++ b/ex2/echo.c @@ -34,6 +34,7 @@ #include #include #include +#include #include #define MAX_CONNECTIONS 256 @@ -55,6 +56,7 @@ struct client { int state; unsigned char *readptr, *writeptr, *nextptr; unsigned char buf[BUFLEN]; + struct tls *ctx; }; static struct client clients[MAX_CONNECTIONS]; @@ -62,10 +64,11 @@ static struct pollfd pollfds[MAX_CONNECTIONS]; static int throttle = 0; static void -client_init(struct client *client) +client_init(struct client *client, struct tls *ctx) { client->readptr = client->writeptr = client->nextptr = client->buf; client->state = STATE_READING; + client->ctx = ctx; } static ssize_t @@ -133,8 +136,15 @@ client_put(struct client *client, const unsigned char *inbuf, size_t inlen) } static void -closeconn (struct pollfd *pfd) +closeconn (struct pollfd *pfd, struct client *client) { + int i; + + do { + i = tls_close(client->ctx); + } while (i == TLS_WANT_POLLIN || i == TLS_WANT_POLLOUT); + tls_free(client->ctx); + close(pfd->fd); pfd->fd = -1; pfd->revents = 0; @@ -159,43 +169,46 @@ handle_client(struct pollfd *pfd, struct client *client) { if ((pfd->revents & (POLLERR | POLLNVAL))) errx(1, "bad fd %d", pfd->fd); - if (pfd->revents & POLLHUP) - closeconn(pfd); + if (pfd->revents & POLLHUP) { + closeconn(pfd, client); + } else if (pfd->revents & pfd->events) { char buf[BUFLEN]; ssize_t len = 0; if (client->state == STATE_READING) { - len = read(pfd->fd, buf, sizeof(buf)); - if (len > 0) { - if (client_put(client, buf, len) - != len) { + len = tls_read(client->ctx, buf, sizeof(buf)); + if (len == TLS_WANT_POLLIN) + pfd->events = POLLIN | POLLHUP; + else if (len == TLS_WANT_POLLOUT) + pfd->events = POLLOUT | POLLHUP; + else if (len < 0) + warn("tls_read: %s", tls_error(client->ctx)); + else if (len == 0) + closeconn(pfd, client); + else { + if (client_put(client, buf, len) != len) { warnx("client buffer failed"); - closeconn(pfd); + closeconn(pfd, client); } else { client->state=STATE_WRITING; pfd->events = POLLOUT | POLLHUP; } } - else if (len == 0) - closeconn(pfd); - else - pfd->events = POLLIN | POLLHUP; } else if (client->state == STATE_WRITING) { - ssize_t w = 0; - ssize_t written = 0; - do { - len = client_get(client, buf, sizeof(buf)); - w = write(pfd->fd, buf, len); - if (w == -1) { - if (errno != EINTR) - closeconn(pfd); - } - else { - written += w; - client_consume(client, w); - } - } while (written < len); - if (pfd->fd > 0) { + ssize_t ret = 0; + len = client_get(client, buf, sizeof(buf)); + if (len) { + ret = tls_write(client->ctx, buf, len); + if (ret == TLS_WANT_POLLIN) + pfd->events = POLLIN | POLLHUP; + else if (ret == TLS_WANT_POLLOUT) + pfd->events = POLLOUT | POLLHUP; + else if (ret < 0) + warn("tls_write: %s", tls_error(client->ctx)); + else + client_consume(client, ret); + } + if (ret == len) { client->state = STATE_READING; pfd->events = POLLIN | POLLHUP; } @@ -204,13 +217,31 @@ handle_client(struct pollfd *pfd, struct client *client) } int main(int argc, char **argv) { - + struct tls_config *tls_cfg = NULL; + struct tls *tls_ctx = NULL; + struct tls *tls_cctx = NULL; struct addrinfo hints, *res; int i, listenfd, error; + if (argc != 3) usage(); + /* now set up TLS */ + + if ((tls_cfg = tls_config_new()) == NULL) + errx(1, "unable to allocate TLS config"); + if (tls_config_set_ca_file(tls_cfg, "../CA/root.pem") == -1) + errx(1, "unable to set root CA filet"); + if (tls_config_set_cert_file(tls_cfg, "../CA/server.crt") == -1) + errx(1, "unable to set TLS certificate file"); + if (tls_config_set_key_file(tls_cfg, "../CA/server.key") == -1) + errx(1, "unable to set TLS key file"); + if ((tls_ctx = tls_server()) == NULL) + errx(1, "tls server creation failed"); + if (tls_configure(tls_ctx, tls_cfg) == -1) + errx(1, "tls configuration failed (%s)", tls_error(tls_ctx)); + bzero(&hints, sizeof(hints)); hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; @@ -258,9 +289,16 @@ int main(int argc, char **argv) { throttle = 1; for (i = 1; fd >= 0 && i < MAX_CONNECTIONS; i++) { if (pollfds[i].fd == -1) { - newconn(&pollfds[i], fd); - client_init(&clients[i]); throttle = 0; + if (tls_accept_socket(tls_ctx, + &tls_cctx, fd) == -1) { + warnx("tls accept failed (%s)", + tls_error(tls_ctx)); + close(fd); + break; + } + newconn(&pollfds[i], fd); + client_init(&clients[i], tls_cctx); break; } }