On Wed, Apr 21, 2021 at 02:40:37PM +0300, Leon Romanovsky wrote: > @@ -4396,6 +4439,14 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data) > cm_dev->going_down = 1; > spin_unlock_irq(&cm.lock); > > + list_for_each_entry_safe(cm_id_priv, tmp, > + &cm_dev->cm_id_priv_list, cm_dev_list) { > + if (!list_empty(&cm_id_priv->cm_dev_list)) > + list_del(&cm_id_priv->cm_dev_list); > + cm_id_priv->av.port = NULL; > + cm_id_priv->alt_av.port = NULL; > + } Ugh, this is in the wrong order, it has to be after the work queue flush.. Hurm, I didn't see an easy way to fix it up, but I did think of a much better design! Generally speaking all we need is the memory of the cm_dev and port to remain active, we don't need to block or fence with cm_remove_one(), so just stick a memory kref on this thing and keep the memory. The only things that needs to seralize with cm_remove_one() are on the workqueue or take a spinlock (eg because they touch mad_agent) Try this, I didn't finish every detail, applies on top of your series, but you'll need to reflow it into new commits: diff --git a/drivers/infiniband/core/cm.c b/drivers/infiniband/core/cm.c index 3feff999a5e003..c26367006a4485 100644 --- a/drivers/infiniband/core/cm.c +++ b/drivers/infiniband/core/cm.c @@ -205,8 +205,11 @@ struct cm_port { }; struct cm_device { + struct kref kref; struct list_head list; + struct mutex unregistration_lock; struct ib_device *ib_device; + unsigned int num_ports; u8 ack_delay; int going_down; struct list_head cm_id_priv_list; @@ -262,7 +265,6 @@ struct cm_id_private { /* todo: use alternate port on send failure */ struct cm_av av; struct cm_av alt_av; - rwlock_t av_rwlock; /* Do not acquire inside cm.lock */ void *private_data; __be64 tid; @@ -287,10 +289,23 @@ struct cm_id_private { atomic_t work_count; struct rdma_ucm_ece ece; - - struct list_head cm_dev_list; }; +static void cm_dev_release(struct kref *kref) +{ + struct cm_device *cm_dev = container_of(kref, struct cm_device, kref); + unsigned int i; + + for (i = 0; i != cm_dev->num_ports; i++) + kfree(cm_dev->port[i]); + kfree(cm_dev); +} + +static void cm_device_put(struct cm_device *cm_dev) +{ + kref_put(&cm_dev->kref, cm_dev_release); +} + static void cm_work_handler(struct work_struct *work); static inline void cm_deref_id(struct cm_id_private *cm_id_priv) @@ -306,12 +321,12 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv) struct ib_ah *ah; int ret; - read_lock(&cm_id_priv->av_rwlock); if (!cm_id_priv->av.port) { ret = -EINVAL; goto out; } + spin_lock(&cm_id_priv->av.port.cm_dev->unregistration_lock); mad_agent = cm_id_priv->av.port->mad_agent; if (!mad_agent) { ret = -EINVAL; @@ -330,7 +345,6 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv) GFP_ATOMIC, IB_MGMT_BASE_VERSION); - read_unlock(&cm_id_priv->av_rwlock); if (IS_ERR(m)) { rdma_destroy_ah(ah, 0); ret = PTR_ERR(m); @@ -346,7 +360,7 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv) return m; out: - read_unlock(&cm_id_priv->av_rwlock); + spin_unlock(&cm_id_priv->av.port.cm_dev->unregistration_lock); return ERR_PTR(ret); } @@ -465,20 +479,18 @@ static void cm_set_private_data(struct cm_id_private *cm_id_priv, cm_id_priv->private_data_len = private_data_len; } -static void add_cm_id_to_cm_dev_list(struct cm_id_private *cm_id_priv, - struct cm_device *cm_dev) +static void cm_set_av_port(struct cm_av *av, struct cm_port *port) { - unsigned long flags; + struct cm_port *old_port = av->port; - spin_lock_irqsave(&cm.lock, flags); - if (cm_dev->going_down) - goto out; + if (old_port == port) + return; - if (!list_empty(&cm_id_priv->cm_dev_list)) - list_del(&cm_id_priv->cm_dev_list); - list_add_tail(&cm_id_priv->cm_dev_list, &cm_dev->cm_id_priv_list); -out: - spin_unlock_irqrestore(&cm.lock, flags); + av->port = port; + if (old_port) + cm_device_put(old_port->cm_dev); + if (port) + kref_get(&old_port->cm_dev->kref); } static int cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc, @@ -505,11 +517,8 @@ static int cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc, if (ret) return ret; - write_lock(&cm_id_priv->av_rwlock); - av->port = port; + cm_set_av_port(av, port); av->pkey_index = wc->pkey_index; - add_cm_id_to_cm_dev_list(cm_id_priv, port->cm_dev); - write_unlock(&cm_id_priv->av_rwlock); rdma_move_ah_attr(&av->ah_attr, &new_ah_attr); return 0; @@ -521,10 +530,7 @@ static int cm_init_av_for_response(struct cm_port *port, struct ib_wc *wc, { struct cm_av *av = &cm_id_priv->av; - write_lock(&cm_id_priv->av_rwlock); - av->port = port; - add_cm_id_to_cm_dev_list(cm_id_priv, port->cm_dev); - write_unlock(&cm_id_priv->av_rwlock); + cm_set_av_port(av, port); av->pkey_index = wc->pkey_index; return ib_init_ah_attr_from_wc(port->cm_dev->ib_device, port->port_num, wc, @@ -588,12 +594,9 @@ static int cm_init_av_by_path(struct sa_path_rec *path, return -EINVAL; cm_dev = port->cm_dev; - read_lock(&cm_id_priv->av_rwlock); if (!is_priv_av && (!cm_id_priv->av.port || cm_dev != cm_id_priv->av.port->cm_dev)) ret = -EINVAL; - - read_unlock(&cm_id_priv->av_rwlock); if (ret) return ret; @@ -618,13 +621,8 @@ static int cm_init_av_by_path(struct sa_path_rec *path, if (ret) return ret; - write_lock(&cm_id_priv->av_rwlock); - av->port = port; + cm_set_av_port(av, port); av->timeout = path->packet_life_time + 1; - if (is_priv_av) - add_cm_id_to_cm_dev_list(cm_id_priv, cm_dev); - - write_unlock(&cm_id_priv->av_rwlock); rdma_move_ah_attr(&av->ah_attr, &new_ah_attr); return 0; @@ -905,10 +903,8 @@ static struct cm_id_private *cm_alloc_id_priv(struct ib_device *device, spin_lock_init(&cm_id_priv->lock); init_completion(&cm_id_priv->comp); INIT_LIST_HEAD(&cm_id_priv->work_list); - INIT_LIST_HEAD(&cm_id_priv->cm_dev_list); atomic_set(&cm_id_priv->work_count, -1); refcount_set(&cm_id_priv->refcount, 1); - rwlock_init(&cm_id_priv->av_rwlock); ret = xa_alloc_cyclic(&cm.local_id_table, &id, NULL, xa_limit_32b, &cm.local_id_next, GFP_KERNEL); @@ -1027,10 +1023,8 @@ static u8 cm_ack_timeout_req(struct cm_id_private *cm_id_priv, { u8 ack_delay = 0; - read_lock(&cm_id_priv->av_rwlock); - if (cm_id_priv->av.port && cm_id_priv->av.port->cm_dev) + if (cm_id_priv->av.port) ack_delay = cm_id_priv->av.port->cm_dev->ack_delay; - read_unlock(&cm_id_priv->av_rwlock); return cm_ack_timeout(ack_delay, packet_life_time); } @@ -1228,8 +1222,6 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err) cm_id_priv->timewait_info = NULL; } - if (!list_empty(&cm_id_priv->cm_dev_list)) - list_del(&cm_id_priv->cm_dev_list); WARN_ON(cm_id_priv->listen_sharecount); WARN_ON(!RB_EMPTY_NODE(&cm_id_priv->service_node)); if (!RB_EMPTY_NODE(&cm_id_priv->sidr_id_node)) @@ -1246,6 +1238,8 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err) rdma_destroy_ah_attr(&cm_id_priv->av.ah_attr); rdma_destroy_ah_attr(&cm_id_priv->alt_av.ah_attr); kfree(cm_id_priv->private_data); + cm_set_av_port(&cm_id_priv->av, NULL); + cm_set_av_port(&cm_id_priv->alt_av, NULL); kfree_rcu(cm_id_priv, rcu); } @@ -1378,10 +1372,13 @@ static __be64 cm_form_tid(struct cm_id_private *cm_id_priv) { u64 hi_tid = 0, low_tid; - read_lock(&cm_id_priv->av_rwlock); - if (cm_id_priv->av.port && cm_id_priv->av.port->mad_agent) - hi_tid = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32; - read_unlock(&cm_id_priv->av_rwlock); + if (cm_id_priv->av.port) { + spin_lock(&cm_id_priv->av.port->cm_dev->unregistration_lock); + if (cm_id_priv->av.port->mad_agent) + hi_tid = ((u64)cm_id_priv->av.port->mad_agent->hi_tid) + << 32; + spin_unlock(&cm_id_priv->av.port->cm_dev->unregistration_lock); + } low_tid = (u64)cm_id_priv->id.local_id; return cpu_to_be64(hi_tid | low_tid); @@ -1879,12 +1876,10 @@ static void cm_format_req_event(struct cm_work *work, param = &work->cm_event.param.req_rcvd; param->listen_id = listen_id; param->bth_pkey = cm_get_bth_pkey(work); - read_lock(&cm_id_priv->av_rwlock); if (cm_id_priv->av.port) param->port = cm_id_priv->av.port->port_num; else param->port = 0; - read_unlock(&cm_id_priv->av_rwlock); param->primary_path = &work->path[0]; cm_opa_to_ib_sgid(work, param->primary_path); if (cm_req_has_alt_path(req_msg)) { @@ -2311,13 +2306,11 @@ static void cm_format_rep(struct cm_rep_msg *rep_msg, IBA_SET(CM_REP_STARTING_PSN, rep_msg, param->starting_psn); IBA_SET(CM_REP_RESPONDER_RESOURCES, rep_msg, param->responder_resources); - read_lock(&cm_id_priv->av_rwlock); if (cm_id_priv->av.port && cm_id_priv->av.port->cm_dev) IBA_SET(CM_REP_TARGET_ACK_DELAY, rep_msg, cm_id_priv->av.port->cm_dev->ack_delay); else IBA_SET(CM_REP_TARGET_ACK_DELAY, rep_msg, 0); - read_unlock(&cm_id_priv->av_rwlock); IBA_SET(CM_REP_FAILOVER_ACCEPTED, rep_msg, param->failover_accepted); IBA_SET(CM_REP_RNR_RETRY_COUNT, rep_msg, param->rnr_retry_count); IBA_SET(CM_REP_LOCAL_CA_GUID, rep_msg, @@ -4187,10 +4180,8 @@ static int cm_init_qp_init_attr(struct cm_id_private *cm_id_priv, qp_attr->qp_access_flags |= IB_ACCESS_REMOTE_READ | IB_ACCESS_REMOTE_ATOMIC; qp_attr->pkey_index = cm_id_priv->av.pkey_index; - read_lock(&cm_id_priv->av_rwlock); qp_attr->port_num = cm_id_priv->av.port ? cm_id_priv->av.port->port_num : 0; - read_unlock(&cm_id_priv->av_rwlock); ret = 0; break; default: @@ -4234,10 +4225,8 @@ static int cm_init_qp_rtr_attr(struct cm_id_private *cm_id_priv, } if (rdma_ah_get_dlid(&cm_id_priv->alt_av.ah_attr)) { *qp_attr_mask |= IB_QP_ALT_PATH; - read_lock(&cm_id_priv->av_rwlock); qp_attr->alt_port_num = cm_id_priv->alt_av.port ? cm_id_priv->alt_av.port->port_num : 0; - read_unlock(&cm_id_priv->av_rwlock); qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index; qp_attr->alt_timeout = cm_id_priv->alt_av.timeout; qp_attr->alt_ah_attr = cm_id_priv->alt_av.ah_attr; @@ -4296,10 +4285,8 @@ static int cm_init_qp_rts_attr(struct cm_id_private *cm_id_priv, } } else { *qp_attr_mask = IB_QP_ALT_PATH | IB_QP_PATH_MIG_STATE; - read_lock(&cm_id_priv->av_rwlock); qp_attr->alt_port_num = cm_id_priv->alt_av.port ? cm_id_priv->alt_av.port->port_num : 0; - read_unlock(&cm_id_priv->av_rwlock); qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index; qp_attr->alt_timeout = cm_id_priv->alt_av.timeout; qp_attr->alt_ah_attr = cm_id_priv->alt_av.ah_attr; @@ -4417,9 +4404,11 @@ static int cm_add_one(struct ib_device *ib_device) if (!cm_dev) return -ENOMEM; + kref_init(&cm_dev->kref); cm_dev->ib_device = ib_device; cm_dev->ack_delay = ib_device->attrs.local_ca_ack_delay; cm_dev->going_down = 0; + cm_dev->num_ports = ib_device->phys_port_cnt; INIT_LIST_HEAD(&cm_dev->cm_id_priv_list); set_bit(IB_MGMT_METHOD_SEND, reg_req.method_mask); @@ -4489,10 +4478,9 @@ static int cm_add_one(struct ib_device *ib_device) ib_modify_port(ib_device, port->port_num, 0, &port_modify); ib_unregister_mad_agent(port->mad_agent); cm_remove_port_fs(port); - kfree(port); } free: - kfree(cm_dev); + cm_device_put(cm_dev); return ret; } @@ -4515,21 +4503,15 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data) cm_dev->going_down = 1; spin_unlock_irq(&cm.lock); - list_for_each_entry_safe(cm_id_priv, tmp, - &cm_dev->cm_id_priv_list, cm_dev_list) { - write_lock(&cm_id_priv->av_rwlock); - if (!list_empty(&cm_id_priv->cm_dev_list)) - list_del(&cm_id_priv->cm_dev_list); - cm_id_priv->av.port = NULL; - cm_id_priv->alt_av.port = NULL; - write_unlock(&cm_id_priv->av_rwlock); - } - rdma_for_each_port (ib_device, i) { + struct ib_mad_agent *mad_agent; + if (!rdma_cap_ib_cm(ib_device, i)) continue; port = cm_dev->port[i-1]; + mad_agent = port->mad_agent; + ib_modify_port(ib_device, port->port_num, 0, &port_modify); /* * We flush the queue here after the going_down set, this @@ -4537,12 +4519,20 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data) * after that we can call the unregister_mad_agent */ flush_workqueue(cm.wq); - ib_unregister_mad_agent(port->mad_agent); + /* + * The above ensures no call paths from the work are running, + * the remaining paths all take the unregistration lock + */ + spin_lock(&cm_dev->unregistration_lock); + port->mad_agent = NULL; + spin_unlock(&cm_dev->unregistration_lock); + ib_unregister_mad_agent(mad_agent); cm_remove_port_fs(port); - kfree(port); } - kfree(cm_dev); + /* All touches can only be on call path from the work */ + cm_dev->ib_device = NULL; + cm_device_put(cm_dev); } static int __init ib_cm_init(void)