diff --git a/src/db.c b/src/db.c index 9c0349bd0..32d539f6b 100644 --- a/src/db.c +++ b/src/db.c @@ -319,6 +319,106 @@ void keysCommand(redisClient *c) { setDeferredMultiBulkLength(c,replylen,numkeys); } +void scanCallback(void *privdata, const dictEntry *de) { + list *keys = (list *)privdata; + sds key = dictGetKey(de); + robj *kobj = createStringObject(key, sdslen(key)); + listAddNodeTail(keys, kobj); +} + +void scanCommand(redisClient *c) { + int rv; + int i, j; + char buf[32]; + list *keys = listCreate(); + listNode *ln, *ln_; + unsigned long cursor = 0; + long count = 1; + sds pat; + int patlen, patnoop = 1; + + /* Use sscanf because we need an *unsigned* long */ + rv = sscanf(c->argv[1]->ptr, "%lu", &cursor); + if (rv != 1) { + addReplyError(c, "invalid cursor"); + goto cleanup; + } + + i = 2; + while (i < c->argc) { + j = c->argc - i; + if (!strcasecmp(c->argv[i]->ptr, "count") && j >= 2) { + if (getLongFromObjectOrReply(c, c->argv[i+1], &count, NULL) != REDIS_OK) { + goto cleanup; + } + + if (count < 1) { + addReply(c,shared.syntaxerr); + goto cleanup; + } + + i += 2; + } else if (!strcasecmp(c->argv[i]->ptr, "pattern") && j >= 2) { + pat = c->argv[i+1]->ptr; + patlen = sdslen(pat); + + /* The pattern is a no-op iff == "*" */ + patnoop = (pat[0] == '*' && patlen == 1); + + i += 2; + } else { + addReply(c,shared.syntaxerr); + goto cleanup; + } + } + + do { + cursor = dictScan(c->db->dict, cursor, scanCallback, keys); + } while (cursor && listLength(keys) < count); + + /* Filter keys */ + ln = listFirst(keys); + while (ln) { + robj *kobj = listNodeValue(ln); + ln_ = listNextNode(ln); + + /* Keep key iff pattern matches and it hasn't expired */ + if ((patnoop || stringmatchlen(pat, patlen, kobj->ptr, sdslen(kobj->ptr), 0)) && + (expireIfNeeded(c->db, kobj) == 0)) + { + /* Keep */ + } else { + decrRefCount(kobj); + listDelNode(keys, ln); + } + + ln = ln_; + } + + addReplyMultiBulkLen(c, 2); + + rv = snprintf(buf, sizeof(buf), "%lu", cursor); + redisAssert(rv < sizeof(buf)); + addReplyBulkCBuffer(c, buf, rv); + + addReplyMultiBulkLen(c, listLength(keys)); + while ((ln = listFirst(keys)) != NULL) { + robj *kobj = listNodeValue(ln); + addReplyBulk(c, kobj); + decrRefCount(kobj); + listDelNode(keys, ln); + } + +cleanup: + while ((ln = listFirst(keys)) != NULL) { + robj *kobj = listNodeValue(ln); + decrRefCount(kobj); + listDelNode(keys, ln); + } + + listRelease(keys); +} + void dbsizeCommand(redisClient *c) { addReplyLongLong(c,dictSize(c->db->dict)); } diff --git a/src/dict.c b/src/dict.c index 97a2bca43..f4a44cf2e 100644 --- a/src/dict.c +++ b/src/dict.c @@ -648,6 +648,98 @@ dictEntry *dictGetRandomKey(dict *d) return he; } +/* Function to reverse bits. Algorithm from: + * http://graphics.stanford.edu/~seander/bithacks.html#ReverseParallel */ +static unsigned long rev(unsigned long v) { + unsigned long s = 8 * sizeof(v); // bit size; must be power of 2 + unsigned long mask = ~0; + while ((s >>= 1) > 0) { + mask ^= (mask << s); + v = ((v >> s) & mask) | ((v << s) & ~mask); + } + return v; +} + +unsigned long dictScan(dict *d, + unsigned long v, + dictScanFunction *fn, + void *privdata) +{ + dictht *t0, *t1; + const dictEntry *de; + unsigned long s0, s1; + unsigned long m0, m1; + + if (!dictIsRehashing(d)) { + t0 = &(d->ht[0]); + m0 = t0->sizemask; + + /* Emit entries at cursor */ + de = t0->table[v & m0]; + while (de) { + fn(privdata, de); + de = de->next; + } + + } else { + t0 = &d->ht[0]; + t1 = &d->ht[1]; + + /* Make sure t0 is the smaller and t1 is the bigger table */ + if (t0->size > t1->size) { + t0 = &d->ht[1]; + t1 = &d->ht[0]; + } + + s0 = t0->size; + s1 = t1->size; + m0 = t0->sizemask; + m1 = t1->sizemask; + + /* Emit entries at cursor */ + de = t0->table[v & m0]; + while (de) { + fn(privdata, de); + de = de->next; + } + + /* Iterate over indices in larger table that are the expansion + * of the index pointed to by the cursor in the smaller table */ + do { + /* Emit entries at cursor */ + de = t1->table[v & m1]; + while (de) { + fn(privdata, de); + de = de->next; + } + + /* Increment bits not covered by the smaller mask */ + v = (((v | m0) + 1) & ~m0) | (v & m0); + + /* Continue while bits covered by mask difference is non-zero */ + } while (v & (m0 ^ m1)); + } + + /* Set unmasked bits so incrementing the reversed cursor + * operates on the masked bits of the smaller table */ + v |= ~m0; + + /* Increment the reverse cursor */ + v = rev(v); + v++; + v = rev(v); + + /* Only preprare cursor for the next iteration when it is non-zero, + * so that 0 can be used as end-of-scan sentinel. */ + if (v) { + /* Set unmasked bits so the cursor will keep its position + * regardless of the mask in the next iterations */ + v |= ~m0; + } + + return v; +} + /* ------------------------- private functions ------------------------------ */ /* Expand the hash table if needed */ diff --git a/src/dict.h b/src/dict.h index 4d750ae85..11e1b97ee 100644 --- a/src/dict.h +++ b/src/dict.h @@ -91,6 +91,8 @@ typedef struct dictIterator { long long fingerprint; /* unsafe iterator fingerprint for misuse detection */ } dictIterator; +typedef void (dictScanFunction)(void *privdata, const dictEntry *de); + /* This is the initial size of every hash table */ #define DICT_HT_INITIAL_SIZE 4 @@ -165,6 +167,7 @@ int dictRehash(dict *d, int n); int dictRehashMilliseconds(dict *d, int ms); void dictSetHashFunctionSeed(unsigned int initval); unsigned int dictGetHashFunctionSeed(void); +unsigned long dictScan(dict *d, unsigned long v, dictScanFunction *fn, void *privdata); /* Hash table types */ extern dictType dictTypeHeapStringCopyKey; diff --git a/src/redis.c b/src/redis.c index 30348a674..6e5181e1f 100644 --- a/src/redis.c +++ b/src/redis.c @@ -210,6 +210,7 @@ struct redisCommand redisCommandTable[] = { {"pexpire",pexpireCommand,3,"w",0,NULL,1,1,1,0,0}, {"pexpireat",pexpireatCommand,3,"w",0,NULL,1,1,1,0,0}, {"keys",keysCommand,2,"rS",0,NULL,0,0,0,0,0}, + {"scan",scanCommand,-1,"RS",0,NULL,0,0,0,0,0}, {"dbsize",dbsizeCommand,1,"r",0,NULL,0,0,0,0,0}, {"auth",authCommand,2,"rslt",0,NULL,0,0,0,0,0}, {"ping",pingCommand,1,"rt",0,NULL,0,0,0,0,0}, diff --git a/src/redis.h b/src/redis.h index a8a68bced..7b643017a 100644 --- a/src/redis.h +++ b/src/redis.h @@ -1250,6 +1250,7 @@ void incrbyfloatCommand(redisClient *c); void selectCommand(redisClient *c); void randomkeyCommand(redisClient *c); void keysCommand(redisClient *c); +void scanCommand(redisClient *c); void dbsizeCommand(redisClient *c); void lastsaveCommand(redisClient *c); void saveCommand(redisClient *c); diff --git a/tests/unit/basic.tcl b/tests/unit/basic.tcl index 1f46ba666..a4a0e791a 100644 --- a/tests/unit/basic.tcl +++ b/tests/unit/basic.tcl @@ -761,4 +761,58 @@ start_server {tags {"basic"}} { r keys * r keys * } {dlskeriewrioeuwqoirueioqwrueoqwrueqw} + + test "SCAN basic" { + r flushdb + r debug populate 1000 + + set cur 0 + set keys {} + while 1 { + set res [r scan $cur] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys $k + if {$cur == 0} break + } + + set keys [lsort -unique [concat {*}$keys]] + assert_equal 1000 [llength $keys] + } + + test "SCAN COUNT" { + r flushdb + r debug populate 1000 + + set cur 0 + set keys {} + while 1 { + set res [r scan $cur count 5] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys $k + if {$cur == 0} break + } + + set keys [lsort -unique [concat {*}$keys]] + assert_equal 1000 [llength $keys] + } + + test "SCAN PATTERN" { + r flushdb + r debug populate 1000 + + set cur 0 + set keys {} + while 1 { + set res [r scan $cur pattern "key:1??"] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys $k + if {$cur == 0} break + } + + set keys [lsort -unique [concat {*}$keys]] + assert_equal 100 [llength $keys] + } }