Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions net/sunrpc/xprtsock.c
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp)

static int
xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
struct cmsghdr *cmsg, int ret)
unsigned int *msg_flags, struct cmsghdr *cmsg, int ret)
{
u8 content_type = tls_get_record_type(sock->sk, cmsg);
u8 level, description;
Expand All @@ -380,7 +380,7 @@ xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
* record, even though there might be more frames
* waiting to be decrypted.
*/
msg->msg_flags &= ~MSG_EOR;
*msg_flags &= ~MSG_EOR;
break;
case TLS_RECORD_TYPE_ALERT:
tls_alert_recv(sock->sk, msg, &level, &description);
Expand All @@ -395,19 +395,33 @@ xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
}

static int
xs_sock_recv_cmsg(struct socket *sock, struct msghdr *msg, int flags)
xs_sock_recv_cmsg(struct socket *sock, unsigned int *msg_flags, int flags)
{
union {
struct cmsghdr cmsg;
u8 buf[CMSG_SPACE(sizeof(u8))];
} u;
u8 alert[2];
struct kvec alert_kvec = {
.iov_base = alert,
.iov_len = sizeof(alert),
};
struct msghdr msg = {
.msg_flags = *msg_flags,
.msg_control = &u,
.msg_controllen = sizeof(u),
};
int ret;

msg->msg_control = &u;
msg->msg_controllen = sizeof(u);
ret = sock_recvmsg(sock, msg, flags);
if (msg->msg_controllen != sizeof(u))
ret = xs_sock_process_cmsg(sock, msg, &u.cmsg, ret);
iov_iter_kvec(&msg.msg_iter, ITER_DEST, &alert_kvec, 1,
alert_kvec.iov_len);
ret = sock_recvmsg(sock, &msg, flags);
if (ret > 0) {
if (tls_get_record_type(sock->sk, &u.cmsg) == TLS_RECORD_TYPE_ALERT)
iov_iter_revert(&msg.msg_iter, ret);
ret = xs_sock_process_cmsg(sock, &msg, msg_flags, &u.cmsg,
-EAGAIN);
}
return ret;
}

Expand All @@ -417,7 +431,13 @@ xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek)
ssize_t ret;
if (seek != 0)
iov_iter_advance(&msg->msg_iter, seek);
ret = xs_sock_recv_cmsg(sock, msg, flags);
ret = sock_recvmsg(sock, msg, flags);
/* Handle TLS inband control message lazily */
if (msg->msg_flags & MSG_CTRUNC) {
msg->msg_flags &= ~(MSG_CTRUNC | MSG_EOR);
if (ret == 0 || ret == -EIO)
ret = xs_sock_recv_cmsg(sock, &msg->msg_flags, flags);
}
return ret > 0 ? ret + seek : ret;
}

Expand All @@ -443,7 +463,7 @@ xs_read_discard(struct socket *sock, struct msghdr *msg, int flags,
size_t count)
{
iov_iter_discard(&msg->msg_iter, READ, count);
return xs_sock_recv_cmsg(sock, msg, flags);
return xs_sock_recvmsg(sock, msg, flags, 0);
}

#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
Expand Down
31 changes: 22 additions & 9 deletions net/tls/tls_sw.c
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,9 @@ int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
return tls_decrypt_sg(sk, NULL, sgout, &darg);
}

/* All records returned from a recvmsg() call must have the same type.
* 0 is not a valid content type. Use it as "no type reported, yet".
*/
static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
u8 *control)
{
Expand Down Expand Up @@ -1811,7 +1814,8 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
u8 *control,
size_t skip,
size_t len,
bool is_peek)
bool is_peek,
bool *more)
{
struct sk_buff *skb = skb_peek(&ctx->rx_list);
struct tls_msg *tlm;
Expand All @@ -1824,7 +1828,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,

err = tls_record_content_type(msg, tlm, control);
if (err <= 0)
goto out;
goto more;

if (skip < rxm->full_len)
break;
Expand All @@ -1842,12 +1846,12 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,

err = tls_record_content_type(msg, tlm, control);
if (err <= 0)
goto out;
goto more;

err = skb_copy_datagram_msg(skb, rxm->offset + skip,
msg, chunk);
if (err < 0)
goto out;
goto more;

len = len - chunk;
copied = copied + chunk;
Expand Down Expand Up @@ -1883,6 +1887,10 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,

out:
return copied ? : err;
more:
if (more)
*more = true;
goto out;
}

static bool
Expand Down Expand Up @@ -1987,6 +1995,7 @@ int tls_sw_recvmsg(struct sock *sk,
int target, err;
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
bool is_peek = flags & MSG_PEEK;
bool rx_more = false;
bool released = true;
bool bpf_strp_enabled;
bool zc_capable;
Expand All @@ -2008,12 +2017,14 @@ int tls_sw_recvmsg(struct sock *sk,
goto end;

/* Process pending decrypted records. It must be non-zero-copy */
err = process_rx_list(ctx, msg, &control, 0, len, is_peek);
err = process_rx_list(ctx, msg, &control, 0, len, is_peek, &rx_more);
if (err < 0)
goto end;

/* process_rx_list() will set @control if it processed any records */
copied = err;
if (len <= copied)
if (len <= copied || rx_more ||
(control && control != TLS_RECORD_TYPE_DATA))
goto end;

target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
Expand Down Expand Up @@ -2170,11 +2181,13 @@ int tls_sw_recvmsg(struct sock *sk,
/* Drain records from the rx_list & copy if required */
if (is_peek || is_kvec)
err = process_rx_list(ctx, msg, &control, copied,
decrypted, is_peek);
decrypted, is_peek, NULL);
else
err = process_rx_list(ctx, msg, &control, 0,
async_copy_bytes, is_peek);
decrypted += max(err, 0);
async_copy_bytes, is_peek, NULL);

/* we could have copied less than we wanted, and possibly nothing */
decrypted += max(err, 0) - async_copy_bytes;
}

copied += decrypted;
Expand Down
Loading