diff --git a/src/module.c b/src/module.c index 8acafccfe..fc5396d0c 100644 --- a/src/module.c +++ b/src/module.c @@ -7308,22 +7308,30 @@ int RM_GetLFU(RedisModuleKey *key, long long *lfu_freq) { * Unlike RM_ModuleTypeSetValue() which will free the old value, this function * simply swaps the old value with the new value. * - * The function returns the old value, or NULL if any of the above conditions is - * not met. + * The function returns REDISMODULE_OK on success, REDISMODULE_ERR on errors + * such as: + * + * 1. Key is not opened for writing. + * 2. Key is not a module data type key. + * 3. Key is a module datatype other than 'mt'. + * + * If old_value is non-NULL, the old value is returned by reference. */ -void *RM_ModuleTypeReplaceValue(RedisModuleKey *key, moduleType *mt, void *new_value) { +int RM_ModuleTypeReplaceValue(RedisModuleKey *key, moduleType *mt, void *new_value, void **old_value) { if (!(key->mode & REDISMODULE_WRITE) || key->iter) - return NULL; + return REDISMODULE_ERR; if (!key->value || key->value->type != OBJ_MODULE) - return NULL; + return REDISMODULE_ERR; moduleValue *mv = key->value->ptr; if (mv->type != mt) - return NULL; + return REDISMODULE_ERR; - void *old_val = mv->value; + if (old_value) + *old_value = mv->value; mv->value = new_value; - return old_val; + + return REDISMODULE_OK; } /* Register all the APIs we export. Keep this function at the end of the diff --git a/src/redismodule.h b/src/redismodule.h index 6214af7e6..02f267273 100644 --- a/src/redismodule.h +++ b/src/redismodule.h @@ -525,7 +525,7 @@ int REDISMODULE_API_FUNC(RedisModule_GetContextFlags)(RedisModuleCtx *ctx); void *REDISMODULE_API_FUNC(RedisModule_PoolAlloc)(RedisModuleCtx *ctx, size_t bytes); RedisModuleType *REDISMODULE_API_FUNC(RedisModule_CreateDataType)(RedisModuleCtx *ctx, const char *name, int encver, RedisModuleTypeMethods *typemethods); int REDISMODULE_API_FUNC(RedisModule_ModuleTypeSetValue)(RedisModuleKey *key, RedisModuleType *mt, void *value); -void *REDISMODULE_API_FUNC(RedisModule_ModuleTypeReplaceValue)(RedisModuleKey *key, RedisModuleType *mt, void *new_value); +int REDISMODULE_API_FUNC(RedisModule_ModuleTypeReplaceValue)(RedisModuleKey *key, RedisModuleType *mt, void *new_value, void **old_value); RedisModuleType *REDISMODULE_API_FUNC(RedisModule_ModuleTypeGetType)(RedisModuleKey *key); void *REDISMODULE_API_FUNC(RedisModule_ModuleTypeGetValue)(RedisModuleKey *key); int REDISMODULE_API_FUNC(RedisModule_IsIOError)(RedisModuleIO *io); diff --git a/tests/modules/datatype.c b/tests/modules/datatype.c index 7c39ab457..6596f9368 100644 --- a/tests/modules/datatype.c +++ b/tests/modules/datatype.c @@ -124,6 +124,28 @@ static int datatype_dump(RedisModuleCtx *ctx, RedisModuleString **argv, int argc return REDISMODULE_OK; } +static int datatype_swap(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + if (argc != 3) { + RedisModule_WrongArity(ctx); + return REDISMODULE_OK; + } + + RedisModuleKey *a = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE); + RedisModuleKey *b = RedisModule_OpenKey(ctx, argv[2], REDISMODULE_WRITE); + void *val = RedisModule_ModuleTypeGetValue(a); + + int error = (RedisModule_ModuleTypeReplaceValue(b, datatype, val, &val) == REDISMODULE_ERR || + RedisModule_ModuleTypeReplaceValue(a, datatype, val, NULL) == REDISMODULE_ERR); + if (!error) + RedisModule_ReplyWithSimpleString(ctx, "OK"); + else + RedisModule_ReplyWithError(ctx, "ERR failed"); + + RedisModule_CloseKey(a); + RedisModule_CloseKey(b); + + return REDISMODULE_OK; +} int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { REDISMODULE_NOT_USED(argv); @@ -157,5 +179,8 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) if (RedisModule_CreateCommand(ctx,"datatype.dump", datatype_dump,"",1,1,1) == REDISMODULE_ERR) return REDISMODULE_ERR; + if (RedisModule_CreateCommand(ctx,"datatype.swap", datatype_swap,"",1,1,1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + return REDISMODULE_OK; } diff --git a/tests/unit/moduleapi/datatype.tcl b/tests/unit/moduleapi/datatype.tcl index c1da696d3..e235462ea 100644 --- a/tests/unit/moduleapi/datatype.tcl +++ b/tests/unit/moduleapi/datatype.tcl @@ -24,4 +24,21 @@ start_server {tags {"modules"}} { catch {r datatype.restore dtkeycopy $truncated} e set e } {*Invalid*} + + test {DataType: ModuleTypeReplaceValue() happy path works} { + r datatype.set key-a 1 AAA + r datatype.set key-b 2 BBB + + assert {[r datatype.swap key-a key-b] eq {OK}} + assert {[r datatype.get key-a] eq {2 BBB}} + assert {[r datatype.get key-b] eq {1 AAA}} + } + + test {DataType: ModuleTypeReplaceValue() fails on non-module keys} { + r datatype.set key-a 1 AAA + r set key-b RedisString + + catch {r datatype.swap key-a key-b} e + set e + } {*ERR*} }