In preparation for allowing seamless reconnects we need a way to make sure that we don't free the socks array out from underneath ourselves. So a socks_ref counter in order to keep track of who is using the socks array, and only free it and change num_connections once our reference reduces to zero. We also need to make sure that somebody calling SET_SOCK isn't coming in before we're done with our socks array, so add a waitqueue to wait on previous users of the socks array before initiating a new socks array. Signed-off-by: Josef Bacik <jbacik@xxxxxx> --- drivers/block/nbd.c | 126 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 91 insertions(+), 35 deletions(-) diff --git a/drivers/block/nbd.c b/drivers/block/nbd.c index 1914ba2..3dc2f1d 100644 --- a/drivers/block/nbd.c +++ b/drivers/block/nbd.c @@ -54,19 +54,24 @@ struct nbd_sock { #define NBD_TIMEDOUT 0 #define NBD_DISCONNECT_REQUESTED 1 #define NBD_DISCONNECTED 2 -#define NBD_RUNNING 3 +#define NBD_HAS_SOCKS_REF 3 struct nbd_device { u32 flags; unsigned long runtime_flags; + + struct mutex socks_lock; struct nbd_sock **socks; + atomic_t socks_ref; + wait_queue_head_t socks_wq; + int num_connections; + int magic; struct blk_mq_tag_set tag_set; struct mutex config_lock; struct gendisk *disk; - int num_connections; atomic_t recv_threads; wait_queue_head_t recv_wq; loff_t blksize; @@ -102,7 +107,6 @@ static int part_shift; static int nbd_dev_dbg_init(struct nbd_device *nbd); static void nbd_dev_dbg_close(struct nbd_device *nbd); - static inline struct device *nbd_to_dev(struct nbd_device *nbd) { return disk_to_dev(nbd->disk); @@ -125,6 +129,27 @@ static const char *nbdcmd_to_ascii(int cmd) return "invalid"; } +static int nbd_socks_get_unless_zero(struct nbd_device *nbd) +{ + return atomic_add_unless(&nbd->socks_ref, 1, 0); +} + +static void nbd_socks_put(struct nbd_device *nbd) +{ + if (atomic_dec_and_test(&nbd->socks_ref)) { + mutex_lock(&nbd->socks_lock); + if (nbd->num_connections) { + int i; + for (i = 0; i < nbd->num_connections; i++) + kfree(nbd->socks[i]); + kfree(nbd->socks); + nbd->num_connections = 0; + nbd->socks = NULL; + } + mutex_unlock(&nbd->socks_lock); + } +} + static int nbd_size_clear(struct nbd_device *nbd, struct block_device *bdev) { bdev->bd_inode->i_size = 0; @@ -190,6 +215,7 @@ static void sock_shutdown(struct nbd_device *nbd) mutex_lock(&nsock->tx_lock); kernel_sock_shutdown(nsock->sock, SHUT_RDWR); mutex_unlock(&nsock->tx_lock); + nsock->dead = true; } dev_warn(disk_to_dev(nbd->disk), "shutting down sockets\n"); } @@ -200,6 +226,9 @@ static enum blk_eh_timer_return nbd_xmit_timeout(struct request *req, struct nbd_cmd *cmd = blk_mq_rq_to_pdu(req); struct nbd_device *nbd = cmd->nbd; + if (!nbd_socks_get_unless_zero(nbd)) + return BLK_EH_HANDLED; + if (nbd->num_connections > 1) { dev_err_ratelimited(nbd_to_dev(nbd), "Connection timed out, retrying\n"); @@ -219,6 +248,7 @@ static enum blk_eh_timer_return nbd_xmit_timeout(struct request *req, } mutex_unlock(&nbd->config_lock); blk_mq_requeue_request(req, true); + nbd_socks_put(nbd); return BLK_EH_RESET_TIMER; } mutex_unlock(&nbd->config_lock); @@ -228,10 +258,9 @@ static enum blk_eh_timer_return nbd_xmit_timeout(struct request *req, } set_bit(NBD_TIMEDOUT, &nbd->runtime_flags); req->errors++; - - mutex_lock(&nbd->config_lock); sock_shutdown(nbd); - mutex_unlock(&nbd->config_lock); + nbd_socks_put(nbd); + return BLK_EH_HANDLED; } @@ -523,6 +552,7 @@ static void recv_work(struct work_struct *work) nbd_end_request(cmd); } + nbd_socks_put(nbd); atomic_dec(&nbd->recv_threads); wake_up(&nbd->recv_wq); } @@ -598,9 +628,16 @@ static int nbd_handle_cmd(struct nbd_cmd *cmd, int index) struct nbd_sock *nsock; int ret; + if (!nbd_socks_get_unless_zero(nbd)) { + dev_err_ratelimited(disk_to_dev(nbd->disk), + "Socks array is empty\n"); + return -EINVAL; + } + if (index >= nbd->num_connections) { dev_err_ratelimited(disk_to_dev(nbd->disk), "Attempted send on invalid socket\n"); + nbd_socks_put(nbd); return -EINVAL; } req->errors = 0; @@ -608,8 +645,10 @@ static int nbd_handle_cmd(struct nbd_cmd *cmd, int index) nsock = nbd->socks[index]; if (nsock->dead) { index = find_fallback(nbd, index); - if (index < 0) + if (index < 0) { + nbd_socks_put(nbd); return -EIO; + } nsock = nbd->socks[index]; } @@ -627,7 +666,7 @@ static int nbd_handle_cmd(struct nbd_cmd *cmd, int index) goto again; } mutex_unlock(&nsock->tx_lock); - + nbd_socks_put(nbd); return ret; } @@ -656,6 +695,25 @@ static int nbd_queue_rq(struct blk_mq_hw_ctx *hctx, return BLK_MQ_RQ_QUEUE_OK; } +static int nbd_wait_for_socks(struct nbd_device *nbd) +{ + int ret; + + if (!atomic_read(&nbd->socks_ref)) + return 0; + + do { + mutex_unlock(&nbd->socks_lock); + mutex_unlock(&nbd->config_lock); + ret = wait_event_interruptible(nbd->socks_wq, + atomic_read(&nbd->socks_ref) == 0); + mutex_lock(&nbd->config_lock); + mutex_lock(&nbd->socks_lock); + } while (!ret && atomic_read(&nbd->socks_ref)); + + return ret; +} + static int nbd_add_socket(struct nbd_device *nbd, struct block_device *bdev, unsigned long arg) { @@ -668,21 +726,30 @@ static int nbd_add_socket(struct nbd_device *nbd, struct block_device *bdev, if (!sock) return err; - if (!nbd->task_setup) + err = -EINVAL; + mutex_lock(&nbd->socks_lock); + if (!nbd->task_setup) { nbd->task_setup = current; + if (nbd_wait_for_socks(nbd)) + goto out; + atomic_inc(&nbd->socks_ref); + set_bit(NBD_HAS_SOCKS_REF, &nbd->runtime_flags); + } + if (nbd->task_setup != current) { dev_err(disk_to_dev(nbd->disk), "Device being setup by another task"); - return -EINVAL; + goto out; } + err = -ENOMEM; socks = krealloc(nbd->socks, (nbd->num_connections + 1) * sizeof(struct nbd_sock *), GFP_KERNEL); if (!socks) - return -ENOMEM; + goto out; nsock = kzalloc(sizeof(struct nbd_sock), GFP_KERNEL); if (!nsock) - return -ENOMEM; + goto out; nbd->socks = socks; @@ -694,7 +761,10 @@ static int nbd_add_socket(struct nbd_device *nbd, struct block_device *bdev, if (max_part) bdev->bd_invalidated = 1; - return 0; + err = 0; +out: + mutex_unlock(&nbd->socks_lock); + return err; } /* Reset all properties of an NBD device */ @@ -750,20 +820,17 @@ static void send_disconnects(struct nbd_device *nbd) static int nbd_disconnect(struct nbd_device *nbd, struct block_device *bdev) { dev_info(disk_to_dev(nbd->disk), "NBD_DISCONNECT\n"); - if (!nbd->socks) + if (!nbd_socks_get_unless_zero(nbd)) return -EINVAL; mutex_unlock(&nbd->config_lock); fsync_bdev(bdev); mutex_lock(&nbd->config_lock); - /* Check again after getting mutex back. */ - if (!nbd->socks) - return -EINVAL; - if (!test_and_set_bit(NBD_DISCONNECT_REQUESTED, &nbd->runtime_flags)) send_disconnects(nbd); + nbd_socks_put(nbd); return 0; } @@ -773,22 +840,9 @@ static int nbd_clear_sock(struct nbd_device *nbd, struct block_device *bdev) nbd_clear_que(nbd); kill_bdev(bdev); nbd_bdev_reset(bdev); - /* - * We want to give the run thread a chance to wait for everybody - * to clean up and then do it's own cleanup. - */ - if (!test_bit(NBD_RUNNING, &nbd->runtime_flags) && - nbd->num_connections) { - int i; - - for (i = 0; i < nbd->num_connections; i++) - kfree(nbd->socks[i]); - kfree(nbd->socks); - nbd->socks = NULL; - nbd->num_connections = 0; - } nbd->task_setup = NULL; - + if (test_and_clear_bit(NBD_HAS_SOCKS_REF, &nbd->runtime_flags)) + nbd_socks_put(nbd); return 0; } @@ -809,7 +863,6 @@ static int nbd_start_device(struct nbd_device *nbd, struct block_device *bdev) goto out_err; } - set_bit(NBD_RUNNING, &nbd->runtime_flags); blk_mq_update_nr_hw_queues(&nbd->tag_set, nbd->num_connections); args = kcalloc(num_connections, sizeof(*args), GFP_KERNEL); if (!args) { @@ -833,6 +886,7 @@ static int nbd_start_device(struct nbd_device *nbd, struct block_device *bdev) for (i = 0; i < num_connections; i++) { sk_set_memalloc(nbd->socks[i]->sock->sk); atomic_inc(&nbd->recv_threads); + atomic_inc(&nbd->socks_ref); INIT_WORK(&args[i].work, recv_work); args[i].nbd = nbd; args[i].index = i; @@ -849,7 +903,6 @@ static int nbd_start_device(struct nbd_device *nbd, struct block_device *bdev) mutex_lock(&nbd->config_lock); nbd->task_recv = NULL; out_err: - clear_bit(NBD_RUNNING, &nbd->runtime_flags); nbd_clear_sock(nbd, bdev); /* user requested, ignore socket errors */ @@ -1149,12 +1202,15 @@ static int nbd_dev_add(int index) nbd->magic = NBD_MAGIC; mutex_init(&nbd->config_lock); + mutex_init(&nbd->socks_lock); + atomic_set(&nbd->socks_ref, 0); disk->major = NBD_MAJOR; disk->first_minor = index << part_shift; disk->fops = &nbd_fops; disk->private_data = nbd; sprintf(disk->disk_name, "nbd%d", index); init_waitqueue_head(&nbd->recv_wq); + init_waitqueue_head(&nbd->socks_wq); nbd_reset(nbd); add_disk(disk); return index; -- 2.7.4