[PATCH rdma-next 5/6] IB/cm: Add lock protection when access av/alt_av's port of a cm_id

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



From: Mark Zhang <markzhang@xxxxxxxxxx>

Add a rwlock protection when access the av/alt_av's port pointer.

Signed-off-by: Mark Zhang <markzhang@xxxxxxxxxx>
Signed-off-by: Leon Romanovsky <leonro@xxxxxxxxxx>
---
 drivers/infiniband/core/cm.c | 129 +++++++++++++++++++++++++++--------
 1 file changed, 102 insertions(+), 27 deletions(-)

diff --git a/drivers/infiniband/core/cm.c b/drivers/infiniband/core/cm.c
index 964b7a98ad5e..4e6f4c6d7b2a 100644
--- a/drivers/infiniband/core/cm.c
+++ b/drivers/infiniband/core/cm.c
@@ -261,6 +261,7 @@ struct cm_id_private {
 	/* todo: use alternate port on send failure */
 	struct cm_av av;
 	struct cm_av alt_av;
+	rwlock_t av_rwlock;	/* Do not acquire inside cm.lock */
 
 	void *private_data;
 	__be64 tid;
@@ -303,17 +304,33 @@ static int cm_alloc_msg(struct cm_id_private *cm_id_priv,
 	struct ib_mad_agent *mad_agent;
 	struct ib_mad_send_buf *m;
 	struct ib_ah *ah;
+	int ret;
+
+	read_lock(&cm_id_priv->av_rwlock);
+	if (!cm_id_priv->av.port) {
+		ret = -EINVAL;
+		goto out;
+	}
 
 	mad_agent = cm_id_priv->av.port->mad_agent;
+	if (!mad_agent) {
+		ret = -EINVAL;
+		goto out;
+	}
+
 	ah = rdma_create_ah(mad_agent->qp->pd, &cm_id_priv->av.ah_attr, 0);
-	if (IS_ERR(ah))
-		return PTR_ERR(ah);
+	if (IS_ERR(ah)) {
+		ret = PTR_ERR(ah);
+		goto out;
+	}
 
 	m = ib_create_send_mad(mad_agent, cm_id_priv->id.remote_cm_qpn,
 			       cm_id_priv->av.pkey_index,
 			       0, IB_MGMT_MAD_HDR, IB_MGMT_MAD_DATA,
 			       GFP_ATOMIC,
 			       IB_MGMT_BASE_VERSION);
+
+	read_unlock(&cm_id_priv->av_rwlock);
 	if (IS_ERR(m)) {
 		rdma_destroy_ah(ah, 0);
 		return PTR_ERR(m);
@@ -327,6 +344,10 @@ static int cm_alloc_msg(struct cm_id_private *cm_id_priv,
 	m->context[0] = cm_id_priv;
 	*msg = m;
 	return 0;
+
+out:
+	read_unlock(&cm_id_priv->av_rwlock);
+	return ret;
 }
 
 static struct ib_mad_send_buf *cm_alloc_response_msg_no_ah(struct cm_port *port,
@@ -433,9 +454,6 @@ static int cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc,
 	struct rdma_ah_attr new_ah_attr;
 	int ret;
 
-	av->port = port;
-	av->pkey_index = wc->pkey_index;
-
 	/*
 	 * av->ah_attr might be initialized based on past wc during incoming
 	 * connect request or while sending out connect request. So initialize
@@ -449,7 +467,11 @@ static int cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc,
 	if (ret)
 		return ret;
 
+	write_lock(&cm_id_priv->av_rwlock);
+	av->port = port;
+	av->pkey_index = wc->pkey_index;
 	add_cm_id_to_cm_dev_list(cm_id_priv, port->cm_dev);
+	write_unlock(&cm_id_priv->av_rwlock);
 
 	rdma_move_ah_attr(&av->ah_attr, &new_ah_attr);
 	return 0;
@@ -461,8 +483,10 @@ static int cm_init_av_for_response(struct cm_port *port, struct ib_wc *wc,
 {
 	struct cm_av *av = &cm_id_priv->av;
 
+	write_lock(&cm_id_priv->av_rwlock);
 	av->port = port;
 	add_cm_id_to_cm_dev_list(cm_id_priv, port->cm_dev);
+	write_unlock(&cm_id_priv->av_rwlock);
 	av->pkey_index = wc->pkey_index;
 	return ib_init_ah_attr_from_wc(port->cm_dev->ib_device,
 				       port->port_num, wc,
@@ -519,15 +543,21 @@ static int cm_init_av_by_path(struct sa_path_rec *path,
 	struct cm_device *cm_dev;
 	struct cm_port *port;
 	struct cm_av *av;
-	int ret;
+	int ret = 0;
 
 	port = get_cm_port_from_path(path, sgid_attr);
 	if (!port)
 		return -EINVAL;
 	cm_dev = port->cm_dev;
 
-	if (!is_priv_av && cm_dev != cm_id_priv->av.port->cm_dev)
-		return -EINVAL;
+	read_lock(&cm_id_priv->av_rwlock);
+	if (!is_priv_av &&
+	    (!cm_id_priv->av.port || cm_dev != cm_id_priv->av.port->cm_dev))
+		ret = -EINVAL;
+
+	read_unlock(&cm_id_priv->av_rwlock);
+	if (ret)
+		return ret;
 
 	av = is_priv_av ? &cm_id_priv->av : &cm_id_priv->alt_av;
 
@@ -536,8 +566,6 @@ static int cm_init_av_by_path(struct sa_path_rec *path,
 	if (ret)
 		return ret;
 
-	av->port = port;
-
 	/*
 	 * av->ah_attr might be initialized based on wc or during
 	 * request processing time which might have reference to sgid_attr.
@@ -552,11 +580,15 @@ static int cm_init_av_by_path(struct sa_path_rec *path,
 	if (ret)
 		return ret;
 
+	write_lock(&cm_id_priv->av_rwlock);
+	av->port = port;
 	av->timeout = path->packet_life_time + 1;
-	rdma_move_ah_attr(&av->ah_attr, &new_ah_attr);
 	if (is_priv_av)
 		add_cm_id_to_cm_dev_list(cm_id_priv, cm_dev);
 
+	write_unlock(&cm_id_priv->av_rwlock);
+
+	rdma_move_ah_attr(&av->ah_attr, &new_ah_attr);
 	return 0;
 }
 
@@ -838,6 +870,7 @@ static struct cm_id_private *cm_alloc_id_priv(struct ib_device *device,
 	INIT_LIST_HEAD(&cm_id_priv->cm_dev_list);
 	atomic_set(&cm_id_priv->work_count, -1);
 	refcount_set(&cm_id_priv->refcount, 1);
+	rwlock_init(&cm_id_priv->av_rwlock);
 
 	ret = xa_alloc_cyclic(&cm.local_id_table, &id, NULL, xa_limit_32b,
 			      &cm.local_id_next, GFP_KERNEL);
@@ -951,6 +984,26 @@ static u8 cm_ack_timeout(u8 ca_ack_delay, u8 packet_life_time)
 	return min(31, ack_timeout);
 }
 
+static u8 cm_ack_timeout_req(struct cm_id_private *cm_id_priv,
+			     u8 packet_life_time)
+{
+	u8 ack_delay = 0;
+
+	read_lock(&cm_id_priv->av_rwlock);
+	if (cm_id_priv->av.port && cm_id_priv->av.port->cm_dev)
+		ack_delay = cm_id_priv->av.port->cm_dev->ack_delay;
+	read_unlock(&cm_id_priv->av_rwlock);
+
+	return cm_ack_timeout(ack_delay, packet_life_time);
+}
+
+static u8 cm_ack_timeout_rep(struct cm_id_private *cm_id_priv,
+			     u8 packet_life_time)
+{
+	return cm_ack_timeout(cm_id_priv->target_ack_delay,
+			      packet_life_time);
+}
+
 static void cm_remove_remote(struct cm_id_private *cm_id_priv)
 {
 	struct cm_timewait_info *timewait_info = cm_id_priv->timewait_info;
@@ -1285,9 +1338,13 @@ EXPORT_SYMBOL(ib_cm_insert_listen);
 
 static __be64 cm_form_tid(struct cm_id_private *cm_id_priv)
 {
-	u64 hi_tid, low_tid;
+	u64 hi_tid = 0, low_tid;
+
+	read_lock(&cm_id_priv->av_rwlock);
+	if (cm_id_priv->av.port && cm_id_priv->av.port->mad_agent)
+		hi_tid = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32;
+	read_unlock(&cm_id_priv->av_rwlock);
 
-	hi_tid   = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32;
 	low_tid  = (u64)cm_id_priv->id.local_id;
 	return cpu_to_be64(hi_tid | low_tid);
 }
@@ -1391,8 +1448,7 @@ static void cm_format_req(struct cm_req_msg *req_msg,
 	IBA_SET(CM_REQ_PRIMARY_SUBNET_LOCAL, req_msg,
 		(pri_path->hop_limit <= 1));
 	IBA_SET(CM_REQ_PRIMARY_LOCAL_ACK_TIMEOUT, req_msg,
-		cm_ack_timeout(cm_id_priv->av.port->cm_dev->ack_delay,
-			       pri_path->packet_life_time));
+		cm_ack_timeout_req(cm_id_priv, pri_path->packet_life_time));
 
 	if (alt_path) {
 		bool alt_ext = false;
@@ -1443,8 +1499,8 @@ static void cm_format_req(struct cm_req_msg *req_msg,
 		IBA_SET(CM_REQ_ALTERNATE_SUBNET_LOCAL, req_msg,
 			(alt_path->hop_limit <= 1));
 		IBA_SET(CM_REQ_ALTERNATE_LOCAL_ACK_TIMEOUT, req_msg,
-			cm_ack_timeout(cm_id_priv->av.port->cm_dev->ack_delay,
-				       alt_path->packet_life_time));
+			cm_ack_timeout_req(cm_id_priv,
+					   alt_path->packet_life_time));
 	}
 	IBA_SET(CM_REQ_VENDOR_ID, req_msg, param->ece.vendor_id);
 
@@ -1781,7 +1837,12 @@ static void cm_format_req_event(struct cm_work *work,
 	param = &work->cm_event.param.req_rcvd;
 	param->listen_id = listen_id;
 	param->bth_pkey = cm_get_bth_pkey(work);
-	param->port = cm_id_priv->av.port->port_num;
+	read_lock(&cm_id_priv->av_rwlock);
+	if (cm_id_priv->av.port)
+		param->port = cm_id_priv->av.port->port_num;
+	else
+		param->port = 0;
+	read_unlock(&cm_id_priv->av_rwlock);
 	param->primary_path = &work->path[0];
 	cm_opa_to_ib_sgid(work, param->primary_path);
 	if (cm_req_has_alt_path(req_msg)) {
@@ -2211,8 +2272,13 @@ static void cm_format_rep(struct cm_rep_msg *rep_msg,
 	IBA_SET(CM_REP_STARTING_PSN, rep_msg, param->starting_psn);
 	IBA_SET(CM_REP_RESPONDER_RESOURCES, rep_msg,
 		param->responder_resources);
-	IBA_SET(CM_REP_TARGET_ACK_DELAY, rep_msg,
-		cm_id_priv->av.port->cm_dev->ack_delay);
+	read_lock(&cm_id_priv->av_rwlock);
+	if (cm_id_priv->av.port && cm_id_priv->av.port->cm_dev)
+		IBA_SET(CM_REP_TARGET_ACK_DELAY, rep_msg,
+			cm_id_priv->av.port->cm_dev->ack_delay);
+	else
+		IBA_SET(CM_REP_TARGET_ACK_DELAY, rep_msg, 0);
+	read_unlock(&cm_id_priv->av_rwlock);
 	IBA_SET(CM_REP_FAILOVER_ACCEPTED, rep_msg, param->failover_accepted);
 	IBA_SET(CM_REP_RNR_RETRY_COUNT, rep_msg, param->rnr_retry_count);
 	IBA_SET(CM_REP_LOCAL_CA_GUID, rep_msg,
@@ -2525,11 +2591,9 @@ static int cm_rep_handler(struct cm_work *work)
 	cm_id_priv->target_ack_delay =
 		IBA_GET(CM_REP_TARGET_ACK_DELAY, rep_msg);
 	cm_id_priv->av.timeout =
-			cm_ack_timeout(cm_id_priv->target_ack_delay,
-				       cm_id_priv->av.timeout - 1);
+		cm_ack_timeout_rep(cm_id_priv, cm_id_priv->av.timeout - 1);
 	cm_id_priv->alt_av.timeout =
-			cm_ack_timeout(cm_id_priv->target_ack_delay,
-				       cm_id_priv->alt_av.timeout - 1);
+		cm_ack_timeout_rep(cm_id_priv, cm_id_priv->alt_av.timeout - 1);
 
 	ib_cancel_mad(cm_id_priv->msg);
 	cm_queue_work_unlock(cm_id_priv, work);
@@ -4097,7 +4161,10 @@ static int cm_init_qp_init_attr(struct cm_id_private *cm_id_priv,
 			qp_attr->qp_access_flags |= IB_ACCESS_REMOTE_READ |
 						    IB_ACCESS_REMOTE_ATOMIC;
 		qp_attr->pkey_index = cm_id_priv->av.pkey_index;
-		qp_attr->port_num = cm_id_priv->av.port->port_num;
+		read_lock(&cm_id_priv->av_rwlock);
+		qp_attr->port_num = cm_id_priv->av.port ?
+			cm_id_priv->av.port->port_num : 0;
+		read_unlock(&cm_id_priv->av_rwlock);
 		ret = 0;
 		break;
 	default:
@@ -4141,7 +4208,10 @@ static int cm_init_qp_rtr_attr(struct cm_id_private *cm_id_priv,
 		}
 		if (rdma_ah_get_dlid(&cm_id_priv->alt_av.ah_attr)) {
 			*qp_attr_mask |= IB_QP_ALT_PATH;
-			qp_attr->alt_port_num = cm_id_priv->alt_av.port->port_num;
+			read_lock(&cm_id_priv->av_rwlock);
+			qp_attr->alt_port_num = cm_id_priv->alt_av.port ?
+				cm_id_priv->alt_av.port->port_num : 0;
+			read_unlock(&cm_id_priv->av_rwlock);
 			qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index;
 			qp_attr->alt_timeout = cm_id_priv->alt_av.timeout;
 			qp_attr->alt_ah_attr = cm_id_priv->alt_av.ah_attr;
@@ -4200,7 +4270,10 @@ static int cm_init_qp_rts_attr(struct cm_id_private *cm_id_priv,
 			}
 		} else {
 			*qp_attr_mask = IB_QP_ALT_PATH | IB_QP_PATH_MIG_STATE;
-			qp_attr->alt_port_num = cm_id_priv->alt_av.port->port_num;
+			read_lock(&cm_id_priv->av_rwlock);
+			qp_attr->alt_port_num = cm_id_priv->alt_av.port ?
+				cm_id_priv->alt_av.port->port_num : 0;
+			read_unlock(&cm_id_priv->av_rwlock);
 			qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index;
 			qp_attr->alt_timeout = cm_id_priv->alt_av.timeout;
 			qp_attr->alt_ah_attr = cm_id_priv->alt_av.ah_attr;
@@ -4418,10 +4491,12 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data)
 
 	list_for_each_entry_safe(cm_id_priv, tmp,
 				 &cm_dev->cm_id_priv_list, cm_dev_list) {
+		write_lock(&cm_id_priv->av_rwlock);
 		if (!list_empty(&cm_id_priv->cm_dev_list))
 			list_del(&cm_id_priv->cm_dev_list);
 		cm_id_priv->av.port = NULL;
 		cm_id_priv->alt_av.port = NULL;
+		write_unlock(&cm_id_priv->av_rwlock);
 	}
 
 	rdma_for_each_port (ib_device, i) {
-- 
2.30.2




[Index of Archives]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Photo]     [Yosemite News]     [Yosemite Photos]     [Linux Kernel]     [Linux SCSI]     [XFree86]

  Powered by Linux