diff options
-rw-r--r-- | include/linux/netdevice.h | 4 | ||||
-rw-r--r-- | net/core/dev.c | 38 | ||||
-rw-r--r-- | net/ipv6/addrconf.c | 12 | ||||
-rw-r--r-- | net/ipv6/anycast.c | 100 | ||||
-rw-r--r-- | net/ipv6/ipv6_sockglue.c | 28 | ||||
-rw-r--r-- | net/ipv6/mcast.c | 332 | ||||
-rw-r--r-- | net/ipv6/ndisc.c | 13 |
7 files changed, 260 insertions, 267 deletions
diff --git a/include/linux/netdevice.h b/include/linux/netdevice.h index 5847c20994d3..a80d21a14612 100644 --- a/include/linux/netdevice.h +++ b/include/linux/netdevice.h @@ -3332,8 +3332,8 @@ int dev_get_iflink(const struct net_device *dev); int dev_fill_metadata_dst(struct net_device *dev, struct sk_buff *skb); int dev_fill_forward_path(const struct net_device *dev, const u8 *daddr, struct net_device_path_stack *stack); -struct net_device *__dev_get_by_flags(struct net *net, unsigned short flags, - unsigned short mask); +struct net_device *dev_get_by_flags_rcu(struct net *net, unsigned short flags, + unsigned short mask); struct net_device *dev_get_by_name(struct net *net, const char *name); struct net_device *dev_get_by_name_rcu(struct net *net, const char *name); struct net_device *__dev_get_by_name(struct net *net, const char *name); diff --git a/net/core/dev.c b/net/core/dev.c index fe677ccec5b0..e365b099484e 100644 --- a/net/core/dev.c +++ b/net/core/dev.c @@ -1267,33 +1267,31 @@ struct net_device *dev_getfirstbyhwtype(struct net *net, unsigned short type) EXPORT_SYMBOL(dev_getfirstbyhwtype); /** - * __dev_get_by_flags - find any device with given flags - * @net: the applicable net namespace - * @if_flags: IFF_* values - * @mask: bitmask of bits in if_flags to check + * dev_get_by_flags_rcu - find any device with given flags + * @net: the applicable net namespace + * @if_flags: IFF_* values + * @mask: bitmask of bits in if_flags to check * - * Search for any interface with the given flags. Returns NULL if a device - * is not found or a pointer to the device. Must be called inside - * rtnl_lock(), and result refcount is unchanged. + * Search for any interface with the given flags. + * + * Context: rcu_read_lock() must be held. + * Returns: NULL if a device is not found or a pointer to the device. */ - -struct net_device *__dev_get_by_flags(struct net *net, unsigned short if_flags, - unsigned short mask) +struct net_device *dev_get_by_flags_rcu(struct net *net, unsigned short if_flags, + unsigned short mask) { - struct net_device *dev, *ret; - - ASSERT_RTNL(); + struct net_device *dev; - ret = NULL; - for_each_netdev(net, dev) { - if (((dev->flags ^ if_flags) & mask) == 0) { - ret = dev; - break; + for_each_netdev_rcu(net, dev) { + if (((READ_ONCE(dev->flags) ^ if_flags) & mask) == 0) { + dev_hold(dev); + return dev; } } - return ret; + + return NULL; } -EXPORT_SYMBOL(__dev_get_by_flags); +EXPORT_IPV6_MOD(dev_get_by_flags_rcu); /** * dev_valid_name - check if name is okay for network device diff --git a/net/ipv6/addrconf.c b/net/ipv6/addrconf.c index 9c297974d3a6..3c200157634e 100644 --- a/net/ipv6/addrconf.c +++ b/net/ipv6/addrconf.c @@ -2229,32 +2229,29 @@ errdad: in6_ifa_put(ifp); } -/* Join to solicited addr multicast group. - * caller must hold RTNL */ +/* Join to solicited addr multicast group. */ void addrconf_join_solict(struct net_device *dev, const struct in6_addr *addr) { struct in6_addr maddr; - if (dev->flags&(IFF_LOOPBACK|IFF_NOARP)) + if (READ_ONCE(dev->flags) & (IFF_LOOPBACK | IFF_NOARP)) return; addrconf_addr_solict_mult(addr, &maddr); ipv6_dev_mc_inc(dev, &maddr); } -/* caller must hold RTNL */ void addrconf_leave_solict(struct inet6_dev *idev, const struct in6_addr *addr) { struct in6_addr maddr; - if (idev->dev->flags&(IFF_LOOPBACK|IFF_NOARP)) + if (READ_ONCE(idev->dev->flags) & (IFF_LOOPBACK | IFF_NOARP)) return; addrconf_addr_solict_mult(addr, &maddr); __ipv6_dev_mc_dec(idev, &maddr); } -/* caller must hold RTNL */ static void addrconf_join_anycast(struct inet6_ifaddr *ifp) { struct in6_addr addr; @@ -2267,7 +2264,6 @@ static void addrconf_join_anycast(struct inet6_ifaddr *ifp) __ipv6_dev_ac_inc(ifp->idev, &addr); } -/* caller must hold RTNL */ static void addrconf_leave_anycast(struct inet6_ifaddr *ifp) { struct in6_addr addr; @@ -3865,7 +3861,7 @@ static int addrconf_ifdown(struct net_device *dev, bool unregister) * Do not dev_put! */ if (unregister) { - idev->dead = 1; + WRITE_ONCE(idev->dead, 1); /* protected by rtnl_lock */ RCU_INIT_POINTER(dev->ip6_ptr, NULL); diff --git a/net/ipv6/anycast.c b/net/ipv6/anycast.c index 21e01695b48c..53cf68e0242b 100644 --- a/net/ipv6/anycast.c +++ b/net/ipv6/anycast.c @@ -47,6 +47,9 @@ static struct hlist_head inet6_acaddr_lst[IN6_ADDR_HSIZE]; static DEFINE_SPINLOCK(acaddr_hash_lock); +#define ac_dereference(a, idev) \ + rcu_dereference_protected(a, lockdep_is_held(&(idev)->lock)) + static int ipv6_dev_ac_dec(struct net_device *dev, const struct in6_addr *addr); static u32 inet6_acaddr_hash(const struct net *net, @@ -64,14 +67,11 @@ static u32 inet6_acaddr_hash(const struct net *net, int ipv6_sock_ac_join(struct sock *sk, int ifindex, const struct in6_addr *addr) { struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_ac_socklist *pac = NULL; + struct net *net = sock_net(sk); struct net_device *dev = NULL; struct inet6_dev *idev; - struct ipv6_ac_socklist *pac; - struct net *net = sock_net(sk); - int ishost = !net->ipv6.devconf_all->forwarding; - int err = 0; - - ASSERT_RTNL(); + int err = 0, ishost; if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) return -EPERM; @@ -79,32 +79,43 @@ int ipv6_sock_ac_join(struct sock *sk, int ifindex, const struct in6_addr *addr) return -EINVAL; if (ifindex) - dev = __dev_get_by_index(net, ifindex); + dev = dev_get_by_index(net, ifindex); - if (ipv6_chk_addr_and_flags(net, addr, dev, true, 0, IFA_F_TENTATIVE)) - return -EINVAL; + if (ipv6_chk_addr_and_flags(net, addr, dev, true, 0, IFA_F_TENTATIVE)) { + err = -EINVAL; + goto error; + } pac = sock_kmalloc(sk, sizeof(struct ipv6_ac_socklist), GFP_KERNEL); - if (!pac) - return -ENOMEM; + if (!pac) { + err = -ENOMEM; + goto error; + } + pac->acl_next = NULL; pac->acl_addr = *addr; + ishost = !READ_ONCE(net->ipv6.devconf_all->forwarding); + if (ifindex == 0) { struct rt6_info *rt; + rcu_read_lock(); rt = rt6_lookup(net, addr, NULL, 0, NULL, 0); if (rt) { - dev = rt->dst.dev; + dev = dst_dev(&rt->dst); + dev_hold(dev); ip6_rt_put(rt); } else if (ishost) { + rcu_read_unlock(); err = -EADDRNOTAVAIL; goto error; } else { /* router, no matching interface: just pick one */ - dev = __dev_get_by_flags(net, IFF_UP, - IFF_UP | IFF_LOOPBACK); + dev = dev_get_by_flags_rcu(net, IFF_UP, + IFF_UP | IFF_LOOPBACK); } + rcu_read_unlock(); } if (!dev) { @@ -112,7 +123,7 @@ int ipv6_sock_ac_join(struct sock *sk, int ifindex, const struct in6_addr *addr) goto error; } - idev = __in6_dev_get(dev); + idev = in6_dev_get(dev); if (!idev) { if (ifindex) err = -ENODEV; @@ -120,8 +131,9 @@ int ipv6_sock_ac_join(struct sock *sk, int ifindex, const struct in6_addr *addr) err = -EADDRNOTAVAIL; goto error; } + /* reset ishost, now that we have a specific device */ - ishost = !idev->cnf.forwarding; + ishost = !READ_ONCE(idev->cnf.forwarding); pac->acl_ifindex = dev->ifindex; @@ -134,7 +146,7 @@ int ipv6_sock_ac_join(struct sock *sk, int ifindex, const struct in6_addr *addr) if (ishost) err = -EADDRNOTAVAIL; if (err) - goto error; + goto error_idev; } err = __ipv6_dev_ac_inc(idev, addr); @@ -144,7 +156,11 @@ int ipv6_sock_ac_join(struct sock *sk, int ifindex, const struct in6_addr *addr) pac = NULL; } +error_idev: + in6_dev_put(idev); error: + dev_put(dev); + if (pac) sock_kfree_s(sk, pac, sizeof(*pac)); return err; @@ -155,12 +171,10 @@ error: */ int ipv6_sock_ac_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) { - struct ipv6_pinfo *np = inet6_sk(sk); - struct net_device *dev; struct ipv6_ac_socklist *pac, *prev_pac; + struct ipv6_pinfo *np = inet6_sk(sk); struct net *net = sock_net(sk); - - ASSERT_RTNL(); + struct net_device *dev; prev_pac = NULL; for (pac = np->ipv6_ac_list; pac; pac = pac->acl_next) { @@ -176,9 +190,11 @@ int ipv6_sock_ac_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) else np->ipv6_ac_list = pac->acl_next; - dev = __dev_get_by_index(net, pac->acl_ifindex); - if (dev) + dev = dev_get_by_index(net, pac->acl_ifindex); + if (dev) { ipv6_dev_ac_dec(dev, &pac->acl_addr); + dev_put(dev); + } sock_kfree_s(sk, pac, sizeof(*pac)); return 0; @@ -187,21 +203,20 @@ int ipv6_sock_ac_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) void __ipv6_sock_ac_close(struct sock *sk) { struct ipv6_pinfo *np = inet6_sk(sk); + struct net *net = sock_net(sk); struct net_device *dev = NULL; struct ipv6_ac_socklist *pac; - struct net *net = sock_net(sk); - int prev_index; + int prev_index = 0; - ASSERT_RTNL(); pac = np->ipv6_ac_list; np->ipv6_ac_list = NULL; - prev_index = 0; while (pac) { struct ipv6_ac_socklist *next = pac->acl_next; if (pac->acl_ifindex != prev_index) { - dev = __dev_get_by_index(net, pac->acl_ifindex); + dev_put(dev); + dev = dev_get_by_index(net, pac->acl_ifindex); prev_index = pac->acl_ifindex; } if (dev) @@ -209,6 +224,8 @@ void __ipv6_sock_ac_close(struct sock *sk) sock_kfree_s(sk, pac, sizeof(*pac)); pac = next; } + + dev_put(dev); } void ipv6_sock_ac_close(struct sock *sk) @@ -217,9 +234,8 @@ void ipv6_sock_ac_close(struct sock *sk) if (!np->ipv6_ac_list) return; - rtnl_lock(); + __ipv6_sock_ac_close(sk); - rtnl_unlock(); } static void ipv6_add_acaddr_hash(struct net *net, struct ifacaddr6 *aca) @@ -319,16 +335,14 @@ int __ipv6_dev_ac_inc(struct inet6_dev *idev, const struct in6_addr *addr) struct net *net; int err; - ASSERT_RTNL(); - write_lock_bh(&idev->lock); if (idev->dead) { err = -ENODEV; goto out; } - for (aca = rtnl_dereference(idev->ac_list); aca; - aca = rtnl_dereference(aca->aca_next)) { + for (aca = ac_dereference(idev->ac_list, idev); aca; + aca = ac_dereference(aca->aca_next, idev)) { if (ipv6_addr_equal(&aca->aca_addr, addr)) { aca->aca_users++; err = 0; @@ -380,12 +394,10 @@ int __ipv6_dev_ac_dec(struct inet6_dev *idev, const struct in6_addr *addr) { struct ifacaddr6 *aca, *prev_aca; - ASSERT_RTNL(); - write_lock_bh(&idev->lock); prev_aca = NULL; - for (aca = rtnl_dereference(idev->ac_list); aca; - aca = rtnl_dereference(aca->aca_next)) { + for (aca = ac_dereference(idev->ac_list, idev); aca; + aca = ac_dereference(aca->aca_next, idev)) { if (ipv6_addr_equal(&aca->aca_addr, addr)) break; prev_aca = aca; @@ -414,14 +426,18 @@ int __ipv6_dev_ac_dec(struct inet6_dev *idev, const struct in6_addr *addr) return 0; } -/* called with rtnl_lock() */ static int ipv6_dev_ac_dec(struct net_device *dev, const struct in6_addr *addr) { - struct inet6_dev *idev = __in6_dev_get(dev); + struct inet6_dev *idev = in6_dev_get(dev); + int err; if (!idev) return -ENODEV; - return __ipv6_dev_ac_dec(idev, addr); + + err = __ipv6_dev_ac_dec(idev, addr); + in6_dev_put(idev); + + return err; } void ipv6_ac_destroy_dev(struct inet6_dev *idev) @@ -429,7 +445,7 @@ void ipv6_ac_destroy_dev(struct inet6_dev *idev) struct ifacaddr6 *aca; write_lock_bh(&idev->lock); - while ((aca = rtnl_dereference(idev->ac_list)) != NULL) { + while ((aca = ac_dereference(idev->ac_list, idev)) != NULL) { rcu_assign_pointer(idev->ac_list, aca->aca_next); write_unlock_bh(&idev->lock); diff --git a/net/ipv6/ipv6_sockglue.c b/net/ipv6/ipv6_sockglue.c index 1e225e6489ea..e66ec623972e 100644 --- a/net/ipv6/ipv6_sockglue.c +++ b/net/ipv6/ipv6_sockglue.c @@ -117,26 +117,6 @@ struct ipv6_txoptions *ipv6_update_options(struct sock *sk, return opt; } -static bool setsockopt_needs_rtnl(int optname) -{ - switch (optname) { - case IPV6_ADDRFORM: - case IPV6_ADD_MEMBERSHIP: - case IPV6_DROP_MEMBERSHIP: - case IPV6_JOIN_ANYCAST: - case IPV6_LEAVE_ANYCAST: - case MCAST_JOIN_GROUP: - case MCAST_LEAVE_GROUP: - case MCAST_JOIN_SOURCE_GROUP: - case MCAST_LEAVE_SOURCE_GROUP: - case MCAST_BLOCK_SOURCE: - case MCAST_UNBLOCK_SOURCE: - case MCAST_MSFILTER: - return true; - } - return false; -} - static int copy_group_source_from_sockptr(struct group_source_req *greqs, sockptr_t optval, int optlen) { @@ -395,9 +375,8 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname, { struct ipv6_pinfo *np = inet6_sk(sk); struct net *net = sock_net(sk); - int val, valbool; int retv = -ENOPROTOOPT; - bool needs_rtnl = setsockopt_needs_rtnl(optname); + int val, valbool; if (sockptr_is_null(optval)) val = 0; @@ -562,8 +541,7 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname, return 0; } } - if (needs_rtnl) - rtnl_lock(); + sockopt_lock_sock(sk); /* Another thread has converted the socket into IPv4 with @@ -969,8 +947,6 @@ done: unlock: sockopt_release_sock(sk); - if (needs_rtnl) - rtnl_unlock(); return retv; diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c index 65831b4fee1f..6c875721d423 100644 --- a/net/ipv6/mcast.c +++ b/net/ipv6/mcast.c @@ -108,9 +108,9 @@ static int __ipv6_dev_mc_inc(struct net_device *dev, int sysctl_mld_max_msf __read_mostly = IPV6_MLD_MAX_MSF; int sysctl_mld_qrv __read_mostly = MLD_QRV_DEFAULT; -/* - * socket join on multicast group - */ +#define mc_assert_locked(idev) \ + lockdep_assert_held(&(idev)->mc_lock) + #define mc_dereference(e, idev) \ rcu_dereference_protected(e, lockdep_is_held(&(idev)->mc_lock)) @@ -169,17 +169,18 @@ static int unsolicited_report_interval(struct inet6_dev *idev) return iv > 0 ? iv : 1; } +/* + * socket join on multicast group + */ static int __ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr, unsigned int mode) { - struct net_device *dev = NULL; - struct ipv6_mc_socklist *mc_lst; struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_mc_socklist *mc_lst; struct net *net = sock_net(sk); + struct net_device *dev = NULL; int err; - ASSERT_RTNL(); - if (!ipv6_addr_is_multicast(addr)) return -EINVAL; @@ -199,13 +200,18 @@ static int __ipv6_sock_mc_join(struct sock *sk, int ifindex, if (ifindex == 0) { struct rt6_info *rt; + + rcu_read_lock(); rt = rt6_lookup(net, addr, NULL, 0, NULL, 0); if (rt) { - dev = rt->dst.dev; + dev = dst_dev(&rt->dst); + dev_hold(dev); ip6_rt_put(rt); } - } else - dev = __dev_get_by_index(net, ifindex); + rcu_read_unlock(); + } else { + dev = dev_get_by_index(net, ifindex); + } if (!dev) { sock_kfree_s(sk, mc_lst, sizeof(*mc_lst)); @@ -216,12 +222,11 @@ static int __ipv6_sock_mc_join(struct sock *sk, int ifindex, mc_lst->sfmode = mode; RCU_INIT_POINTER(mc_lst->sflist, NULL); - /* - * now add/increase the group membership on the device - */ - + /* now add/increase the group membership on the device */ err = __ipv6_dev_mc_inc(dev, addr, mode); + dev_put(dev); + if (err) { sock_kfree_s(sk, mc_lst, sizeof(*mc_lst)); return err; @@ -248,14 +253,36 @@ int ipv6_sock_mc_join_ssm(struct sock *sk, int ifindex, /* * socket leave on multicast group */ +static void __ipv6_sock_mc_drop(struct sock *sk, struct ipv6_mc_socklist *mc_lst) +{ + struct net *net = sock_net(sk); + struct net_device *dev; + + dev = dev_get_by_index(net, mc_lst->ifindex); + if (dev) { + struct inet6_dev *idev = in6_dev_get(dev); + + ip6_mc_leave_src(sk, mc_lst, idev); + + if (idev) { + __ipv6_dev_mc_dec(idev, &mc_lst->addr); + in6_dev_put(idev); + } + + dev_put(dev); + } else { + ip6_mc_leave_src(sk, mc_lst, NULL); + } + + atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc); + kfree_rcu(mc_lst, rcu); +} + int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) { struct ipv6_pinfo *np = inet6_sk(sk); - struct ipv6_mc_socklist *mc_lst; struct ipv6_mc_socklist __rcu **lnk; - struct net *net = sock_net(sk); - - ASSERT_RTNL(); + struct ipv6_mc_socklist *mc_lst; if (!ipv6_addr_is_multicast(addr)) return -EINVAL; @@ -265,23 +292,8 @@ int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) lnk = &mc_lst->next) { if ((ifindex == 0 || mc_lst->ifindex == ifindex) && ipv6_addr_equal(&mc_lst->addr, addr)) { - struct net_device *dev; - *lnk = mc_lst->next; - - dev = __dev_get_by_index(net, mc_lst->ifindex); - if (dev) { - struct inet6_dev *idev = __in6_dev_get(dev); - - ip6_mc_leave_src(sk, mc_lst, idev); - if (idev) - __ipv6_dev_mc_dec(idev, &mc_lst->addr); - } else { - ip6_mc_leave_src(sk, mc_lst, NULL); - } - - atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc); - kfree_rcu(mc_lst, rcu); + __ipv6_sock_mc_drop(sk, mc_lst); return 0; } } @@ -290,31 +302,36 @@ int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) } EXPORT_SYMBOL(ipv6_sock_mc_drop); -static struct inet6_dev *ip6_mc_find_dev_rtnl(struct net *net, - const struct in6_addr *group, - int ifindex) +static struct inet6_dev *ip6_mc_find_dev(struct net *net, + const struct in6_addr *group, + int ifindex) { struct net_device *dev = NULL; - struct inet6_dev *idev = NULL; + struct inet6_dev *idev; if (ifindex == 0) { - struct rt6_info *rt = rt6_lookup(net, group, NULL, 0, NULL, 0); + struct rt6_info *rt; + rcu_read_lock(); + rt = rt6_lookup(net, group, NULL, 0, NULL, 0); if (rt) { - dev = rt->dst.dev; + dev = dst_dev(&rt->dst); + dev_hold(dev); ip6_rt_put(rt); } + rcu_read_unlock(); } else { - dev = __dev_get_by_index(net, ifindex); + dev = dev_get_by_index(net, ifindex); } - if (!dev) return NULL; - idev = __in6_dev_get(dev); + + idev = in6_dev_get(dev); + dev_put(dev); + if (!idev) return NULL; - if (idev->dead) - return NULL; + return idev; } @@ -322,28 +339,10 @@ void __ipv6_sock_mc_close(struct sock *sk) { struct ipv6_pinfo *np = inet6_sk(sk); struct ipv6_mc_socklist *mc_lst; - struct net *net = sock_net(sk); - - ASSERT_RTNL(); while ((mc_lst = sock_dereference(np->ipv6_mc_list, sk)) != NULL) { - struct net_device *dev; - np->ipv6_mc_list = mc_lst->next; - - dev = __dev_get_by_index(net, mc_lst->ifindex); - if (dev) { - struct inet6_dev *idev = __in6_dev_get(dev); - - ip6_mc_leave_src(sk, mc_lst, idev); - if (idev) - __ipv6_dev_mc_dec(idev, &mc_lst->addr); - } else { - ip6_mc_leave_src(sk, mc_lst, NULL); - } - - atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc); - kfree_rcu(mc_lst, rcu); + __ipv6_sock_mc_drop(sk, mc_lst); } } @@ -354,24 +353,22 @@ void ipv6_sock_mc_close(struct sock *sk) if (!rcu_access_pointer(np->ipv6_mc_list)) return; - rtnl_lock(); lock_sock(sk); __ipv6_sock_mc_close(sk); release_sock(sk); - rtnl_unlock(); } int ip6_mc_source(int add, int omode, struct sock *sk, - struct group_source_req *pgsr) + struct group_source_req *pgsr) { + struct ipv6_pinfo *inet6 = inet6_sk(sk); struct in6_addr *source, *group; + struct net *net = sock_net(sk); struct ipv6_mc_socklist *pmc; - struct inet6_dev *idev; - struct ipv6_pinfo *inet6 = inet6_sk(sk); struct ip6_sf_socklist *psl; - struct net *net = sock_net(sk); - int i, j, rv; + struct inet6_dev *idev; int leavegroup = 0; + int i, j, rv; int err; source = &((struct sockaddr_in6 *)&pgsr->gsr_source)->sin6_addr; @@ -380,13 +377,19 @@ int ip6_mc_source(int add, int omode, struct sock *sk, if (!ipv6_addr_is_multicast(group)) return -EINVAL; - idev = ip6_mc_find_dev_rtnl(net, group, pgsr->gsr_interface); + idev = ip6_mc_find_dev(net, group, pgsr->gsr_interface); if (!idev) return -ENODEV; + mutex_lock(&idev->mc_lock); + + if (idev->dead) { + err = -ENODEV; + goto done; + } + err = -EADDRNOTAVAIL; - mutex_lock(&idev->mc_lock); for_each_pmc_socklock(inet6, sk, pmc) { if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface) continue; @@ -483,6 +486,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk, ip6_mc_add_src(idev, group, omode, 1, source, 1); done: mutex_unlock(&idev->mc_lock); + in6_dev_put(idev); if (leavegroup) err = ipv6_sock_mc_drop(sk, pgsr->gsr_interface, group); return err; @@ -491,12 +495,12 @@ done: int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf, struct sockaddr_storage *list) { - const struct in6_addr *group; - struct ipv6_mc_socklist *pmc; - struct inet6_dev *idev; struct ipv6_pinfo *inet6 = inet6_sk(sk); struct ip6_sf_socklist *newpsl, *psl; struct net *net = sock_net(sk); + const struct in6_addr *group; + struct ipv6_mc_socklist *pmc; + struct inet6_dev *idev; int leavegroup = 0; int i, err; @@ -508,10 +512,17 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf, gsf->gf_fmode != MCAST_EXCLUDE) return -EINVAL; - idev = ip6_mc_find_dev_rtnl(net, group, gsf->gf_interface); + idev = ip6_mc_find_dev(net, group, gsf->gf_interface); if (!idev) return -ENODEV; + mutex_lock(&idev->mc_lock); + + if (idev->dead) { + err = -ENODEV; + goto done; + } + err = 0; if (gsf->gf_fmode == MCAST_INCLUDE && gsf->gf_numsrc == 0) { @@ -544,24 +555,19 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf, psin6 = (struct sockaddr_in6 *)list; newpsl->sl_addr[i] = psin6->sin6_addr; } - mutex_lock(&idev->mc_lock); + err = ip6_mc_add_src(idev, group, gsf->gf_fmode, newpsl->sl_count, newpsl->sl_addr, 0); if (err) { - mutex_unlock(&idev->mc_lock); sock_kfree_s(sk, newpsl, struct_size(newpsl, sl_addr, newpsl->sl_max)); goto done; } - mutex_unlock(&idev->mc_lock); } else { newpsl = NULL; - mutex_lock(&idev->mc_lock); ip6_mc_add_src(idev, group, gsf->gf_fmode, 0, NULL, 0); - mutex_unlock(&idev->mc_lock); } - mutex_lock(&idev->mc_lock); psl = sock_dereference(pmc->sflist, sk); if (psl) { ip6_mc_del_src(idev, group, pmc->sfmode, @@ -571,12 +577,14 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf, } else { ip6_mc_del_src(idev, group, pmc->sfmode, 0, NULL, 0); } + rcu_assign_pointer(pmc->sflist, newpsl); - mutex_unlock(&idev->mc_lock); kfree_rcu(psl, rcu); pmc->sfmode = gsf->gf_fmode; err = 0; done: + mutex_unlock(&idev->mc_lock); + in6_dev_put(idev); if (leavegroup) err = ipv6_sock_mc_drop(sk, gsf->gf_interface, group); return err; @@ -597,10 +605,6 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf, if (!ipv6_addr_is_multicast(group)) return -EINVAL; - /* changes to the ipv6_mc_list require the socket lock and - * rtnl lock. We have the socket lock, so reading the list is safe. - */ - for_each_pmc_socklock(inet6, sk, pmc) { if (pmc->ifindex != gsf->gf_interface) continue; @@ -668,12 +672,13 @@ bool inet6_mc_check(const struct sock *sk, const struct in6_addr *mc_addr, return rv; } -/* called with mc_lock */ static void igmp6_group_added(struct ifmcaddr6 *mc) { struct net_device *dev = mc->idev->dev; char buf[MAX_ADDR_LEN]; + mc_assert_locked(mc->idev); + if (IPV6_ADDR_MC_SCOPE(&mc->mca_addr) < IPV6_ADDR_SCOPE_LINKLOCAL) return; @@ -703,12 +708,13 @@ static void igmp6_group_added(struct ifmcaddr6 *mc) mld_ifc_event(mc->idev); } -/* called with mc_lock */ static void igmp6_group_dropped(struct ifmcaddr6 *mc) { struct net_device *dev = mc->idev->dev; char buf[MAX_ADDR_LEN]; + mc_assert_locked(mc->idev); + if (IPV6_ADDR_MC_SCOPE(&mc->mca_addr) < IPV6_ADDR_SCOPE_LINKLOCAL) return; @@ -729,14 +735,13 @@ static void igmp6_group_dropped(struct ifmcaddr6 *mc) refcount_dec(&mc->mca_refcnt); } -/* - * deleted ifmcaddr6 manipulation - * called with mc_lock - */ +/* deleted ifmcaddr6 manipulation */ static void mld_add_delrec(struct inet6_dev *idev, struct ifmcaddr6 *im) { struct ifmcaddr6 *pmc; + mc_assert_locked(idev); + /* this is an "ifmcaddr6" for convenience; only the fields below * are actually used. In particular, the refcnt and users are not * used for management of the delete list. Using the same structure @@ -770,13 +775,14 @@ static void mld_add_delrec(struct inet6_dev *idev, struct ifmcaddr6 *im) rcu_assign_pointer(idev->mc_tomb, pmc); } -/* called with mc_lock */ static void mld_del_delrec(struct inet6_dev *idev, struct ifmcaddr6 *im) { struct ip6_sf_list *psf, *sources, *tomb; struct in6_addr *pmca = &im->mca_addr; struct ifmcaddr6 *pmc, *pmc_prev; + mc_assert_locked(idev); + pmc_prev = NULL; for_each_mc_tomb(idev, pmc) { if (ipv6_addr_equal(&pmc->mca_addr, pmca)) @@ -813,11 +819,12 @@ static void mld_del_delrec(struct inet6_dev *idev, struct ifmcaddr6 *im) } } -/* called with mc_lock */ static void mld_clear_delrec(struct inet6_dev *idev) { struct ifmcaddr6 *pmc, *nextpmc; + mc_assert_locked(idev); + pmc = mc_dereference(idev->mc_tomb, idev); RCU_INIT_POINTER(idev->mc_tomb, NULL); @@ -861,11 +868,6 @@ static void mld_clear_report(struct inet6_dev *idev) spin_unlock_bh(&idev->mc_report_lock); } -static void mca_get(struct ifmcaddr6 *mc) -{ - refcount_inc(&mc->mca_refcnt); -} - static void ma_put(struct ifmcaddr6 *mc) { if (refcount_dec_and_test(&mc->mca_refcnt)) { @@ -874,13 +876,14 @@ static void ma_put(struct ifmcaddr6 *mc) } } -/* called with mc_lock */ static struct ifmcaddr6 *mca_alloc(struct inet6_dev *idev, const struct in6_addr *addr, unsigned int mode) { struct ifmcaddr6 *mc; + mc_assert_locked(idev); + mc = kzalloc(sizeof(*mc), GFP_KERNEL); if (!mc) return NULL; @@ -945,23 +948,22 @@ error: static int __ipv6_dev_mc_inc(struct net_device *dev, const struct in6_addr *addr, unsigned int mode) { - struct ifmcaddr6 *mc; struct inet6_dev *idev; - - ASSERT_RTNL(); + struct ifmcaddr6 *mc; /* we need to take a reference on idev */ idev = in6_dev_get(dev); - if (!idev) return -EINVAL; - if (idev->dead) { + mutex_lock(&idev->mc_lock); + + if (READ_ONCE(idev->dead)) { + mutex_unlock(&idev->mc_lock); in6_dev_put(idev); return -ENODEV; } - mutex_lock(&idev->mc_lock); for_each_mc_mclock(idev, mc) { if (ipv6_addr_equal(&mc->mca_addr, addr)) { mc->mca_users++; @@ -982,13 +984,11 @@ static int __ipv6_dev_mc_inc(struct net_device *dev, rcu_assign_pointer(mc->next, idev->mc_list); rcu_assign_pointer(idev->mc_list, mc); - mca_get(mc); - mld_del_delrec(idev, mc); igmp6_group_added(mc); inet6_ifmcaddr_notify(dev, mc, RTM_NEWMULTICAST); mutex_unlock(&idev->mc_lock); - ma_put(mc); + return 0; } @@ -1005,9 +1005,8 @@ int __ipv6_dev_mc_dec(struct inet6_dev *idev, const struct in6_addr *addr) { struct ifmcaddr6 *ma, __rcu **map; - ASSERT_RTNL(); - mutex_lock(&idev->mc_lock); + for (map = &idev->mc_list; (ma = mc_dereference(*map, idev)); map = &ma->next) { @@ -1038,13 +1037,12 @@ int ipv6_dev_mc_dec(struct net_device *dev, const struct in6_addr *addr) struct inet6_dev *idev; int err; - ASSERT_RTNL(); - - idev = __in6_dev_get(dev); + idev = in6_dev_get(dev); if (!idev) - err = -ENODEV; - else - err = __ipv6_dev_mc_dec(idev, addr); + return -ENODEV; + + err = __ipv6_dev_mc_dec(idev, addr); + in6_dev_put(idev); return err; } @@ -1091,46 +1089,51 @@ unlock: return rv; } -/* called with mc_lock */ static void mld_gq_start_work(struct inet6_dev *idev) { unsigned long tv = get_random_u32_below(idev->mc_maxdelay); + mc_assert_locked(idev); + idev->mc_gq_running = 1; if (!mod_delayed_work(mld_wq, &idev->mc_gq_work, tv + 2)) in6_dev_hold(idev); } -/* called with mc_lock */ static void mld_gq_stop_work(struct inet6_dev *idev) { + mc_assert_locked(idev); + idev->mc_gq_running = 0; if (cancel_delayed_work(&idev->mc_gq_work)) __in6_dev_put(idev); } -/* called with mc_lock */ static void mld_ifc_start_work(struct inet6_dev *idev, unsigned long delay) { unsigned long tv = get_random_u32_below(delay); + mc_assert_locked(idev); + if (!mod_delayed_work(mld_wq, &idev->mc_ifc_work, tv + 2)) in6_dev_hold(idev); } -/* called with mc_lock */ static void mld_ifc_stop_work(struct inet6_dev *idev) { + mc_assert_locked(idev); + idev->mc_ifc_count = 0; if (cancel_delayed_work(&idev->mc_ifc_work)) __in6_dev_put(idev); } -/* called with mc_lock */ static void mld_dad_start_work(struct inet6_dev *idev, unsigned long delay) { unsigned long tv = get_random_u32_below(delay); + mc_assert_locked(idev); + if (!mod_delayed_work(mld_wq, &idev->mc_dad_work, tv + 2)) in6_dev_hold(idev); } @@ -1155,14 +1158,13 @@ static void mld_report_stop_work(struct inet6_dev *idev) __in6_dev_put(idev); } -/* - * IGMP handling (alias multicast ICMPv6 messages) - * called with mc_lock - */ +/* IGMP handling (alias multicast ICMPv6 messages) */ static void igmp6_group_queried(struct ifmcaddr6 *ma, unsigned long resptime) { unsigned long delay = resptime; + mc_assert_locked(ma->idev); + /* Do not start work for these addresses */ if (ipv6_addr_is_ll_all_nodes(&ma->mca_addr) || IPV6_ADDR_MC_SCOPE(&ma->mca_addr) < IPV6_ADDR_SCOPE_LINKLOCAL) @@ -1181,15 +1183,15 @@ static void igmp6_group_queried(struct ifmcaddr6 *ma, unsigned long resptime) ma->mca_flags |= MAF_TIMER_RUNNING; } -/* mark EXCLUDE-mode sources - * called with mc_lock - */ +/* mark EXCLUDE-mode sources */ static bool mld_xmarksources(struct ifmcaddr6 *pmc, int nsrcs, const struct in6_addr *srcs) { struct ip6_sf_list *psf; int i, scount; + mc_assert_locked(pmc->idev); + scount = 0; for_each_psf_mclock(pmc, psf) { if (scount == nsrcs) @@ -1212,13 +1214,14 @@ static bool mld_xmarksources(struct ifmcaddr6 *pmc, int nsrcs, return true; } -/* called with mc_lock */ static bool mld_marksources(struct ifmcaddr6 *pmc, int nsrcs, const struct in6_addr *srcs) { struct ip6_sf_list *psf; int i, scount; + mc_assert_locked(pmc->idev); + if (pmc->mca_sfmode == MCAST_EXCLUDE) return mld_xmarksources(pmc, nsrcs, srcs); @@ -1913,7 +1916,6 @@ static struct sk_buff *add_grhead(struct sk_buff *skb, struct ifmcaddr6 *pmc, #define AVAILABLE(skb) ((skb) ? skb_availroom(skb) : 0) -/* called with mc_lock */ static struct sk_buff *add_grec(struct sk_buff *skb, struct ifmcaddr6 *pmc, int type, int gdeleted, int sdeleted, int crsend) @@ -1927,6 +1929,8 @@ static struct sk_buff *add_grec(struct sk_buff *skb, struct ifmcaddr6 *pmc, struct mld2_report *pmr; unsigned int mtu; + mc_assert_locked(idev); + if (pmc->mca_flags & MAF_NOREPORT) return skb; @@ -2045,12 +2049,13 @@ empty_source: return skb; } -/* called with mc_lock */ static void mld_send_report(struct inet6_dev *idev, struct ifmcaddr6 *pmc) { struct sk_buff *skb = NULL; int type; + mc_assert_locked(idev); + if (!pmc) { for_each_mc_mclock(idev, pmc) { if (pmc->mca_flags & MAF_NOREPORT) @@ -2072,10 +2077,7 @@ static void mld_send_report(struct inet6_dev *idev, struct ifmcaddr6 *pmc) mld_sendpack(skb); } -/* - * remove zero-count source records from a source filter list - * called with mc_lock - */ +/* remove zero-count source records from a source filter list */ static void mld_clear_zeros(struct ip6_sf_list __rcu **ppsf, struct inet6_dev *idev) { struct ip6_sf_list *psf_prev, *psf_next, *psf; @@ -2099,7 +2101,6 @@ static void mld_clear_zeros(struct ip6_sf_list __rcu **ppsf, struct inet6_dev *i } } -/* called with mc_lock */ static void mld_send_cr(struct inet6_dev *idev) { struct ifmcaddr6 *pmc, *pmc_prev, *pmc_next; @@ -2263,13 +2264,14 @@ err_out: goto out; } -/* called with mc_lock */ static void mld_send_initial_cr(struct inet6_dev *idev) { - struct sk_buff *skb; struct ifmcaddr6 *pmc; + struct sk_buff *skb; int type; + mc_assert_locked(idev); + if (mld_in_v1_mode(idev)) return; @@ -2316,13 +2318,14 @@ static void mld_dad_work(struct work_struct *work) in6_dev_put(idev); } -/* called with mc_lock */ static int ip6_mc_del1_src(struct ifmcaddr6 *pmc, int sfmode, - const struct in6_addr *psfsrc) + const struct in6_addr *psfsrc) { struct ip6_sf_list *psf, *psf_prev; int rv = 0; + mc_assert_locked(pmc->idev); + psf_prev = NULL; for_each_psf_mclock(pmc, psf) { if (ipv6_addr_equal(&psf->sf_addr, psfsrc)) @@ -2359,7 +2362,6 @@ static int ip6_mc_del1_src(struct ifmcaddr6 *pmc, int sfmode, return rv; } -/* called with mc_lock */ static int ip6_mc_del_src(struct inet6_dev *idev, const struct in6_addr *pmca, int sfmode, int sfcount, const struct in6_addr *psfsrc, int delta) @@ -2371,6 +2373,8 @@ static int ip6_mc_del_src(struct inet6_dev *idev, const struct in6_addr *pmca, if (!idev) return -ENODEV; + mc_assert_locked(idev); + for_each_mc_mclock(idev, pmc) { if (ipv6_addr_equal(pmca, &pmc->mca_addr)) break; @@ -2412,15 +2416,14 @@ static int ip6_mc_del_src(struct inet6_dev *idev, const struct in6_addr *pmca, return err; } -/* - * Add multicast single-source filter to the interface list - * called with mc_lock - */ +/* Add multicast single-source filter to the interface list */ static int ip6_mc_add1_src(struct ifmcaddr6 *pmc, int sfmode, - const struct in6_addr *psfsrc) + const struct in6_addr *psfsrc) { struct ip6_sf_list *psf, *psf_prev; + mc_assert_locked(pmc->idev); + psf_prev = NULL; for_each_psf_mclock(pmc, psf) { if (ipv6_addr_equal(&psf->sf_addr, psfsrc)) @@ -2443,11 +2446,12 @@ static int ip6_mc_add1_src(struct ifmcaddr6 *pmc, int sfmode, return 0; } -/* called with mc_lock */ static void sf_markstate(struct ifmcaddr6 *pmc) { - struct ip6_sf_list *psf; int mca_xcount = pmc->mca_sfcount[MCAST_EXCLUDE]; + struct ip6_sf_list *psf; + + mc_assert_locked(pmc->idev); for_each_psf_mclock(pmc, psf) { if (pmc->mca_sfcount[MCAST_EXCLUDE]) { @@ -2460,14 +2464,15 @@ static void sf_markstate(struct ifmcaddr6 *pmc) } } -/* called with mc_lock */ static int sf_setstate(struct ifmcaddr6 *pmc) { - struct ip6_sf_list *psf, *dpsf; int mca_xcount = pmc->mca_sfcount[MCAST_EXCLUDE]; + struct ip6_sf_list *psf, *dpsf; int qrv = pmc->idev->mc_qrv; int new_in, rv; + mc_assert_locked(pmc->idev); + rv = 0; for_each_psf_mclock(pmc, psf) { if (pmc->mca_sfcount[MCAST_EXCLUDE]) { @@ -2526,10 +2531,7 @@ static int sf_setstate(struct ifmcaddr6 *pmc) return rv; } -/* - * Add multicast source filter list to the interface list - * called with mc_lock - */ +/* Add multicast source filter list to the interface list */ static int ip6_mc_add_src(struct inet6_dev *idev, const struct in6_addr *pmca, int sfmode, int sfcount, const struct in6_addr *psfsrc, int delta) @@ -2541,6 +2543,8 @@ static int ip6_mc_add_src(struct inet6_dev *idev, const struct in6_addr *pmca, if (!idev) return -ENODEV; + mc_assert_locked(idev); + for_each_mc_mclock(idev, pmc) { if (ipv6_addr_equal(pmca, &pmc->mca_addr)) break; @@ -2588,11 +2592,12 @@ static int ip6_mc_add_src(struct inet6_dev *idev, const struct in6_addr *pmca, return err; } -/* called with mc_lock */ static void ip6_mc_clear_src(struct ifmcaddr6 *pmc) { struct ip6_sf_list *psf, *nextpsf; + mc_assert_locked(pmc->idev); + for (psf = mc_dereference(pmc->mca_tomb, pmc->idev); psf; psf = nextpsf) { @@ -2613,11 +2618,12 @@ static void ip6_mc_clear_src(struct ifmcaddr6 *pmc) WRITE_ONCE(pmc->mca_sfcount[MCAST_EXCLUDE], 1); } -/* called with mc_lock */ static void igmp6_join_group(struct ifmcaddr6 *ma) { unsigned long delay; + mc_assert_locked(ma->idev); + if (ma->mca_flags & MAF_NOREPORT) return; @@ -2664,9 +2670,10 @@ static int ip6_mc_leave_src(struct sock *sk, struct ipv6_mc_socklist *iml, return err; } -/* called with mc_lock */ static void igmp6_leave_group(struct ifmcaddr6 *ma) { + mc_assert_locked(ma->idev); + if (mld_in_v1_mode(ma->idev)) { if (ma->mca_flags & MAF_LAST_REPORTER) { igmp6_send(&ma->mca_addr, ma->idev->dev, @@ -2711,9 +2718,10 @@ static void mld_ifc_work(struct work_struct *work) in6_dev_put(idev); } -/* called with mc_lock */ static void mld_ifc_event(struct inet6_dev *idev) { + mc_assert_locked(idev); + if (mld_in_v1_mode(idev)) return; @@ -2868,8 +2876,6 @@ static void ipv6_mc_rejoin_groups(struct inet6_dev *idev) { struct ifmcaddr6 *pmc; - ASSERT_RTNL(); - mutex_lock(&idev->mc_lock); if (mld_in_v1_mode(idev)) { for_each_mc_mclock(idev, pmc) diff --git a/net/ipv6/ndisc.c b/net/ipv6/ndisc.c index f2299b61221b..28f35cbb6577 100644 --- a/net/ipv6/ndisc.c +++ b/net/ipv6/ndisc.c @@ -377,24 +377,25 @@ static int ndisc_constructor(struct neighbour *neigh) static int pndisc_constructor(struct pneigh_entry *n) { struct in6_addr *addr = (struct in6_addr *)&n->key; - struct in6_addr maddr; struct net_device *dev = n->dev; + struct in6_addr maddr; - if (!dev || !__in6_dev_get(dev)) + if (!dev) return -EINVAL; + addrconf_addr_solict_mult(addr, &maddr); - ipv6_dev_mc_inc(dev, &maddr); - return 0; + return ipv6_dev_mc_inc(dev, &maddr); } static void pndisc_destructor(struct pneigh_entry *n) { struct in6_addr *addr = (struct in6_addr *)&n->key; - struct in6_addr maddr; struct net_device *dev = n->dev; + struct in6_addr maddr; - if (!dev || !__in6_dev_get(dev)) + if (!dev) return; + addrconf_addr_solict_mult(addr, &maddr); ipv6_dev_mc_dec(dev, &maddr); } |