From: Mark Zhang <markzhang@xxxxxxxxxx> Add a rwlock protection when access the av/alt_av's port pointer. Signed-off-by: Mark Zhang <markzhang@xxxxxxxxxx> Signed-off-by: Leon Romanovsky <leonro@xxxxxxxxxx> --- drivers/infiniband/core/cm.c | 134 +++++++++++++++++++++++++++-------- 1 file changed, 106 insertions(+), 28 deletions(-) diff --git a/drivers/infiniband/core/cm.c b/drivers/infiniband/core/cm.c index f1a24492924f..28eb8a5ee54e 100644 --- a/drivers/infiniband/core/cm.c +++ b/drivers/infiniband/core/cm.c @@ -262,6 +262,7 @@ struct cm_id_private { /* todo: use alternate port on send failure */ struct cm_av av; struct cm_av alt_av; + rwlock_t av_rwlock; /* Do not acquire inside cm.lock */ void *private_data; __be64 tid; @@ -303,20 +304,37 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv) struct ib_mad_agent *mad_agent; struct ib_mad_send_buf *m; struct ib_ah *ah; + int ret; + + read_lock(&cm_id_priv->av_rwlock); + if (!cm_id_priv->av.port) { + ret = -EINVAL; + goto out; + } mad_agent = cm_id_priv->av.port->mad_agent; + if (!mad_agent) { + ret = -EINVAL; + goto out; + } + ah = rdma_create_ah(mad_agent->qp->pd, &cm_id_priv->av.ah_attr, 0); - if (IS_ERR(ah)) - return (void *)ah; + if (IS_ERR(ah)) { + ret = PTR_ERR(ah); + goto out; + } m = ib_create_send_mad(mad_agent, cm_id_priv->id.remote_cm_qpn, cm_id_priv->av.pkey_index, 0, IB_MGMT_MAD_HDR, IB_MGMT_MAD_DATA, GFP_ATOMIC, IB_MGMT_BASE_VERSION); + + read_unlock(&cm_id_priv->av_rwlock); if (IS_ERR(m)) { rdma_destroy_ah(ah, 0); - return m; + ret = PTR_ERR(m); + goto out; } /* Timeout set by caller if response is expected. */ @@ -326,6 +344,10 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv) refcount_inc(&cm_id_priv->refcount); m->context[0] = cm_id_priv; return m; + +out: + read_unlock(&cm_id_priv->av_rwlock); + return ERR_PTR(ret); } static void cm_free_msg(struct ib_mad_send_buf *msg) @@ -455,7 +477,6 @@ static void add_cm_id_to_cm_dev_list(struct cm_id_private *cm_id_priv, if (!list_empty(&cm_id_priv->cm_dev_list)) list_del(&cm_id_priv->cm_dev_list); list_add_tail(&cm_id_priv->cm_dev_list, &cm_dev->cm_id_priv_list); - out: spin_unlock_irqrestore(&cm.lock, flags); } @@ -468,8 +489,8 @@ static int cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc, struct rdma_ah_attr new_ah_attr; int ret; - av->port = port; - av->pkey_index = wc->pkey_index; + if (!port) + return -EINVAL; /* * av->ah_attr might be initialized based on past wc during incoming @@ -484,7 +505,11 @@ static int cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc, if (ret) return ret; + write_lock(&cm_id_priv->av_rwlock); + av->port = port; + av->pkey_index = wc->pkey_index; add_cm_id_to_cm_dev_list(cm_id_priv, port->cm_dev); + write_unlock(&cm_id_priv->av_rwlock); rdma_move_ah_attr(&av->ah_attr, &new_ah_attr); return 0; @@ -496,8 +521,10 @@ static int cm_init_av_for_response(struct cm_port *port, struct ib_wc *wc, { struct cm_av *av = &cm_id_priv->av; + write_lock(&cm_id_priv->av_rwlock); av->port = port; add_cm_id_to_cm_dev_list(cm_id_priv, port->cm_dev); + write_unlock(&cm_id_priv->av_rwlock); av->pkey_index = wc->pkey_index; return ib_init_ah_attr_from_wc(port->cm_dev->ib_device, port->port_num, wc, @@ -554,15 +581,21 @@ static int cm_init_av_by_path(struct sa_path_rec *path, struct cm_device *cm_dev; struct cm_port *port; struct cm_av *av; - int ret; + int ret = 0; port = get_cm_port_from_path(path, sgid_attr); if (!port) return -EINVAL; cm_dev = port->cm_dev; - if (!is_priv_av && cm_dev != cm_id_priv->av.port->cm_dev) - return -EINVAL; + read_lock(&cm_id_priv->av_rwlock); + if (!is_priv_av && + (!cm_id_priv->av.port || cm_dev != cm_id_priv->av.port->cm_dev)) + ret = -EINVAL; + + read_unlock(&cm_id_priv->av_rwlock); + if (ret) + return ret; av = is_priv_av ? &cm_id_priv->av : &cm_id_priv->alt_av; @@ -571,8 +604,6 @@ static int cm_init_av_by_path(struct sa_path_rec *path, if (ret) return ret; - av->port = port; - /* * av->ah_attr might be initialized based on wc or during * request processing time which might have reference to sgid_attr. @@ -587,11 +618,15 @@ static int cm_init_av_by_path(struct sa_path_rec *path, if (ret) return ret; + write_lock(&cm_id_priv->av_rwlock); + av->port = port; av->timeout = path->packet_life_time + 1; - rdma_move_ah_attr(&av->ah_attr, &new_ah_attr); if (is_priv_av) add_cm_id_to_cm_dev_list(cm_id_priv, cm_dev); + write_unlock(&cm_id_priv->av_rwlock); + + rdma_move_ah_attr(&av->ah_attr, &new_ah_attr); return 0; } @@ -873,6 +908,7 @@ static struct cm_id_private *cm_alloc_id_priv(struct ib_device *device, INIT_LIST_HEAD(&cm_id_priv->cm_dev_list); atomic_set(&cm_id_priv->work_count, -1); refcount_set(&cm_id_priv->refcount, 1); + rwlock_init(&cm_id_priv->av_rwlock); ret = xa_alloc_cyclic(&cm.local_id_table, &id, NULL, xa_limit_32b, &cm.local_id_next, GFP_KERNEL); @@ -986,6 +1022,26 @@ static u8 cm_ack_timeout(u8 ca_ack_delay, u8 packet_life_time) return min(31, ack_timeout); } +static u8 cm_ack_timeout_req(struct cm_id_private *cm_id_priv, + u8 packet_life_time) +{ + u8 ack_delay = 0; + + read_lock(&cm_id_priv->av_rwlock); + if (cm_id_priv->av.port && cm_id_priv->av.port->cm_dev) + ack_delay = cm_id_priv->av.port->cm_dev->ack_delay; + read_unlock(&cm_id_priv->av_rwlock); + + return cm_ack_timeout(ack_delay, packet_life_time); +} + +static u8 cm_ack_timeout_rep(struct cm_id_private *cm_id_priv, + u8 packet_life_time) +{ + return cm_ack_timeout(cm_id_priv->target_ack_delay, + packet_life_time); +} + static void cm_remove_remote(struct cm_id_private *cm_id_priv) { struct cm_timewait_info *timewait_info = cm_id_priv->timewait_info; @@ -1320,9 +1376,13 @@ EXPORT_SYMBOL(ib_cm_insert_listen); static __be64 cm_form_tid(struct cm_id_private *cm_id_priv) { - u64 hi_tid, low_tid; + u64 hi_tid = 0, low_tid; + + read_lock(&cm_id_priv->av_rwlock); + if (cm_id_priv->av.port && cm_id_priv->av.port->mad_agent) + hi_tid = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32; + read_unlock(&cm_id_priv->av_rwlock); - hi_tid = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32; low_tid = (u64)cm_id_priv->id.local_id; return cpu_to_be64(hi_tid | low_tid); } @@ -1426,8 +1486,7 @@ static void cm_format_req(struct cm_req_msg *req_msg, IBA_SET(CM_REQ_PRIMARY_SUBNET_LOCAL, req_msg, (pri_path->hop_limit <= 1)); IBA_SET(CM_REQ_PRIMARY_LOCAL_ACK_TIMEOUT, req_msg, - cm_ack_timeout(cm_id_priv->av.port->cm_dev->ack_delay, - pri_path->packet_life_time)); + cm_ack_timeout_req(cm_id_priv, pri_path->packet_life_time)); if (alt_path) { bool alt_ext = false; @@ -1478,8 +1537,8 @@ static void cm_format_req(struct cm_req_msg *req_msg, IBA_SET(CM_REQ_ALTERNATE_SUBNET_LOCAL, req_msg, (alt_path->hop_limit <= 1)); IBA_SET(CM_REQ_ALTERNATE_LOCAL_ACK_TIMEOUT, req_msg, - cm_ack_timeout(cm_id_priv->av.port->cm_dev->ack_delay, - alt_path->packet_life_time)); + cm_ack_timeout_req(cm_id_priv, + alt_path->packet_life_time)); } IBA_SET(CM_REQ_VENDOR_ID, req_msg, param->ece.vendor_id); @@ -1820,7 +1879,12 @@ static void cm_format_req_event(struct cm_work *work, param = &work->cm_event.param.req_rcvd; param->listen_id = listen_id; param->bth_pkey = cm_get_bth_pkey(work); - param->port = cm_id_priv->av.port->port_num; + read_lock(&cm_id_priv->av_rwlock); + if (cm_id_priv->av.port) + param->port = cm_id_priv->av.port->port_num; + else + param->port = 0; + read_unlock(&cm_id_priv->av_rwlock); param->primary_path = &work->path[0]; cm_opa_to_ib_sgid(work, param->primary_path); if (cm_req_has_alt_path(req_msg)) { @@ -2247,8 +2311,13 @@ static void cm_format_rep(struct cm_rep_msg *rep_msg, IBA_SET(CM_REP_STARTING_PSN, rep_msg, param->starting_psn); IBA_SET(CM_REP_RESPONDER_RESOURCES, rep_msg, param->responder_resources); - IBA_SET(CM_REP_TARGET_ACK_DELAY, rep_msg, - cm_id_priv->av.port->cm_dev->ack_delay); + read_lock(&cm_id_priv->av_rwlock); + if (cm_id_priv->av.port && cm_id_priv->av.port->cm_dev) + IBA_SET(CM_REP_TARGET_ACK_DELAY, rep_msg, + cm_id_priv->av.port->cm_dev->ack_delay); + else + IBA_SET(CM_REP_TARGET_ACK_DELAY, rep_msg, 0); + read_unlock(&cm_id_priv->av_rwlock); IBA_SET(CM_REP_FAILOVER_ACCEPTED, rep_msg, param->failover_accepted); IBA_SET(CM_REP_RNR_RETRY_COUNT, rep_msg, param->rnr_retry_count); IBA_SET(CM_REP_LOCAL_CA_GUID, rep_msg, @@ -2566,11 +2635,9 @@ static int cm_rep_handler(struct cm_work *work) cm_id_priv->target_ack_delay = IBA_GET(CM_REP_TARGET_ACK_DELAY, rep_msg); cm_id_priv->av.timeout = - cm_ack_timeout(cm_id_priv->target_ack_delay, - cm_id_priv->av.timeout - 1); + cm_ack_timeout_rep(cm_id_priv, cm_id_priv->av.timeout - 1); cm_id_priv->alt_av.timeout = - cm_ack_timeout(cm_id_priv->target_ack_delay, - cm_id_priv->alt_av.timeout - 1); + cm_ack_timeout_rep(cm_id_priv, cm_id_priv->alt_av.timeout - 1); ib_cancel_mad(cm_id_priv->msg); cm_queue_work_unlock(cm_id_priv, work); @@ -4120,7 +4187,10 @@ static int cm_init_qp_init_attr(struct cm_id_private *cm_id_priv, qp_attr->qp_access_flags |= IB_ACCESS_REMOTE_READ | IB_ACCESS_REMOTE_ATOMIC; qp_attr->pkey_index = cm_id_priv->av.pkey_index; - qp_attr->port_num = cm_id_priv->av.port->port_num; + read_lock(&cm_id_priv->av_rwlock); + qp_attr->port_num = cm_id_priv->av.port ? + cm_id_priv->av.port->port_num : 0; + read_unlock(&cm_id_priv->av_rwlock); ret = 0; break; default: @@ -4164,7 +4234,10 @@ static int cm_init_qp_rtr_attr(struct cm_id_private *cm_id_priv, } if (rdma_ah_get_dlid(&cm_id_priv->alt_av.ah_attr)) { *qp_attr_mask |= IB_QP_ALT_PATH; - qp_attr->alt_port_num = cm_id_priv->alt_av.port->port_num; + read_lock(&cm_id_priv->av_rwlock); + qp_attr->alt_port_num = cm_id_priv->alt_av.port ? + cm_id_priv->alt_av.port->port_num : 0; + read_unlock(&cm_id_priv->av_rwlock); qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index; qp_attr->alt_timeout = cm_id_priv->alt_av.timeout; qp_attr->alt_ah_attr = cm_id_priv->alt_av.ah_attr; @@ -4223,7 +4296,10 @@ static int cm_init_qp_rts_attr(struct cm_id_private *cm_id_priv, } } else { *qp_attr_mask = IB_QP_ALT_PATH | IB_QP_PATH_MIG_STATE; - qp_attr->alt_port_num = cm_id_priv->alt_av.port->port_num; + read_lock(&cm_id_priv->av_rwlock); + qp_attr->alt_port_num = cm_id_priv->alt_av.port ? + cm_id_priv->alt_av.port->port_num : 0; + read_unlock(&cm_id_priv->av_rwlock); qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index; qp_attr->alt_timeout = cm_id_priv->alt_av.timeout; qp_attr->alt_ah_attr = cm_id_priv->alt_av.ah_attr; @@ -4441,10 +4517,12 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data) list_for_each_entry_safe(cm_id_priv, tmp, &cm_dev->cm_id_priv_list, cm_dev_list) { + write_lock(&cm_id_priv->av_rwlock); if (!list_empty(&cm_id_priv->cm_dev_list)) list_del(&cm_id_priv->cm_dev_list); cm_id_priv->av.port = NULL; cm_id_priv->alt_av.port = NULL; + write_unlock(&cm_id_priv->av_rwlock); } rdma_for_each_port (ib_device, i) { -- 2.30.2