diff options
Diffstat (limited to 'io_uring/sqpoll.c')
-rw-r--r-- | io_uring/sqpoll.c | 34 |
1 files changed, 24 insertions, 10 deletions
diff --git a/io_uring/sqpoll.c b/io_uring/sqpoll.c index 0625a421626f..268d2fbe6160 100644 --- a/io_uring/sqpoll.c +++ b/io_uring/sqpoll.c @@ -30,7 +30,7 @@ enum { void io_sq_thread_unpark(struct io_sq_data *sqd) __releases(&sqd->lock) { - WARN_ON_ONCE(sqd->thread == current); + WARN_ON_ONCE(sqpoll_task_locked(sqd) == current); /* * Do the dance but not conditional clear_bit() because it'd race with @@ -46,24 +46,32 @@ void io_sq_thread_unpark(struct io_sq_data *sqd) void io_sq_thread_park(struct io_sq_data *sqd) __acquires(&sqd->lock) { - WARN_ON_ONCE(data_race(sqd->thread) == current); + struct task_struct *tsk; atomic_inc(&sqd->park_pending); set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state); mutex_lock(&sqd->lock); - if (sqd->thread) - wake_up_process(sqd->thread); + + tsk = sqpoll_task_locked(sqd); + if (tsk) { + WARN_ON_ONCE(tsk == current); + wake_up_process(tsk); + } } void io_sq_thread_stop(struct io_sq_data *sqd) { - WARN_ON_ONCE(sqd->thread == current); + struct task_struct *tsk; + WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state)); set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state); mutex_lock(&sqd->lock); - if (sqd->thread) - wake_up_process(sqd->thread); + tsk = sqpoll_task_locked(sqd); + if (tsk) { + WARN_ON_ONCE(tsk == current); + wake_up_process(tsk); + } mutex_unlock(&sqd->lock); wait_for_completion(&sqd->exited); } @@ -486,7 +494,10 @@ __cold int io_sq_offload_create(struct io_ring_ctx *ctx, goto err_sqpoll; } - sqd->thread = tsk; + mutex_lock(&sqd->lock); + rcu_assign_pointer(sqd->thread, tsk); + mutex_unlock(&sqd->lock); + task_to_put = get_task_struct(tsk); ret = io_uring_alloc_task_context(tsk, ctx); wake_up_new_task(tsk); @@ -514,10 +525,13 @@ __cold int io_sqpoll_wq_cpu_affinity(struct io_ring_ctx *ctx, int ret = -EINVAL; if (sqd) { + struct task_struct *tsk; + io_sq_thread_park(sqd); /* Don't set affinity for a dying thread */ - if (sqd->thread) - ret = io_wq_cpu_affinity(sqd->thread->io_uring, mask); + tsk = sqpoll_task_locked(sqd); + if (tsk) + ret = io_wq_cpu_affinity(tsk->io_uring, mask); io_sq_thread_unpark(sqd); } |