diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 72fde2888ad2f..76c1cef0d4f80 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -313,7 +313,10 @@ static int inet_create(struct net *net, struct socket *sock, int protocol, answer_flags = answer->flags; rcu_read_unlock(); +#if !IS_ENABLED(CONFIG_KASAN) + /* with kasan we use kmalloc */ WARN_ON(!answer_prot->slab); +#endif err = -ENOMEM; sk = sk_alloc(net, PF_INET, GFP_KERNEL, answer_prot, kern); diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c index 8fe7900f19499..7412d18241514 100644 --- a/net/ipv6/af_inet6.c +++ b/net/ipv6/af_inet6.c @@ -178,7 +178,10 @@ static int inet6_create(struct net *net, struct socket *sock, int protocol, answer_flags = answer->flags; rcu_read_unlock(); +#if !IS_ENABLED(CONFIG_KASAN) + /* with kasan we use kmalloc */ WARN_ON(!answer_prot->slab); +#endif err = -ENOBUFS; sk = sk_alloc(net, PF_INET6, GFP_KERNEL, answer_prot, kern); diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index 1c72f25f083ea..326cda968a48a 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -2542,6 +2542,10 @@ static int __mptcp_init_sock(struct sock *sk) timer_setup(&msk->sk.icsk_retransmit_timer, mptcp_retransmit_timer, 0); timer_setup(&sk->sk_timer, mptcp_timeout_timer, 0); +#if IS_ENABLED(CONFIG_KASAN) + sock_set_flag(sk, SOCK_RCU_FREE); +#endif + return 0; } @@ -2914,7 +2918,9 @@ struct sock *mptcp_sk_clone(const struct sock *sk, WRITE_ONCE(msk->rcv_wnd_sent, ack_seq); } +#if !IS_ENABLED(CONFIG_KASAN) sock_reset_flag(nsk, SOCK_RCU_FREE); +#endif /* will be fully established after successful MPC subflow creation */ inet_sk_state_store(nsk, TCP_SYN_RECV); @@ -3684,6 +3690,12 @@ static int mptcp_napi_poll(struct napi_struct *napi, int budget) return work_done; } +#if IS_ENABLED(CONFIG_KASAN) +#define MPTCP_USE_SLAB 0 +#else +#define MPTCP_USE_SLAB 1 +#endif + void __init mptcp_proto_init(void) { struct mptcp_delegated_action *delegated; @@ -3707,7 +3719,7 @@ void __init mptcp_proto_init(void) mptcp_pm_init(); mptcp_token_init(); - if (proto_register(&mptcp_prot, 1) != 0) + if (proto_register(&mptcp_prot, MPTCP_USE_SLAB) != 0) panic("Failed to register MPTCP proto.\n"); inet_register_protosw(&mptcp_protosw); @@ -3767,7 +3779,7 @@ int __init mptcp_proto_v6_init(void) mptcp_v6_prot.destroy = mptcp_v6_destroy; mptcp_v6_prot.obj_size = sizeof(struct mptcp6_sock); - err = proto_register(&mptcp_v6_prot, 1); + err = proto_register(&mptcp_v6_prot, MPTCP_USE_SLAB); if (err) return err;