another attempt, with a more comprehensive solution
--
diff --git a/net/kcm/kcmsock.c b/net/kcm/kcmsock.c
index 84b7d5c6fec8..3b78abfb300c 100644
--- a/net/kcm/kcmsock.c
+++ b/net/kcm/kcmsock.c
@@ -223,7 +223,7 @@ static void requeue_rx_msgs(struct kcm_mux *mux, struct sk_buff_head *head)
struct sk_buff *skb;
struct kcm_sock *kcm;
- while ((skb = __skb_dequeue(head))) {
+ while ((skb = skb_dequeue(head))) {
/* Reset destructor to avoid calling kcm_rcv_ready */
skb->destructor = sock_rfree;
skb_orphan(skb);
@@ -1080,12 +1080,28 @@ static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
return err;
}
-static struct sk_buff *kcm_wait_data(struct sock *sk, int flags,
+static struct sk_buff *kcm_dequeue_or_peek(struct sock *sk, bool peek)
+{
+ struct sk_buff *skb;
+ unsigned long flags;
+
+ if (!peek)
+ return skb_dequeue(&sk->sk_receive_queue);
+
+ spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
+ skb = skb_peek(&sk->sk_receive_queue);
+ if (skb)
+ skb_get(skb);
+ spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
+ return skb;
+}
+
+static struct sk_buff *kcm_wait_data(struct sock *sk, int flags, bool peek,
long timeo, int *err)
{
struct sk_buff *skb;
- while (!(skb = skb_peek(&sk->sk_receive_queue))) {
+ while (!(skb = kcm_dequeue_or_peek(sk, peek))) {
if (sk->sk_err) {
*err = sock_error(sk);
return NULL;
@@ -1116,6 +1132,7 @@ static int kcm_recvmsg(struct socket *sock, struct msghdr *msg,
{
struct sock *sk = sock->sk;
struct kcm_sock *kcm = kcm_sk(sk);
+ bool peek = flags & MSG_PEEK;
int err = 0;
long timeo;
struct strp_msg *stm;
@@ -1126,7 +1143,7 @@ static int kcm_recvmsg(struct socket *sock, struct msghdr *msg,
lock_sock(sk);
- skb = kcm_wait_data(sk, flags, timeo, &err);
+ skb = kcm_wait_data(sk, flags, peek, timeo, &err);
if (!skb)
goto out;
@@ -1138,11 +1155,13 @@ static int kcm_recvmsg(struct socket *sock, struct msghdr *msg,
len = stm->full_len;
err = skb_copy_datagram_msg(skb, stm->offset, msg, len);
- if (err < 0)
+ if (err < 0) {
+ kfree_skb(skb);
goto out;
+ }
copied = len;
- if (likely(!(flags & MSG_PEEK))) {
+ if (likely(!peek)) {
KCM_STATS_ADD(kcm->stats.rx_bytes, copied);
if (copied < stm->full_len) {
if (sock->type == SOCK_DGRAM) {
@@ -1157,10 +1176,9 @@ static int kcm_recvmsg(struct socket *sock, struct msghdr *msg,
/* Finished with message */
msg->msg_flags |= MSG_EOR;
KCM_STATS_INCR(kcm->stats.rx_msgs);
- skb_unlink(skb, &sk->sk_receive_queue);
- kfree_skb(skb);
}
}
+ consume_skb(skb);
out:
release_sock(sk);
@@ -1186,7 +1204,7 @@ static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos,
lock_sock(sk);
- skb = kcm_wait_data(sk, flags, timeo, &err);
+ skb = kcm_wait_data(sk, flags, true, timeo, &err);
if (!skb)
goto err_out;
@@ -1200,6 +1218,7 @@ static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos,
copied = skb_splice_bits(skb, sk, stm->offset, pipe, len, flags);
if (copied < 0) {
err = copied;
+ kfree_skb(skb);
goto err_out;
}
@@ -1214,6 +1233,7 @@ static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos,
* finish reading the message.
*/
+ consume_skb(skb);
release_sock(sk);
return copied;