diff --git a/src/scripting.c b/src/scripting.c index c75858d31..d36d8bb48 100644 --- a/src/scripting.c +++ b/src/scripting.c @@ -2243,7 +2243,8 @@ int ldbDelBreakpoint(int line) { /* Expect a valid multi-bulk command in the debugging client query buffer. * On success the command is parsed and returned as an array of SDS strings, * otherwise NULL is returned and there is to read more buffer. */ -sds *ldbReplParseCommand(int *argcp) { +sds *ldbReplParseCommand(int *argcp, char** err) { + static char* protocol_error = "protocol error"; sds *argv = NULL; int argc = 0; if (sdslen(ldb.cbuf) == 0) return NULL; @@ -2260,7 +2261,7 @@ sds *ldbReplParseCommand(int *argcp) { /* Seek and parse *\r\n. */ p = strchr(p,'*'); if (!p) goto protoerr; char *plen = p+1; /* Multi bulk len pointer. */ - p = strstr(p,"\r\n"); if (!p) goto protoerr; + p = strstr(p,"\r\n"); if (!p) goto keep_reading; *p = '\0'; p += 2; *argcp = atoi(plen); if (*argcp <= 0 || *argcp > 1024) goto protoerr; @@ -2269,12 +2270,16 @@ sds *ldbReplParseCommand(int *argcp) { argv = zmalloc(sizeof(sds)*(*argcp)); argc = 0; while(argc < *argcp) { + // reached the end but there should be more data to read + if (*p == '\0') goto keep_reading; + if (*p != '$') goto protoerr; plen = p+1; /* Bulk string len pointer. */ - p = strstr(p,"\r\n"); if (!p) goto protoerr; + p = strstr(p,"\r\n"); if (!p) goto keep_reading; *p = '\0'; p += 2; int slen = atoi(plen); /* Length of this arg. */ if (slen <= 0 || slen > 1024) goto protoerr; + if ((size_t)(p + slen + 2 - copy) > sdslen(copy) ) goto keep_reading; argv[argc++] = sdsnewlen(p,slen); p += slen; /* Skip the already parsed argument. */ if (p[0] != '\r' || p[1] != '\n') goto protoerr; @@ -2284,6 +2289,8 @@ sds *ldbReplParseCommand(int *argcp) { return argv; protoerr: + *err = protocol_error; +keep_reading: sdsfreesplitres(argv,argc); sdsfree(copy); return NULL; @@ -2772,12 +2779,17 @@ void ldbMaxlen(sds *argv, int argc) { int ldbRepl(lua_State *lua) { sds *argv; int argc; + char* err = NULL; /* We continue processing commands until a command that should return * to the Lua interpreter is found. */ while(1) { - while((argv = ldbReplParseCommand(&argc)) == NULL) { + while((argv = ldbReplParseCommand(&argc, &err)) == NULL) { char buf[1024]; + if (err) { + lua_pushstring(lua, err); + lua_error(lua); + } int nread = connRead(ldb.conn,buf,sizeof(buf)); if (nread <= 0) { /* Make sure the script runs without user input since the @@ -2787,6 +2799,15 @@ int ldbRepl(lua_State *lua) { return C_ERR; } ldb.cbuf = sdscatlen(ldb.cbuf,buf,nread); + /* after 1M we will exit with an error + * so that the client will not blow the memory + */ + if (sdslen(ldb.cbuf) > 1<<20) { + sdsfree(ldb.cbuf); + ldb.cbuf = sdsempty(); + lua_pushstring(lua, "max client buffer reached"); + lua_error(lua); + } } /* Flush the old buffer. */ diff --git a/tests/unit/scripting.tcl b/tests/unit/scripting.tcl index 25297e767..1613d28d7 100644 --- a/tests/unit/scripting.tcl +++ b/tests/unit/scripting.tcl @@ -935,6 +935,21 @@ start_server {tags {"scripting external:skip"}} { r eval {return 'hello'} 0 } +start_server {tags {"scripting needs:debug external:skip"}} { + test {Test scripting debug protocol parsing} { + r script debug sync + r eval {return 'hello'} 0 + catch {r 'hello\0world'} e + assert_match {*Unknown Redis Lua debugger command*} $e + catch {r 'hello\0'} e + assert_match {*Unknown Redis Lua debugger command*} $e + catch {r '\0hello'} e + assert_match {*Unknown Redis Lua debugger command*} $e + catch {r '\0hello\0'} e + assert_match {*Unknown Redis Lua debugger command*} $e + } +} + start_server {tags {"scripting resp3 needs:debug"}} { r debug set-disable-deny-scripts 1 for {set i 2} {$i <= 3} {incr i} { @@ -1046,4 +1061,4 @@ start_server {tags {"scripting resp3 needs:debug"}} { } {Some real reply following the attribute} r debug set-disable-deny-scripts 0 -} \ No newline at end of file +}