diff options
Diffstat (limited to 'net/core/skmsg.c')
| -rw-r--r-- | net/core/skmsg.c | 195 | 
1 files changed, 151 insertions, 44 deletions
| diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 649583158983..654182ecf87b 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -433,10 +433,12 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)  static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,  			       u32 off, u32 len, bool ingress)  { -	if (ingress) -		return sk_psock_skb_ingress(psock, skb); -	else +	if (!ingress) { +		if (!sock_writeable(psock->sk)) +			return -EAGAIN;  		return skb_send_sock_locked(psock->sk, skb, off, len); +	} +	return sk_psock_skb_ingress(psock, skb);  }  static void sk_psock_backlog(struct work_struct *work) @@ -494,14 +496,34 @@ end:  struct sk_psock *sk_psock_init(struct sock *sk, int node)  { -	struct sk_psock *psock = kzalloc_node(sizeof(*psock), -					      GFP_ATOMIC | __GFP_NOWARN, -					      node); -	if (!psock) -		return NULL; +	struct sk_psock *psock; +	struct proto *prot; + +	write_lock_bh(&sk->sk_callback_lock); + +	if (inet_csk_has_ulp(sk)) { +		psock = ERR_PTR(-EINVAL); +		goto out; +	} + +	if (sk->sk_user_data) { +		psock = ERR_PTR(-EBUSY); +		goto out; +	} + +	psock = kzalloc_node(sizeof(*psock), GFP_ATOMIC | __GFP_NOWARN, node); +	if (!psock) { +		psock = ERR_PTR(-ENOMEM); +		goto out; +	} +	prot = READ_ONCE(sk->sk_prot);  	psock->sk = sk; -	psock->eval =  __SK_NONE; +	psock->eval = __SK_NONE; +	psock->sk_proto = prot; +	psock->saved_unhash = prot->unhash; +	psock->saved_close = prot->close; +	psock->saved_write_space = sk->sk_write_space;  	INIT_LIST_HEAD(&psock->link);  	spin_lock_init(&psock->link_lock); @@ -516,6 +538,8 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)  	rcu_assign_sk_user_data_nocopy(sk, psock);  	sock_hold(sk); +out: +	write_unlock_bh(&sk->sk_callback_lock);  	return psock;  }  EXPORT_SYMBOL_GPL(sk_psock_init); @@ -603,6 +627,8 @@ void sk_psock_drop(struct sock *sk, struct sk_psock *psock)  	rcu_assign_sk_user_data(sk, NULL);  	if (psock->progs.skb_parser)  		sk_psock_stop_strp(sk, psock); +	else if (psock->progs.skb_verdict) +		sk_psock_stop_verdict(sk, psock);  	write_unlock_bh(&sk->sk_callback_lock);  	sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); @@ -660,19 +686,8 @@ EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);  static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,  			    struct sk_buff *skb)  { -	int ret; - -	skb->sk = psock->sk;  	bpf_compute_data_end_sk_skb(skb); -	ret = bpf_prog_run_pin_on_cpu(prog, skb); -	/* strparser clones the skb before handing it to a upper layer, -	 * meaning skb_orphan has been called. We NULL sk on the way out -	 * to ensure we don't trigger a BUG_ON() in skb/sk operations -	 * later and because we are not charging the memory of this skb -	 * to any socket yet. -	 */ -	skb->sk = NULL; -	return ret; +	return bpf_prog_run_pin_on_cpu(prog, skb);  }  static struct sk_psock *sk_psock_from_strp(struct strparser *strp) @@ -687,38 +702,35 @@ static void sk_psock_skb_redirect(struct sk_buff *skb)  {  	struct sk_psock *psock_other;  	struct sock *sk_other; -	bool ingress;  	sk_other = tcp_skb_bpf_redirect_fetch(skb); +	/* This error is a buggy BPF program, it returned a redirect +	 * return code, but then didn't set a redirect interface. +	 */  	if (unlikely(!sk_other)) {  		kfree_skb(skb);  		return;  	}  	psock_other = sk_psock(sk_other); +	/* This error indicates the socket is being torn down or had another +	 * error that caused the pipe to break. We can't send a packet on +	 * a socket that is in this state so we drop the skb. +	 */  	if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||  	    !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) {  		kfree_skb(skb);  		return;  	} -	ingress = tcp_skb_bpf_ingress(skb); -	if ((!ingress && sock_writeable(sk_other)) || -	    (ingress && -	     atomic_read(&sk_other->sk_rmem_alloc) <= -	     sk_other->sk_rcvbuf)) { -		if (!ingress) -			skb_set_owner_w(skb, sk_other); -		skb_queue_tail(&psock_other->ingress_skb, skb); -		schedule_work(&psock_other->work); -	} else { -		kfree_skb(skb); -	} +	skb_queue_tail(&psock_other->ingress_skb, skb); +	schedule_work(&psock_other->work);  } -static void sk_psock_tls_verdict_apply(struct sk_buff *skb, int verdict) +static void sk_psock_tls_verdict_apply(struct sk_buff *skb, struct sock *sk, int verdict)  {  	switch (verdict) {  	case __SK_REDIRECT: +		skb_set_owner_r(skb, sk);  		sk_psock_skb_redirect(skb);  		break;  	case __SK_PASS: @@ -736,11 +748,17 @@ int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb)  	rcu_read_lock();  	prog = READ_ONCE(psock->progs.skb_verdict);  	if (likely(prog)) { +		/* We skip full set_owner_r here because if we do a SK_PASS +		 * or SK_DROP we can skip skb memory accounting and use the +		 * TLS context. +		 */ +		skb->sk = psock->sk;  		tcp_skb_bpf_redirect_clear(skb);  		ret = sk_psock_bpf_run(psock, prog, skb);  		ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); +		skb->sk = NULL;  	} -	sk_psock_tls_verdict_apply(skb, ret); +	sk_psock_tls_verdict_apply(skb, psock->sk, ret);  	rcu_read_unlock();  	return ret;  } @@ -749,7 +767,9 @@ EXPORT_SYMBOL_GPL(sk_psock_tls_strp_read);  static void sk_psock_verdict_apply(struct sk_psock *psock,  				   struct sk_buff *skb, int verdict)  { +	struct tcp_skb_cb *tcp;  	struct sock *sk_other; +	int err = -EIO;  	switch (verdict) {  	case __SK_PASS: @@ -758,16 +778,24 @@ static void sk_psock_verdict_apply(struct sk_psock *psock,  		    !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {  			goto out_free;  		} -		if (atomic_read(&sk_other->sk_rmem_alloc) <= -		    sk_other->sk_rcvbuf) { -			struct tcp_skb_cb *tcp = TCP_SKB_CB(skb); -			tcp->bpf.flags |= BPF_F_INGRESS; +		tcp = TCP_SKB_CB(skb); +		tcp->bpf.flags |= BPF_F_INGRESS; + +		/* If the queue is empty then we can submit directly +		 * into the msg queue. If its not empty we have to +		 * queue work otherwise we may get OOO data. Otherwise, +		 * if sk_psock_skb_ingress errors will be handled by +		 * retrying later from workqueue. +		 */ +		if (skb_queue_empty(&psock->ingress_skb)) { +			err = sk_psock_skb_ingress(psock, skb); +		} +		if (err < 0) {  			skb_queue_tail(&psock->ingress_skb, skb);  			schedule_work(&psock->work); -			break;  		} -		goto out_free; +		break;  	case __SK_REDIRECT:  		sk_psock_skb_redirect(skb);  		break; @@ -792,9 +820,9 @@ static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)  		kfree_skb(skb);  		goto out;  	} +	skb_set_owner_r(skb, sk);  	prog = READ_ONCE(psock->progs.skb_verdict);  	if (likely(prog)) { -		skb_orphan(skb);  		tcp_skb_bpf_redirect_clear(skb);  		ret = sk_psock_bpf_run(psock, prog, skb);  		ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); @@ -817,8 +845,11 @@ static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)  	rcu_read_lock();  	prog = READ_ONCE(psock->progs.skb_parser); -	if (likely(prog)) +	if (likely(prog)) { +		skb->sk = psock->sk;  		ret = sk_psock_bpf_run(psock, prog, skb); +		skb->sk = NULL; +	}  	rcu_read_unlock();  	return ret;  } @@ -842,6 +873,57 @@ static void sk_psock_strp_data_ready(struct sock *sk)  	rcu_read_unlock();  } +static int sk_psock_verdict_recv(read_descriptor_t *desc, struct sk_buff *skb, +				 unsigned int offset, size_t orig_len) +{ +	struct sock *sk = (struct sock *)desc->arg.data; +	struct sk_psock *psock; +	struct bpf_prog *prog; +	int ret = __SK_DROP; +	int len = skb->len; + +	/* clone here so sk_eat_skb() in tcp_read_sock does not drop our data */ +	skb = skb_clone(skb, GFP_ATOMIC); +	if (!skb) { +		desc->error = -ENOMEM; +		return 0; +	} + +	rcu_read_lock(); +	psock = sk_psock(sk); +	if (unlikely(!psock)) { +		len = 0; +		kfree_skb(skb); +		goto out; +	} +	skb_set_owner_r(skb, sk); +	prog = READ_ONCE(psock->progs.skb_verdict); +	if (likely(prog)) { +		tcp_skb_bpf_redirect_clear(skb); +		ret = sk_psock_bpf_run(psock, prog, skb); +		ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); +	} +	sk_psock_verdict_apply(psock, skb, ret); +out: +	rcu_read_unlock(); +	return len; +} + +static void sk_psock_verdict_data_ready(struct sock *sk) +{ +	struct socket *sock = sk->sk_socket; +	read_descriptor_t desc; + +	if (unlikely(!sock || !sock->ops || !sock->ops->read_sock)) +		return; + +	desc.arg.data = sk; +	desc.error = 0; +	desc.count = 1; + +	sock->ops->read_sock(sk, &desc, sk_psock_verdict_recv); +} +  static void sk_psock_write_space(struct sock *sk)  {  	struct sk_psock *psock; @@ -871,6 +953,19 @@ int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)  	return strp_init(&psock->parser.strp, sk, &cb);  } +void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock) +{ +	struct sk_psock_parser *parser = &psock->parser; + +	if (parser->enabled) +		return; + +	parser->saved_data_ready = sk->sk_data_ready; +	sk->sk_data_ready = sk_psock_verdict_data_ready; +	sk->sk_write_space = sk_psock_write_space; +	parser->enabled = true; +} +  void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)  {  	struct sk_psock_parser *parser = &psock->parser; @@ -896,3 +991,15 @@ void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)  	strp_stop(&parser->strp);  	parser->enabled = false;  } + +void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock) +{ +	struct sk_psock_parser *parser = &psock->parser; + +	if (!parser->enabled) +		return; + +	sk->sk_data_ready = parser->saved_data_ready; +	parser->saved_data_ready = NULL; +	parser->enabled = false; +} | 
