Re: [PATCH 4/5] staging: vc04_services: use kref + RCU to reference count services

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Wed, Feb 12, 2020 at 01:43:32PM -0500, Marcelo Diop-Gonzalez wrote:
> Currently reference counts are implemented by locking service_spinlock
> and then incrementing the service's ->ref_count field, calling
> kfree() when the last reference has been dropped. But at the same
> time, there's code in multiple places that dereferences pointers
> to services without having a reference, so there could be a race there.
> 
> It should be possible to avoid taking any lock in unlock_service()
> or service_release() because we are setting a single array element
> to NULL, and on service creation, a mutex is locked before looking
> for a NULL spot to put the new service in.
> 
> Using a struct kref and RCU-delaying the freeing of services fixes
> this race condition while still making it possible to skip
> grabbing a reference in many places. Also it avoids the need to
> acquire a single spinlock when e.g. taking a reference on
> state->services[i] when somebody else is in the middle of taking
> a reference on state->services[j].
> 
> Signed-off-by: Marcelo Diop-Gonzalez <marcgonzalez@xxxxxxxxxx>
> ---
>  .../interface/vchiq_arm/vchiq_arm.c           |  25 +-
>  .../interface/vchiq_arm/vchiq_core.c          | 222 +++++++++---------
>  .../interface/vchiq_arm/vchiq_core.h          |  12 +-
>  3 files changed, 140 insertions(+), 119 deletions(-)
> 
> diff --git a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c
> index c456ced431af..3ed0e4ea7f5c 100644
> --- a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c
> +++ b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c
> @@ -22,6 +22,7 @@
>  #include <linux/platform_device.h>
>  #include <linux/compat.h>
>  #include <linux/dma-mapping.h>
> +#include <linux/rcupdate.h>
>  #include <soc/bcm2835/raspberrypi-firmware.h>
>  
>  #include "vchiq_core.h"
> @@ -2096,10 +2097,12 @@ int vchiq_dump_platform_instances(void *dump_context)
>  	/* There is no list of instances, so instead scan all services,
>  		marking those that have been dumped. */
>  
> +	rcu_read_lock();
>  	for (i = 0; i < state->unused_service; i++) {
> -		struct vchiq_service *service = state->services[i];
> +		struct vchiq_service *service;
>  		struct vchiq_instance *instance;
>  
> +		service = rcu_dereference(state->services[i]);
>  		if (!service || service->base.callback != service_callback)
>  			continue;
>  
> @@ -2107,18 +2110,26 @@ int vchiq_dump_platform_instances(void *dump_context)
>  		if (instance)
>  			instance->mark = 0;
>  	}
> +	rcu_read_unlock();
>  
>  	for (i = 0; i < state->unused_service; i++) {
> -		struct vchiq_service *service = state->services[i];
> +		struct vchiq_service *service;
>  		struct vchiq_instance *instance;
>  		int err;
>  
> -		if (!service || service->base.callback != service_callback)
> +		rcu_read_lock();
> +		service = rcu_dereference(state->services[i]);
> +		if (!service || service->base.callback != service_callback) {
> +			rcu_read_unlock();
>  			continue;
> +		}
>  
>  		instance = service->instance;
> -		if (!instance || instance->mark)
> +		if (!instance || instance->mark) {
> +			rcu_read_unlock();
>  			continue;
> +		}
> +		rcu_read_unlock();
>  
>  		len = snprintf(buf, sizeof(buf),
>  			       "Instance %pK: pid %d,%s completions %d/%d",
> @@ -2128,7 +2139,6 @@ int vchiq_dump_platform_instances(void *dump_context)
>  			       instance->completion_insert -
>  			       instance->completion_remove,
>  			       MAX_COMPLETIONS);
> -
>  		err = vchiq_dump(dump_context, buf, len + 1);
>  		if (err)
>  			return err;
> @@ -2585,8 +2595,10 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
>  	if (active_services > MAX_SERVICES)
>  		only_nonzero = 1;
>  
> +	rcu_read_lock();
>  	for (i = 0; i < active_services; i++) {
> -		struct vchiq_service *service_ptr = state->services[i];
> +		struct vchiq_service *service_ptr =
> +			rcu_dereference(state->services[i]);
>  
>  		if (!service_ptr)
>  			continue;
> @@ -2604,6 +2616,7 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
>  		if (found >= MAX_SERVICES)
>  			break;
>  	}
> +	rcu_read_unlock();
>  
>  	read_unlock_bh(&arm_state->susp_res_lock);
>  
> diff --git a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c
> index b2d9013b7f79..65270a5b29db 100644
> --- a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c
> +++ b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c
> @@ -1,6 +1,9 @@
>  // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
>  /* Copyright (c) 2010-2012 Broadcom. All rights reserved. */
>  
> +#include <linux/kref.h>
> +#include <linux/rcupdate.h>
> +
>  #include "vchiq_core.h"
>  
>  #define VCHIQ_SLOT_HANDLER_STACK 8192
> @@ -54,7 +57,6 @@ int vchiq_core_log_level = VCHIQ_LOG_DEFAULT;
>  int vchiq_core_msg_log_level = VCHIQ_LOG_DEFAULT;
>  int vchiq_sync_log_level = VCHIQ_LOG_DEFAULT;
>  
> -static DEFINE_SPINLOCK(service_spinlock);
>  DEFINE_SPINLOCK(bulk_waiter_spinlock);
>  static DEFINE_SPINLOCK(quota_spinlock);
>  
> @@ -136,44 +138,41 @@ find_service_by_handle(unsigned int handle)
>  {
>  	struct vchiq_service *service;
>  
> -	spin_lock(&service_spinlock);
> +	rcu_read_lock();
>  	service = handle_to_service(handle);
>  	if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
> -	    service->handle == handle) {
> -		WARN_ON(service->ref_count == 0);
> -		service->ref_count++;
> -	} else
> -		service = NULL;
> -	spin_unlock(&service_spinlock);
> -
> -	if (!service)
> -		vchiq_log_info(vchiq_core_log_level,
> -			"Invalid service handle 0x%x", handle);
> -
> -	return service;
> +	    service->handle == handle &&
> +	    kref_get_unless_zero(&service->ref_count)) {
> +		service = rcu_pointer_handoff(service);
> +		rcu_read_unlock();
> +		return service;
> +	}
> +	rcu_read_unlock();
> +	vchiq_log_info(vchiq_core_log_level,
> +		       "Invalid service handle 0x%x", handle);
> +	return NULL;
>  }
>  
>  struct vchiq_service *
>  find_service_by_port(struct vchiq_state *state, int localport)
>  {
> -	struct vchiq_service *service = NULL;
>  
>  	if ((unsigned int)localport <= VCHIQ_PORT_MAX) {
> -		spin_lock(&service_spinlock);
> -		service = state->services[localport];
> -		if (service && service->srvstate != VCHIQ_SRVSTATE_FREE) {
> -			WARN_ON(service->ref_count == 0);
> -			service->ref_count++;
> -		} else
> -			service = NULL;
> -		spin_unlock(&service_spinlock);
> -	}
> -
> -	if (!service)
> -		vchiq_log_info(vchiq_core_log_level,
> -			"Invalid port %d", localport);
> +		struct vchiq_service *service;
>  
> -	return service;
> +		rcu_read_lock();
> +		service = rcu_dereference(state->services[localport]);
> +		if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
> +		    kref_get_unless_zero(&service->ref_count)) {
> +			service = rcu_pointer_handoff(service);
> +			rcu_read_unlock();
> +			return service;
> +		}
> +		rcu_read_unlock();
> +	}
> +	vchiq_log_info(vchiq_core_log_level,
> +		       "Invalid port %d", localport);
> +	return NULL;
>  }
>  
>  struct vchiq_service *
> @@ -182,22 +181,20 @@ find_service_for_instance(struct vchiq_instance *instance,
>  {
>  	struct vchiq_service *service;
>  
> -	spin_lock(&service_spinlock);
> +	rcu_read_lock();
>  	service = handle_to_service(handle);
>  	if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
>  	    service->handle == handle &&
> -	    service->instance == instance) {
> -		WARN_ON(service->ref_count == 0);
> -		service->ref_count++;
> -	} else
> -		service = NULL;
> -	spin_unlock(&service_spinlock);
> -
> -	if (!service)
> -		vchiq_log_info(vchiq_core_log_level,
> -			"Invalid service handle 0x%x", handle);
> -
> -	return service;
> +	    service->instance == instance &&
> +	    kref_get_unless_zero(&service->ref_count)) {
> +		service = rcu_pointer_handoff(service);
> +		rcu_read_unlock();
> +		return service;
> +	}
> +	rcu_read_unlock();
> +	vchiq_log_info(vchiq_core_log_level,
> +		       "Invalid service handle 0x%x", handle);
> +	return NULL;
>  }
>  
>  struct vchiq_service *
> @@ -206,23 +203,21 @@ find_closed_service_for_instance(struct vchiq_instance *instance,
>  {
>  	struct vchiq_service *service;
>  
> -	spin_lock(&service_spinlock);
> +	rcu_read_lock();
>  	service = handle_to_service(handle);
>  	if (service &&
>  	    (service->srvstate == VCHIQ_SRVSTATE_FREE ||
>  	     service->srvstate == VCHIQ_SRVSTATE_CLOSED) &&
>  	    service->handle == handle &&
> -	    service->instance == instance) {
> -		WARN_ON(service->ref_count == 0);
> -		service->ref_count++;
> -	} else
> -		service = NULL;
> -	spin_unlock(&service_spinlock);
> -
> -	if (!service)
> -		vchiq_log_info(vchiq_core_log_level,
> -			"Invalid service handle 0x%x", handle);
> -
> +	    service->instance == instance &&
> +	    kref_get_unless_zero(&service->ref_count)) {
> +		service = rcu_pointer_handoff(service);
> +		rcu_read_unlock();
> +		return service;
> +	}
> +	rcu_read_unlock();
> +	vchiq_log_info(vchiq_core_log_level,
> +		       "Invalid service handle 0x%x", handle);
>  	return service;
>  }
>  
> @@ -233,19 +228,19 @@ next_service_by_instance(struct vchiq_state *state, struct vchiq_instance *insta
>  	struct vchiq_service *service = NULL;
>  	int idx = *pidx;
>  
> -	spin_lock(&service_spinlock);
> +	rcu_read_lock();
>  	while (idx < state->unused_service) {
> -		struct vchiq_service *srv = state->services[idx++];
> +		struct vchiq_service *srv;
>  
> +		srv = rcu_dereference(state->services[idx++]);
>  		if (srv && srv->srvstate != VCHIQ_SRVSTATE_FREE &&
> -		    srv->instance == instance) {
> -			service = srv;
> -			WARN_ON(service->ref_count == 0);
> -			service->ref_count++;
> +		    srv->instance == instance &&
> +		    kref_get_unless_zero(&srv->ref_count)) {
> +			service = rcu_pointer_handoff(srv);
>  			break;
>  		}
>  	}
> -	spin_unlock(&service_spinlock);
> +	rcu_read_unlock();
>  
>  	*pidx = idx;
>  
> @@ -255,43 +250,34 @@ next_service_by_instance(struct vchiq_state *state, struct vchiq_instance *insta
>  void
>  lock_service(struct vchiq_service *service)
>  {
> -	spin_lock(&service_spinlock);
> -	WARN_ON(!service);
> -	if (service) {
> -		WARN_ON(service->ref_count == 0);
> -		service->ref_count++;
> +	if (!service) {
> +		WARN(1, "%s service is NULL\n", __func__);
> +		return;
>  	}
> -	spin_unlock(&service_spinlock);
> +	kref_get(&service->ref_count);
> +}
> +
> +static void service_release(struct kref *kref)
> +{
> +	struct vchiq_service *service =
> +		container_of(kref, struct vchiq_service, ref_count);
> +	struct vchiq_state *state = service->state;
> +
> +	WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE);
> +	rcu_assign_pointer(state->services[service->localport], NULL);
> +	if (service->userdata_term)
> +		service->userdata_term(service->base.userdata);
> +	kfree_rcu(service, rcu);
>  }

I think that's the first time I've seen krefs used with rcu.

It looks sane at first glance, but it's a lot of tricky changes, so I'll
assume you tested this and go merge it to see what breaks :)

thanks for doing this,

greg k-h
_______________________________________________
devel mailing list
devel@xxxxxxxxxxxxxxxxxxxxxx
http://driverdev.linuxdriverproject.org/mailman/listinfo/driverdev-devel



[Index of Archives]     [Linux Driver Backports]     [DMA Engine]     [Linux GPIO]     [Linux SPI]     [Video for Linux]     [Linux USB Devel]     [Linux Coverity]     [Linux Audio Users]     [Linux Kernel]     [Linux SCSI]     [Yosemite Backpacking]
  Powered by Linux