diff --git a/docs/src/guide/overview/commands.md b/docs/src/guide/overview/commands.md index 3a01c9fb..0a3fcc50 100644 --- a/docs/src/guide/overview/commands.md +++ b/docs/src/guide/overview/commands.md @@ -65,6 +65,21 @@ Implemented commands: +
+ Set + +- SADD +- SMEMBERS +- SISMEMBER +- SCARD +- SDIFF +- SINTER +- SUNION +- SMOVE +- SREM + +
+
Pub/Sub diff --git a/server/src/main/java/dev/keva/server/command/impl/connection/Info.java b/server/src/main/java/dev/keva/server/command/impl/connection/Info.java index ee99d313..adc4d64b 100644 --- a/server/src/main/java/dev/keva/server/command/impl/connection/Info.java +++ b/server/src/main/java/dev/keva/server/command/impl/connection/Info.java @@ -19,7 +19,7 @@ public BulkReply execute() { val threads = ManagementFactory.getThreadMXBean().getThreadCount(); String infoStr = "# Server\r\n" + "keva_version: 1.0.0\r\n" - + "io_threads_active: " + threads; + + "io_threads_active: " + threads + "\r\n"; return new BulkReply(infoStr); } } diff --git a/server/src/main/java/dev/keva/server/command/impl/set/SAdd.java b/server/src/main/java/dev/keva/server/command/impl/set/SAdd.java new file mode 100644 index 00000000..03668813 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/set/SAdd.java @@ -0,0 +1,29 @@ +package dev.keva.server.command.impl.set; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.IntegerReply; +import dev.keva.server.command.annotation.CommandImpl; +import dev.keva.server.command.annotation.Execute; +import dev.keva.server.command.annotation.ParamLength; +import dev.keva.store.KevaDatabase; + +import java.util.Arrays; + +@Component +@CommandImpl("sadd") +@ParamLength(type = ParamLength.Type.AT_LEAST, value = 2) +public class SAdd { + private final KevaDatabase database; + + @Autowired + public SAdd(KevaDatabase database) { + this.database = database; + } + + @Execute + public IntegerReply execute(byte[][] params) { + int added = database.sadd(params[0], Arrays.copyOfRange(params, 1, params.length)); + return new IntegerReply(added); + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/set/SCard.java b/server/src/main/java/dev/keva/server/command/impl/set/SCard.java new file mode 100644 index 00000000..d2955036 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/set/SCard.java @@ -0,0 +1,27 @@ +package dev.keva.server.command.impl.set; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.IntegerReply; +import dev.keva.server.command.annotation.CommandImpl; +import dev.keva.server.command.annotation.Execute; +import dev.keva.server.command.annotation.ParamLength; +import dev.keva.store.KevaDatabase; + +@Component +@CommandImpl("scard") +@ParamLength(1) +public class SCard { + private final KevaDatabase database; + + @Autowired + public SCard(KevaDatabase database) { + this.database = database; + } + + @Execute + public IntegerReply execute(byte[] key) { + int num = database.scard(key); + return new IntegerReply(num); + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/set/SDiff.java b/server/src/main/java/dev/keva/server/command/impl/set/SDiff.java new file mode 100644 index 00000000..e1f68cc9 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/set/SDiff.java @@ -0,0 +1,32 @@ +package dev.keva.server.command.impl.set; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.BulkReply; +import dev.keva.protocol.resp.reply.MultiBulkReply; +import dev.keva.server.command.annotation.CommandImpl; +import dev.keva.server.command.annotation.Execute; +import dev.keva.server.command.annotation.ParamLength; +import dev.keva.store.KevaDatabase; + +@Component +@CommandImpl("sdiff") +@ParamLength(2) +public class SDiff { + private final KevaDatabase database; + + @Autowired + public SDiff(KevaDatabase database) { + this.database = database; + } + + @Execute + public MultiBulkReply execute(byte[]... keys) { + byte[][] diff = database.sdiff(keys); + BulkReply[] replies = new BulkReply[diff.length]; + for (int i = 0; i < diff.length; i++) { + replies[i] = new BulkReply(diff[i]); + } + return new MultiBulkReply(replies); + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/set/SInter.java b/server/src/main/java/dev/keva/server/command/impl/set/SInter.java new file mode 100644 index 00000000..411236a1 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/set/SInter.java @@ -0,0 +1,32 @@ +package dev.keva.server.command.impl.set; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.BulkReply; +import dev.keva.protocol.resp.reply.MultiBulkReply; +import dev.keva.server.command.annotation.CommandImpl; +import dev.keva.server.command.annotation.Execute; +import dev.keva.server.command.annotation.ParamLength; +import dev.keva.store.KevaDatabase; + +@Component +@CommandImpl("sinter") +@ParamLength(2) +public class SInter { + private final KevaDatabase database; + + @Autowired + public SInter(KevaDatabase database) { + this.database = database; + } + + @Execute + public MultiBulkReply execute(byte[]... keys) { + byte[][] diff = database.sinter(keys); + BulkReply[] replies = new BulkReply[diff.length]; + for (int i = 0; i < diff.length; i++) { + replies[i] = new BulkReply(diff[i]); + } + return new MultiBulkReply(replies); + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/set/SIsMember.java b/server/src/main/java/dev/keva/server/command/impl/set/SIsMember.java new file mode 100644 index 00000000..2506bb15 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/set/SIsMember.java @@ -0,0 +1,27 @@ +package dev.keva.server.command.impl.set; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.IntegerReply; +import dev.keva.server.command.annotation.CommandImpl; +import dev.keva.server.command.annotation.Execute; +import dev.keva.server.command.annotation.ParamLength; +import dev.keva.store.KevaDatabase; + +@Component +@CommandImpl("sismember") +@ParamLength(2) +public class SIsMember { + private final KevaDatabase database; + + @Autowired + public SIsMember(KevaDatabase database) { + this.database = database; + } + + @Execute + public IntegerReply execute(byte[] key, byte[] value) { + boolean isMember = database.sismember(key, value); + return new IntegerReply(isMember ? 1 : 0); + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/set/SMembers.java b/server/src/main/java/dev/keva/server/command/impl/set/SMembers.java new file mode 100644 index 00000000..976638aa --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/set/SMembers.java @@ -0,0 +1,36 @@ +package dev.keva.server.command.impl.set; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.BulkReply; +import dev.keva.protocol.resp.reply.MultiBulkReply; +import dev.keva.protocol.resp.reply.Reply; +import dev.keva.server.command.annotation.CommandImpl; +import dev.keva.server.command.annotation.Execute; +import dev.keva.server.command.annotation.ParamLength; +import dev.keva.store.KevaDatabase; + +@Component +@CommandImpl("smembers") +@ParamLength(1) +public class SMembers { + private final KevaDatabase database; + + @Autowired + public SMembers(KevaDatabase database) { + this.database = database; + } + + @Execute + public MultiBulkReply execute(byte[] key) { + byte[][] result = database.smembers(key); + if (result == null) { + return new MultiBulkReply(new Reply[0]); + } + BulkReply[] replies = new BulkReply[result.length]; + for (int i = 0; i < result.length; i++) { + replies[i] = new BulkReply(result[i]); + } + return new MultiBulkReply(replies); + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/set/SMove.java b/server/src/main/java/dev/keva/server/command/impl/set/SMove.java new file mode 100644 index 00000000..1ee4afa3 --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/set/SMove.java @@ -0,0 +1,27 @@ +package dev.keva.server.command.impl.set; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.IntegerReply; +import dev.keva.server.command.annotation.CommandImpl; +import dev.keva.server.command.annotation.Execute; +import dev.keva.server.command.annotation.ParamLength; +import dev.keva.store.KevaDatabase; + +@Component +@CommandImpl("smove") +@ParamLength(3) +public class SMove { + private final KevaDatabase database; + + @Autowired + public SMove(KevaDatabase database) { + this.database = database; + } + + @Execute + public IntegerReply execute(byte[] source, byte[] destination, byte[] member) { + int count = database.smove(source, destination, member); + return new IntegerReply(count); + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/set/SRem.java b/server/src/main/java/dev/keva/server/command/impl/set/SRem.java new file mode 100644 index 00000000..7e95d1ba --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/set/SRem.java @@ -0,0 +1,29 @@ +package dev.keva.server.command.impl.set; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.IntegerReply; +import dev.keva.server.command.annotation.CommandImpl; +import dev.keva.server.command.annotation.Execute; +import dev.keva.server.command.annotation.ParamLength; +import dev.keva.store.KevaDatabase; + +import java.util.Arrays; + +@Component +@CommandImpl("srem") +@ParamLength(type = ParamLength.Type.AT_LEAST, value = 2) +public class SRem { + private final KevaDatabase database; + + @Autowired + public SRem(KevaDatabase database) { + this.database = database; + } + + @Execute + public IntegerReply execute(byte[][] params) { + int removed = database.srem(params[0], Arrays.copyOfRange(params, 1, params.length)); + return new IntegerReply(removed); + } +} diff --git a/server/src/main/java/dev/keva/server/command/impl/set/SUnion.java b/server/src/main/java/dev/keva/server/command/impl/set/SUnion.java new file mode 100644 index 00000000..e66eeeaf --- /dev/null +++ b/server/src/main/java/dev/keva/server/command/impl/set/SUnion.java @@ -0,0 +1,32 @@ +package dev.keva.server.command.impl.set; + +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.BulkReply; +import dev.keva.protocol.resp.reply.MultiBulkReply; +import dev.keva.server.command.annotation.CommandImpl; +import dev.keva.server.command.annotation.Execute; +import dev.keva.server.command.annotation.ParamLength; +import dev.keva.store.KevaDatabase; + +@Component +@CommandImpl("sunion") +@ParamLength(2) +public class SUnion { + private final KevaDatabase database; + + @Autowired + public SUnion(KevaDatabase database) { + this.database = database; + } + + @Execute + public MultiBulkReply execute(byte[]... keys) { + byte[][] diff = database.sunion(keys); + BulkReply[] replies = new BulkReply[diff.length]; + for (int i = 0; i < diff.length; i++) { + replies[i] = new BulkReply(diff[i]); + } + return new MultiBulkReply(replies); + } +} diff --git a/server/src/main/java/dev/keva/server/command/mapping/CommandMapper.java b/server/src/main/java/dev/keva/server/command/mapping/CommandMapper.java index 21c2aeac..1ac2113f 100644 --- a/server/src/main/java/dev/keva/server/command/mapping/CommandMapper.java +++ b/server/src/main/java/dev/keva/server/command/mapping/CommandMapper.java @@ -72,6 +72,9 @@ public void init() { } catch (Exception e) { log.error("", e); if (e instanceof InvocationTargetException) { + if (e.getCause() instanceof ClassCastException) { + return new ErrorReply("ERR WRONGTYPE Operation against a key holding the wrong kind of value"); + } return new ErrorReply("ERR " + e.getCause().getMessage()); } return new ErrorReply("ERR " + e.getMessage()); diff --git a/server/src/test/java/dev/keva/server/core/AbstractServerTest.java b/server/src/test/java/dev/keva/server/core/AbstractServerTest.java index 436f36be..801fa72d 100644 --- a/server/src/test/java/dev/keva/server/core/AbstractServerTest.java +++ b/server/src/test/java/dev/keva/server/core/AbstractServerTest.java @@ -543,4 +543,114 @@ void lrem() { fail(e); } } + + @Test + void sadd() { + try { + val sadd = jedis.sadd("test", "val"); + assertEquals(1, sadd); + } catch (Exception e) { + fail(e); + } + } + + @Test + void smembers() { + try { + val sadd = jedis.sadd("test", "val"); + val smembers = jedis.smembers("test"); + assertEquals(1, sadd); + assertEquals(1, smembers.size()); + assertEquals("val", smembers.toArray()[0]); + } catch (Exception e) { + fail(e); + } + } + + @Test + void sismember() { + try { + val sadd = jedis.sadd("test", "val"); + val sismember = jedis.sismember("test", "val"); + assertEquals(1, sadd); + assertEquals(true, sismember); + } catch (Exception e) { + fail(e); + } + } + + @Test + void scard() { + try { + val sadd = jedis.sadd("test", "val"); + val scard = jedis.scard("test"); + assertEquals(1, sadd); + assertEquals(1, scard); + } catch (Exception e) { + fail(e); + } + } + + @Test + void sdiff() { + try { + val sadd = jedis.sadd("test", "val"); + val sdiff = jedis.sdiff("test", "test2"); + assertEquals(1, sadd); + assertEquals(1, sdiff.size()); + } catch (Exception e) { + fail(e); + } + } + + @Test + void sinter() { + try { + val sadd = jedis.sadd("test", "val"); + val sadd2 = jedis.sadd("test2", "val"); + val sinter = jedis.sinter("test", "test2"); + assertEquals(1, sadd); + assertEquals(1, sadd2); + assertEquals(1, sinter.size()); + } catch (Exception e) { + fail(e); + } + } + + @Test + void sunion() { + try { + val sadd = jedis.sadd("test", "val"); + val sunion = jedis.sunion("test", "test2"); + assertEquals(1, sadd); + assertEquals(1, sunion.size()); + assertEquals("val", sunion.toArray()[0]); + } catch (Exception e) { + fail(e); + } + } + + @Test + void smove() { + try { + val sadd = jedis.sadd("test", "val"); + val smove = jedis.smove("test", "test2", "val"); + assertEquals(1, sadd); + assertEquals(1, smove); + } catch (Exception e) { + fail(e); + } + } + + @Test + void srem() { + try { + val sadd = jedis.sadd("test", "val"); + val srem = jedis.srem("test", "val"); + assertEquals(1, sadd); + assertEquals(1, srem); + } catch (Exception e) { + fail(e); + } + } } diff --git a/store/src/main/java/dev/keva/store/KevaDatabase.java b/store/src/main/java/dev/keva/store/KevaDatabase.java index 8d96101b..2e69417b 100644 --- a/store/src/main/java/dev/keva/store/KevaDatabase.java +++ b/store/src/main/java/dev/keva/store/KevaDatabase.java @@ -45,4 +45,22 @@ public interface KevaDatabase { void lset(byte[] key, int index, byte[] value); int lrem(byte[] key, int count, byte[] value); + + int sadd(byte[] key, byte[]... values); + + byte[][] smembers(byte[] key); + + boolean sismember(byte[] key, byte[] value); + + int scard(byte[] key); + + byte[][] sdiff(byte[]... keys); + + byte[][] sinter(byte[]... keys); + + byte[][] sunion(byte[]... keys); + + int smove(byte[] source, byte[] destination, byte[] value); + + int srem(byte[] key, byte[]... values); } diff --git a/store/src/main/java/dev/keva/store/impl/ChronicleMapImpl.java b/store/src/main/java/dev/keva/store/impl/ChronicleMapImpl.java index ef761354..13021d4d 100644 --- a/store/src/main/java/dev/keva/store/impl/ChronicleMapImpl.java +++ b/store/src/main/java/dev/keva/store/impl/ChronicleMapImpl.java @@ -453,4 +453,211 @@ public int lrem(byte[] key, int count, byte[] value) { lock.unlock(); } } + + @Override + @SuppressWarnings("unchecked") + public int sadd(byte[] key, byte[]... values) { + lock.lock(); + try { + byte[] value = chronicleMap.get(key); + HashSet set; + set = value == null ? new HashSet<>() : (HashSet) SerializationUtils.deserialize(value); + int count = 0; + for (byte[] v : values) { + boolean isNewElement = set.add(new BytesKey(v)); + if (isNewElement) { + count++; + } + } + chronicleMap.put(key, SerializationUtils.serialize(set)); + return count; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public byte[][] smembers(byte[] key) { + lock.lock(); + try { + byte[] value = chronicleMap.get(key); + if (value == null) { + return null; + } + HashSet set = (HashSet) SerializationUtils.deserialize(value); + byte[][] result = new byte[set.size()][]; + int i = 0; + for (BytesKey v : set) { + result[i++] = v.getBytes(); + } + return result; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public boolean sismember(byte[] key, byte[] value) { + lock.lock(); + try { + byte[] got = chronicleMap.get(key); + if (got == null) { + return false; + } + HashSet set = (HashSet) SerializationUtils.deserialize(got); + return set.contains(new BytesKey(value)); + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public int scard(byte[] key) { + lock.lock(); + try { + byte[] value = chronicleMap.get(key); + if (value == null) { + return 0; + } + HashSet set = (HashSet) SerializationUtils.deserialize(value); + return set.size(); + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public byte[][] sdiff(byte[]... keys) { + lock.lock(); + try { + HashSet set = new HashSet<>(); + for (byte[] key : keys) { + byte[] value = chronicleMap.get(key); + if (set.isEmpty() && value != null) { + set.addAll((HashSet) SerializationUtils.deserialize(value)); + } else if (value != null) { + HashSet set1 = (HashSet) SerializationUtils.deserialize(value); + set.removeAll(set1); + } + } + byte[][] result = new byte[set.size()][]; + int i = 0; + for (BytesKey v : set) { + result[i++] = v.getBytes(); + } + return result; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public byte[][] sinter(byte[]... keys) { + lock.lock(); + try { + HashSet set = new HashSet<>(); + for (byte[] key : keys) { + byte[] value = chronicleMap.get(key); + if (set.isEmpty() && value != null) { + set.addAll((HashSet) SerializationUtils.deserialize(value)); + } else if (value != null) { + HashSet set1 = (HashSet) SerializationUtils.deserialize(value); + set.retainAll(set1); + } + } + byte[][] result = new byte[set.size()][]; + int i = 0; + for (BytesKey v : set) { + result[i++] = v.getBytes(); + } + return result; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public byte[][] sunion(byte[]... keys) { + lock.lock(); + try { + HashSet set = new HashSet<>(); + for (byte[] key : keys) { + byte[] value = chronicleMap.get(key); + if (value != null) { + HashSet set1 = (HashSet) SerializationUtils.deserialize(value); + set.addAll(set1); + } + } + byte[][] result = new byte[set.size()][]; + int i = 0; + for (BytesKey v : set) { + result[i++] = v.getBytes(); + } + return result; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public int smove(byte[] source, byte[] destination, byte[] value) { + lock.lock(); + try { + byte[] sourceValue = chronicleMap.get(source); + if (sourceValue == null) { + return 0; + } + HashSet set = (HashSet) SerializationUtils.deserialize(sourceValue); + if (set.remove(new BytesKey(value))) { + byte[] destinationValue = chronicleMap.get(destination); + HashSet set1; + if (destinationValue == null) { + set1 = new HashSet<>(); + } else { + set1 = (HashSet) SerializationUtils.deserialize(destinationValue); + } + boolean result = set1.add(new BytesKey(value)); + chronicleMap.put(source, SerializationUtils.serialize(set)); + chronicleMap.put(destination, SerializationUtils.serialize(set1)); + return result ? 1 : 0; + } + return 0; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public int srem(byte[] key, byte[]... values) { + lock.lock(); + try { + byte[] value = chronicleMap.get(key); + if (value == null) { + return 0; + } + HashSet set = (HashSet) SerializationUtils.deserialize(value); + int result = 0; + for (byte[] v : values) { + if (set.remove(new BytesKey(v))) { + result++; + } + } + if (set.isEmpty()) { + chronicleMap.remove(key); + } else { + chronicleMap.put(key, SerializationUtils.serialize(set)); + } + return result; + } finally { + lock.unlock(); + } + } } diff --git a/store/src/main/java/dev/keva/store/impl/HashMapImpl.java b/store/src/main/java/dev/keva/store/impl/HashMapImpl.java index 3d131472..a2df9755 100644 --- a/store/src/main/java/dev/keva/store/impl/HashMapImpl.java +++ b/store/src/main/java/dev/keva/store/impl/HashMapImpl.java @@ -411,4 +411,211 @@ public int lrem(byte[] key, int count, byte[] value) { lock.unlock(); } } + + @Override + @SuppressWarnings("unchecked") + public int sadd(byte[] key, byte[]... values) { + lock.lock(); + try { + byte[] value = map.get(new BytesKey(key)).getBytes(); + HashSet set; + set = value == null ? new HashSet<>() : (HashSet) SerializationUtils.deserialize(value); + int count = 0; + for (byte[] v : values) { + boolean isNewElement = set.add(new BytesKey(v)); + if (isNewElement) { + count++; + } + } + map.put(new BytesKey(key), new BytesValue(SerializationUtils.serialize(set))); + return count; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public byte[][] smembers(byte[] key) { + lock.lock(); + try { + byte[] value = map.get(new BytesKey(key)).getBytes(); + if (value == null) { + return null; + } + HashSet set = (HashSet) SerializationUtils.deserialize(value); + byte[][] result = new byte[set.size()][]; + int i = 0; + for (BytesKey v : set) { + result[i++] = v.getBytes(); + } + return result; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public boolean sismember(byte[] key, byte[] value) { + lock.lock(); + try { + byte[] got = map.get(new BytesKey(key)).getBytes(); + if (got == null) { + return false; + } + HashSet set = (HashSet) SerializationUtils.deserialize(got); + return set.contains(new BytesKey(value)); + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public int scard(byte[] key) { + lock.lock(); + try { + byte[] value = map.get(new BytesKey(key)).getBytes(); + if (value == null) { + return 0; + } + HashSet set = (HashSet) SerializationUtils.deserialize(value); + return set.size(); + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public byte[][] sdiff(byte[]... keys) { + lock.lock(); + try { + HashSet set = new HashSet<>(); + for (byte[] key : keys) { + byte[] value = map.get(new BytesKey(key)).getBytes(); + if (set.isEmpty() && value != null) { + set.addAll((HashSet) SerializationUtils.deserialize(value)); + } else if (value != null) { + HashSet set1 = (HashSet) SerializationUtils.deserialize(value); + set.removeAll(set1); + } + } + byte[][] result = new byte[set.size()][]; + int i = 0; + for (BytesKey v : set) { + result[i++] = v.getBytes(); + } + return result; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public byte[][] sinter(byte[]... keys) { + lock.lock(); + try { + HashSet set = new HashSet<>(); + for (byte[] key : keys) { + byte[] value = map.get(new BytesKey(key)).getBytes(); + if (set.isEmpty() && value != null) { + set.addAll((HashSet) SerializationUtils.deserialize(value)); + } else if (value != null) { + HashSet set1 = (HashSet) SerializationUtils.deserialize(value); + set.retainAll(set1); + } + } + byte[][] result = new byte[set.size()][]; + int i = 0; + for (BytesKey v : set) { + result[i++] = v.getBytes(); + } + return result; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public byte[][] sunion(byte[]... keys) { + lock.lock(); + try { + HashSet set = new HashSet<>(); + for (byte[] key : keys) { + byte[] value = map.get(new BytesKey(key)).getBytes(); + if (value != null) { + HashSet set1 = (HashSet) SerializationUtils.deserialize(value); + set.addAll(set1); + } + } + byte[][] result = new byte[set.size()][]; + int i = 0; + for (BytesKey v : set) { + result[i++] = v.getBytes(); + } + return result; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public int smove(byte[] source, byte[] destination, byte[] value) { + lock.lock(); + try { + byte[] sourceValue = map.get(new BytesKey(source)).getBytes(); + if (sourceValue == null) { + return 0; + } + HashSet set = (HashSet) SerializationUtils.deserialize(sourceValue); + if (set.remove(new BytesKey(value))) { + byte[] destinationValue = map.get(new BytesKey(destination)).getBytes(); + HashSet set1; + if (destinationValue == null) { + set1 = new HashSet<>(); + } else { + set1 = (HashSet) SerializationUtils.deserialize(destinationValue); + } + set1.add(new BytesKey(value)); + map.put(new BytesKey(destination), new BytesKey(SerializationUtils.serialize(set1))); + map.put(new BytesKey(source), new BytesKey(SerializationUtils.serialize(set))); + return 1; + } + return 0; + } finally { + lock.unlock(); + } + } + + @Override + @SuppressWarnings("unchecked") + public int srem(byte[] key, byte[]... values) { + lock.lock(); + try { + byte[] value = map.get(new BytesKey(key)).getBytes(); + if (value == null) { + return 0; + } + HashSet set = (HashSet) SerializationUtils.deserialize(value); + int count = 0; + for (byte[] v : values) { + if (set.remove(new BytesKey(v))) { + count++; + } + } + if (set.isEmpty()) { + map.remove(new BytesKey(key)); + } else { + map.put(new BytesKey(key), new BytesKey(SerializationUtils.serialize(set))); + } + return count; + } finally { + lock.unlock(); + } + } }