@@ -581,6 +581,7 @@ enum {
581581
582582struct async_poll {
583583 struct io_poll_iocb poll ;
584+ struct io_poll_iocb * double_poll ;
584585 struct io_wq_work work ;
585586};
586587
@@ -4220,9 +4221,9 @@ static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll)
42204221 return false;
42214222}
42224223
4223- static void io_poll_remove_double (struct io_kiocb * req )
4224+ static void io_poll_remove_double (struct io_kiocb * req , void * data )
42244225{
4225- struct io_poll_iocb * poll = ( struct io_poll_iocb * ) req -> io ;
4226+ struct io_poll_iocb * poll = data ;
42264227
42274228 lockdep_assert_held (& req -> ctx -> completion_lock );
42284229
@@ -4242,7 +4243,7 @@ static void io_poll_complete(struct io_kiocb *req, __poll_t mask, int error)
42424243{
42434244 struct io_ring_ctx * ctx = req -> ctx ;
42444245
4245- io_poll_remove_double (req );
4246+ io_poll_remove_double (req , req -> io );
42464247 req -> poll .done = true;
42474248 io_cqring_fill_event (req , error ? error : mangle_poll (mask ));
42484249 io_commit_cqring (ctx );
@@ -4285,21 +4286,21 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
42854286 int sync , void * key )
42864287{
42874288 struct io_kiocb * req = wait -> private ;
4288- struct io_poll_iocb * poll = ( struct io_poll_iocb * ) req -> io ;
4289+ struct io_poll_iocb * poll = req -> apoll -> double_poll ;
42894290 __poll_t mask = key_to_poll (key );
42904291
42914292 /* for instances that support it check for an event match first: */
42924293 if (mask && !(mask & poll -> events ))
42934294 return 0 ;
42944295
4295- if (req -> poll . head ) {
4296+ if (poll && poll -> head ) {
42964297 bool done ;
42974298
4298- spin_lock (& req -> poll . head -> lock );
4299- done = list_empty (& req -> poll . wait .entry );
4299+ spin_lock (& poll -> head -> lock );
4300+ done = list_empty (& poll -> wait .entry );
43004301 if (!done )
4301- list_del_init (& req -> poll . wait .entry );
4302- spin_unlock (& req -> poll . head -> lock );
4302+ list_del_init (& poll -> wait .entry );
4303+ spin_unlock (& poll -> head -> lock );
43034304 if (!done )
43044305 __io_async_wake (req , poll , mask , io_poll_task_func );
43054306 }
@@ -4319,7 +4320,8 @@ static void io_init_poll_iocb(struct io_poll_iocb *poll, __poll_t events,
43194320}
43204321
43214322static void __io_queue_proc (struct io_poll_iocb * poll , struct io_poll_table * pt ,
4322- struct wait_queue_head * head )
4323+ struct wait_queue_head * head ,
4324+ struct io_poll_iocb * * poll_ptr )
43234325{
43244326 struct io_kiocb * req = pt -> req ;
43254327
@@ -4330,7 +4332,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
43304332 */
43314333 if (unlikely (poll -> head )) {
43324334 /* already have a 2nd entry, fail a third attempt */
4333- if (req -> io ) {
4335+ if (* poll_ptr ) {
43344336 pt -> error = - EINVAL ;
43354337 return ;
43364338 }
@@ -4342,7 +4344,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
43424344 io_init_poll_iocb (poll , req -> poll .events , io_poll_double_wake );
43434345 refcount_inc (& req -> refs );
43444346 poll -> wait .private = req ;
4345- req -> io = ( void * ) poll ;
4347+ * poll_ptr = poll ;
43464348 }
43474349
43484350 pt -> error = 0 ;
@@ -4354,8 +4356,9 @@ static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
43544356 struct poll_table_struct * p )
43554357{
43564358 struct io_poll_table * pt = container_of (p , struct io_poll_table , pt );
4359+ struct async_poll * apoll = pt -> req -> apoll ;
43574360
4358- __io_queue_proc (& pt -> req -> apoll -> poll , pt , head );
4361+ __io_queue_proc (& apoll -> poll , pt , head , & apoll -> double_poll );
43594362}
43604363
43614364static void io_sq_thread_drop_mm (struct io_ring_ctx * ctx )
@@ -4409,6 +4412,7 @@ static void io_async_task_func(struct callback_head *cb)
44094412 memcpy (& req -> work , & apoll -> work , sizeof (req -> work ));
44104413
44114414 if (canceled ) {
4415+ kfree (apoll -> double_poll );
44124416 kfree (apoll );
44134417 io_cqring_ev_posted (ctx );
44144418end_req :
@@ -4426,6 +4430,7 @@ static void io_async_task_func(struct callback_head *cb)
44264430 __io_queue_sqe (req , NULL );
44274431 mutex_unlock (& ctx -> uring_lock );
44284432
4433+ kfree (apoll -> double_poll );
44294434 kfree (apoll );
44304435}
44314436
@@ -4497,7 +4502,6 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
44974502 struct async_poll * apoll ;
44984503 struct io_poll_table ipt ;
44994504 __poll_t mask , ret ;
4500- bool had_io ;
45014505
45024506 if (!req -> file || !file_can_poll (req -> file ))
45034507 return false;
@@ -4509,10 +4513,10 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
45094513 apoll = kmalloc (sizeof (* apoll ), GFP_ATOMIC );
45104514 if (unlikely (!apoll ))
45114515 return false;
4516+ apoll -> double_poll = NULL ;
45124517
45134518 req -> flags |= REQ_F_POLLED ;
45144519 memcpy (& apoll -> work , & req -> work , sizeof (req -> work ));
4515- had_io = req -> io != NULL ;
45164520
45174521 get_task_struct (current );
45184522 req -> task = current ;
@@ -4531,12 +4535,10 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
45314535 ret = __io_arm_poll_handler (req , & apoll -> poll , & ipt , mask ,
45324536 io_async_wake );
45334537 if (ret ) {
4534- ipt .error = 0 ;
4535- /* only remove double add if we did it here */
4536- if (!had_io )
4537- io_poll_remove_double (req );
4538+ io_poll_remove_double (req , apoll -> double_poll );
45384539 spin_unlock_irq (& ctx -> completion_lock );
45394540 memcpy (& req -> work , & apoll -> work , sizeof (req -> work ));
4541+ kfree (apoll -> double_poll );
45404542 kfree (apoll );
45414543 return false;
45424544 }
@@ -4567,11 +4569,13 @@ static bool io_poll_remove_one(struct io_kiocb *req)
45674569 bool do_complete ;
45684570
45694571 if (req -> opcode == IORING_OP_POLL_ADD ) {
4570- io_poll_remove_double (req );
4572+ io_poll_remove_double (req , req -> io );
45714573 do_complete = __io_poll_remove_one (req , & req -> poll );
45724574 } else {
45734575 struct async_poll * apoll = req -> apoll ;
45744576
4577+ io_poll_remove_double (req , apoll -> double_poll );
4578+
45754579 /* non-poll requests have submit ref still */
45764580 do_complete = __io_poll_remove_one (req , & apoll -> poll );
45774581 if (do_complete ) {
@@ -4582,6 +4586,7 @@ static bool io_poll_remove_one(struct io_kiocb *req)
45824586 * final reference.
45834587 */
45844588 memcpy (& req -> work , & apoll -> work , sizeof (req -> work ));
4589+ kfree (apoll -> double_poll );
45854590 kfree (apoll );
45864591 }
45874592 }
@@ -4682,7 +4687,7 @@ static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
46824687{
46834688 struct io_poll_table * pt = container_of (p , struct io_poll_table , pt );
46844689
4685- __io_queue_proc (& pt -> req -> poll , pt , head );
4690+ __io_queue_proc (& pt -> req -> poll , pt , head , ( struct io_poll_iocb * * ) & pt -> req -> io );
46864691}
46874692
46884693static int io_poll_add_prep (struct io_kiocb * req , const struct io_uring_sqe * sqe )
0 commit comments