Better memtoll() error checking.

Related to PR #2357.
This commit is contained in:
antirez 2015-02-12 16:40:41 +01:00
parent a1d9ec0d44
commit 29b54db320

View File

@ -38,6 +38,7 @@
#include <sys/time.h> #include <sys/time.h>
#include <float.h> #include <float.h>
#include <stdint.h> #include <stdint.h>
#include <errno.h>
#include "util.h" #include "util.h"
#include "sha1.h" #include "sha1.h"
@ -170,11 +171,12 @@ int stringmatch(const char *pattern, const char *string, int nocase) {
} }
/* Convert a string representing an amount of memory into the number of /* Convert a string representing an amount of memory into the number of
* bytes, so for instance memtoll("1Gi") will return 1073741824 that is * bytes, so for instance memtoll("1Gb") will return 1073741824 that is
* (1024*1024*1024). * (1024*1024*1024).
* *
* On parsing error, if *err is not NULL, it's set to 1, otherwise it's * On parsing error, if *err is not NULL, it's set to 1, otherwise it's
* set to 0 */ * set to 0. On error the function return value is 0, regardless of the
* fact 'err' is NULL or not. */
long long memtoll(const char *p, int *err) { long long memtoll(const char *p, int *err) {
const char *u; const char *u;
char buf[128]; char buf[128];
@ -183,6 +185,7 @@ long long memtoll(const char *p, int *err) {
unsigned int digits; unsigned int digits;
if (err) *err = 0; if (err) *err = 0;
/* Search the first non digit character. */ /* Search the first non digit character. */
u = p; u = p;
if (*u == '-') u++; if (*u == '-') u++;
@ -203,16 +206,26 @@ long long memtoll(const char *p, int *err) {
mul = 1024L*1024*1024; mul = 1024L*1024*1024;
} else { } else {
if (err) *err = 1; if (err) *err = 1;
mul = 1; return 0;
} }
/* Copy the digits into a buffer, we'll use strtoll() to convert
* the digit (without the unit) into a number. */
digits = u-p; digits = u-p;
if (digits >= sizeof(buf)) { if (digits >= sizeof(buf)) {
if (err) *err = 1; if (err) *err = 1;
return LLONG_MAX; return 0;
} }
memcpy(buf,p,digits); memcpy(buf,p,digits);
buf[digits] = '\0'; buf[digits] = '\0';
val = strtoll(buf,NULL,10);
char *endptr;
errno = 0;
val = strtoll(buf,&endptr,10);
if ((val == 0 && errno == EINVAL) || *endptr != '\0') {
if (err) *err = 1;
return 0;
}
return val*mul; return val*mul;
} }