Skip to content

Implement ZAdd and Zscore #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/main/java/dev/keva/core/aof/AOFContainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public List<Command> read() throws IOException {
byte[][] objects = (byte[][]) input.readObject();
commands.add(Command.newInstance(objects, false));
} catch (EOFException e) {
log.error("Error while reading AOF command", e);
fis.close();
return commands;
} catch (ClassNotFoundException e) {
Expand Down
123 changes: 123 additions & 0 deletions core/src/main/java/dev/keva/core/command/impl/zset/ZAdd.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package dev.keva.core.command.impl.zset;

import dev.keva.core.command.annotation.CommandImpl;
import dev.keva.core.command.annotation.Execute;
import dev.keva.core.command.annotation.Mutate;
import dev.keva.core.command.annotation.ParamLength;
import dev.keva.ioc.annotation.Autowired;
import dev.keva.ioc.annotation.Component;
import dev.keva.protocol.resp.reply.BulkReply;
import dev.keva.protocol.resp.reply.ErrorReply;
import dev.keva.protocol.resp.reply.IntegerReply;
import dev.keva.protocol.resp.reply.Reply;
import dev.keva.store.KevaDatabase;
import dev.keva.util.DoubleUtil;
import dev.keva.util.hashbytes.BytesKey;

import java.nio.charset.StandardCharsets;
import java.util.AbstractMap.SimpleEntry;

import static dev.keva.util.Constants.FLAG_CH;
import static dev.keva.util.Constants.FLAG_GT;
import static dev.keva.util.Constants.FLAG_INCR;
import static dev.keva.util.Constants.FLAG_LT;
import static dev.keva.util.Constants.FLAG_NX;
import static dev.keva.util.Constants.FLAG_XX;

@Component
@CommandImpl("zadd")
@ParamLength(type = ParamLength.Type.AT_LEAST, value = 3)
@Mutate
public final class ZAdd {
private static final String XX = "xx";
private static final String NX = "nx";
private static final String GT = "gt";
private static final String LT = "lt";
private static final String INCR = "incr";
private static final String CH = "ch";

private final KevaDatabase database;

@Autowired
public ZAdd(KevaDatabase database) {
this.database = database;
}

@Execute
public Reply<?> execute(byte[][] params) {
// Parse the flags, if any
boolean xx = false, nx = false, gt = false, lt = false, incr = false;
int argPos = 1, flags = 0;
String arg;
while (argPos < params.length) {
arg = new String(params[argPos], StandardCharsets.UTF_8);
if (XX.equalsIgnoreCase(arg)) {
xx = true;
flags |= FLAG_XX;
} else if (NX.equalsIgnoreCase(arg)) {
nx = true;
flags |= FLAG_NX;
} else if (GT.equalsIgnoreCase(arg)) {
gt = true;
flags |= FLAG_GT;
} else if (LT.equalsIgnoreCase(arg)) {
lt = true;
flags |= FLAG_LT;
} else if (INCR.equalsIgnoreCase(arg)) {
incr = true;
flags |= FLAG_INCR;
} else if (CH.equalsIgnoreCase(arg)) {
flags |= FLAG_CH;
} else {
break;
}
++argPos;
}

int numMembers = params.length - argPos;
if (numMembers % 2 != 0) {
return ErrorReply.SYNTAX_ERROR;
}
numMembers /= 2;

if (nx && xx) {
return ErrorReply.ZADD_NX_XX_ERROR;
}
if ((gt && nx) || (lt && nx) || (gt && lt)) {
return ErrorReply.ZADD_GT_LT_NX_ERROR;
}
if (incr && numMembers > 1) {
return ErrorReply.ZADD_INCR_ERROR;
}

// Parse the key and value
final SimpleEntry<Double, BytesKey>[] members = new SimpleEntry[numMembers];
double score;
String rawScore;
for (int memberIndex = 0; memberIndex < numMembers; ++memberIndex) {
try {
rawScore = new String(params[argPos++], StandardCharsets.UTF_8);
if (rawScore.equalsIgnoreCase("inf") || rawScore.equalsIgnoreCase("infinity")
|| rawScore.equalsIgnoreCase("+inf") || rawScore.equalsIgnoreCase("+infinity")
) {
score = Double.POSITIVE_INFINITY;
} else if (rawScore.equalsIgnoreCase("-inf") || rawScore.equalsIgnoreCase("-infinity")) {
score = Double.NEGATIVE_INFINITY;
} else {
score = Double.parseDouble(rawScore);
}
} catch (final NumberFormatException ignored) {
// return on first bad input
return ErrorReply.ZADD_SCORE_FLOAT_ERROR;
}
members[memberIndex] = new SimpleEntry<>(score, new BytesKey(params[argPos++]));
}

if (incr) {
Double result = database.zincrby(params[0], members[0].getKey(), members[0].getValue(), flags);
return result == null ? BulkReply.NIL_REPLY : new BulkReply(DoubleUtil.toString(result));
}
int result = database.zadd(params[0], members, flags);
return new IntegerReply(result);
}
}
36 changes: 36 additions & 0 deletions core/src/main/java/dev/keva/core/command/impl/zset/ZScore.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package dev.keva.core.command.impl.zset;

import dev.keva.core.command.annotation.CommandImpl;
import dev.keva.core.command.annotation.Execute;
import dev.keva.core.command.annotation.ParamLength;
import dev.keva.ioc.annotation.Autowired;
import dev.keva.ioc.annotation.Component;
import dev.keva.protocol.resp.reply.BulkReply;
import dev.keva.store.KevaDatabase;

@Component
@CommandImpl("zscore")
@ParamLength(type = ParamLength.Type.EXACT, value = 2)
public final class ZScore {
private final KevaDatabase database;

@Autowired
public ZScore(KevaDatabase database) {
this.database = database;
}

@Execute
public BulkReply execute(byte[] key, byte[] member) {
final Double result = database.zscore(key, member);
if(result == null){
return BulkReply.NIL_REPLY;
}
if (result.equals(Double.POSITIVE_INFINITY)) {
return BulkReply.POSITIVE_INFINITY_REPLY;
}
if (result.equals(Double.NEGATIVE_INFINITY)) {
return BulkReply.NEGATIVE_INFINITY_REPLY;
}
return new BulkReply(result.toString());
}
}
2 changes: 1 addition & 1 deletion core/src/test/java/dev/keva/core/server/AOFTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Server startServer(int port) throws Exception {
.persistence(false)
.aof(true)
.aofInterval(1000)
.workDirectory("./")
.workDirectory(System.getProperty("java.io.tmpdir"))
.build();
val server = KevaServer.of(config);
new Thread(() -> {
Expand Down
98 changes: 98 additions & 0 deletions core/src/test/java/dev/keva/core/server/AbstractServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import lombok.var;
import redis.clients.jedis.params.ZAddParams;

import static org.junit.jupiter.api.Assertions.*;

Expand Down Expand Up @@ -827,6 +830,101 @@ void setrange() {
}
}

@Test
void zaddWithXXAndNXErrs() {
assertThrows(JedisDataException.class, () -> {
jedis.zadd("zset", 1.0, "val", new ZAddParams().xx().nx());
});
}

@Test
void zaddSingleWithNxAndGtErrs() {
assertThrows(JedisDataException.class, () -> {
jedis.zadd("zset", 1.0, "val", new ZAddParams().gt().nx());
});
}

@Test
void zaddSingleWithNxAndLtErrs() {
assertThrows(JedisDataException.class, () -> {
jedis.zadd("zset", 1.0, "val", new ZAddParams().lt().nx());
});
}

@Test
void zaddSingleWithGtAndLtErrs() {
assertThrows(JedisDataException.class, () -> {
jedis.zadd("zset", 1.0, "val", new ZAddParams().lt().gt());
});
}

@Test
void zaddSingleWithoutOptions() {
try {
var result = jedis.zadd("zset", 1.0, "val");
assertEquals(1, result);

result = jedis.zadd("zset", 1.0, "val");
assertEquals(0, result);
} catch (Exception e) {
fail(e);
}
}

@Test
void zaddMultipleWithoutOptions() {
try {
Map<String, Double> members = new HashMap<>();
int numMembers = 100;
for(int i=0; i<numMembers; ++i) {
members.put(Integer.toString(i), (double) i);
}
var result = jedis.zadd("zset", members);
assertEquals(numMembers, result);

result = jedis.zadd("zset", members);
assertEquals(0, result);
} catch (Exception e) {
fail(e);
}
}

@Test
void zaddCh() {
try {
var result = jedis.zadd("zset", 1.0, "mem", new ZAddParams().ch());
assertEquals(1, result);

result = jedis.zadd("zset", 1.0, "mem", new ZAddParams().ch());
assertEquals(0, result);

result = jedis.zadd("zset", 2.0, "mem", new ZAddParams().ch());
assertEquals(1, result);
} catch (Exception e) {
fail(e);
}
}

@Test
void zscoreNonExistingKey() {
val result = jedis.zscore("key", "mem");
assertNull(result);
}

@Test
void zscoreNonExistingMember() {
jedis.zadd("zset", 1.0, "mem");
val result = jedis.zscore("zset", "foo");
assertNull(result);
}

@Test
void zscoreExistingMember() {
jedis.zadd("zset", 1.0, "mem");
val result = jedis.zscore("zset", "mem");
assertEquals(result, 1.0);
}

@Test
void dumpAndRestore() {
try {
Expand Down
8 changes: 8 additions & 0 deletions docs/src/guide/overview/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ Implemented commands:

</details>

<details>
<summary>SortedSet</summary>

- ZADD
- ZSCORE

</details>

<details>
<summary>Pub/Sub</summary>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

public class BulkReply implements Reply<ByteBuf> {
public static final BulkReply NIL_REPLY = new BulkReply();
public static final BulkReply POSITIVE_INFINITY_REPLY = new BulkReply("inf");
public static final BulkReply NEGATIVE_INFINITY_REPLY = new BulkReply("-inf");

public static final char MARKER = '$';
private final ByteBuf bytes;
Expand All @@ -22,11 +24,7 @@ private BulkReply() {
}

public BulkReply(byte[] bytes) {
if (bytes.length == 0) {
this.bytes = Unpooled.EMPTY_BUFFER;
} else {
this.bytes = Unpooled.wrappedBuffer(bytes);
}
this.bytes = Unpooled.wrappedBuffer(bytes);
capacity = bytes.length;
}

Expand Down Expand Up @@ -59,7 +57,7 @@ public void write(ByteBuf os) throws IOException {
os.writeByte(MARKER);
os.writeBytes(numToBytes(capacity, true));
if (capacity > 0) {
os.writeBytes(bytes);
os.writeBytes(bytes.array());
os.writeBytes(CRLF);
}
if (capacity == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

public class ErrorReply implements Reply<String> {
public static final char MARKER = '-';
// Pre-defined errors
public static final ErrorReply SYNTAX_ERROR = new ErrorReply("ERR syntax error");
public static final ErrorReply ZADD_NX_XX_ERROR = new ErrorReply("ERR XX and NX options at the same time are not compatible");
public static final ErrorReply ZADD_GT_LT_NX_ERROR = new ErrorReply("GT, LT, and/or NX options at the same time are not compatible");
public static final ErrorReply ZADD_INCR_ERROR = new ErrorReply("INCR option supports a single increment-element pair");
public static final ErrorReply ZADD_SCORE_FLOAT_ERROR = new ErrorReply("value is not a valid float");

private final String error;

public ErrorReply(String error) {
Expand Down
8 changes: 8 additions & 0 deletions store/src/main/java/dev/keva/store/KevaDatabase.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package dev.keva.store;

import dev.keva.util.hashbytes.BytesKey;

import java.util.AbstractMap;
import java.util.concurrent.locks.Lock;

public interface KevaDatabase {
Expand Down Expand Up @@ -69,4 +72,9 @@ public interface KevaDatabase {

byte[][] mget(byte[]... keys);

int zadd(byte[] key, AbstractMap.SimpleEntry<Double, BytesKey>[] members, int flags);

Double zincrby(byte[] key, Double score, BytesKey e, int flags);

Double zscore(byte[] key, byte[] member);
}
Loading