From: Li Nan <linan122@xxxxxxxxxx> There is no input check when echo md/safe_mode_delay and overflow will occur. There is risk of overflow in strict_strtoul_scaled(), too. Fix it by using kstrtoul instead of parsing word one by one. Fixes: 72e02075a33f ("md: factor out parsing of fixed-point numbers") Signed-off-by: Li Nan <linan122@xxxxxxxxxx> Reviewed-by: Yu Kuai <yukuai3@xxxxxxxxxx> --- drivers/md/md.c | 76 +++++++++++++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/drivers/md/md.c b/drivers/md/md.c index 8e344b4b3444..5bba071ea907 100644 --- a/drivers/md/md.c +++ b/drivers/md/md.c @@ -3767,56 +3767,76 @@ static int analyze_sbs(struct mddev *mddev) */ int strict_strtoul_scaled(const char *cp, unsigned long *res, int scale) { - unsigned long result = 0; - long decimals = -1; - while (isdigit(*cp) || (*cp == '.' && decimals < 0)) { - if (*cp == '.') - decimals = 0; - else if (decimals < scale) { - unsigned int value; - value = *cp - '0'; - result = result * 10 + value; - if (decimals >= 0) - decimals++; - } - cp++; - } - if (*cp == '\n') - cp++; - if (*cp) + unsigned long result = 0, decimals = 0; + char *pos, *str; + int rv; + + str = kmemdup_nul(cp, strlen(cp), GFP_KERNEL); + if (!str) + return -ENOMEM; + pos = strchr(str, '.'); + if (pos) { + int cnt = scale; + + *pos = '\0'; + while (isdigit(*(++pos))) { + if (cnt) { + decimals = decimals * 10 + *pos - '0'; + cnt--; + } + } + if (*pos == '\n') + pos++; + if (*pos) { + kfree(str); + return -EINVAL; + } + decimals *= int_pow(10, cnt); + } + + rv = kstrtoul(str, 10, &result); + kfree(str); + if (rv) + return rv; + + if (result > div64_u64(ULONG_MAX - decimals, int_pow(10, scale))) return -EINVAL; - if (decimals < 0) - decimals = 0; - *res = result * int_pow(10, scale - decimals); - return 0; + *res = result * int_pow(10, scale) + decimals; + + return rv; } static ssize_t safe_delay_show(struct mddev *mddev, char *page) { - int msec = (mddev->safemode_delay*1000)/HZ; - return sprintf(page, "%d.%03d\n", msec/1000, msec%1000); + unsigned int msec = ((unsigned long)mddev->safemode_delay*1000)/HZ; + + return sprintf(page, "%u.%03u\n", msec/1000, msec%1000); } static ssize_t safe_delay_store(struct mddev *mddev, const char *cbuf, size_t len) { unsigned long msec; + int ret; if (mddev_is_clustered(mddev)) { pr_warn("md: Safemode is disabled for clustered mode\n"); return -EINVAL; } - if (strict_strtoul_scaled(cbuf, &msec, 3) < 0) + ret = strict_strtoul_scaled(cbuf, &msec, 3); + if (ret < 0) + return ret; + if (msec > UINT_MAX) return -EINVAL; + if (msec == 0) mddev->safemode_delay = 0; else { - unsigned long old_delay = mddev->safemode_delay; - unsigned long new_delay = (msec*HZ)/1000; + unsigned int old_delay = mddev->safemode_delay; + /* HZ <= 1000, so new_delay < UINT_MAX, too */ + unsigned int new_delay = DIV64_U64_ROUND_UP(msec * HZ, 1000); - if (new_delay == 0) - new_delay = 1; mddev->safemode_delay = new_delay; if (new_delay < old_delay || old_delay == 0) mod_timer(&mddev->safemode_timer, jiffies+1); -- 2.31.1