summaryrefslogtreecommitdiff
path: root/include/net/psp/functions.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/net/psp/functions.h')
-rw-r--r--include/net/psp/functions.h114
1 files changed, 105 insertions, 9 deletions
diff --git a/include/net/psp/functions.h b/include/net/psp/functions.h
index 1ccc5fc238b8..0d7141230f47 100644
--- a/include/net/psp/functions.h
+++ b/include/net/psp/functions.h
@@ -4,7 +4,9 @@
#define __NET_PSP_HELPERS_H
#include <linux/skbuff.h>
+#include <linux/rcupdate.h>
#include <net/sock.h>
+#include <net/tcp.h>
#include <net/psp/types.h>
struct inet_timewait_sock;
@@ -16,41 +18,130 @@ psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
void psp_dev_unregister(struct psp_dev *psd);
/* Kernel-facing API */
+void psp_assoc_put(struct psp_assoc *pas);
+
+static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
+{
+ return pas->drv_data;
+}
+
#if IS_ENABLED(CONFIG_INET_PSP)
-static inline void psp_sk_assoc_free(struct sock *sk) { }
-static inline void
-psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
-static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
-static inline void
-psp_reply_set_decrypted(struct sk_buff *skb) { }
+unsigned int psp_key_size(u32 version);
+void psp_sk_assoc_free(struct sock *sk);
+void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
+void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
+void psp_reply_set_decrypted(struct sk_buff *skb);
+
+static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
+{
+ return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
+}
static inline void
psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
{
+ struct psp_assoc *pas;
+
+ pas = psp_sk_assoc(sk);
+ if (pas && pas->tx.spi)
+ skb->decrypted = 1;
}
static inline unsigned long
__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
unsigned long diffs)
{
+ struct psp_skb_ext *a, *b;
+
+ a = skb_ext_find(one, SKB_EXT_PSP);
+ b = skb_ext_find(two, SKB_EXT_PSP);
+
+ diffs |= (!!a) ^ (!!b);
+ if (!diffs && unlikely(a))
+ diffs |= memcmp(a, b, sizeof(*a));
return diffs;
}
+static inline bool
+psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
+{
+ bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
+ u32 end_seq = TCP_SKB_CB(skb)->end_seq;
+ u32 seq = TCP_SKB_CB(skb)->seq;
+ bool pure_fin;
+
+ pure_fin = fin && end_seq - seq == 1;
+
+ return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
+}
+
+static inline bool
+psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
+{
+ return pse && pas->rx.spi == pse->spi &&
+ pas->generation == pse->generation &&
+ pas->version == pse->version &&
+ pas->dev_id == pse->dev_id;
+}
+
+static inline enum skb_drop_reason
+__psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
+{
+ struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
+
+ if (!pas)
+ return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
+
+ if (likely(psp_pse_matches_pas(pse, pas))) {
+ if (unlikely(!pas->peer_tx))
+ pas->peer_tx = 1;
+
+ return 0;
+ }
+
+ if (!pse) {
+ if (!pas->tx.spi ||
+ (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
+ return 0;
+ }
+
+ return SKB_DROP_REASON_PSP_INPUT;
+}
+
static inline enum skb_drop_reason
psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
{
- return 0;
+ return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
}
static inline enum skb_drop_reason
psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
{
- return 0;
+ return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
+}
+
+static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk)
+{
+ struct inet_timewait_sock *tw;
+ struct psp_assoc *pas;
+ int state;
+
+ state = 1 << READ_ONCE(sk->sk_state);
+ if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV)
+ return NULL;
+
+ tw = inet_twsk(sk);
+ pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) :
+ rcu_dereference(sk->psp_assoc);
+ return pas;
}
static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
{
- return NULL;
+ if (!skb->decrypted || !skb->sk)
+ return NULL;
+
+ return psp_sk_get_assoc_rcu(skb->sk);
}
#else
static inline void psp_sk_assoc_free(struct sock *sk) { }
@@ -60,6 +151,11 @@ static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
static inline void
psp_reply_set_decrypted(struct sk_buff *skb) { }
+static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
+{
+ return NULL;
+}
+
static inline void
psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }