Skip to content

Commit 16cfec7

Browse files
committed
Implement ZAdd and Zscore
1 parent 93a0093 commit 16cfec7

File tree

12 files changed

+763
-8
lines changed

12 files changed

+763
-8
lines changed

core/src/main/java/dev/keva/core/aof/AOFContainer.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ public List<Command> read() throws IOException {
114114
byte[][] objects = (byte[][]) input.readObject();
115115
commands.add(Command.newInstance(objects, false));
116116
} catch (EOFException e) {
117+
log.error("Error while reading AOF command", e);
117118
fis.close();
118119
return commands;
119120
} catch (ClassNotFoundException e) {
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package dev.keva.core.command.impl.zset;
2+
3+
import dev.keva.core.command.annotation.CommandImpl;
4+
import dev.keva.core.command.annotation.Execute;
5+
import dev.keva.core.command.annotation.Mutate;
6+
import dev.keva.core.command.annotation.ParamLength;
7+
import dev.keva.ioc.annotation.Autowired;
8+
import dev.keva.ioc.annotation.Component;
9+
import dev.keva.protocol.resp.reply.BulkReply;
10+
import dev.keva.protocol.resp.reply.ErrorReply;
11+
import dev.keva.protocol.resp.reply.IntegerReply;
12+
import dev.keva.protocol.resp.reply.Reply;
13+
import dev.keva.store.KevaDatabase;
14+
import dev.keva.util.DoubleUtil;
15+
import dev.keva.util.hashbytes.BytesKey;
16+
17+
import java.nio.charset.StandardCharsets;
18+
import java.util.AbstractMap.SimpleEntry;
19+
20+
import static dev.keva.util.Constants.FLAG_CH;
21+
import static dev.keva.util.Constants.FLAG_GT;
22+
import static dev.keva.util.Constants.FLAG_INCR;
23+
import static dev.keva.util.Constants.FLAG_LT;
24+
import static dev.keva.util.Constants.FLAG_NX;
25+
import static dev.keva.util.Constants.FLAG_XX;
26+
27+
@Component
28+
@CommandImpl("zadd")
29+
@ParamLength(type = ParamLength.Type.AT_LEAST, value = 3)
30+
@Mutate
31+
public final class ZAdd {
32+
private static final String XX = "xx";
33+
private static final String NX = "nx";
34+
private static final String GT = "gt";
35+
private static final String LT = "lt";
36+
private static final String INCR = "incr";
37+
private static final String CH = "ch";
38+
39+
private final KevaDatabase database;
40+
41+
@Autowired
42+
public ZAdd(KevaDatabase database) {
43+
this.database = database;
44+
}
45+
46+
@Execute
47+
public Reply<?> execute(byte[][] params) {
48+
// Parse the flags, if any
49+
boolean xx = false, nx = false, gt = false, lt = false, incr = false;
50+
int argPos = 1, flags = 0;
51+
String arg;
52+
while (argPos < params.length) {
53+
arg = new String(params[argPos], StandardCharsets.UTF_8);
54+
if (XX.equalsIgnoreCase(arg)) {
55+
xx = true;
56+
flags |= FLAG_XX;
57+
} else if (NX.equalsIgnoreCase(arg)) {
58+
nx = true;
59+
flags |= FLAG_NX;
60+
} else if (GT.equalsIgnoreCase(arg)) {
61+
gt = true;
62+
flags |= FLAG_GT;
63+
} else if (LT.equalsIgnoreCase(arg)) {
64+
lt = true;
65+
flags |= FLAG_LT;
66+
} else if (INCR.equalsIgnoreCase(arg)) {
67+
incr = true;
68+
flags |= FLAG_INCR;
69+
} else if (CH.equalsIgnoreCase(arg)) {
70+
flags |= FLAG_CH;
71+
} else {
72+
break;
73+
}
74+
++argPos;
75+
}
76+
77+
int numMembers = params.length - argPos;
78+
if (numMembers % 2 != 0) {
79+
return ErrorReply.SYNTAX_ERROR;
80+
}
81+
numMembers /= 2;
82+
83+
if (nx && xx) {
84+
return ErrorReply.ZADD_NX_XX_ERROR;
85+
}
86+
if ((gt && nx) || (lt && nx) || (gt && lt)) {
87+
return ErrorReply.ZADD_GT_LT_NX_ERROR;
88+
}
89+
if (incr && numMembers > 1) {
90+
return ErrorReply.ZADD_INCR_ERROR;
91+
}
92+
93+
// Parse the key and value
94+
final SimpleEntry<Double, BytesKey>[] members = new SimpleEntry[numMembers];
95+
double score;
96+
String rawScore;
97+
for (int memberIndex = 0; memberIndex < numMembers; ++memberIndex) {
98+
try {
99+
rawScore = new String(params[argPos++], StandardCharsets.UTF_8);
100+
if (rawScore.equalsIgnoreCase("inf") || rawScore.equalsIgnoreCase("infinity")
101+
|| rawScore.equalsIgnoreCase("+inf") || rawScore.equalsIgnoreCase("+infinity")
102+
) {
103+
score = Double.POSITIVE_INFINITY;
104+
} else if (rawScore.equalsIgnoreCase("-inf") || rawScore.equalsIgnoreCase("-infinity")) {
105+
score = Double.NEGATIVE_INFINITY;
106+
} else {
107+
score = Double.parseDouble(rawScore);
108+
}
109+
} catch (final NumberFormatException ignored) {
110+
// return on first bad input
111+
return ErrorReply.ZADD_SCORE_FLOAT_ERROR;
112+
}
113+
members[memberIndex] = new SimpleEntry<>(score, new BytesKey(params[argPos++]));
114+
}
115+
116+
if (incr) {
117+
Double result = database.zincrby(params[0], members[0].getKey(), members[0].getValue(), flags);
118+
return result == null ? BulkReply.NIL_REPLY : new BulkReply(DoubleUtil.toString(result));
119+
}
120+
int result = database.zadd(params[0], members, flags);
121+
return new IntegerReply(result);
122+
}
123+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package dev.keva.core.command.impl.zset;
2+
3+
import dev.keva.core.command.annotation.CommandImpl;
4+
import dev.keva.core.command.annotation.Execute;
5+
import dev.keva.core.command.annotation.ParamLength;
6+
import dev.keva.ioc.annotation.Autowired;
7+
import dev.keva.ioc.annotation.Component;
8+
import dev.keva.protocol.resp.reply.BulkReply;
9+
import dev.keva.store.KevaDatabase;
10+
11+
@Component
12+
@CommandImpl("zscore")
13+
@ParamLength(type = ParamLength.Type.EXACT, value = 2)
14+
public final class ZScore {
15+
private final KevaDatabase database;
16+
17+
@Autowired
18+
public ZScore(KevaDatabase database) {
19+
this.database = database;
20+
}
21+
22+
@Execute
23+
public BulkReply execute(byte[] key, byte[] member) {
24+
final Double result = database.zscore(key, member);
25+
if(result == null){
26+
return BulkReply.NIL_REPLY;
27+
}
28+
if (result.equals(Double.POSITIVE_INFINITY)) {
29+
return BulkReply.POSITIVE_INFINITY_REPLY;
30+
}
31+
if (result.equals(Double.NEGATIVE_INFINITY)) {
32+
return BulkReply.NEGATIVE_INFINITY_REPLY;
33+
}
34+
return new BulkReply(result.toString());
35+
}
36+
}

core/src/test/java/dev/keva/core/server/AbstractServerTest.java

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99

1010
import java.util.Arrays;
1111
import java.util.Collections;
12+
import java.util.HashMap;
13+
import java.util.Map;
1214
import java.util.concurrent.CompletableFuture;
1315
import java.util.concurrent.ExecutionException;
1416

1517
import lombok.var;
18+
import redis.clients.jedis.params.ZAddParams;
1619

1720
import static org.junit.jupiter.api.Assertions.*;
1821

@@ -827,6 +830,131 @@ void setrange() {
827830
}
828831
}
829832

833+
@Test
834+
void zaddWithXXAndNXErrs() {
835+
try {
836+
assertThrows(JedisDataException.class, () -> {
837+
jedis.zadd("zset", 1.0, "val", new ZAddParams().xx().nx());
838+
});
839+
} finally {
840+
jedis.del("zset");
841+
}
842+
}
843+
844+
@Test
845+
void zaddSingleWithNxAndGtErrs() {
846+
try {
847+
assertThrows(JedisDataException.class, () -> {
848+
jedis.zadd("zset", 1.0, "val", new ZAddParams().gt().nx());
849+
});
850+
} finally {
851+
jedis.del("zset");
852+
}
853+
}
854+
855+
@Test
856+
void zaddSingleWithNxAndLtErrs() {
857+
try {
858+
assertThrows(JedisDataException.class, () -> {
859+
jedis.zadd("zset", 1.0, "val", new ZAddParams().lt().nx());
860+
});
861+
} finally {
862+
jedis.del("zset");
863+
}
864+
}
865+
866+
@Test
867+
void zaddSingleWithGtAndLtErrs() {
868+
try {
869+
assertThrows(JedisDataException.class, () -> {
870+
jedis.zadd("zset", 1.0, "val", new ZAddParams().lt().gt());
871+
});
872+
} finally {
873+
jedis.del("zset");
874+
}
875+
}
876+
877+
@Test
878+
void zaddSingleWithoutOptions() {
879+
try {
880+
var result = jedis.zadd("zset", 1.0, "val");
881+
assertEquals(1, result);
882+
883+
result = jedis.zadd("zset", 1.0, "val");
884+
assertEquals(0, result);
885+
} catch (Exception e) {
886+
fail(e);
887+
} finally {
888+
jedis.del("zset");
889+
}
890+
}
891+
892+
@Test
893+
void zaddMultipleWithoutOptions() {
894+
try {
895+
Map<String, Double> members = new HashMap<>();
896+
int numMembers = 100;
897+
for(int i=0; i<numMembers; ++i) {
898+
members.put(Integer.toString(i), (double) i);
899+
}
900+
var result = jedis.zadd("zset", members);
901+
assertEquals(numMembers, result);
902+
903+
result = jedis.zadd("zset", members);
904+
assertEquals(0, result);
905+
} catch (Exception e) {
906+
fail(e);
907+
} finally {
908+
jedis.del("zset");
909+
}
910+
}
911+
912+
@Test
913+
void zaddCh() {
914+
try {
915+
var result = jedis.zadd("zset", 1.0, "mem", new ZAddParams().ch());
916+
assertEquals(1, result);
917+
918+
result = jedis.zadd("zset", 1.0, "mem", new ZAddParams().ch());
919+
assertEquals(0, result);
920+
921+
result = jedis.zadd("zset", 2.0, "mem", new ZAddParams().ch());
922+
assertEquals(1, result);
923+
} catch (Exception e) {
924+
fail(e);
925+
} finally {
926+
jedis.del("zset");
927+
}
928+
}
929+
930+
@Test
931+
void zscoreNonExistingKey() {
932+
val result = jedis.zscore("key", "mem");
933+
assertNull(result);
934+
}
935+
936+
@Test
937+
void zscoreNonExistingMember() {
938+
try {
939+
jedis.zadd("zset", 1.0, "mem");
940+
val result = jedis.zscore("zset", "foo");
941+
assertNull(result);
942+
} finally {
943+
jedis.del("zset");
944+
}
945+
}
946+
947+
@Test
948+
void zscoreExistingMember() {
949+
try {
950+
jedis.zadd("zset", 1.0, "mem");
951+
val result = jedis.zscore("zset", "mem");
952+
assertEquals(result, 1.0);
953+
} finally {
954+
jedis.del("zset");
955+
}
956+
}
957+
830958
@Test
831959
void dumpAndRestore() {
832960
try {

resp-protocol/src/main/java/dev/keva/protocol/resp/reply/BulkReply.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
public class BulkReply implements Reply<ByteBuf> {
1313
public static final BulkReply NIL_REPLY = new BulkReply();
14+
public static final BulkReply POSITIVE_INFINITY_REPLY = new BulkReply("inf");
15+
public static final BulkReply NEGATIVE_INFINITY_REPLY = new BulkReply("-inf");
1416

1517
public static final char MARKER = '$';
1618
private final ByteBuf bytes;
@@ -22,11 +24,7 @@ private BulkReply() {
2224
}
2325

2426
public BulkReply(byte[] bytes) {
25-
if (bytes.length == 0) {
26-
this.bytes = Unpooled.EMPTY_BUFFER;
27-
} else {
28-
this.bytes = Unpooled.wrappedBuffer(bytes);
29-
}
27+
this.bytes = Unpooled.wrappedBuffer(bytes);
3028
capacity = bytes.length;
3129
}
3230

@@ -59,7 +57,7 @@ public void write(ByteBuf os) throws IOException {
5957
os.writeByte(MARKER);
6058
os.writeBytes(numToBytes(capacity, true));
6159
if (capacity > 0) {
62-
os.writeBytes(bytes);
60+
os.writeBytes(bytes.array());
6361
os.writeBytes(CRLF);
6462
}
6563
if (capacity == 0) {

resp-protocol/src/main/java/dev/keva/protocol/resp/reply/ErrorReply.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77

88
public class ErrorReply implements Reply<String> {
99
public static final char MARKER = '-';
10+
// Pre-defined errors
11+
public static final ErrorReply SYNTAX_ERROR = new ErrorReply("ERR syntax error");
12+
public static final ErrorReply ZADD_NX_XX_ERROR = new ErrorReply("ERR XX and NX options at the same time are not compatible");
13+
public static final ErrorReply ZADD_GT_LT_NX_ERROR = new ErrorReply("GT, LT, and/or NX options at the same time are not compatible");
14+
public static final ErrorReply ZADD_INCR_ERROR = new ErrorReply("INCR option supports a single increment-element pair");
15+
public static final ErrorReply ZADD_SCORE_FLOAT_ERROR = new ErrorReply("value is not a valid float");
16+
1017
private final String error;
1118

1219
public ErrorReply(String error) {

store/src/main/java/dev/keva/store/KevaDatabase.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package dev.keva.store;
22

3+
import dev.keva.util.hashbytes.BytesKey;
4+
5+
import java.util.AbstractMap;
36
import java.util.concurrent.locks.Lock;
47

58
public interface KevaDatabase {
@@ -69,4 +72,9 @@ public interface KevaDatabase {
6972

7073
byte[][] mget(byte[]... keys);
7174

75+
int zadd(byte[] key, AbstractMap.SimpleEntry<Double, BytesKey>[] members, int flags);
76+
77+
Double zincrby(byte[] key, Double score, BytesKey e, int flags);
78+
79+
Double zscore(byte[] key, byte[] member);
7280
}

0 commit comments

Comments
 (0)