diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt index ab79fab3d2..4595a88a04 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt @@ -17,6 +17,8 @@ package org.partiql.cli import com.amazon.ion.system.IonSystemBuilder import com.amazon.ion.system.IonTextWriterBuilder +import com.amazon.ionelement.api.IonElement +import com.amazon.ionelement.api.ionString import org.partiql.cli.pico.PartiQLCommand import org.partiql.cli.shell.info import org.partiql.lang.eval.EvaluationSession @@ -24,7 +26,7 @@ import org.partiql.parser.PartiQLParser import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner -import org.partiql.plugins.local.toIon +import org.partiql.types.PType import picocli.CommandLine import java.io.PrintStream import java.nio.file.Paths @@ -92,4 +94,8 @@ object Debug { return "OK" } + + private fun PType.toIon(): IonElement { + return ionString(this.toString()) + } } diff --git a/partiql-eval/api/partiql-eval.api b/partiql-eval/api/partiql-eval.api index 039c81d803..d1574b9fd6 100644 --- a/partiql-eval/api/partiql-eval.api +++ b/partiql-eval/api/partiql-eval.api @@ -84,7 +84,7 @@ public abstract interface class org/partiql/eval/value/Datum : java/lang/Iterabl public fun getString ()Ljava/lang/String; public fun getTime ()Lorg/partiql/value/datetime/Time; public fun getTimestamp ()Lorg/partiql/value/datetime/Timestamp; - public abstract fun getType ()Lorg/partiql/value/PartiQLValueType; + public abstract fun getType ()Lorg/partiql/types/PType; public static fun int32Value (I)Lorg/partiql/eval/value/Datum; public static fun int64Value (J)Lorg/partiql/eval/value/Datum; public fun isMissing ()Z @@ -92,9 +92,9 @@ public abstract interface class org/partiql/eval/value/Datum : java/lang/Iterabl public fun iterator ()Ljava/util/Iterator; public static fun listValue (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum; public static fun missingValue ()Lorg/partiql/eval/value/Datum; - public static fun missingValue (Lorg/partiql/value/PartiQLValueType;)Lorg/partiql/eval/value/Datum; + public static fun missingValue (Lorg/partiql/types/PType;)Lorg/partiql/eval/value/Datum; public static fun nullValue ()Lorg/partiql/eval/value/Datum; - public static fun nullValue (Lorg/partiql/value/PartiQLValueType;)Lorg/partiql/eval/value/Datum; + public static fun nullValue (Lorg/partiql/types/PType;)Lorg/partiql/eval/value/Datum; public static fun of (Lorg/partiql/value/PartiQLValue;)Lorg/partiql/eval/value/Datum; public static fun sexpValue (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum; public static fun stringValue (Ljava/lang/String;)Lorg/partiql/eval/value/Datum; diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/Datum.java b/partiql-eval/src/main/java/org/partiql/eval/value/Datum.java index 7d24a58868..be3892c1c5 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/Datum.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/Datum.java @@ -3,6 +3,7 @@ import kotlin.NotImplementedError; import kotlin.Pair; import org.jetbrains.annotations.NotNull; +import org.partiql.types.PType; import org.partiql.value.PartiQL; import org.partiql.value.PartiQLValue; import org.partiql.value.PartiQLValueType; @@ -11,7 +12,6 @@ import java.math.BigDecimal; import java.math.BigInteger; -import java.util.BitSet; import java.util.Iterator; import java.util.Objects; @@ -60,7 +60,7 @@ default boolean isMissing() { * @return the type of the data at the cursor. */ @NotNull - PartiQLValueType getType(); + PType getType(); /** * @return the underlying value applicable to the types: @@ -346,26 +346,26 @@ default Datum getInsensitive(@NotNull String name) { @NotNull @Deprecated default PartiQLValue toPartiQLValue() { - PartiQLValueType type = this.getType(); - switch (type) { + PType type = this.getType(); + switch (type.getKind()) { case BOOL: return this.isNull() ? PartiQL.boolValue(null) : PartiQL.boolValue(this.getBoolean()); - case INT8: + case TINYINT: return this.isNull() ? PartiQL.int8Value(null) : PartiQL.int8Value(this.getByte()); - case INT16: + case SMALLINT: return this.isNull() ? PartiQL.int16Value(null) : PartiQL.int16Value(this.getShort()); - case INT32: + case INT: return this.isNull() ? PartiQL.int32Value(null) : PartiQL.int32Value(this.getInt()); - case INT64: + case BIGINT: return this.isNull() ? PartiQL.int64Value(null) : PartiQL.int64Value(this.getLong()); - case INT: + case INT_ARBITRARY: return this.isNull() ? PartiQL.intValue(null) : PartiQL.intValue(this.getBigInteger()); case DECIMAL: case DECIMAL_ARBITRARY: return this.isNull() ? PartiQL.decimalValue(null) : PartiQL.decimalValue(this.getBigDecimal()); - case FLOAT32: + case REAL: return this.isNull() ? PartiQL.float32Value(null) : PartiQL.float32Value(this.getFloat()); - case FLOAT64: + case DOUBLE_PRECISION: return this.isNull() ? PartiQL.float64Value(null) : PartiQL.float64Value(this.getDouble()); case CHAR: return this.isNull() ? PartiQL.charValue(null) : PartiQL.charValue(this.getString().charAt(0)); @@ -373,22 +373,18 @@ default PartiQLValue toPartiQLValue() { return this.isNull() ? PartiQL.stringValue(null) : PartiQL.stringValue(this.getString()); case SYMBOL: return this.isNull() ? PartiQL.symbolValue(null) : PartiQL.symbolValue(this.getString()); - case BINARY: - return this.isNull() ? PartiQL.binaryValue(null) : PartiQL.binaryValue(BitSet.valueOf(this.getBytes())); - case BYTE: - return this.isNull() ? PartiQL.byteValue(null) : PartiQL.byteValue(this.getByte()); case BLOB: return this.isNull() ? PartiQL.blobValue(null) : PartiQL.blobValue(this.getBytes()); case CLOB: return this.isNull() ? PartiQL.clobValue(null) : PartiQL.clobValue(this.getBytes()); case DATE: return this.isNull() ? PartiQL.dateValue(null) : PartiQL.dateValue(this.getDate()); - case TIME: + case TIME_WITH_TZ: + case TIME_WITHOUT_TZ: // TODO return this.isNull() ? PartiQL.timeValue(null) : PartiQL.timeValue(this.getTime()); - case TIMESTAMP: + case TIMESTAMP_WITH_TZ: + case TIMESTAMP_WITHOUT_TZ: return this.isNull() ? PartiQL.timestampValue(null) : PartiQL.timestampValue(this.getTimestamp()); - case INTERVAL: - return this.isNull() ? PartiQL.intervalValue(null) : PartiQL.intervalValue(this.getInterval()); case BAG: return this.isNull() ? PartiQL.bagValue((Iterable) null) : PartiQL.bagValue(new PQLToPartiQLIterable(this)); case LIST: @@ -396,14 +392,17 @@ default PartiQLValue toPartiQLValue() { case SEXP: return this.isNull() ? PartiQL.sexpValue((Iterable) null) : PartiQL.sexpValue(new PQLToPartiQLIterable(this)); case STRUCT: + case ROW: return this.isNull() ? PartiQL.structValue((Iterable>) null) : PartiQL.structValue(new PQLToPartiQLStruct(this)); - case NULL: // TODO: This will probably be deleted very soon due to the deprecation of NULL and MISSING types - return PartiQL.nullValue(); - case MISSING: // TODO: This will probably be deleted very soon due to the deprecation of NULL and MISSING types - return PartiQL.missingValue(); - case ANY: + case DYNAMIC: + case UNKNOWN: + if (this.isNull()) { + return PartiQL.nullValue(); + } else if (this.isMissing()) { + return PartiQL.missingValue(); + } default: - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("Unsupported datum type: " + type); } } @@ -417,7 +416,7 @@ default PartiQLValue toPartiQLValue() { static Datum of(PartiQLValue value) { PartiQLValueType type = value.getType(); if (value.isNull()) { - return new DatumNull(type); + return new DatumNull(PType.fromPartiQLValueType(type)); } switch (type) { case MISSING: @@ -426,13 +425,13 @@ static Datum of(PartiQLValue value) { return new DatumNull(); case INT8: org.partiql.value.Int8Value int8Value = (org.partiql.value.Int8Value) value; - return new DatumByte(Objects.requireNonNull(int8Value.getValue()), PartiQLValueType.INT8); + return new DatumByte(Objects.requireNonNull(int8Value.getValue()), PType.typeTinyInt()); case STRUCT: @SuppressWarnings("unchecked") org.partiql.value.StructValue STRUCTValue = (org.partiql.value.StructValue) value; return new DatumStruct(new PartiQLToPQLStruct(Objects.requireNonNull(STRUCTValue))); case STRING: org.partiql.value.StringValue STRINGValue = (org.partiql.value.StringValue) value; - return new DatumString(Objects.requireNonNull(STRINGValue.getValue()), PartiQLValueType.STRING); + return new DatumString(Objects.requireNonNull(STRINGValue.getValue()), PType.typeString()); case INT64: org.partiql.value.Int64Value INT64Value = (org.partiql.value.Int64Value) value; return new DatumLong(Objects.requireNonNull(INT64Value.getValue())); @@ -444,10 +443,10 @@ static Datum of(PartiQLValue value) { return new DatumShort(Objects.requireNonNull(INT16Value.getValue())); case SEXP: @SuppressWarnings("unchecked") org.partiql.value.SexpValue sexpValue = (org.partiql.value.SexpValue) value; - return new DatumCollection(new PartiQLToPQLIterable(Objects.requireNonNull(sexpValue)), PartiQLValueType.SEXP); + return new DatumCollection(new PartiQLToPQLIterable(Objects.requireNonNull(sexpValue)), PType.typeSexp()); case LIST: @SuppressWarnings("unchecked") org.partiql.value.ListValue LISTValue = (org.partiql.value.ListValue) value; - return new DatumCollection(new PartiQLToPQLIterable(Objects.requireNonNull(LISTValue)), PartiQLValueType.LIST); + return new DatumCollection(new PartiQLToPQLIterable(Objects.requireNonNull(LISTValue)), PType.typeList()); case BOOL: org.partiql.value.BoolValue BOOLValue = (org.partiql.value.BoolValue) value; return new DatumBoolean(Objects.requireNonNull(BOOLValue.getValue())); @@ -456,10 +455,9 @@ static Datum of(PartiQLValue value) { return new DatumBigInteger(Objects.requireNonNull(INTValue.getValue())); case BAG: @SuppressWarnings("unchecked") org.partiql.value.BagValue BAGValue = (org.partiql.value.BagValue) value; - return new DatumCollection(new PartiQLToPQLIterable(Objects.requireNonNull(BAGValue)), PartiQLValueType.BAG); + return new DatumCollection(new PartiQLToPQLIterable(Objects.requireNonNull(BAGValue)), PType.typeBag()); case BINARY: - org.partiql.value.BinaryValue BINARYValue = (org.partiql.value.BinaryValue) value; - return new DatumBytes(Objects.requireNonNull(Objects.requireNonNull(BINARYValue.getValue()).toByteArray()), PartiQLValueType.BINARY); + throw new UnsupportedOperationException(); case DATE: org.partiql.value.DateValue DATEValue = (org.partiql.value.DateValue) value; return new DatumDate(Objects.requireNonNull(DATEValue.getValue())); @@ -480,25 +478,25 @@ static Datum of(PartiQLValue value) { return new DatumDouble(Objects.requireNonNull(FLOAT64Value.getValue())); case DECIMAL: org.partiql.value.DecimalValue DECIMALValue = (org.partiql.value.DecimalValue) value; - return new DatumDecimal(Objects.requireNonNull(DECIMALValue.getValue()), PartiQLValueType.DECIMAL); + return new DatumDecimal(Objects.requireNonNull(DECIMALValue.getValue()), PType.typeDecimalArbitrary()); case CHAR: org.partiql.value.CharValue CHARValue = (org.partiql.value.CharValue) value; - return new DatumChars(Objects.requireNonNull(Objects.requireNonNull(CHARValue.getValue()).toString())); + String charString = Objects.requireNonNull(CHARValue.getValue()).toString(); + return new DatumChars(charString, charString.length()); case SYMBOL: org.partiql.value.SymbolValue SYMBOLValue = (org.partiql.value.SymbolValue) value; - return new DatumString(Objects.requireNonNull(SYMBOLValue.getValue()), PartiQLValueType.SYMBOL); + return new DatumString(Objects.requireNonNull(SYMBOLValue.getValue()), PType.typeSymbol()); case CLOB: org.partiql.value.ClobValue CLOBValue = (org.partiql.value.ClobValue) value; - return new DatumBytes(Objects.requireNonNull(CLOBValue.getValue()), PartiQLValueType.CLOB); + return new DatumBytes(Objects.requireNonNull(CLOBValue.getValue()), PType.typeClob(Integer.MAX_VALUE)); // TODO case BLOB: org.partiql.value.BlobValue BLOBValue = (org.partiql.value.BlobValue) value; - return new DatumBytes(Objects.requireNonNull(BLOBValue.getValue()), PartiQLValueType.BLOB); + return new DatumBytes(Objects.requireNonNull(BLOBValue.getValue()), PType.typeBlob(Integer.MAX_VALUE)); // TODO case BYTE: - org.partiql.value.ByteValue BYTEValue = (org.partiql.value.ByteValue) value; - return new DatumByte(Objects.requireNonNull(BYTEValue.getValue()), PartiQLValueType.BYTE); + throw new UnsupportedOperationException(); case DECIMAL_ARBITRARY: org.partiql.value.DecimalValue DECIMAL_ARBITRARYValue = (org.partiql.value.DecimalValue) value; - return new DatumDecimal(Objects.requireNonNull(DECIMAL_ARBITRARYValue.getValue()), PartiQLValueType.DECIMAL_ARBITRARY); + return new DatumDecimal(Objects.requireNonNull(DECIMAL_ARBITRARYValue.getValue()), PType.typeDecimalArbitrary()); case ANY: default: throw new NotImplementedError(); @@ -516,7 +514,7 @@ static Datum missingValue() { } @NotNull - static Datum nullValue(@NotNull PartiQLValueType type) { + static Datum nullValue(@NotNull PType type) { return new DatumNull(type); } @@ -529,13 +527,13 @@ static Datum nullValue(@NotNull PartiQLValueType type) { */ @Deprecated @NotNull - static Datum missingValue(@NotNull PartiQLValueType type) { + static Datum missingValue(@NotNull PType type) { return new DatumMissing(type); } @NotNull static Datum bagValue(@NotNull Iterable values) { - return new DatumCollection(values, PartiQLValueType.BAG); + return new DatumCollection(values, PType.typeBag()); } @NotNull @@ -555,12 +553,12 @@ static Datum boolValue(boolean value) { @NotNull static Datum sexpValue(@NotNull Iterable values) { - return new DatumCollection(values, PartiQLValueType.SEXP); + return new DatumCollection(values, PType.typeSexp()); } @NotNull static Datum listValue(@NotNull Iterable values) { - return new DatumCollection(values, PartiQLValueType.LIST); + return new DatumCollection(values, PType.typeList()); } @NotNull @@ -570,6 +568,6 @@ static Datum structValue(@NotNull Iterable values) { @NotNull static Datum stringValue(@NotNull String value) { - return new DatumString(value, PartiQLValueType.STRING); + return new DatumString(value, PType.typeString()); } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumBigInteger.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumBigInteger.java index f75c2b77db..e13ae333a8 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumBigInteger.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumBigInteger.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; import java.math.BigInteger; @@ -13,6 +13,8 @@ class DatumBigInteger implements Datum { @NotNull private final BigInteger _value; + private final static PType _type = PType.typeIntArbitrary(); + DatumBigInteger(@NotNull BigInteger value) { _value = value; } @@ -25,7 +27,7 @@ public BigInteger getBigInteger() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.INT; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumBoolean.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumBoolean.java index 2504abf9fb..ea67238373 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumBoolean.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumBoolean.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -10,6 +10,8 @@ class DatumBoolean implements Datum { private final boolean _value; + private final static PType _type = PType.typeBool(); + DatumBoolean(boolean value) { _value = value; } @@ -21,7 +23,7 @@ public boolean getBoolean() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.BOOL; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumByte.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumByte.java index 74f2ae231e..a689aab9af 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumByte.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumByte.java @@ -1,24 +1,22 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). *

* This is specifically for: - * {@link PartiQLValueType#BYTE}, - * {@link PartiQLValueType#INT8} + * {@link PType.Kind#TINYINT} */ class DatumByte implements Datum { private final byte _value; @NotNull - private final PartiQLValueType _type; + private final PType _type; - DatumByte(byte value, @NotNull PartiQLValueType type) { - assert(type == PartiQLValueType.BYTE || type == PartiQLValueType.INT8); + DatumByte(byte value, @NotNull PType type) { _value = value; _type = type; } @@ -30,7 +28,7 @@ public byte getByte() { @NotNull @Override - public PartiQLValueType getType() { + public PType getType() { return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumBytes.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumBytes.java index 695320acfc..153831668e 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumBytes.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumBytes.java @@ -1,15 +1,14 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). *

* This is specifically for: - * {@link PartiQLValueType#BINARY}, - * {@link PartiQLValueType#BLOB}, - * {@link PartiQLValueType#CLOB} + * {@link PType.Kind#BLOB}, + * {@link PType.Kind#CLOB} */ class DatumBytes implements Datum { @@ -17,10 +16,9 @@ class DatumBytes implements Datum { private final byte[] _value; @NotNull - private final PartiQLValueType _type; + private final PType _type; - DatumBytes(@NotNull byte[] value, @NotNull PartiQLValueType type) { - assert(type == PartiQLValueType.BINARY || type == PartiQLValueType.BLOB || type == PartiQLValueType.CLOB); + DatumBytes(@NotNull byte[] value, @NotNull PType type) { _value = value; _type = type; } @@ -33,7 +31,7 @@ public byte[] getBytes() { @NotNull @Override - public PartiQLValueType getType() { + public PType getType() { return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumChars.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumChars.java index c156237f65..22a6dfa72c 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumChars.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumChars.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -11,8 +11,12 @@ class DatumChars implements Datum { @NotNull private final String _value; - DatumChars(@NotNull String value) { + @NotNull + private final PType _type; + + DatumChars(@NotNull String value, int length) { _value = value; + _type = PType.typeChar(length); } @Override @@ -23,7 +27,7 @@ public String getString() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.CHAR; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumCollection.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumCollection.java index fd4efe1a77..2fffbbd51b 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumCollection.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumCollection.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; import java.util.Iterator; @@ -9,9 +9,9 @@ * This shall always be package-private (internal). *

* This is specifically for: - * {@link PartiQLValueType#LIST}, - * {@link PartiQLValueType#BAG}, - * {@link PartiQLValueType#SEXP} + * {@link PType.Kind#LIST}, + * {@link PType.Kind#BAG}, + * {@link PType.Kind#SEXP} */ class DatumCollection implements Datum { @@ -19,10 +19,9 @@ class DatumCollection implements Datum { private final Iterable _value; @NotNull - private final PartiQLValueType _type; + private final PType _type; - DatumCollection(@NotNull Iterable value, @NotNull PartiQLValueType type) { - assert(type == PartiQLValueType.LIST || type == PartiQLValueType.BAG || type == PartiQLValueType.SEXP); + DatumCollection(@NotNull Iterable value, @NotNull PType type) { _value = value; _type = type; } @@ -34,7 +33,7 @@ public Iterator iterator() { @NotNull @Override - public PartiQLValueType getType() { + public PType getType() { return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumDate.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumDate.java index 8323779d52..ff9355c718 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumDate.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumDate.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -11,6 +11,8 @@ class DatumDate implements Datum { @NotNull private final org.partiql.value.datetime.Date _value; + private static final PType _type = PType.typeDate(); + DatumDate(@NotNull org.partiql.value.datetime.Date value) { _value = value; } @@ -28,7 +30,7 @@ public org.partiql.value.datetime.Date getDate() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.DATE; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumDecimal.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumDecimal.java index 50e589b05f..16bd4983c1 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumDecimal.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumDecimal.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; import java.math.BigDecimal; @@ -9,8 +9,8 @@ * This shall always be package-private (internal). *

* This is specifically for: - * {@link PartiQLValueType#DECIMAL}, - * {@link PartiQLValueType#DECIMAL_ARBITRARY} + * {@link PType.Kind#DECIMAL}, + * {@link PType.Kind#DECIMAL_ARBITRARY} */ class DatumDecimal implements Datum { @@ -18,10 +18,9 @@ class DatumDecimal implements Datum { private final BigDecimal _value; @NotNull - private final PartiQLValueType _type; + private final PType _type; - DatumDecimal(@NotNull BigDecimal value, @NotNull PartiQLValueType type) { - assert(type == PartiQLValueType.DECIMAL || type == PartiQLValueType.DECIMAL_ARBITRARY); + DatumDecimal(@NotNull BigDecimal value, @NotNull PType type) { _value = value; _type = type; } @@ -34,7 +33,7 @@ public BigDecimal getBigDecimal() { @NotNull @Override - public PartiQLValueType getType() { + public PType getType() { return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumDouble.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumDouble.java index 6a9e9a362e..c78598d323 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumDouble.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumDouble.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -9,6 +9,7 @@ class DatumDouble implements Datum { private final double _value; + private final static PType _type = PType.typeDoublePrecision(); DatumDouble(double value) { _value = value; @@ -21,7 +22,7 @@ public double getDouble() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.FLOAT64; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumFloat.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumFloat.java index e3d0d10c54..4e2f4911f0 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumFloat.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumFloat.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -10,6 +10,8 @@ class DatumFloat implements Datum { private final float _value; + private final static PType _type = PType.typeReal(); + DatumFloat(float value) { _value = value; } @@ -21,7 +23,7 @@ public float getFloat() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.FLOAT32; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumInt.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumInt.java index b53378daf3..2f70fbb83f 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumInt.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumInt.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -9,6 +9,9 @@ class DatumInt implements Datum { private final int _value; + + private final static PType _type = PType.typeInt(); + DatumInt(int value) { _value = value; } @@ -20,7 +23,7 @@ public int getInt() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.INT32; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumInterval.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumInterval.java index 124cbe6530..2eb96ea909 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumInterval.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumInterval.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -21,7 +21,7 @@ public long getInterval() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.INTERVAL; + public PType getType() { + throw new UnsupportedOperationException("NOT YET IMPLEMENTED"); } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumLong.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumLong.java index 52fa345cff..f85665ca60 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumLong.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumLong.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -10,6 +10,8 @@ class DatumLong implements Datum { private final long _value; + private final static PType _type = PType.typeBigInt(); + DatumLong(long value) { _value = value; } @@ -21,7 +23,7 @@ public long getLong() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.INT64; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumMissing.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumMissing.java index 5afed35ddf..fb167200fe 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumMissing.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumMissing.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -9,14 +9,13 @@ class DatumMissing implements Datum { @NotNull - private final PartiQLValueType _type; + private final PType _type; DatumMissing() { - // TODO: This will likely be UNKNOWN in the future. Potentially something like PostgreSQL's unknown type. - _type = PartiQLValueType.MISSING; + _type = PType.typeUnknown(); } - DatumMissing(@NotNull PartiQLValueType type) { + DatumMissing(@NotNull PType type) { _type = type; } @@ -27,7 +26,7 @@ public boolean isMissing() { @NotNull @Override - public PartiQLValueType getType() { + public PType getType() { return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumNull.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumNull.java index 1262aafbf3..8cbed42d40 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumNull.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumNull.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; import org.partiql.value.datetime.Date; import org.partiql.value.datetime.Time; import org.partiql.value.datetime.Timestamp; @@ -16,13 +16,13 @@ class DatumNull implements Datum { @NotNull - private final PartiQLValueType _type; + private final PType _type; DatumNull() { - this._type = PartiQLValueType.NULL; // TODO: This might eventually be UNKNOWN like PostgreSQL's unknown type. + this._type = PType.typeUnknown(); } - DatumNull(@NotNull PartiQLValueType type) { + DatumNull(@NotNull PType type) { this._type = type; } @@ -33,13 +33,13 @@ public boolean isNull() { @NotNull @Override - public PartiQLValueType getType() { + public PType getType() { return _type; } @Override public boolean getBoolean() { - if (_type == PartiQLValueType.BOOL) { + if (_type.getKind() == PType.Kind.BOOL) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -48,7 +48,7 @@ public boolean getBoolean() { @Override public short getShort() { - if (_type == PartiQLValueType.INT16) { + if (_type.getKind() == PType.Kind.SMALLINT) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -57,7 +57,7 @@ public short getShort() { @Override public int getInt() { - if (_type == PartiQLValueType.INT32) { + if (_type.getKind() == PType.Kind.INT) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -66,7 +66,7 @@ public int getInt() { @Override public long getLong() { - if (_type == PartiQLValueType.INT64) { + if (_type.getKind() == PType.Kind.BIGINT) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -76,7 +76,7 @@ public long getLong() { @NotNull @Override public BigInteger getBigInteger() { - if (_type == PartiQLValueType.INT) { + if (_type.getKind() == PType.Kind.INT_ARBITRARY) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -86,7 +86,7 @@ public BigInteger getBigInteger() { @NotNull @Override public BigDecimal getBigDecimal() { - if (_type == PartiQLValueType.DECIMAL || _type == PartiQLValueType.DECIMAL_ARBITRARY) { + if (_type.getKind() == PType.Kind.DECIMAL || _type.getKind() == PType.Kind.DECIMAL_ARBITRARY) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -95,7 +95,7 @@ public BigDecimal getBigDecimal() { @Override public byte getByte() { - if (_type == PartiQLValueType.BYTE || _type == PartiQLValueType.INT8) { + if (_type.getKind() == PType.Kind.TINYINT) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -105,7 +105,7 @@ public byte getByte() { @NotNull @Override public byte[] getBytes() { - if (_type == PartiQLValueType.BINARY || _type == PartiQLValueType.BLOB || _type == PartiQLValueType.CLOB) { + if (_type.getKind() == PType.Kind.BLOB || _type.getKind() == PType.Kind.CLOB) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -115,7 +115,7 @@ public byte[] getBytes() { @NotNull @Override public Date getDate() { - if (_type == PartiQLValueType.DATE) { + if (_type.getKind() == PType.Kind.DATE) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -124,7 +124,7 @@ public Date getDate() { @Override public double getDouble() { - if (_type == PartiQLValueType.FLOAT64) { + if (_type.getKind() == PType.Kind.DOUBLE_PRECISION) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -133,7 +133,7 @@ public double getDouble() { @Override public float getFloat() { - if (_type == PartiQLValueType.FLOAT32) { + if (_type.getKind() == PType.Kind.REAL) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -142,7 +142,7 @@ public float getFloat() { @Override public Iterator iterator() { - if (_type == PartiQLValueType.BAG || _type == PartiQLValueType.LIST || _type == PartiQLValueType.SEXP) { + if (_type.getKind() == PType.Kind.BAG || _type.getKind() == PType.Kind.LIST || _type.getKind() == PType.Kind.SEXP) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -152,7 +152,7 @@ public Iterator iterator() { @NotNull @Override public Iterator getFields() { - if (_type == PartiQLValueType.STRUCT) { + if (_type.getKind() == PType.Kind.STRUCT || _type.getKind() == PType.Kind.ROW) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -162,7 +162,7 @@ public Iterator getFields() { @NotNull @Override public String getString() { - if (_type == PartiQLValueType.STRING || _type == PartiQLValueType.CHAR || _type == PartiQLValueType.SYMBOL) { + if (_type.getKind() == PType.Kind.STRING || _type.getKind() == PType.Kind.CHAR || _type.getKind() == PType.Kind.SYMBOL) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -172,7 +172,7 @@ public String getString() { @NotNull @Override public Time getTime() { - if (_type == PartiQLValueType.TIME) { + if (_type.getKind() == PType.Kind.TIME_WITH_TZ || _type.getKind() == PType.Kind.TIME_WITHOUT_TZ) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -182,7 +182,7 @@ public Time getTime() { @NotNull @Override public Timestamp getTimestamp() { - if (_type == PartiQLValueType.TIMESTAMP) { + if (_type.getKind() == PType.Kind.TIMESTAMP_WITH_TZ || _type.getKind() == PType.Kind.TIMESTAMP_WITHOUT_TZ) { throw new NullPointerException(); } else { throw new UnsupportedOperationException(); @@ -191,10 +191,6 @@ public Timestamp getTimestamp() { @Override public long getInterval() { - if (_type == PartiQLValueType.INTERVAL) { - throw new NullPointerException(); - } else { - throw new UnsupportedOperationException(); - } + throw new UnsupportedOperationException(); } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumShort.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumShort.java index a9f97c5ae7..3e7be73845 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumShort.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumShort.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -10,6 +10,8 @@ class DatumShort implements Datum { private final short _value; + private final static PType _type = PType.typeSmallInt(); + DatumShort(short value) { _value = value; } @@ -21,7 +23,7 @@ public short getShort() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.INT16; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumString.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumString.java index d4e24f5cfb..a0318da380 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumString.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumString.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; /** * This shall always be package-private (internal). @@ -12,10 +12,9 @@ class DatumString implements Datum { private final String _value; @NotNull - private final PartiQLValueType _type; + private final PType _type; - DatumString(@NotNull String value, @NotNull PartiQLValueType type) { - assert(type == PartiQLValueType.STRING || type == PartiQLValueType.SYMBOL); + DatumString(@NotNull String value, @NotNull PType type) { _value = value; _type = type; } @@ -28,7 +27,7 @@ public String getString() { @NotNull @Override - public PartiQLValueType getType() { + public PType getType() { return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumStruct.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumStruct.java index b010653dcf..07b18dcdba 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumStruct.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumStruct.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; import java.util.ArrayList; import java.util.HashMap; @@ -20,6 +20,8 @@ class DatumStruct implements Datum { @NotNull private final Map> _delegateNormalized; + private final static PType _type = PType.typeStruct(); + DatumStruct(@NotNull Iterable fields) { _delegate = new HashMap<>(); _delegateNormalized = new HashMap<>(); @@ -70,7 +72,7 @@ public Datum getInsensitive(@NotNull String name) { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.STRUCT; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumTime.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumTime.java index b9f9083eb9..81fa4ab001 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumTime.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumTime.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; import org.partiql.value.datetime.Time; /** @@ -12,6 +12,10 @@ class DatumTime implements Datum { @NotNull private final Time _value; + // TODO: Pass precision to constructor. + // TODO: Create a variant specifically for without TZ + private final static PType _type = PType.typeTimeWithTZ(6); + DatumTime(@NotNull Time value) { _value = value; } @@ -24,7 +28,7 @@ public Time getTime() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.TIME; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/java/org/partiql/eval/value/DatumTimestamp.java b/partiql-eval/src/main/java/org/partiql/eval/value/DatumTimestamp.java index cf39ab64e2..5063668e3b 100644 --- a/partiql-eval/src/main/java/org/partiql/eval/value/DatumTimestamp.java +++ b/partiql-eval/src/main/java/org/partiql/eval/value/DatumTimestamp.java @@ -1,7 +1,7 @@ package org.partiql.eval.value; import org.jetbrains.annotations.NotNull; -import org.partiql.value.PartiQLValueType; +import org.partiql.types.PType; import org.partiql.value.datetime.Timestamp; /** @@ -12,6 +12,10 @@ class DatumTimestamp implements Datum { @NotNull private final Timestamp _value; + // TODO: Pass precision to constructor. + // TODO: Create a variant specifically for without TZ + private final static PType _type = PType.typeTimeWithTZ(6); + DatumTimestamp(@NotNull Timestamp value) { _value = value; } @@ -24,7 +28,7 @@ public Timestamp getTimestamp() { @NotNull @Override - public PartiQLValueType getType() { - return PartiQLValueType.TIMESTAMP; + public PType getType() { + return _type; } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt index 8297e18914..8c95fb0784 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt @@ -59,26 +59,25 @@ import org.partiql.plan.rexOpErr import org.partiql.plan.visitor.PlanBaseVisitor import org.partiql.spi.fn.Agg import org.partiql.spi.fn.FnExperimental -import org.partiql.types.StaticType +import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType import java.lang.IllegalStateException internal class Compiler( private val plan: PartiQLPlan, private val session: PartiQLEngine.Session, private val symbols: Symbols, -) : PlanBaseVisitor() { +) : PlanBaseVisitor() { fun compile(): Operator.Expr { return visitPartiQLPlan(plan, null) } - override fun defaultReturn(node: PlanNode, ctx: StaticType?): Operator { + override fun defaultReturn(node: PlanNode, ctx: PType?): Operator { TODO("Not yet implemented") } - override fun visitRexOpErr(node: Rex.Op.Err, ctx: StaticType?): Operator { + override fun visitRexOpErr(node: Rex.Op.Err, ctx: PType?): Operator { val message = buildString { this.appendLine(node.message) PlanPrinter.append(this, plan) @@ -86,31 +85,31 @@ internal class Compiler( throw IllegalStateException(message) } - override fun visitRelOpErr(node: Rel.Op.Err, ctx: StaticType?): Operator { + override fun visitRelOpErr(node: Rel.Op.Err, ctx: PType?): Operator { throw IllegalStateException(node.message) } - override fun visitPartiQLPlan(node: PartiQLPlan, ctx: StaticType?): Operator.Expr { + override fun visitPartiQLPlan(node: PartiQLPlan, ctx: PType?): Operator.Expr { return visitStatement(node.statement, ctx) as Operator.Expr } - override fun visitStatementQuery(node: Statement.Query, ctx: StaticType?): Operator.Expr { + override fun visitStatementQuery(node: Statement.Query, ctx: PType?): Operator.Expr { return visitRex(node.root, ctx).modeHandled() } // REX - override fun visitRex(node: Rex, ctx: StaticType?): Operator.Expr { + override fun visitRex(node: Rex, ctx: PType?): Operator.Expr { return super.visitRexOp(node.op, node.type) as Operator.Expr } - override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Operator { + override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: PType?): Operator { val values = node.values.map { visitRex(it, ctx).modeHandled() } val type = ctx ?: error("No type provided in ctx") return ExprCollection(values, type) } - override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Operator { + override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: PType?): Operator { val fields = node.fields.map { val value = visitRex(it.v, ctx).modeHandled() ExprStruct.Field(visitRex(it.k, ctx), value) @@ -118,14 +117,14 @@ internal class Compiler( return ExprStruct(fields) } - override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType?): Operator { + override fun visitRexOpSelect(node: Rex.Op.Select, ctx: PType?): Operator { val rel = visitRel(node.rel, ctx) val ordered = node.rel.type.props.contains(Rel.Prop.ORDERED) val constructor = visitRex(node.constructor, ctx).modeHandled() return ExprSelect(rel, constructor, ordered) } - override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: StaticType?): Operator { + override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: PType?): Operator { val constructor = visitRex(node.constructor, ctx) val input = visitRel(node.rel, ctx) return when (node.coercion) { @@ -134,7 +133,7 @@ internal class Compiler( } } - override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: StaticType?): Operator { + override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: PType?): Operator { val rel = visitRel(node.rel, ctx) val key = visitRex(node.key, ctx) val value = visitRex(node.value, ctx) @@ -144,12 +143,12 @@ internal class Compiler( } } - override fun visitRexOpCoalesce(node: Rex.Op.Coalesce, ctx: StaticType?): Operator { + override fun visitRexOpCoalesce(node: Rex.Op.Coalesce, ctx: PType?): Operator { val args = Array(node.args.size) { visitRex(node.args[it], node.args[it].type) } return ExprCoalesce(args) } - override fun visitRexOpNullif(node: Rex.Op.Nullif, ctx: StaticType?): Operator { + override fun visitRexOpNullif(node: Rex.Op.Nullif, ctx: PType?): Operator { val value = visitRex(node.value, node.value.type) val nullifier = visitRex(node.nullifier, node.value.type) return ExprNullIf(value, nullifier) @@ -161,7 +160,7 @@ internal class Compiler( * All variables coming from the stack have a depth > 0. To slightly minimize computation at execution, we subtract * the depth by 1 to account for the fact that the local scope is not kept on the stack. */ - override fun visitRexOpVar(node: Rex.Op.Var, ctx: StaticType?): Operator { + override fun visitRexOpVar(node: Rex.Op.Var, ctx: PType?): Operator { return when (node.depth) { 0 -> ExprVarLocal(node.ref) else -> { @@ -170,9 +169,9 @@ internal class Compiler( } } - override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Operator = symbols.getGlobal(node.ref) + override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: PType?): Operator = symbols.getGlobal(node.ref) - override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: StaticType?): Operator.Relation { + override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: PType?): Operator.Relation { val input = visitRel(node.input, ctx) val calls = node.calls.map { visitRelOpAggregateCall(it, ctx) @@ -182,7 +181,7 @@ internal class Compiler( } @OptIn(FnExperimental::class) - override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: StaticType?): Operator.Aggregation { + override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: PType?): Operator.Aggregation { val args = node.args.map { visitRex(it, it.type).modeHandled() } val setQuantifier: Operator.Aggregation.SetQuantifier = when (node.setQuantifier) { Rel.Op.Aggregate.Call.SetQuantifier.ALL -> Operator.Aggregation.SetQuantifier.ALL @@ -196,30 +195,30 @@ internal class Compiler( } } - override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Operator { + override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: PType?): Operator { val root = visitRex(node.root, ctx) val key = visitRex(node.key, ctx) return ExprPathKey(root, key) } - override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: StaticType?): Operator { + override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: PType?): Operator { val root = visitRex(node.root, ctx) val symbol = node.key return ExprPathSymbol(root, symbol) } - override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: StaticType?): Operator { + override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: PType?): Operator { val root = visitRex(node.root, ctx) val index = visitRex(node.key, ctx) return ExprPathIndex(root, index) } - @OptIn(FnExperimental::class, PartiQLValueExperimental::class) - override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: StaticType?): Operator { + @OptIn(FnExperimental::class) + override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: PType?): Operator { val fn = symbols.getFn(node.fn) val args = node.args.map { visitRex(it, ctx) }.toTypedArray() val fnTakesInMissing = fn.signature.parameters.any { - it.type == PartiQLValueType.MISSING || it.type == PartiQLValueType.ANY + it.type.kind == PType.Kind.DYNAMIC // TODO: Is this needed? } return when (fnTakesInMissing) { true -> ExprCallStatic(fn, args.map { it.modeHandled() }.toTypedArray()) @@ -228,7 +227,7 @@ internal class Compiler( } @OptIn(FnExperimental::class) - override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Operator { + override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: PType?): Operator { val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray() // Check candidate list size when (node.candidates.size) { @@ -263,11 +262,11 @@ internal class Compiler( return ExprCallDynamic(name, candidates, args) } - override fun visitRexOpCast(node: Rex.Op.Cast, ctx: StaticType?): Operator { + override fun visitRexOpCast(node: Rex.Op.Cast, ctx: PType?): Operator { return ExprCast(visitRex(node.arg, ctx), node.cast) } - override fun visitRexOpMissing(node: Rex.Op.Missing, ctx: StaticType?): Operator { + override fun visitRexOpMissing(node: Rex.Op.Missing, ctx: PType?): Operator { return when (session.mode) { PartiQLEngine.Mode.PERMISSIVE -> { // Make a runtime TypeCheckException. @@ -281,11 +280,11 @@ internal class Compiler( } // REL - override fun visitRel(node: Rel, ctx: StaticType?): Operator.Relation { + override fun visitRel(node: Rel, ctx: PType?): Operator.Relation { return super.visitRelOp(node.op, ctx) as Operator.Relation } - override fun visitRelOpScan(node: Rel.Op.Scan, ctx: StaticType?): Operator { + override fun visitRelOpScan(node: Rel.Op.Scan, ctx: PType?): Operator { val rex = visitRex(node.rex, ctx) return when (session.mode) { PartiQLEngine.Mode.PERMISSIVE -> RelScanPermissive(rex) @@ -293,13 +292,13 @@ internal class Compiler( } } - override fun visitRelOpProject(node: Rel.Op.Project, ctx: StaticType?): Operator { + override fun visitRelOpProject(node: Rel.Op.Project, ctx: PType?): Operator { val input = visitRel(node.input, ctx) val projections = node.projections.map { visitRex(it, ctx).modeHandled() } return RelProject(input, projections) } - override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: StaticType?): Operator { + override fun visitRelOpScanIndexed(node: Rel.Op.ScanIndexed, ctx: PType?): Operator { val rex = visitRex(node.rex, ctx) return when (session.mode) { PartiQLEngine.Mode.PERMISSIVE -> RelScanIndexedPermissive(rex) @@ -307,7 +306,7 @@ internal class Compiler( } } - override fun visitRelOpUnpivot(node: Rel.Op.Unpivot, ctx: StaticType?): Operator { + override fun visitRelOpUnpivot(node: Rel.Op.Unpivot, ctx: PType?): Operator { val expr = visitRex(node.rex, ctx) return when (session.mode) { PartiQLEngine.Mode.PERMISSIVE -> RelUnpivot.Permissive(expr) @@ -315,7 +314,7 @@ internal class Compiler( } } - override fun visitRelOpSetExcept(node: Rel.Op.Set.Except, ctx: StaticType?): Operator { + override fun visitRelOpSetExcept(node: Rel.Op.Set.Except, ctx: PType?): Operator { val lhs = visitRel(node.lhs, ctx) val rhs = visitRel(node.rhs, ctx) return when (node.quantifier) { @@ -324,7 +323,7 @@ internal class Compiler( } } - override fun visitRelOpSetIntersect(node: Rel.Op.Set.Intersect, ctx: StaticType?): Operator { + override fun visitRelOpSetIntersect(node: Rel.Op.Set.Intersect, ctx: PType?): Operator { val lhs = visitRel(node.lhs, ctx) val rhs = visitRel(node.rhs, ctx) return when (node.quantifier) { @@ -333,7 +332,7 @@ internal class Compiler( } } - override fun visitRelOpSetUnion(node: Rel.Op.Set.Union, ctx: StaticType?): Operator { + override fun visitRelOpSetUnion(node: Rel.Op.Set.Union, ctx: PType?): Operator { val lhs = visitRel(node.lhs, ctx) val rhs = visitRel(node.rhs, ctx) return when (node.quantifier) { @@ -342,24 +341,24 @@ internal class Compiler( } } - override fun visitRelOpLimit(node: Rel.Op.Limit, ctx: StaticType?): Operator { + override fun visitRelOpLimit(node: Rel.Op.Limit, ctx: PType?): Operator { val input = visitRel(node.input, ctx) val limit = visitRex(node.limit, ctx) return RelLimit(input, limit) } - override fun visitRelOpOffset(node: Rel.Op.Offset, ctx: StaticType?): Operator { + override fun visitRelOpOffset(node: Rel.Op.Offset, ctx: PType?): Operator { val input = visitRel(node.input, ctx) val offset = visitRex(node.offset, ctx) return RelOffset(input, offset) } - override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType?): Operator { + override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: PType?): Operator { val args = node.args.map { visitRex(it, ctx) }.toTypedArray() return ExprTupleUnion(args) } - override fun visitRelOpJoin(node: Rel.Op.Join, ctx: StaticType?): Operator { + override fun visitRelOpJoin(node: Rel.Op.Join, ctx: PType?): Operator { val lhs = visitRel(node.lhs, ctx) val rhs = visitRel(node.rhs, ctx) val condition = visitRex(node.rex, ctx) @@ -371,7 +370,7 @@ internal class Compiler( } } - override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Operator { + override fun visitRexOpCase(node: Rex.Op.Case, ctx: PType?): Operator { val branches = node.branches.map { branch -> visitRex(branch.condition, ctx) to visitRex(branch.rex, ctx) } @@ -380,27 +379,27 @@ internal class Compiler( } @OptIn(PartiQLValueExperimental::class) - override fun visitRexOpLit(node: Rex.Op.Lit, ctx: StaticType?): Operator { + override fun visitRexOpLit(node: Rex.Op.Lit, ctx: PType?): Operator { return ExprLiteral(Datum.of(node.value)) } - override fun visitRelOpDistinct(node: Rel.Op.Distinct, ctx: StaticType?): Operator { + override fun visitRelOpDistinct(node: Rel.Op.Distinct, ctx: PType?): Operator { val input = visitRel(node.input, ctx) return RelDistinct(input) } - override fun visitRelOpFilter(node: Rel.Op.Filter, ctx: StaticType?): Operator { + override fun visitRelOpFilter(node: Rel.Op.Filter, ctx: PType?): Operator { val input = visitRel(node.input, ctx) val condition = visitRex(node.predicate, ctx).modeHandled() return RelFilter(input, condition) } - override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: StaticType?): Operator { + override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: PType?): Operator { val input = visitRel(node.input, ctx) return RelExclude(input, node.paths) } - override fun visitRelOpSort(node: Rel.Op.Sort, ctx: StaticType?): Operator { + override fun visitRelOpSort(node: Rel.Op.Sort, ctx: PType?): Operator { val input = visitRel(node.input, ctx) val compiledSpecs = node.specs.map { spec -> val expr = visitRex(spec.rex, ctx) diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/helpers/ValueUtility.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/helpers/ValueUtility.kt index fc47b40cf4..312f9a0a2e 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/helpers/ValueUtility.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/helpers/ValueUtility.kt @@ -2,6 +2,7 @@ package org.partiql.eval.internal.helpers import org.partiql.errors.TypeCheckException import org.partiql.eval.value.Datum +import org.partiql.types.PType import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType @@ -15,10 +16,14 @@ internal object ValueUtility { /** * @return whether the value is a boolean and the value itself is not-null and true. */ - @OptIn(PartiQLValueExperimental::class) @JvmStatic fun Datum.isTrue(): Boolean { - return this.type == PartiQLValueType.BOOL && !this.isNull && this.boolean + return this.type.kind == PType.Kind.BOOL && !this.isNull && this.boolean + } + + @OptIn(PartiQLValueExperimental::class) + fun Datum.check(type: PartiQLValueType): Datum { + return this.check(PType.fromPartiQLValueType(type)) } /** @@ -28,8 +33,7 @@ internal object ValueUtility { * @return a [Datum] corresponding to the expected type; this will either be the input value if the value is * already of the expected type, or it will be a null value of the expected type. */ - @OptIn(PartiQLValueExperimental::class) - fun Datum.check(type: PartiQLValueType): Datum { + fun Datum.check(type: PType): Datum { if (this.type == type) { return this } @@ -47,8 +51,8 @@ internal object ValueUtility { */ @OptIn(PartiQLValueExperimental::class) fun Datum.getText(): String { - return when (this.type) { - PartiQLValueType.STRING, PartiQLValueType.SYMBOL, PartiQLValueType.CHAR -> this.string + return when (this.type.kind) { + PType.Kind.STRING, PType.Kind.SYMBOL, PType.Kind.CHAR -> this.string else -> throw TypeCheckException("Expected text, but received ${this.type}.") } } @@ -65,12 +69,12 @@ internal object ValueUtility { */ @OptIn(PartiQLValueExperimental::class) fun Datum.getBigIntCoerced(): BigInteger { - return when (this.type) { - PartiQLValueType.INT8 -> this.byte.toInt().toBigInteger() - PartiQLValueType.INT16 -> this.short.toInt().toBigInteger() - PartiQLValueType.INT32 -> this.int.toBigInteger() - PartiQLValueType.INT64 -> this.long.toBigInteger() - PartiQLValueType.INT -> this.bigInteger + return when (this.type.kind) { + PType.Kind.TINYINT -> this.byte.toInt().toBigInteger() + PType.Kind.SMALLINT -> this.short.toInt().toBigInteger() + PType.Kind.INT -> this.int.toBigInteger() + PType.Kind.BIGINT -> this.long.toBigInteger() + PType.Kind.INT_ARBITRARY -> this.bigInteger else -> throw TypeCheckException() } } @@ -88,12 +92,12 @@ internal object ValueUtility { */ @OptIn(PartiQLValueExperimental::class) fun Datum.getInt32Coerced(): Int { - return when (this.type) { - PartiQLValueType.INT8 -> this.byte.toInt() - PartiQLValueType.INT16 -> this.short.toInt() - PartiQLValueType.INT32 -> this.int - PartiQLValueType.INT64 -> this.long.toInt() - PartiQLValueType.INT -> this.bigInteger.toInt() + return when (this.type.kind) { + PType.Kind.TINYINT -> this.byte.toInt() + PType.Kind.SMALLINT -> this.short.toInt() + PType.Kind.INT -> this.int + PType.Kind.BIGINT -> this.long.toInt() + PType.Kind.INT_ARBITRARY -> this.bigInteger.toInt() else -> throw TypeCheckException() } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelAggregate.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelAggregate.kt index 7c9fe94e3e..ddbe96d0ed 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelAggregate.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelAggregate.kt @@ -64,7 +64,7 @@ internal class RelAggregate( // Initialize the AggregationMap val evaluatedGroupByKeys = keys.map { val key = it.eval(env.push(inputRecord)) - when (key.type == PartiQLValueType.MISSING) { + when (key.isMissing) { true -> nullValue() false -> key.toPartiQLValue() } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelExclude.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelExclude.kt index 4d36ca458a..bc02b3bddd 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelExclude.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelExclude.kt @@ -12,6 +12,7 @@ import org.partiql.plan.relOpExcludeTypeCollWildcard import org.partiql.plan.relOpExcludeTypeStructKey import org.partiql.plan.relOpExcludeTypeStructSymbol import org.partiql.plan.relOpExcludeTypeStructWildcard +import org.partiql.types.PType import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType @@ -44,7 +45,7 @@ internal class RelExclude( input.close() } - private fun excludeStruct( + private fun excludeFields( structValue: Datum, exclusions: List ): Datum { @@ -107,20 +108,18 @@ internal class RelExclude( * Returns a [PartiQLValue] created from an iterable of [coll]. Requires [type] to be a collection type * (i.e. [PartiQLValueType.LIST], [PartiQLValueType.BAG], or [PartiQLValueType.SEXP]). */ - @OptIn(PartiQLValueExperimental::class) - private fun newCollValue(type: PartiQLValueType, coll: Iterable): Datum { - return when (type) { - PartiQLValueType.LIST -> Datum.listValue(coll) - PartiQLValueType.BAG -> Datum.bagValue(coll) - PartiQLValueType.SEXP -> Datum.sexpValue(coll) + private fun newCollValue(type: PType, coll: Iterable): Datum { + return when (type.kind) { + PType.Kind.LIST -> Datum.listValue(coll) + PType.Kind.BAG -> Datum.bagValue(coll) + PType.Kind.SEXP -> Datum.sexpValue(coll) else -> error("Collection type required") } } - @OptIn(PartiQLValueExperimental::class) private fun excludeCollection( coll: Iterable, - type: PartiQLValueType, + type: PType, exclusions: List ): Datum { val indexesToRemove = mutableSetOf() @@ -155,7 +154,7 @@ internal class RelExclude( } else { // deeper level exclusions var value = element - if (type == PartiQLValueType.LIST || type == PartiQLValueType.SEXP) { + if (type.kind == PType.Kind.LIST || type.kind == PType.Kind.SEXP) { // apply collection index exclusions at deeper levels for lists and sexps val collIndex = relOpExcludeTypeCollIndex(index) branches[collIndex]?.let { @@ -173,11 +172,10 @@ internal class RelExclude( return newCollValue(type, finalColl) } - @OptIn(PartiQLValueExperimental::class) private fun excludeValue(initialPartiQLValue: Datum, exclusions: List): Datum { - return when (initialPartiQLValue.type) { - PartiQLValueType.STRUCT -> excludeStruct(initialPartiQLValue, exclusions) - PartiQLValueType.BAG, PartiQLValueType.LIST, PartiQLValueType.SEXP -> excludeCollection( + return when (initialPartiQLValue.type.kind) { + PType.Kind.ROW, PType.Kind.STRUCT -> excludeFields(initialPartiQLValue, exclusions) + PType.Kind.BAG, PType.Kind.LIST, PType.Kind.SEXP -> excludeCollection( initialPartiQLValue, initialPartiQLValue.type, exclusions diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt index 1b0bb97e88..e52817db31 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt @@ -7,8 +7,7 @@ import org.partiql.eval.internal.helpers.ValueUtility.isTrue import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum import org.partiql.eval.value.Field -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType +import org.partiql.types.PType internal abstract class RelJoinNestedLoop : RelPeeking() { @@ -31,7 +30,6 @@ internal abstract class RelJoinNestedLoop : RelPeeking() { abstract fun join(condition: Boolean, lhs: Record, rhs: Record): Record? - @OptIn(PartiQLValueExperimental::class) override fun peek(): Record? { if (lhsRecord == null) { return null @@ -79,10 +77,9 @@ internal abstract class RelJoinNestedLoop : RelPeeking() { } } - @OptIn(PartiQLValueExperimental::class) private fun Datum.padNull(): Datum { - return when (this.type) { - PartiQLValueType.STRUCT -> { + return when (this.type.kind) { + PType.Kind.STRUCT, PType.Kind.ROW -> { val newFields = IteratorSupplier { this.fields }.map { Field.of(it.name, Datum.nullValue()) } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScan.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScan.kt index f9c15dadb0..b6358b9b8d 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScan.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScan.kt @@ -5,10 +5,8 @@ import org.partiql.eval.internal.Environment import org.partiql.eval.internal.Record import org.partiql.eval.internal.helpers.RecordValueIterator import org.partiql.eval.internal.operator.Operator -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType +import org.partiql.types.PType -@OptIn(PartiQLValueExperimental::class) internal class RelScan( private val expr: Operator.Expr ) : Operator.Relation { @@ -17,8 +15,8 @@ internal class RelScan( override fun open(env: Environment) { val r = expr.eval(env.push(Record.empty)) - records = when (r.type) { - PartiQLValueType.LIST, PartiQLValueType.BAG, PartiQLValueType.SEXP -> RecordValueIterator(r.iterator()) + records = when (r.type.kind) { + PType.Kind.LIST, PType.Kind.BAG, PType.Kind.SEXP -> RecordValueIterator(r.iterator()) else -> { close() throw TypeCheckException() diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanIndexed.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanIndexed.kt index e4e84048a4..53e7380e3f 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanIndexed.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanIndexed.kt @@ -5,10 +5,8 @@ import org.partiql.eval.internal.Environment import org.partiql.eval.internal.Record import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType +import org.partiql.types.PType -@OptIn(PartiQLValueExperimental::class) internal class RelScanIndexed( private val expr: Operator.Expr ) : Operator.Relation { @@ -19,12 +17,12 @@ internal class RelScanIndexed( override fun open(env: Environment) { val r = expr.eval(env.push(Record.empty)) index = 0 - iterator = when (r.type) { - PartiQLValueType.BAG -> { + iterator = when (r.type.kind) { + PType.Kind.BAG -> { close() throw TypeCheckException() } - PartiQLValueType.LIST, PartiQLValueType.SEXP -> r.iterator() + PType.Kind.LIST, PType.Kind.SEXP -> r.iterator() else -> { close() throw TypeCheckException() diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanIndexedPermissive.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanIndexedPermissive.kt index 6db479effc..d5441b26f1 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanIndexedPermissive.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanIndexedPermissive.kt @@ -4,10 +4,8 @@ import org.partiql.eval.internal.Environment import org.partiql.eval.internal.Record import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType +import org.partiql.types.PType -@OptIn(PartiQLValueExperimental::class) internal class RelScanIndexedPermissive( private val expr: Operator.Expr ) : Operator.Relation { @@ -19,12 +17,12 @@ internal class RelScanIndexedPermissive( override fun open(env: Environment) { val r = expr.eval(env.push(Record.empty)) index = 0 - iterator = when (r.type) { - PartiQLValueType.BAG -> { + iterator = when (r.type.kind) { + PType.Kind.BAG -> { isIndexable = false r.iterator() } - PartiQLValueType.LIST, PartiQLValueType.SEXP -> r.iterator() + PType.Kind.LIST, PType.Kind.SEXP -> r.iterator() else -> { isIndexable = false iterator { yield(r) } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanPermissive.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanPermissive.kt index 13f5516a98..32b5bce98d 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanPermissive.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScanPermissive.kt @@ -4,10 +4,8 @@ import org.partiql.eval.internal.Environment import org.partiql.eval.internal.Record import org.partiql.eval.internal.helpers.RecordValueIterator import org.partiql.eval.internal.operator.Operator -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType +import org.partiql.types.PType -@OptIn(PartiQLValueExperimental::class) internal class RelScanPermissive( private val expr: Operator.Expr ) : Operator.Relation { @@ -16,8 +14,8 @@ internal class RelScanPermissive( override fun open(env: Environment) { val r = expr.eval(env.push(Record.empty)) - records = when (r.type) { - PartiQLValueType.BAG, PartiQLValueType.LIST, PartiQLValueType.SEXP -> RecordValueIterator(r.iterator()) + records = when (r.type.kind) { + PType.Kind.BAG, PType.Kind.LIST, PType.Kind.SEXP -> RecordValueIterator(r.iterator()) else -> iterator { yield(Record.of(r)) } } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelUnpivot.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelUnpivot.kt index babc4407ce..4a492cf9d4 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelUnpivot.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelUnpivot.kt @@ -6,8 +6,8 @@ import org.partiql.eval.internal.Record import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum import org.partiql.eval.value.Field +import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType /** * The unpivot operator produces a bag of records from a struct. @@ -60,7 +60,7 @@ internal sealed class RelUnpivot : Operator.Relation { override fun struct(): Datum { val v = expr.eval(env.push(Record.empty)) - if (v.type != PartiQLValueType.STRUCT) { + if (v.type.kind != PType.Kind.STRUCT && v.type.kind != PType.Kind.ROW) { throw TypeCheckException() } return v @@ -80,9 +80,11 @@ internal sealed class RelUnpivot : Operator.Relation { override fun struct(): Datum { val v = expr.eval(env.push(Record.empty)) - return when (v.type) { - PartiQLValueType.STRUCT -> v - PartiQLValueType.MISSING -> Datum.structValue(emptyList()) + if (v.isMissing) { + return Datum.structValue(emptyList()) + } + return when (v.type.kind) { + PType.Kind.STRUCT, PType.Kind.ROW -> v else -> Datum.structValue(listOf(Field.of("_1", v))) } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt index 1b797e5b6f..337f28224a 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt @@ -2,12 +2,12 @@ package org.partiql.eval.internal.operator.rex import org.partiql.errors.TypeCheckException import org.partiql.eval.internal.Environment -import org.partiql.eval.internal.helpers.toNull import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum import org.partiql.plan.Ref import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnExperimental +import org.partiql.types.PType import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType @@ -61,12 +61,12 @@ internal class ExprCallDynamic( /** * Memoize creation of nulls */ - private val nil = fn.signature.returns.toNull() + private val nil = { Datum.nullValue(fn.signature.returns) } fun eval(originalArgs: Array, env: Environment): Datum { val args = originalArgs.mapIndexed { i, arg -> if (arg.isNull && fn.signature.isNullCall) { - return Datum.of(nil()) + return nil.invoke() } when (val c = coercions[i]) { null -> arg @@ -79,7 +79,7 @@ internal class ExprCallDynamic( private sealed interface CandidateIndex { - public fun get(args: List): Candidate? + public fun get(args: List): Candidate? /** * Preserves the original ordering of the passed-in candidates while making it faster to lookup matching @@ -130,7 +130,7 @@ internal class ExprCallDynamic( init { val lookupsMutable = mutableListOf() - val accumulator = mutableListOf, Candidate>>() + val accumulator = mutableListOf, Candidate>>() // Indicates that we are currently processing dynamic candidates that accept ANY. var activelyProcessingAny = true @@ -143,7 +143,7 @@ internal class ExprCallDynamic( else -> cast.input } } - val parametersIncludeAny = lookupTypes.any { it == PartiQLValueType.ANY } + val parametersIncludeAny = lookupTypes.any { it.kind == PType.Kind.DYNAMIC } // A way to simplify logic further below. If it's empty, add something and set the processing type. if (accumulator.isEmpty()) { activelyProcessingAny = parametersIncludeAny @@ -182,7 +182,7 @@ internal class ExprCallDynamic( this.lookups = lookupsMutable } - override fun get(args: List): Candidate? { + override fun get(args: List): Candidate? { return this.lookups.firstNotNullOfOrNull { it.get(args) } } } @@ -191,17 +191,17 @@ internal class ExprCallDynamic( * An O(1) structure to quickly find directly matching dynamic candidates. This is specifically used for runtime * types that can be matched directly. AKA int32, int64, etc. This does NOT include [PartiQLValueType.ANY]. */ - data class Direct private constructor(val directCandidates: HashMap, Candidate>) : CandidateIndex { + data class Direct private constructor(val directCandidates: HashMap, Candidate>) : CandidateIndex { companion object { - internal fun of(candidates: List, Candidate>>): Direct { - val candidateMap = java.util.HashMap, Candidate>() + internal fun of(candidates: List, Candidate>>): Direct { + val candidateMap = java.util.HashMap, Candidate>() candidateMap.putAll(candidates) return Direct(candidateMap) } } - override fun get(args: List): Candidate? { + override fun get(args: List): Candidate? { return directCandidates[args] } } @@ -210,11 +210,11 @@ internal class ExprCallDynamic( * Holds all candidates that expect a [PartiQLValueType.ANY] on input. This maintains the original * precedence order. */ - data class Indirect(private val candidates: List, Candidate>>) : CandidateIndex { - override fun get(args: List): Candidate? { + data class Indirect(private val candidates: List, Candidate>>) : CandidateIndex { + override fun get(args: List): Candidate? { candidates.forEach { (types, candidate) -> for (i in args.indices) { - if (args[i] != types[i] && types[i] != PartiQLValueType.ANY) { + if (args[i] != types[i] && types[i].kind != PType.Kind.DYNAMIC) { return@forEach } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallStatic.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallStatic.kt index d52ed3f21c..cfa9347646 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallStatic.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallStatic.kt @@ -1,7 +1,6 @@ package org.partiql.eval.internal.operator.rex import org.partiql.eval.internal.Environment -import org.partiql.eval.internal.helpers.toNull import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum import org.partiql.spi.fn.Fn @@ -17,13 +16,13 @@ internal class ExprCallStatic( /** * Memoize creation of nulls */ - private val nil = fn.signature.returns.toNull() + private val nil = { Datum.nullValue(fn.signature.returns) } override fun eval(env: Environment): Datum { // Evaluate arguments val args = inputs.map { input -> val r = input.eval(env) - if (r.isNull && fn.signature.isNullCall) return Datum.of(nil()) + if (r.isNull && fn.signature.isNullCall) return nil.invoke() r.toPartiQLValue() }.toTypedArray() return Datum.of(fn.invoke(args)) diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCase.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCase.kt index 45e5813f6c..a5be4d7f0d 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCase.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCase.kt @@ -3,15 +3,13 @@ package org.partiql.eval.internal.operator.rex import org.partiql.eval.internal.Environment import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType +import org.partiql.types.PType internal class ExprCase( private val branches: List>, private val default: Operator.Expr ) : Operator.Expr { - @OptIn(PartiQLValueExperimental::class) override fun eval(env: Environment): Datum { branches.forEach { branch -> val condition = branch.first.eval(env) @@ -22,8 +20,7 @@ internal class ExprCase( return default.eval(env) } - @OptIn(PartiQLValueExperimental::class) private fun Datum.isTrue(): Boolean { - return this.type == PartiQLValueType.BOOL && !this.isNull && this.boolean + return this.type.kind == PType.Kind.BOOL && !this.isNull && this.boolean } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCast.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCast.kt index 86268f2c26..eb67517299 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCast.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCast.kt @@ -10,6 +10,7 @@ import org.partiql.eval.internal.Environment import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum import org.partiql.plan.Ref +import org.partiql.types.PType import org.partiql.value.BagValue import org.partiql.value.BoolValue import org.partiql.value.CollectionValue @@ -26,16 +27,13 @@ import org.partiql.value.NullValue import org.partiql.value.NumericValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType import org.partiql.value.SexpValue import org.partiql.value.StringValue import org.partiql.value.SymbolValue import org.partiql.value.TextValue import org.partiql.value.bagValue -import org.partiql.value.binaryValue import org.partiql.value.blobValue import org.partiql.value.boolValue -import org.partiql.value.byteValue import org.partiql.value.charValue import org.partiql.value.clobValue import org.partiql.value.dateValue @@ -48,7 +46,6 @@ import org.partiql.value.int64Value import org.partiql.value.int8Value import org.partiql.value.intValue import org.partiql.value.listValue -import org.partiql.value.missingValue import org.partiql.value.sexpValue import org.partiql.value.stringValue import org.partiql.value.structValue @@ -64,35 +61,35 @@ internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.E override fun eval(env: Environment): Datum { val arg = arg.eval(env).toPartiQLValue() try { - val partiqlValue = when (arg.type) { - PartiQLValueType.ANY -> TODO("Not Possible") - PartiQLValueType.BOOL -> castFromBool(arg as BoolValue, cast.target) - PartiQLValueType.INT8 -> castFromNumeric(arg as Int8Value, cast.target) - PartiQLValueType.INT16 -> castFromNumeric(arg as Int16Value, cast.target) - PartiQLValueType.INT32 -> castFromNumeric(arg as Int32Value, cast.target) - PartiQLValueType.INT64 -> castFromNumeric(arg as Int64Value, cast.target) - PartiQLValueType.INT -> castFromNumeric(arg as IntValue, cast.target) - PartiQLValueType.DECIMAL -> castFromNumeric(arg as DecimalValue, cast.target) - PartiQLValueType.DECIMAL_ARBITRARY -> castFromNumeric(arg as DecimalValue, cast.target) - PartiQLValueType.FLOAT32 -> castFromNumeric(arg as Float32Value, cast.target) - PartiQLValueType.FLOAT64 -> castFromNumeric(arg as Float64Value, cast.target) - PartiQLValueType.CHAR -> TODO("Char value implementation is wrong") - PartiQLValueType.STRING -> castFromText(arg as StringValue, cast.target) - PartiQLValueType.SYMBOL -> castFromText(arg as SymbolValue, cast.target) - PartiQLValueType.BINARY -> TODO("Static Type does not support Binary") - PartiQLValueType.BYTE -> TODO("Static Type does not support Byte") - PartiQLValueType.BLOB -> TODO("CAST FROM BLOB not yet implemented") - PartiQLValueType.CLOB -> TODO("CAST FROM CLOB not yet implemented") - PartiQLValueType.DATE -> TODO("CAST FROM DATE not yet implemented") - PartiQLValueType.TIME -> TODO("CAST FROM TIME not yet implemented") - PartiQLValueType.TIMESTAMP -> TODO("CAST FROM TIMESTAMP not yet implemented") - PartiQLValueType.INTERVAL -> TODO("Static Type does not support INTERVAL") - PartiQLValueType.BAG -> castFromCollection(arg as BagValue<*>, cast.target) - PartiQLValueType.LIST -> castFromCollection(arg as ListValue<*>, cast.target) - PartiQLValueType.SEXP -> castFromCollection(arg as SexpValue<*>, cast.target) - PartiQLValueType.STRUCT -> TODO("CAST FROM STRUCT not yet implemented") - PartiQLValueType.NULL -> castFromNull(arg as NullValue, cast.target) - PartiQLValueType.MISSING -> error("cast from MISSING should be handled by Typer") + val partiqlValue = when (PType.fromPartiQLValueType(arg.type).kind) { + PType.Kind.DYNAMIC -> TODO("Not Possible") + PType.Kind.BOOL -> castFromBool(arg as BoolValue, cast.target) + PType.Kind.TINYINT -> castFromNumeric(arg as Int8Value, cast.target) + PType.Kind.SMALLINT -> castFromNumeric(arg as Int16Value, cast.target) + PType.Kind.INT -> castFromNumeric(arg as Int32Value, cast.target) + PType.Kind.BIGINT -> castFromNumeric(arg as Int64Value, cast.target) + PType.Kind.INT_ARBITRARY -> castFromNumeric(arg as IntValue, cast.target) + PType.Kind.DECIMAL -> castFromNumeric(arg as DecimalValue, cast.target) + PType.Kind.DECIMAL_ARBITRARY -> castFromNumeric(arg as DecimalValue, cast.target) + PType.Kind.REAL -> castFromNumeric(arg as Float32Value, cast.target) + PType.Kind.DOUBLE_PRECISION -> castFromNumeric(arg as Float64Value, cast.target) + PType.Kind.CHAR -> TODO("Char value implementation is wrong") + PType.Kind.STRING -> castFromText(arg as StringValue, cast.target) + PType.Kind.SYMBOL -> castFromText(arg as SymbolValue, cast.target) + PType.Kind.BLOB -> TODO("CAST FROM BLOB not yet implemented") + PType.Kind.CLOB -> TODO("CAST FROM CLOB not yet implemented") + PType.Kind.DATE -> TODO("CAST FROM DATE not yet implemented") + PType.Kind.TIME_WITH_TZ -> TODO("CAST FROM TIME not yet implemented") + PType.Kind.TIME_WITHOUT_TZ -> TODO("CAST FROM TIME not yet implemented") + PType.Kind.TIMESTAMP_WITH_TZ -> TODO("CAST FROM TIMESTAMP not yet implemented") + PType.Kind.TIMESTAMP_WITHOUT_TZ -> TODO("CAST FROM TIMESTAMP not yet implemented") + PType.Kind.BAG -> castFromCollection(arg as BagValue<*>, cast.target) + PType.Kind.LIST -> castFromCollection(arg as ListValue<*>, cast.target) + PType.Kind.SEXP -> castFromCollection(arg as SexpValue<*>, cast.target) + PType.Kind.STRUCT -> TODO("CAST FROM STRUCT not yet implemented") + PType.Kind.ROW -> TODO("CAST FROM ROW not yet implemented") + PType.Kind.UNKNOWN -> TODO("CAST FROM UNKNOWN not yet implemented") + PType.Kind.VARCHAR -> TODO("CAST FROM VARCHAR not yet implemented") } return Datum.of(partiqlValue) } catch (e: DataException) { @@ -101,82 +98,82 @@ internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.E } @OptIn(PartiQLValueExperimental::class) - private fun castFromNull(value: NullValue, t: PartiQLValueType): PartiQLValue { - return when (t) { - PartiQLValueType.ANY -> value - PartiQLValueType.BOOL -> boolValue(null) - PartiQLValueType.CHAR -> charValue(null) - PartiQLValueType.STRING -> stringValue(null) - PartiQLValueType.SYMBOL -> symbolValue(null) - PartiQLValueType.BINARY -> binaryValue(null) - PartiQLValueType.BYTE -> byteValue(null) - PartiQLValueType.BLOB -> blobValue(null) - PartiQLValueType.CLOB -> clobValue(null) - PartiQLValueType.DATE -> dateValue(null) - PartiQLValueType.TIME -> timeValue(null) - PartiQLValueType.TIMESTAMP -> timestampValue(null) - PartiQLValueType.INTERVAL -> TODO("Not yet supported") - PartiQLValueType.BAG -> bagValue(null) - PartiQLValueType.LIST -> listValue(null) - PartiQLValueType.SEXP -> sexpValue(null) - PartiQLValueType.STRUCT -> structValue(null) - PartiQLValueType.NULL -> value - PartiQLValueType.MISSING -> missingValue() // TODO: Os this allowed - PartiQLValueType.INT8 -> int8Value(null) - PartiQLValueType.INT16 -> int16Value(null) - PartiQLValueType.INT32 -> int32Value(null) - PartiQLValueType.INT64 -> int64Value(null) - PartiQLValueType.INT -> intValue(null) - PartiQLValueType.DECIMAL -> decimalValue(null) - PartiQLValueType.DECIMAL_ARBITRARY -> decimalValue(null) - PartiQLValueType.FLOAT32 -> float32Value(null) - PartiQLValueType.FLOAT64 -> float64Value(null) + private fun castFromNull(value: NullValue, t: PType): PartiQLValue { + return when (t.kind) { + PType.Kind.DYNAMIC -> value + PType.Kind.BOOL -> boolValue(null) + PType.Kind.CHAR -> charValue(null) + PType.Kind.VARCHAR -> TODO("There is no VAR CHAR implementation") + PType.Kind.STRING -> stringValue(null) + PType.Kind.SYMBOL -> symbolValue(null) + PType.Kind.BLOB -> blobValue(null) + PType.Kind.CLOB -> clobValue(null) + PType.Kind.DATE -> dateValue(null) + PType.Kind.TIME_WITH_TZ -> timeValue(null) // TODO + PType.Kind.TIME_WITHOUT_TZ -> timeValue(null) + PType.Kind.TIMESTAMP_WITH_TZ -> timestampValue(null) // TODO + PType.Kind.TIMESTAMP_WITHOUT_TZ -> timestampValue(null) + PType.Kind.BAG -> bagValue(null) + PType.Kind.LIST -> listValue(null) + PType.Kind.SEXP -> sexpValue(null) + PType.Kind.STRUCT -> structValue(null) + PType.Kind.TINYINT -> int8Value(null) + PType.Kind.SMALLINT -> int16Value(null) + PType.Kind.INT -> int32Value(null) + PType.Kind.BIGINT -> int64Value(null) + PType.Kind.INT_ARBITRARY -> intValue(null) + PType.Kind.DECIMAL -> decimalValue(null) + PType.Kind.DECIMAL_ARBITRARY -> decimalValue(null) + PType.Kind.REAL -> float32Value(null) + PType.Kind.DOUBLE_PRECISION -> float64Value(null) + PType.Kind.ROW -> structValue(null) // TODO. PartiQLValue doesn't have rows. + PType.Kind.UNKNOWN -> TODO() } } @OptIn(PartiQLValueExperimental::class) - private fun castFromBool(value: BoolValue, t: PartiQLValueType): PartiQLValue { + private fun castFromBool(value: BoolValue, t: PType): PartiQLValue { val v = value.value - return when (t) { - PartiQLValueType.ANY -> value - PartiQLValueType.BOOL -> value - PartiQLValueType.INT8 -> when (v) { + return when (t.kind) { + PType.Kind.DYNAMIC -> value + PType.Kind.BOOL -> value + PType.Kind.TINYINT -> when (v) { true -> int8Value(1) false -> int8Value(0) null -> int8Value(null) } - PartiQLValueType.INT16 -> when (v) { + PType.Kind.SMALLINT -> when (v) { true -> int16Value(1) false -> int16Value(0) null -> int16Value(null) } - PartiQLValueType.INT32 -> when (v) { + PType.Kind.INT -> when (v) { true -> int32Value(1) false -> int32Value(0) null -> int32Value(null) } - PartiQLValueType.INT64 -> when (v) { + PType.Kind.BIGINT -> when (v) { true -> int64Value(1) false -> int64Value(0) null -> int64Value(null) } - PartiQLValueType.INT -> when (v) { + PType.Kind.INT_ARBITRARY -> when (v) { true -> intValue(BigInteger.valueOf(1)) false -> intValue(BigInteger.valueOf(0)) null -> intValue(null) } - PartiQLValueType.DECIMAL, PartiQLValueType.DECIMAL_ARBITRARY -> when (v) { + PType.Kind.DECIMAL, PType.Kind.DECIMAL_ARBITRARY -> when (v) { true -> decimalValue(BigDecimal.ONE) false -> decimalValue(BigDecimal.ZERO) null -> decimalValue(null) } - PartiQLValueType.FLOAT32 -> { + PType.Kind.REAL -> { when (v) { true -> float32Value(1.0.toFloat()) false -> float32Value(0.0.toFloat()) @@ -184,159 +181,158 @@ internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.E } } - PartiQLValueType.FLOAT64 -> when (v) { + PType.Kind.DOUBLE_PRECISION -> when (v) { true -> float64Value(1.0) false -> float64Value(0.0) null -> float64Value(null) } - PartiQLValueType.CHAR -> TODO("Char value implementation is wrong") - PartiQLValueType.STRING -> stringValue(v?.toString()) - PartiQLValueType.SYMBOL -> symbolValue(v?.toString()) - PartiQLValueType.BINARY, PartiQLValueType.BYTE, - PartiQLValueType.BLOB, PartiQLValueType.CLOB, - PartiQLValueType.DATE, PartiQLValueType.TIME, PartiQLValueType.TIMESTAMP, - PartiQLValueType.INTERVAL, - PartiQLValueType.BAG, PartiQLValueType.LIST, - PartiQLValueType.SEXP, - PartiQLValueType.STRUCT -> error("can not perform cast from $value to $t") - PartiQLValueType.NULL -> error("cast to null not supported") - PartiQLValueType.MISSING -> error("cast to missing not supported") + PType.Kind.CHAR -> TODO("Char value implementation is wrong") + PType.Kind.VARCHAR -> TODO("There is no VAR CHAR implementation") + PType.Kind.STRING -> stringValue(v?.toString()) + PType.Kind.SYMBOL -> symbolValue(v?.toString()) + PType.Kind.BLOB, PType.Kind.CLOB, + PType.Kind.DATE, PType.Kind.TIMESTAMP_WITH_TZ, PType.Kind.TIMESTAMP_WITHOUT_TZ, PType.Kind.TIME_WITH_TZ, + PType.Kind.TIME_WITHOUT_TZ, PType.Kind.BAG, PType.Kind.LIST, + PType.Kind.SEXP, + PType.Kind.ROW, + PType.Kind.STRUCT -> error("can not perform cast from $value to $t") + PType.Kind.UNKNOWN -> TODO() } } @OptIn(PartiQLValueExperimental::class) - private fun castFromNumeric(value: NumericValue<*>, t: PartiQLValueType): PartiQLValue { + private fun castFromNumeric(value: NumericValue<*>, t: PType): PartiQLValue { val v = value.value - return when (t) { - PartiQLValueType.ANY -> value - PartiQLValueType.BOOL -> when { + return when (t.kind) { + PType.Kind.DYNAMIC -> value + PType.Kind.BOOL -> when { v == null -> boolValue(null) v == 0.0 -> boolValue(false) else -> boolValue(true) } - PartiQLValueType.INT8 -> value.toInt8() - PartiQLValueType.INT16 -> value.toInt16() - PartiQLValueType.INT32 -> value.toInt32() - PartiQLValueType.INT64 -> value.toInt64() - PartiQLValueType.INT -> value.toInt() - PartiQLValueType.DECIMAL -> value.toDecimal() - PartiQLValueType.DECIMAL_ARBITRARY -> value.toDecimal() - PartiQLValueType.FLOAT32 -> value.toFloat32() - PartiQLValueType.FLOAT64 -> value.toFloat64() - PartiQLValueType.CHAR -> TODO("Char value implementation is wrong") - PartiQLValueType.STRING -> stringValue(v?.toString(), value.annotations) - PartiQLValueType.SYMBOL -> symbolValue(v?.toString(), value.annotations) - PartiQLValueType.BINARY, PartiQLValueType.BYTE, - PartiQLValueType.BLOB, PartiQLValueType.CLOB, - PartiQLValueType.DATE, PartiQLValueType.TIME, PartiQLValueType.TIMESTAMP, - PartiQLValueType.INTERVAL, - PartiQLValueType.BAG, PartiQLValueType.LIST, - PartiQLValueType.SEXP, - PartiQLValueType.STRUCT -> error("can not perform cast from $value to $t") - PartiQLValueType.NULL -> error("cast to null not supported") - PartiQLValueType.MISSING -> error("cast to missing not supported") + PType.Kind.TINYINT -> value.toInt8() + PType.Kind.SMALLINT -> value.toInt16() + PType.Kind.INT -> value.toInt32() + PType.Kind.BIGINT -> value.toInt64() + PType.Kind.INT_ARBITRARY -> value.toInt() + PType.Kind.DECIMAL -> value.toDecimal() + PType.Kind.DECIMAL_ARBITRARY -> value.toDecimal() + PType.Kind.REAL -> value.toFloat32() + PType.Kind.DOUBLE_PRECISION -> value.toFloat64() + PType.Kind.CHAR -> TODO("Char value implementation is wrong") + PType.Kind.VARCHAR -> TODO("There is no VAR CHAR implementation") + PType.Kind.STRING -> stringValue(v?.toString(), value.annotations) + PType.Kind.SYMBOL -> symbolValue(v?.toString(), value.annotations) + PType.Kind.BLOB, PType.Kind.CLOB, + PType.Kind.DATE, PType.Kind.TIME_WITH_TZ, PType.Kind.TIME_WITHOUT_TZ, PType.Kind.TIMESTAMP_WITH_TZ, + PType.Kind.TIMESTAMP_WITHOUT_TZ, + PType.Kind.BAG, PType.Kind.LIST, + PType.Kind.SEXP, + PType.Kind.STRUCT -> error("can not perform cast from $value to $t") + PType.Kind.ROW -> error("can not perform cast from $value to $t") + PType.Kind.UNKNOWN -> TODO() } } @OptIn(PartiQLValueExperimental::class) - private fun castFromText(value: TextValue, t: PartiQLValueType): PartiQLValue { - return when (t) { - PartiQLValueType.ANY -> value - PartiQLValueType.BOOL -> { + private fun castFromText(value: TextValue, t: PType): PartiQLValue { + return when (t.kind) { + PType.Kind.DYNAMIC -> value + PType.Kind.BOOL -> { val str = value.value?.lowercase() ?: return boolValue(null, value.annotations) if (str == "true") return boolValue(true, value.annotations) if (str == "false") return boolValue(false, value.annotations) throw TypeCheckException() } - PartiQLValueType.INT8 -> { + PType.Kind.TINYINT -> { val stringValue = value.value ?: return int8Value(null, value.annotations) when (val number = getNumberValueFromString(stringValue)) { is BigInteger -> intValue(number, value.annotations).toInt8() else -> throw TypeCheckException() } } - PartiQLValueType.INT16 -> { + PType.Kind.SMALLINT -> { val stringValue = value.value ?: return int16Value(null, value.annotations) when (val number = getNumberValueFromString(stringValue)) { is BigInteger -> intValue(number, value.annotations).toInt16() else -> throw TypeCheckException() } } - PartiQLValueType.INT32 -> { + PType.Kind.INT -> { val stringValue = value.value ?: return int32Value(null, value.annotations) when (val number = getNumberValueFromString(stringValue)) { is BigInteger -> intValue(number, value.annotations).toInt32() else -> throw TypeCheckException() } } - PartiQLValueType.INT64 -> { + PType.Kind.BIGINT -> { val stringValue = value.value ?: return int64Value(null, value.annotations) when (val number = getNumberValueFromString(stringValue)) { is BigInteger -> intValue(number, value.annotations).toInt64() else -> throw TypeCheckException() } } - PartiQLValueType.INT -> { + PType.Kind.INT_ARBITRARY -> { val stringValue = value.value ?: return intValue(null, value.annotations) when (val number = getNumberValueFromString(stringValue)) { is BigInteger -> intValue(number, value.annotations).toInt() else -> throw TypeCheckException() } } - PartiQLValueType.DECIMAL -> { + PType.Kind.DECIMAL -> { val stringValue = value.value ?: return int16Value(null, value.annotations) when (val number = getNumberValueFromString(stringValue)) { is Decimal -> decimalValue(number, value.annotations).toDecimal() else -> throw TypeCheckException() } } - PartiQLValueType.DECIMAL_ARBITRARY -> { + PType.Kind.DECIMAL_ARBITRARY -> { val stringValue = value.value ?: return int16Value(null, value.annotations) when (val number = getNumberValueFromString(stringValue)) { is Decimal -> decimalValue(number, value.annotations).toDecimal() else -> throw TypeCheckException() } } - PartiQLValueType.FLOAT32 -> { + PType.Kind.REAL -> { val stringValue = value.value ?: return int16Value(null, value.annotations) when (val number = getNumberValueFromString(stringValue)) { is Double -> float64Value(number, value.annotations).toFloat32() else -> throw TypeCheckException() } } - PartiQLValueType.FLOAT64 -> { + PType.Kind.DOUBLE_PRECISION -> { val stringValue = value.value ?: return int16Value(null, value.annotations) when (val number = getNumberValueFromString(stringValue)) { is Double -> float64Value(number, value.annotations).toFloat32() else -> throw TypeCheckException() } } - PartiQLValueType.CHAR -> TODO("Char value implementation is wrong") - PartiQLValueType.STRING -> stringValue(value.value, value.annotations) - PartiQLValueType.SYMBOL -> symbolValue(value.value, value.annotations) - PartiQLValueType.BINARY, PartiQLValueType.BYTE, - PartiQLValueType.BLOB, PartiQLValueType.CLOB, - PartiQLValueType.DATE, PartiQLValueType.TIME, PartiQLValueType.TIMESTAMP, - PartiQLValueType.INTERVAL, - PartiQLValueType.BAG, PartiQLValueType.LIST, - PartiQLValueType.SEXP, - PartiQLValueType.STRUCT -> error("can not perform cast from struct to $t") - PartiQLValueType.NULL -> error("cast to null not supported") - PartiQLValueType.MISSING -> error("cast to missing not supported") + PType.Kind.CHAR -> TODO("CHAR implementation is wrong.") + PType.Kind.VARCHAR -> TODO("There is no VAR CHAR implementation") + PType.Kind.STRING -> stringValue(value.value, value.annotations) + PType.Kind.SYMBOL -> symbolValue(value.value, value.annotations) + PType.Kind.BLOB, PType.Kind.CLOB, + PType.Kind.DATE, PType.Kind.TIME_WITH_TZ, PType.Kind.TIME_WITHOUT_TZ, PType.Kind.TIMESTAMP_WITH_TZ, + PType.Kind.TIMESTAMP_WITHOUT_TZ, + PType.Kind.BAG, PType.Kind.LIST, + PType.Kind.SEXP, + PType.Kind.STRUCT -> error("can not perform cast from struct to $t") + PType.Kind.ROW -> error("can not perform cast from $value to $t") + PType.Kind.UNKNOWN -> error("can not perform cast from $value to $t") } } // TODO: Fix NULL Collection @OptIn(PartiQLValueExperimental::class) - private fun castFromCollection(value: CollectionValue<*>, t: PartiQLValueType): PartiQLValue { + private fun castFromCollection(value: CollectionValue<*>, t: PType): PartiQLValue { val elements = mutableListOf() value.iterator().forEachRemaining { elements.add(it) } - return when (t) { - PartiQLValueType.BAG -> bagValue(elements) - PartiQLValueType.LIST -> listValue(elements) - PartiQLValueType.SEXP -> sexpValue(elements) + return when (t.kind) { + PType.Kind.BAG -> bagValue(elements) + PType.Kind.LIST -> listValue(elements) + PType.Kind.SEXP -> sexpValue(elements) else -> error("can not perform cast from $value to $t") } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCoalesce.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCoalesce.kt index 7b79ca32ca..0b1d9b9a73 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCoalesce.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCoalesce.kt @@ -4,7 +4,6 @@ import org.partiql.eval.internal.Environment import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType internal class ExprCoalesce( private val args: Array @@ -14,7 +13,7 @@ internal class ExprCoalesce( override fun eval(env: Environment): Datum { for (arg in args) { val result = arg.eval(env) - if (!result.isNull && result.type != PartiQLValueType.MISSING) { + if (!result.isNull && !result.isMissing) { return result } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCollection.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCollection.kt index ccc0008154..4944d865ee 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCollection.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCollection.kt @@ -3,23 +3,20 @@ package org.partiql.eval.internal.operator.rex import org.partiql.eval.internal.Environment import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum -import org.partiql.types.BagType -import org.partiql.types.ListType -import org.partiql.types.SexpType -import org.partiql.types.StaticType +import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental internal class ExprCollection( private val values: List, - private val type: StaticType + private val type: PType ) : Operator.Expr { @PartiQLValueExperimental override fun eval(env: Environment): Datum { - return when (type) { - is BagType -> Datum.bagValue(values.map { it.eval(env) }) - is SexpType -> Datum.sexpValue(values.map { it.eval(env) }) - is ListType -> Datum.listValue(values.map { it.eval(env) }) + return when (type.kind) { + PType.Kind.BAG -> Datum.bagValue(values.map { it.eval(env) }) + PType.Kind.SEXP -> Datum.sexpValue(values.map { it.eval(env) }) + PType.Kind.LIST -> Datum.listValue(values.map { it.eval(env) }) else -> error("Unsupported type for collection $type") } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprLiteral.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprLiteral.kt index 100ddfd243..7736e7b503 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprLiteral.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprLiteral.kt @@ -5,7 +5,7 @@ import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum import org.partiql.value.PartiQLValueExperimental -internal class ExprLiteral @OptIn(PartiQLValueExperimental::class) constructor(private val value: Datum) : Operator.Expr { +internal class ExprLiteral(private val value: Datum) : Operator.Expr { @PartiQLValueExperimental override fun eval(env: Environment): Datum { return value diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathIndex.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathIndex.kt index 8b29f7b586..c33fc281df 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathIndex.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathIndex.kt @@ -5,19 +5,17 @@ import org.partiql.eval.internal.Environment import org.partiql.eval.internal.helpers.ValueUtility.getInt32Coerced import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType +import org.partiql.types.PType internal class ExprPathIndex( @JvmField val root: Operator.Expr, @JvmField val key: Operator.Expr, ) : Operator.Expr { - @OptIn(PartiQLValueExperimental::class) override fun eval(env: Environment): Datum { val input = root.eval(env) - val iterator = when (input.type) { - PartiQLValueType.BAG, PartiQLValueType.LIST, PartiQLValueType.SEXP -> input.iterator() + val iterator = when (input.type.kind) { + PType.Kind.BAG, PType.Kind.LIST, PType.Kind.SEXP -> input.iterator() else -> throw TypeCheckException() } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt index b9a35ae9b9..782cd94217 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt @@ -4,11 +4,8 @@ import org.partiql.eval.internal.Environment import org.partiql.eval.internal.helpers.ValueUtility.getText import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType internal class ExprStruct(private val fields: List) : Operator.Expr { - @OptIn(PartiQLValueExperimental::class) override fun eval(env: Environment): Datum { val fields = fields.mapNotNull { val key = it.key.eval(env) @@ -17,9 +14,9 @@ internal class ExprStruct(private val fields: List) : Operator.Expr { } val keyString = key.getText() val value = it.value.eval(env) - when (value.type) { - PartiQLValueType.MISSING -> null - else -> org.partiql.eval.value.Field.of(keyString, value) + when (value.isMissing) { + true -> null + false -> org.partiql.eval.value.Field.of(keyString, value) } } return Datum.structValue(fields) diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprTupleUnion.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprTupleUnion.kt index 73031ce55b..119ef69528 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprTupleUnion.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprTupleUnion.kt @@ -4,6 +4,7 @@ import org.partiql.eval.internal.Environment import org.partiql.eval.internal.helpers.ValueUtility.check import org.partiql.eval.internal.operator.Operator import org.partiql.eval.value.Datum +import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType @@ -20,7 +21,7 @@ internal class ExprTupleUnion( // Return NULL if any arguments are NULL tuples.forEach { if (it.isNull) { - return Datum.nullValue(PartiQLValueType.STRUCT) + return Datum.nullValue(PType.typeStruct()) } } diff --git a/partiql-plan/api/partiql-plan.api b/partiql-plan/api/partiql-plan.api index 0afdf62eea..15010e0526 100644 --- a/partiql-plan/api/partiql-plan.api +++ b/partiql-plan/api/partiql-plan.api @@ -68,14 +68,14 @@ public final class org/partiql/plan/Catalog$Item$Fn$Companion { public final class org/partiql/plan/Catalog$Item$Value : org/partiql/plan/Catalog$Item { public static final field Companion Lorg/partiql/plan/Catalog$Item$Value$Companion; public final field path Ljava/util/List; - public final field type Lorg/partiql/types/StaticType; - public fun (Ljava/util/List;Lorg/partiql/types/StaticType;)V + public final field type Lorg/partiql/types/PType; + public fun (Ljava/util/List;Lorg/partiql/types/PType;)V public fun accept (Lorg/partiql/plan/visitor/PlanVisitor;Ljava/lang/Object;)Ljava/lang/Object; public static final fun builder ()Lorg/partiql/plan/builder/CatalogItemValueBuilder; public final fun component1 ()Ljava/util/List; - public final fun component2 ()Lorg/partiql/types/StaticType; - public final fun copy (Ljava/util/List;Lorg/partiql/types/StaticType;)Lorg/partiql/plan/Catalog$Item$Value; - public static synthetic fun copy$default (Lorg/partiql/plan/Catalog$Item$Value;Ljava/util/List;Lorg/partiql/types/StaticType;ILjava/lang/Object;)Lorg/partiql/plan/Catalog$Item$Value; + public final fun component2 ()Lorg/partiql/types/PType; + public final fun copy (Ljava/util/List;Lorg/partiql/types/PType;)Lorg/partiql/plan/Catalog$Item$Value; + public static synthetic fun copy$default (Lorg/partiql/plan/Catalog$Item$Value;Ljava/util/List;Lorg/partiql/types/PType;ILjava/lang/Object;)Lorg/partiql/plan/Catalog$Item$Value; public fun equals (Ljava/lang/Object;)Z public fun getChildren ()Ljava/util/List; public fun hashCode ()I @@ -164,14 +164,14 @@ public final class org/partiql/plan/Plan { public static final fun catalog (Ljava/lang/String;Ljava/util/List;)Lorg/partiql/plan/Catalog; public static final fun catalogItemAgg (Ljava/util/List;Ljava/lang/String;)Lorg/partiql/plan/Catalog$Item$Agg; public static final fun catalogItemFn (Ljava/util/List;Ljava/lang/String;)Lorg/partiql/plan/Catalog$Item$Fn; - public static final fun catalogItemValue (Ljava/util/List;Lorg/partiql/types/StaticType;)Lorg/partiql/plan/Catalog$Item$Value; + public static final fun catalogItemValue (Ljava/util/List;Lorg/partiql/types/PType;)Lorg/partiql/plan/Catalog$Item$Value; public static final fun identifierQualified (Lorg/partiql/plan/Identifier$Symbol;Ljava/util/List;)Lorg/partiql/plan/Identifier$Qualified; public static final fun identifierSymbol (Ljava/lang/String;Lorg/partiql/plan/Identifier$CaseSensitivity;)Lorg/partiql/plan/Identifier$Symbol; public static final fun partiQLPlan (Ljava/util/List;Lorg/partiql/plan/Statement;)Lorg/partiql/plan/PartiQLPlan; public static final fun ref (II)Lorg/partiql/plan/Ref; - public static final fun refCast (Lorg/partiql/value/PartiQLValueType;Lorg/partiql/value/PartiQLValueType;Z)Lorg/partiql/plan/Ref$Cast; + public static final fun refCast (Lorg/partiql/types/PType;Lorg/partiql/types/PType;Z)Lorg/partiql/plan/Ref$Cast; public static final fun rel (Lorg/partiql/plan/Rel$Type;Lorg/partiql/plan/Rel$Op;)Lorg/partiql/plan/Rel; - public static final fun relBinding (Ljava/lang/String;Lorg/partiql/types/StaticType;)Lorg/partiql/plan/Rel$Binding; + public static final fun relBinding (Ljava/lang/String;Lorg/partiql/types/PType;)Lorg/partiql/plan/Rel$Binding; public static final fun relOpAggregate (Lorg/partiql/plan/Rel;Lorg/partiql/plan/Rel$Op$Aggregate$Strategy;Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/Rel$Op$Aggregate; public static final fun relOpAggregateCall (Lorg/partiql/plan/Ref;Lorg/partiql/plan/Rel$Op$Aggregate$Call$SetQuantifier;Ljava/util/List;)Lorg/partiql/plan/Rel$Op$Aggregate$Call; public static final fun relOpDistinct (Lorg/partiql/plan/Rel;)Lorg/partiql/plan/Rel$Op$Distinct; @@ -198,7 +198,7 @@ public final class org/partiql/plan/Plan { public static final fun relOpSortSpec (Lorg/partiql/plan/Rex;Lorg/partiql/plan/Rel$Op$Sort$Order;)Lorg/partiql/plan/Rel$Op$Sort$Spec; public static final fun relOpUnpivot (Lorg/partiql/plan/Rex;)Lorg/partiql/plan/Rel$Op$Unpivot; public static final fun relType (Ljava/util/List;Ljava/util/Set;)Lorg/partiql/plan/Rel$Type; - public static final fun rex (Lorg/partiql/types/StaticType;Lorg/partiql/plan/Rex$Op;)Lorg/partiql/plan/Rex; + public static final fun rex (Lorg/partiql/types/PType;Lorg/partiql/plan/Rex$Op;)Lorg/partiql/plan/Rex; public static final fun rexOpCallDynamic (Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/Rex$Op$Call$Dynamic; public static final fun rexOpCallDynamicCandidate (Lorg/partiql/plan/Ref;Ljava/util/List;)Lorg/partiql/plan/Rex$Op$Call$Dynamic$Candidate; public static final fun rexOpCallStatic (Lorg/partiql/plan/Ref;Ljava/util/List;)Lorg/partiql/plan/Rex$Op$Call$Static; @@ -251,17 +251,17 @@ public final class org/partiql/plan/Ref : org/partiql/plan/PlanNode { public final class org/partiql/plan/Ref$Cast : org/partiql/plan/PlanNode { public static final field Companion Lorg/partiql/plan/Ref$Cast$Companion; - public final field input Lorg/partiql/value/PartiQLValueType; + public final field input Lorg/partiql/types/PType; public final field isNullable Z - public final field target Lorg/partiql/value/PartiQLValueType; - public fun (Lorg/partiql/value/PartiQLValueType;Lorg/partiql/value/PartiQLValueType;Z)V + public final field target Lorg/partiql/types/PType; + public fun (Lorg/partiql/types/PType;Lorg/partiql/types/PType;Z)V public fun accept (Lorg/partiql/plan/visitor/PlanVisitor;Ljava/lang/Object;)Ljava/lang/Object; public static final fun builder ()Lorg/partiql/plan/builder/RefCastBuilder; - public final fun component1 ()Lorg/partiql/value/PartiQLValueType; - public final fun component2 ()Lorg/partiql/value/PartiQLValueType; + public final fun component1 ()Lorg/partiql/types/PType; + public final fun component2 ()Lorg/partiql/types/PType; public final fun component3 ()Z - public final fun copy (Lorg/partiql/value/PartiQLValueType;Lorg/partiql/value/PartiQLValueType;Z)Lorg/partiql/plan/Ref$Cast; - public static synthetic fun copy$default (Lorg/partiql/plan/Ref$Cast;Lorg/partiql/value/PartiQLValueType;Lorg/partiql/value/PartiQLValueType;ZILjava/lang/Object;)Lorg/partiql/plan/Ref$Cast; + public final fun copy (Lorg/partiql/types/PType;Lorg/partiql/types/PType;Z)Lorg/partiql/plan/Ref$Cast; + public static synthetic fun copy$default (Lorg/partiql/plan/Ref$Cast;Lorg/partiql/types/PType;Lorg/partiql/types/PType;ZILjava/lang/Object;)Lorg/partiql/plan/Ref$Cast; public fun equals (Ljava/lang/Object;)Z public fun getChildren ()Ljava/util/List; public fun hashCode ()I @@ -296,14 +296,14 @@ public final class org/partiql/plan/Rel : org/partiql/plan/PlanNode { public final class org/partiql/plan/Rel$Binding : org/partiql/plan/PlanNode { public static final field Companion Lorg/partiql/plan/Rel$Binding$Companion; public final field name Ljava/lang/String; - public final field type Lorg/partiql/types/StaticType; - public fun (Ljava/lang/String;Lorg/partiql/types/StaticType;)V + public final field type Lorg/partiql/types/PType; + public fun (Ljava/lang/String;Lorg/partiql/types/PType;)V public fun accept (Lorg/partiql/plan/visitor/PlanVisitor;Ljava/lang/Object;)Ljava/lang/Object; public static final fun builder ()Lorg/partiql/plan/builder/RelBindingBuilder; public final fun component1 ()Ljava/lang/String; - public final fun component2 ()Lorg/partiql/types/StaticType; - public final fun copy (Ljava/lang/String;Lorg/partiql/types/StaticType;)Lorg/partiql/plan/Rel$Binding; - public static synthetic fun copy$default (Lorg/partiql/plan/Rel$Binding;Ljava/lang/String;Lorg/partiql/types/StaticType;ILjava/lang/Object;)Lorg/partiql/plan/Rel$Binding; + public final fun component2 ()Lorg/partiql/types/PType; + public final fun copy (Ljava/lang/String;Lorg/partiql/types/PType;)Lorg/partiql/plan/Rel$Binding; + public static synthetic fun copy$default (Lorg/partiql/plan/Rel$Binding;Ljava/lang/String;Lorg/partiql/types/PType;ILjava/lang/Object;)Lorg/partiql/plan/Rel$Binding; public fun equals (Ljava/lang/Object;)Z public fun getChildren ()Ljava/util/List; public fun hashCode ()I @@ -924,14 +924,14 @@ public final class org/partiql/plan/Rel$Type$Companion { public final class org/partiql/plan/Rex : org/partiql/plan/PlanNode { public static final field Companion Lorg/partiql/plan/Rex$Companion; public final field op Lorg/partiql/plan/Rex$Op; - public final field type Lorg/partiql/types/StaticType; - public fun (Lorg/partiql/types/StaticType;Lorg/partiql/plan/Rex$Op;)V + public final field type Lorg/partiql/types/PType; + public fun (Lorg/partiql/types/PType;Lorg/partiql/plan/Rex$Op;)V public fun accept (Lorg/partiql/plan/visitor/PlanVisitor;Ljava/lang/Object;)Ljava/lang/Object; public static final fun builder ()Lorg/partiql/plan/builder/RexBuilder; - public final fun component1 ()Lorg/partiql/types/StaticType; + public final fun component1 ()Lorg/partiql/types/PType; public final fun component2 ()Lorg/partiql/plan/Rex$Op; - public final fun copy (Lorg/partiql/types/StaticType;Lorg/partiql/plan/Rex$Op;)Lorg/partiql/plan/Rex; - public static synthetic fun copy$default (Lorg/partiql/plan/Rex;Lorg/partiql/types/StaticType;Lorg/partiql/plan/Rex$Op;ILjava/lang/Object;)Lorg/partiql/plan/Rex; + public final fun copy (Lorg/partiql/types/PType;Lorg/partiql/plan/Rex$Op;)Lorg/partiql/plan/Rex; + public static synthetic fun copy$default (Lorg/partiql/plan/Rex;Lorg/partiql/types/PType;Lorg/partiql/plan/Rex$Op;ILjava/lang/Object;)Lorg/partiql/plan/Rex; public fun equals (Ljava/lang/Object;)Z public fun getChildren ()Ljava/util/List; public fun hashCode ()I @@ -1500,15 +1500,15 @@ public final class org/partiql/plan/builder/CatalogItemFnBuilder { public final class org/partiql/plan/builder/CatalogItemValueBuilder { public fun ()V - public fun (Ljava/util/List;Lorg/partiql/types/StaticType;)V - public synthetic fun (Ljava/util/List;Lorg/partiql/types/StaticType;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/util/List;Lorg/partiql/types/PType;)V + public synthetic fun (Ljava/util/List;Lorg/partiql/types/PType;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun build ()Lorg/partiql/plan/Catalog$Item$Value; public final fun getPath ()Ljava/util/List; - public final fun getType ()Lorg/partiql/types/StaticType; + public final fun getType ()Lorg/partiql/types/PType; public final fun path (Ljava/util/List;)Lorg/partiql/plan/builder/CatalogItemValueBuilder; public final fun setPath (Ljava/util/List;)V - public final fun setType (Lorg/partiql/types/StaticType;)V - public final fun type (Lorg/partiql/types/StaticType;)Lorg/partiql/plan/builder/CatalogItemValueBuilder; + public final fun setType (Lorg/partiql/types/PType;)V + public final fun type (Lorg/partiql/types/PType;)Lorg/partiql/plan/builder/CatalogItemValueBuilder; } public final class org/partiql/plan/builder/IdentifierQualifiedBuilder { @@ -1558,8 +1558,8 @@ public final class org/partiql/plan/builder/PlanBuilder { public static synthetic fun catalogItemAgg$default (Lorg/partiql/plan/builder/PlanBuilder;Ljava/util/List;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Catalog$Item$Agg; public final fun catalogItemFn (Ljava/util/List;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Catalog$Item$Fn; public static synthetic fun catalogItemFn$default (Lorg/partiql/plan/builder/PlanBuilder;Ljava/util/List;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Catalog$Item$Fn; - public final fun catalogItemValue (Ljava/util/List;Lorg/partiql/types/StaticType;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Catalog$Item$Value; - public static synthetic fun catalogItemValue$default (Lorg/partiql/plan/builder/PlanBuilder;Ljava/util/List;Lorg/partiql/types/StaticType;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Catalog$Item$Value; + public final fun catalogItemValue (Ljava/util/List;Lorg/partiql/types/PType;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Catalog$Item$Value; + public static synthetic fun catalogItemValue$default (Lorg/partiql/plan/builder/PlanBuilder;Ljava/util/List;Lorg/partiql/types/PType;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Catalog$Item$Value; public final fun identifierQualified (Lorg/partiql/plan/Identifier$Symbol;Ljava/util/List;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Identifier$Qualified; public static synthetic fun identifierQualified$default (Lorg/partiql/plan/builder/PlanBuilder;Lorg/partiql/plan/Identifier$Symbol;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Identifier$Qualified; public final fun identifierSymbol (Ljava/lang/String;Lorg/partiql/plan/Identifier$CaseSensitivity;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Identifier$Symbol; @@ -1568,12 +1568,12 @@ public final class org/partiql/plan/builder/PlanBuilder { public static synthetic fun partiQLPlan$default (Lorg/partiql/plan/builder/PlanBuilder;Ljava/util/List;Lorg/partiql/plan/Statement;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/PartiQLPlan; public final fun ref (Ljava/lang/Integer;Ljava/lang/Integer;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Ref; public static synthetic fun ref$default (Lorg/partiql/plan/builder/PlanBuilder;Ljava/lang/Integer;Ljava/lang/Integer;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Ref; - public final fun refCast (Lorg/partiql/value/PartiQLValueType;Lorg/partiql/value/PartiQLValueType;Ljava/lang/Boolean;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Ref$Cast; - public static synthetic fun refCast$default (Lorg/partiql/plan/builder/PlanBuilder;Lorg/partiql/value/PartiQLValueType;Lorg/partiql/value/PartiQLValueType;Ljava/lang/Boolean;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Ref$Cast; + public final fun refCast (Lorg/partiql/types/PType;Lorg/partiql/types/PType;Ljava/lang/Boolean;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Ref$Cast; + public static synthetic fun refCast$default (Lorg/partiql/plan/builder/PlanBuilder;Lorg/partiql/types/PType;Lorg/partiql/types/PType;Ljava/lang/Boolean;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Ref$Cast; public final fun rel (Lorg/partiql/plan/Rel$Type;Lorg/partiql/plan/Rel$Op;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Rel; public static synthetic fun rel$default (Lorg/partiql/plan/builder/PlanBuilder;Lorg/partiql/plan/Rel$Type;Lorg/partiql/plan/Rel$Op;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Rel; - public final fun relBinding (Ljava/lang/String;Lorg/partiql/types/StaticType;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Rel$Binding; - public static synthetic fun relBinding$default (Lorg/partiql/plan/builder/PlanBuilder;Ljava/lang/String;Lorg/partiql/types/StaticType;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Rel$Binding; + public final fun relBinding (Ljava/lang/String;Lorg/partiql/types/PType;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Rel$Binding; + public static synthetic fun relBinding$default (Lorg/partiql/plan/builder/PlanBuilder;Ljava/lang/String;Lorg/partiql/types/PType;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Rel$Binding; public final fun relOpAggregate (Lorg/partiql/plan/Rel;Lorg/partiql/plan/Rel$Op$Aggregate$Strategy;Ljava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Rel$Op$Aggregate; public static synthetic fun relOpAggregate$default (Lorg/partiql/plan/builder/PlanBuilder;Lorg/partiql/plan/Rel;Lorg/partiql/plan/Rel$Op$Aggregate$Strategy;Ljava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Rel$Op$Aggregate; public final fun relOpAggregateCall (Lorg/partiql/plan/Ref;Lorg/partiql/plan/Rel$Op$Aggregate$Call$SetQuantifier;Ljava/util/List;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Rel$Op$Aggregate$Call; @@ -1626,8 +1626,8 @@ public final class org/partiql/plan/builder/PlanBuilder { public static synthetic fun relOpUnpivot$default (Lorg/partiql/plan/builder/PlanBuilder;Lorg/partiql/plan/Rex;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Rel$Op$Unpivot; public final fun relType (Ljava/util/List;Ljava/util/Set;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Rel$Type; public static synthetic fun relType$default (Lorg/partiql/plan/builder/PlanBuilder;Ljava/util/List;Ljava/util/Set;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Rel$Type; - public final fun rex (Lorg/partiql/types/StaticType;Lorg/partiql/plan/Rex$Op;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Rex; - public static synthetic fun rex$default (Lorg/partiql/plan/builder/PlanBuilder;Lorg/partiql/types/StaticType;Lorg/partiql/plan/Rex$Op;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Rex; + public final fun rex (Lorg/partiql/types/PType;Lorg/partiql/plan/Rex$Op;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Rex; + public static synthetic fun rex$default (Lorg/partiql/plan/builder/PlanBuilder;Lorg/partiql/types/PType;Lorg/partiql/plan/Rex$Op;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Rex; public final fun rexOpCallDynamic (Ljava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Rex$Op$Call$Dynamic; public static synthetic fun rexOpCallDynamic$default (Lorg/partiql/plan/builder/PlanBuilder;Ljava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/plan/Rex$Op$Call$Dynamic; public final fun rexOpCallDynamicCandidate (Lorg/partiql/plan/Ref;Ljava/util/List;Lkotlin/jvm/functions/Function1;)Lorg/partiql/plan/Rex$Op$Call$Dynamic$Candidate; @@ -1697,31 +1697,31 @@ public final class org/partiql/plan/builder/RefBuilder { public final class org/partiql/plan/builder/RefCastBuilder { public fun ()V - public fun (Lorg/partiql/value/PartiQLValueType;Lorg/partiql/value/PartiQLValueType;Ljava/lang/Boolean;)V - public synthetic fun (Lorg/partiql/value/PartiQLValueType;Lorg/partiql/value/PartiQLValueType;Ljava/lang/Boolean;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lorg/partiql/types/PType;Lorg/partiql/types/PType;Ljava/lang/Boolean;)V + public synthetic fun (Lorg/partiql/types/PType;Lorg/partiql/types/PType;Ljava/lang/Boolean;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun build ()Lorg/partiql/plan/Ref$Cast; - public final fun getInput ()Lorg/partiql/value/PartiQLValueType; - public final fun getTarget ()Lorg/partiql/value/PartiQLValueType; - public final fun input (Lorg/partiql/value/PartiQLValueType;)Lorg/partiql/plan/builder/RefCastBuilder; + public final fun getInput ()Lorg/partiql/types/PType; + public final fun getTarget ()Lorg/partiql/types/PType; + public final fun input (Lorg/partiql/types/PType;)Lorg/partiql/plan/builder/RefCastBuilder; public final fun isNullable ()Ljava/lang/Boolean; public final fun isNullable (Ljava/lang/Boolean;)Lorg/partiql/plan/builder/RefCastBuilder; - public final fun setInput (Lorg/partiql/value/PartiQLValueType;)V + public final fun setInput (Lorg/partiql/types/PType;)V public final fun setNullable (Ljava/lang/Boolean;)V - public final fun setTarget (Lorg/partiql/value/PartiQLValueType;)V - public final fun target (Lorg/partiql/value/PartiQLValueType;)Lorg/partiql/plan/builder/RefCastBuilder; + public final fun setTarget (Lorg/partiql/types/PType;)V + public final fun target (Lorg/partiql/types/PType;)Lorg/partiql/plan/builder/RefCastBuilder; } public final class org/partiql/plan/builder/RelBindingBuilder { public fun ()V - public fun (Ljava/lang/String;Lorg/partiql/types/StaticType;)V - public synthetic fun (Ljava/lang/String;Lorg/partiql/types/StaticType;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/lang/String;Lorg/partiql/types/PType;)V + public synthetic fun (Ljava/lang/String;Lorg/partiql/types/PType;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun build ()Lorg/partiql/plan/Rel$Binding; public final fun getName ()Ljava/lang/String; - public final fun getType ()Lorg/partiql/types/StaticType; + public final fun getType ()Lorg/partiql/types/PType; public final fun name (Ljava/lang/String;)Lorg/partiql/plan/builder/RelBindingBuilder; public final fun setName (Ljava/lang/String;)V - public final fun setType (Lorg/partiql/types/StaticType;)V - public final fun type (Lorg/partiql/types/StaticType;)Lorg/partiql/plan/builder/RelBindingBuilder; + public final fun setType (Lorg/partiql/types/PType;)V + public final fun type (Lorg/partiql/types/PType;)Lorg/partiql/plan/builder/RelBindingBuilder; } public final class org/partiql/plan/builder/RelBuilder { @@ -2061,15 +2061,15 @@ public final class org/partiql/plan/builder/RelTypeBuilder { public final class org/partiql/plan/builder/RexBuilder { public fun ()V - public fun (Lorg/partiql/types/StaticType;Lorg/partiql/plan/Rex$Op;)V - public synthetic fun (Lorg/partiql/types/StaticType;Lorg/partiql/plan/Rex$Op;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lorg/partiql/types/PType;Lorg/partiql/plan/Rex$Op;)V + public synthetic fun (Lorg/partiql/types/PType;Lorg/partiql/plan/Rex$Op;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun build ()Lorg/partiql/plan/Rex; public final fun getOp ()Lorg/partiql/plan/Rex$Op; - public final fun getType ()Lorg/partiql/types/StaticType; + public final fun getType ()Lorg/partiql/types/PType; public final fun op (Lorg/partiql/plan/Rex$Op;)Lorg/partiql/plan/builder/RexBuilder; public final fun setOp (Lorg/partiql/plan/Rex$Op;)V - public final fun setType (Lorg/partiql/types/StaticType;)V - public final fun type (Lorg/partiql/types/StaticType;)Lorg/partiql/plan/builder/RexBuilder; + public final fun setType (Lorg/partiql/types/PType;)V + public final fun type (Lorg/partiql/types/PType;)Lorg/partiql/plan/builder/RexBuilder; } public final class org/partiql/plan/builder/RexOpCallDynamicBuilder { diff --git a/partiql-plan/src/main/kotlin/org/partiql/plan/debug/PlanPrinter.kt b/partiql-plan/src/main/kotlin/org/partiql/plan/debug/PlanPrinter.kt index cf7084b85c..87c41440f2 100644 --- a/partiql-plan/src/main/kotlin/org/partiql/plan/debug/PlanPrinter.kt +++ b/partiql-plan/src/main/kotlin/org/partiql/plan/debug/PlanPrinter.kt @@ -3,9 +3,8 @@ package org.partiql.plan.debug import org.partiql.plan.PlanNode import org.partiql.plan.Rel import org.partiql.plan.Rex -import org.partiql.plan.debug.PlanPrinter.Visitor.primitives import org.partiql.plan.visitor.PlanBaseVisitor -import org.partiql.types.StaticType +import org.partiql.types.PType import kotlin.reflect.KVisibility import kotlin.reflect.full.isSubclassOf import kotlin.reflect.full.memberProperties @@ -34,7 +33,7 @@ object PlanPrinter { ) { sealed interface TypeInfo { class Rel(val type: org.partiql.plan.Rel.Type) : TypeInfo - class Rex(val type: StaticType) : TypeInfo + class Rex(val type: PType) : TypeInfo object Nil : TypeInfo } @@ -78,7 +77,7 @@ object PlanPrinter { "props" to this.type.props ) is Args.TypeInfo.Rex -> return listOf( - "static_type" to this.type + "static_type" to this.type.toString() ) is Args.TypeInfo.Nil -> emptyList() } diff --git a/partiql-plan/src/main/resources/partiql_plan.ion b/partiql-plan/src/main/resources/partiql_plan.ion index cf723fce0b..675d32de7d 100644 --- a/partiql-plan/src/main/resources/partiql_plan.ion +++ b/partiql-plan/src/main/resources/partiql_plan.ion @@ -1,8 +1,8 @@ imports::{ kotlin: [ partiql_value::'org.partiql.value.PartiQLValue', - partiql_value_type::'org.partiql.value.PartiQLValueType', - static_type::'org.partiql.types.StaticType', + partiql_value_type::'org.partiql.types.PType', + static_type::'org.partiql.types.PType', ], } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt index 18ca6447a6..aa7d22491d 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt @@ -2,6 +2,7 @@ package org.partiql.planner.internal import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.casts.CastTable +import org.partiql.planner.internal.casts.Coercions import org.partiql.planner.internal.ir.Ref import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex @@ -12,21 +13,19 @@ import org.partiql.planner.internal.ir.relOpAggregateCallResolved import org.partiql.planner.internal.ir.rex import org.partiql.planner.internal.ir.rexOpCallDynamic import org.partiql.planner.internal.ir.rexOpCallDynamicCandidate -import org.partiql.planner.internal.ir.rexOpCallStatic import org.partiql.planner.internal.ir.rexOpCastResolved import org.partiql.planner.internal.ir.rexOpVarGlobal +import org.partiql.planner.internal.typer.CompilerType import org.partiql.planner.internal.typer.TypeEnv.Companion.toPath -import org.partiql.planner.internal.typer.toRuntimeType -import org.partiql.planner.internal.typer.toStaticType import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath import org.partiql.spi.connector.ConnectorMetadata import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental -import org.partiql.types.StaticType +import org.partiql.types.PType +import org.partiql.types.PType.Kind import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType /** * [Env] is similar to the database type environment from the PartiQL Specification. This includes resolution of @@ -40,11 +39,6 @@ import org.partiql.value.PartiQLValueType */ internal class Env(private val session: PartiQLPlanner.Session) { - /** - * Cast table used for coercion and explicit cast resolution. - */ - private val casts = CastTable.partiql - /** * Current catalog [ConnectorMetadata]. Error if missing from the session. */ @@ -80,7 +74,7 @@ internal class Env(private val session: PartiQLPlanner.Session) { val ref = refObj( catalog = item.catalog, path = item.handle.path.steps, - type = item.handle.entity.getType(), + type = CompilerType(item.handle.entity.getPType()), ) // Rewrite as a path expression. val root = rex(ref.type, rexOpVarGlobal(ref)) @@ -109,7 +103,7 @@ internal class Env(private val session: PartiQLPlanner.Session) { } return ProblemGenerator.missingRex( rexOpCallDynamic(args, candidates), - ProblemGenerator.incompatibleTypesForOp(args.map { it.type }, path.normalized.joinToString(".")) + ProblemGenerator.incompatibleTypesForOp(path.normalized.joinToString("."), args.map { it.type }) ) } return when (match) { @@ -126,7 +120,7 @@ internal class Env(private val session: PartiQLPlanner.Session) { ) } // Rewrite as a dynamic call to be typed by PlanTyper - rex(StaticType.ANY, rexOpCallDynamic(args, candidates)) + Rex(CompilerType(PType.typeDynamic()), Rex.Op.Call.Dynamic(args, candidates)) } is FnMatch.Static -> { // Create an internal typed reference @@ -139,11 +133,11 @@ internal class Env(private val session: PartiQLPlanner.Session) { val coercions: List = args.mapIndexed { i, arg -> when (val cast = match.mapping[i]) { null -> arg - else -> rex(StaticType.ANY, rexOpCastResolved(cast, arg)) + else -> Rex(CompilerType(PType.typeDynamic()), Rex.Op.Cast.Resolved(cast, arg)) } } // Rewrite as a static call to be typed by PlanTyper - rex(StaticType.ANY, rexOpCallStatic(ref, coercions)) + Rex(CompilerType(PType.typeDynamic()), Rex.Op.Call.Static(ref, coercions)) } } } @@ -154,13 +148,7 @@ internal class Env(private val session: PartiQLPlanner.Session) { val path = BindingPath(listOf(BindingName(name, BindingCase.INSENSITIVE))) val item = aggs.lookup(path) ?: return null val candidates = item.handle.entity.getVariants() - var hadMissingArg = false - val parameters = args.mapIndexed { i, arg -> - if (!hadMissingArg && arg.type.isMissable()) { - hadMissingArg = true - } - arg.type.toRuntimeType() - } + val parameters = args.mapIndexed { i, arg -> arg.type } val match = match(candidates, parameters) ?: return null val agg = match.first val mapping = match.second @@ -170,16 +158,15 @@ internal class Env(private val session: PartiQLPlanner.Session) { val coercions: List = args.mapIndexed { i, arg -> when (val cast = mapping[i]) { null -> arg - else -> rex(cast.target.toStaticType(), rexOpCastResolved(cast, arg)) + else -> rex(cast.target, rexOpCastResolved(cast, arg)) } } return relOpAggregateCallResolved(ref, setQuantifier, coercions) } - @OptIn(PartiQLValueExperimental::class) - fun resolveCast(input: Rex, target: PartiQLValueType): Rex.Op.Cast.Resolved? { - val operand = input.type.toRuntimeType() - val cast = casts.get(operand, target) ?: return null + fun resolveCast(input: Rex, target: CompilerType): Rex.Op.Cast.Resolved? { + val operand = input.type + val cast = CastTable.partiql.get(operand, target) ?: return null return rexOpCastResolved(cast, input) } @@ -229,7 +216,7 @@ internal class Env(private val session: PartiQLPlanner.Session) { } @OptIn(FnExperimental::class, PartiQLValueExperimental::class) - private fun match(candidates: List, args: List): Pair>? { + private fun match(candidates: List, args: List): Pair>? { // 1. Check for an exact match for (candidate in candidates) { if (candidate.matches(args)) { @@ -255,11 +242,11 @@ internal class Env(private val session: PartiQLPlanner.Session) { * Check if this function accepts the exact input argument types. Assume same arity. */ @OptIn(FnExperimental::class, PartiQLValueExperimental::class) - private fun AggSignature.matches(args: List): Boolean { + private fun AggSignature.matches(args: List): Boolean { for (i in args.indices) { val a = args[i] val p = parameters[i] - if (p.type != PartiQLValueType.ANY && a != p.type) return false + if (p.type.kind != Kind.DYNAMIC && a != p.type) return false } return true } @@ -271,7 +258,7 @@ internal class Env(private val session: PartiQLPlanner.Session) { * @return */ @OptIn(FnExperimental::class, PartiQLValueExperimental::class) - private fun AggSignature.match(args: List): Pair>? { + private fun AggSignature.match(args: List): Pair>? { val mapping = arrayOfNulls(args.size) for (i in args.indices) { val arg = args[i] @@ -280,11 +267,9 @@ internal class Env(private val session: PartiQLPlanner.Session) { // 1. Exact match arg == p.type -> continue // 2. Match ANY, no coercion needed - p.type == PartiQLValueType.ANY -> continue - // 3. Match NULL argument - arg == PartiQLValueType.NULL -> continue - // 4. Check for a coercion - else -> when (val coercion = PathResolverAgg.casts.lookupCoercion(arg, p.type)) { + p.type.kind == Kind.DYNAMIC -> continue + // 3. Check for a coercion + else -> when (val coercion = Coercions.get(arg, p.type)) { null -> return null // short-circuit else -> mapping[i] = coercion } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnComparator.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnComparator.kt index 4ad673b473..f94f52c10c 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnComparator.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnComparator.kt @@ -3,36 +3,9 @@ package org.partiql.planner.internal import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter import org.partiql.spi.fn.FnSignature +import org.partiql.types.PType +import org.partiql.types.PType.Kind import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType -import org.partiql.value.PartiQLValueType.ANY -import org.partiql.value.PartiQLValueType.BAG -import org.partiql.value.PartiQLValueType.BINARY -import org.partiql.value.PartiQLValueType.BLOB -import org.partiql.value.PartiQLValueType.BOOL -import org.partiql.value.PartiQLValueType.BYTE -import org.partiql.value.PartiQLValueType.CHAR -import org.partiql.value.PartiQLValueType.CLOB -import org.partiql.value.PartiQLValueType.DATE -import org.partiql.value.PartiQLValueType.DECIMAL -import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY -import org.partiql.value.PartiQLValueType.FLOAT32 -import org.partiql.value.PartiQLValueType.FLOAT64 -import org.partiql.value.PartiQLValueType.INT -import org.partiql.value.PartiQLValueType.INT16 -import org.partiql.value.PartiQLValueType.INT32 -import org.partiql.value.PartiQLValueType.INT64 -import org.partiql.value.PartiQLValueType.INT8 -import org.partiql.value.PartiQLValueType.INTERVAL -import org.partiql.value.PartiQLValueType.LIST -import org.partiql.value.PartiQLValueType.MISSING -import org.partiql.value.PartiQLValueType.NULL -import org.partiql.value.PartiQLValueType.SEXP -import org.partiql.value.PartiQLValueType.STRING -import org.partiql.value.PartiQLValueType.STRUCT -import org.partiql.value.PartiQLValueType.SYMBOL -import org.partiql.value.PartiQLValueType.TIME -import org.partiql.value.PartiQLValueType.TIMESTAMP /** * Function precedence comparator; this is not formally specified. @@ -62,44 +35,46 @@ internal object FnComparator : Comparator { private fun FnParameter.compareTo(other: FnParameter): Int = comparePrecedence(this.type, other.type) - private fun comparePrecedence(t1: PartiQLValueType, t2: PartiQLValueType): Int { + private fun comparePrecedence(t1: PType, t2: PType): Int { if (t1 == t2) return 0 - val p1 = precedence[t1]!! - val p2 = precedence[t2]!! + val p1 = precedence[t1.kind]!! + val p2 = precedence[t2.kind]!! return p1 - p2 } - // This simply describes some precedence for ordering functions. - // This is not explicitly defined in the PartiQL Specification!! - // This does not imply the ability to CAST; this defines function resolution behavior. - private val precedence: Map = listOf( - NULL, - MISSING, - BOOL, - INT8, - INT16, - INT32, - INT64, - INT, - DECIMAL, - FLOAT32, - FLOAT64, - DECIMAL_ARBITRARY, // Arbitrary precision decimal has a higher precedence than FLOAT - CHAR, - STRING, - CLOB, - SYMBOL, - BINARY, - BYTE, - BLOB, - DATE, - TIME, - TIMESTAMP, - INTERVAL, - LIST, - SEXP, - BAG, - STRUCT, - ANY, + /** + * This simply describes some precedence for ordering functions. + * This is not explicitly defined in the PartiQL Specification!! + * This does not imply the ability to CAST; this defines function resolution behavior. + * This excludes [Kind.ROW] and [Kind.UNKNOWN]. + */ + private val precedence: Map = listOf( + Kind.BOOL, + Kind.TINYINT, + Kind.SMALLINT, + Kind.INT, + Kind.BIGINT, + Kind.INT_ARBITRARY, + Kind.DECIMAL, + Kind.REAL, + Kind.DOUBLE_PRECISION, + Kind.DECIMAL_ARBITRARY, // Arbitrary precision decimal has a higher precedence than FLOAT + Kind.CHAR, + Kind.VARCHAR, + Kind.SYMBOL, + Kind.STRING, + Kind.CLOB, + Kind.BLOB, + Kind.DATE, + Kind.TIME_WITHOUT_TZ, + Kind.TIME_WITH_TZ, + Kind.TIMESTAMP_WITHOUT_TZ, + Kind.TIMESTAMP_WITH_TZ, + Kind.LIST, + Kind.SEXP, + Kind.BAG, + Kind.ROW, + Kind.STRUCT, + Kind.DYNAMIC, ).mapIndexed { precedence, type -> type to precedence }.toMap() } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt index 79fae42136..ceea6f1c53 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt @@ -1,14 +1,11 @@ package org.partiql.planner.internal -import org.partiql.planner.internal.casts.CastTable +import org.partiql.planner.internal.casts.Coercions import org.partiql.planner.internal.ir.Ref -import org.partiql.planner.internal.typer.toRuntimeTypeOrNull +import org.partiql.planner.internal.typer.CompilerType import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnSignature -import org.partiql.types.StaticType -import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType -import org.partiql.value.PartiQLValueType.ANY +import org.partiql.types.PType.Kind /** * @@ -25,49 +22,62 @@ import org.partiql.value.PartiQLValueType.ANY * * Reference https://www.postgresql.org/docs/current/typeconv-func.html */ -@OptIn(FnExperimental::class, PartiQLValueExperimental::class) +@OptIn(FnExperimental::class) internal object FnResolver { - @JvmStatic - private val casts = CastTable.partiql - /** * Resolution of either a static or dynamic function. * + * TODO: How do we handle DYNAMIC? + * * @param variants * @param args * @return */ - fun resolve(variants: List, args: List): FnMatch? { - + fun resolve(variants: List, args: List): FnMatch? { val candidates = variants .filter { it.parameters.size == args.size } - .sortedWith(FnComparator) .ifEmpty { return null } - val argPermutations = buildArgumentPermutations(args).mapNotNull { argList -> - argList.map { arg -> - // Skip over if we cannot convert type to runtime type. - arg.toRuntimeTypeOrNull() ?: return@mapNotNull null + // 1. Look for exact match + for (candidate in candidates) { + if (candidate.matchesExactly(args)) { + return FnMatch.Static(candidate, arrayOfNulls(args.size)) } } - // Match candidates on all argument permutations - val matches = argPermutations.mapNotNull { match(candidates, it) } + // 2. Discard functions that cannot be matched (via implicit coercion or exact matches) + var matches = match(candidates, args).ifEmpty { return null } + if (matches.size == 1) { + return matches.first().match + } - // Order based on original candidate function ordering - val orderedUniqueMatches = matches.toSet().toList() - val orderedCandidates = candidates.flatMap { candidate -> - orderedUniqueMatches.filter { it.signature == candidate } + // 3. Run through all candidates and keep those with the most exact matches on input types. + matches = matchOn(matches) { it.numberOfExactInputTypes } + if (matches.size == 1) { + return matches.first().match } - // Static call iff only one match for every branch - val n = orderedCandidates.size - return when (n) { - 0 -> null - 1 -> orderedCandidates.first() - else -> FnMatch.Dynamic(orderedCandidates) + // TODO: Do we care about preferred types? This is a PostgreSQL concept. + // 4. Run through all candidates and keep those that accept preferred types (of the input data type's type category) at the most positions where type conversion will be required. + + // 5. If there are DYNAMIC nodes, return all candidates + var isDynamic = false + for (match in matches) { + val params = match.match.signature.parameters + for (index in params.indices) { + if ((args[index].kind == Kind.DYNAMIC) && params[index].type.kind != Kind.DYNAMIC) { + isDynamic = true + } + } } + if (isDynamic) { + return FnMatch.Dynamic(matches.map { it.match }) + } + + // 6. Find the highest precedence one. NOTE: This is a remnant of the previous implementation. Whether we want + // to keep this is up to us. + return matches.sortedWith(MatchResultComparator).first().match } /** @@ -77,33 +87,37 @@ internal object FnResolver { * @param args * @return */ - private fun match(candidates: List, args: List): FnMatch.Static? { - // 1. Check for an exact match + private fun match(candidates: List, args: List): List { + val matches = mutableSetOf() for (candidate in candidates) { - if (candidate.matches(args)) { - return FnMatch.Static(candidate, arrayOfNulls(args.size)) - } + val m = candidate.match(args) ?: continue + matches.add(m) } - // 2. Look for best match (for now, first match). + return matches.toList() + } + + private fun matchOn(candidates: List, toCompare: (MatchResult) -> Int): List { + var mostExactMatches = 0 + val matches = mutableSetOf() for (candidate in candidates) { - val m = candidate.match(args) - if (m != null) { - return m + when (toCompare(candidate).compareTo(mostExactMatches)) { + -1 -> continue + 0 -> matches.add(candidate) + 1 -> { + mostExactMatches = toCompare(candidate) + matches.clear() + matches.add(candidate) + } + else -> error("CompareTo should never return outside of range [-1, 1]") } - // if (match != null && m.exact < match.exact) { - // // already had a better match. - // continue - // } - // match = m } - // 3. No match, return null - return null + return matches.toList() } /** * Check if this function accepts the exact input argument types. Assume same arity. */ - private fun FnSignature.matches(args: List): Boolean { + private fun FnSignature.matchesExactly(args: List): Boolean { for (i in args.indices) { val a = args[i] val p = parameters[i] @@ -118,47 +132,42 @@ internal object FnResolver { * @param args * @return */ - private fun FnSignature.match(args: List): FnMatch.Static? { + private fun FnSignature.match(args: List): MatchResult? { val mapping = arrayOfNulls(args.size) + var exactInputTypes: Int = 0 for (i in args.indices) { val arg = args[i] val p = parameters[i] when { // 1. Exact match - arg == p.type -> continue + arg == p.type -> { + exactInputTypes++ + continue + } // 2. Match ANY, no coercion needed - p.type == ANY -> continue + // TODO: Rewrite args in this scenario + arg.kind == Kind.UNKNOWN || p.type.kind == Kind.DYNAMIC || arg.kind == Kind.DYNAMIC -> continue // 3. Check for a coercion - else -> when (val coercion = casts.lookupCoercion(arg, p.type)) { + else -> when (val coercion = Coercions.get(arg, p.type)) { null -> return null // short-circuit else -> mapping[i] = coercion } } } - return FnMatch.Static(this, mapping) + return MatchResult( + FnMatch.Static(this, mapping), + exactInputTypes, + ) } - private fun buildArgumentPermutations(args: List): List> { - val flattenedArgs = args.map { it.flatten().allTypes } - return buildArgumentPermutations(flattenedArgs, accumulator = emptyList()) - } + private class MatchResult( + val match: FnMatch.Static, + val numberOfExactInputTypes: Int, + ) - private fun buildArgumentPermutations( - args: List>, - accumulator: List, - ): List> { - if (args.isEmpty()) { - return listOf(accumulator) - } - val first = args.first() - val rest = when (args.size) { - 1 -> emptyList() - else -> args.subList(1, args.size) - } - return buildList { - first.forEach { argSubType -> - addAll(buildArgumentPermutations(rest, accumulator + listOf(argSubType))) - } + private object MatchResultComparator : Comparator { + override fun compare(o1: MatchResult, o2: MatchResult): Int { + return FnComparator.reversed().compare(o1.match.signature, o2.match.signature) } } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt index c976626784..27d3e3ddec 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PathResolverAgg.kt @@ -1,7 +1,6 @@ package org.partiql.planner.internal import org.partiql.planner.PartiQLPlanner -import org.partiql.planner.internal.casts.CastTable import org.partiql.spi.BindingPath import org.partiql.spi.connector.ConnectorAgg import org.partiql.spi.connector.ConnectorHandle @@ -14,11 +13,6 @@ internal class PathResolverAgg( session: PartiQLPlanner.Session, ) : PathResolver(catalog, session) { - companion object { - @JvmStatic - public val casts = CastTable.partiql - } - override fun get(metadata: ConnectorMetadata, path: BindingPath): ConnectorHandle.Agg? { return metadata.getAggregation(path) } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt index c0724d2388..252102c2b3 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt @@ -4,6 +4,7 @@ import org.partiql.errors.ProblemDetails import org.partiql.errors.ProblemSeverity import org.partiql.plan.Identifier import org.partiql.planner.internal.utils.PlanUtils +import org.partiql.types.PType import org.partiql.types.StaticType /** @@ -151,15 +152,15 @@ internal open class PlanningProblemDetails( ) data class UnexpectedType( - val actualType: StaticType, - val expectedTypes: Set, + val actualType: PType, + val expectedTypes: Set, ) : PlanningProblemDetails(ProblemSeverity.ERROR, { - "Unexpected type $actualType, expected one of ${expectedTypes.joinToString()}" + "Unexpected type $actualType, expected one of ${expectedTypes.joinToString { it.toString() }}" }) data class UnknownFunction( val identifier: String, - val args: List, + val args: List, ) : PlanningProblemDetails(ProblemSeverity.ERROR, { val types = args.joinToString { "<${it.toString().lowercase()}>" } "Unknown function `$identifier($types)" @@ -194,12 +195,12 @@ internal open class PlanningProblemDetails( ) data class IncompatibleTypesForOp( - val actualTypes: List, + val actualTypes: List, val operator: String, ) : PlanningProblemDetails( severity = ProblemSeverity.ERROR, - messageFormatter = { "${actualTypes.joinToString()} is/are incompatible data types for the '$operator' operator." } + messageFormatter = { "${actualTypes.joinToString { it.toString() }} is/are incompatible data types for the '$operator' operator." } ) data class UnresolvedExcludeExprRoot(val root: String) : diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ProblemGenerator.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ProblemGenerator.kt index 7c1f55ffa9..10d2f93d27 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ProblemGenerator.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ProblemGenerator.kt @@ -10,6 +10,8 @@ import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.rex import org.partiql.planner.internal.ir.rexOpErr import org.partiql.planner.internal.ir.rexOpMissing +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.types.PType import org.partiql.types.StaticType import org.partiql.planner.internal.ir.Identifier as InternalIdentifier @@ -39,17 +41,17 @@ internal object ProblemGenerator { ) } - fun missingRex(causes: List, problem: Problem): Rex = - rex(StaticType.ANY, rexOpMissing(problem, causes)) + fun missingRex(causes: List, problem: Problem, type: CompilerType = CompilerType(PType.typeDynamic(), isMissingValue = true)): Rex = + rex(type, rexOpMissing(problem, causes)) - fun missingRex(causes: Rex.Op, problem: Problem): Rex = - rex(StaticType.ANY, rexOpMissing(problem, listOf(causes))) + fun missingRex(causes: Rex.Op, problem: Problem, type: CompilerType = CompilerType(PType.typeDynamic(), isMissingValue = true)): Rex = + rex(type, rexOpMissing(problem, listOf(causes))) fun errorRex(causes: List, problem: Problem): Rex = - rex(StaticType.ANY, rexOpErr(problem, causes)) + rex(CompilerType(PType.typeDynamic(), isMissingValue = true), rexOpErr(problem, causes)) fun errorRex(trace: Rex.Op, problem: Problem): Rex = - rex(StaticType.ANY, rexOpErr(problem, listOf(trace))) + rex(CompilerType(PType.typeDynamic(), isMissingValue = true), rexOpErr(problem, listOf(trace))) private fun InternalIdentifier.debug(): String = when (this) { is InternalIdentifier.Qualified -> (listOf(root.debug()) + steps.map { it.debug() }).joinToString(".") @@ -60,15 +62,32 @@ internal object ProblemGenerator { } fun undefinedFunction(identifier: InternalIdentifier, args: List, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.UnknownFunction(identifier.debug(), args.map { PType.fromStaticType(it) })) + + fun undefinedFunction( + args: List, + identifier: org.partiql.planner.internal.ir.Identifier, + location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION + ): Problem = problem(location, PlanningProblemDetails.UnknownFunction(identifier.debug(), args)) fun undefinedFunction(identifier: String, args: List, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.UnknownFunction(identifier, args.map { PType.fromStaticType(it) })) + + fun undefinedFunction(args: List, identifier: String, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = problem(location, PlanningProblemDetails.UnknownFunction(identifier, args)) fun undefinedVariable(id: Identifier, inScopeVariables: Set = emptySet(), location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = problem(location, PlanningProblemDetails.UndefinedVariable(id, inScopeVariables)) fun incompatibleTypesForOp(actualTypes: List, operator: String, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.IncompatibleTypesForOp(actualTypes.map { PType.fromStaticType(it) }, operator)) + + fun incompatibleTypesForOp( + operator: String, + actualTypes: List, + location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION + ): Problem = problem(location, PlanningProblemDetails.IncompatibleTypesForOp(actualTypes, operator)) fun unresolvedExcludedExprRoot(root: InternalIdentifier, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = @@ -81,6 +100,9 @@ internal object ProblemGenerator { problem(location, PlanningProblemDetails.ExpressionAlwaysReturnsMissing(reason)) fun unexpectedType(actualType: StaticType, expectedTypes: Set, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.UnexpectedType(PType.fromStaticType(actualType), expectedTypes.map { PType.fromStaticType(it) }.toSet())) + + fun unexpectedType(actualType: PType, expectedTypes: Set, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = problem(location, PlanningProblemDetails.UnexpectedType(actualType, expectedTypes)) fun compilerError(message: String, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/casts/CastTable.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/casts/CastTable.kt index c3d7e9b15f..581c460502 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/casts/CastTable.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/casts/CastTable.kt @@ -2,36 +2,11 @@ package org.partiql.planner.internal.casts import org.partiql.planner.internal.ir.Ref import org.partiql.planner.internal.ir.Ref.Cast +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.types.PType +import org.partiql.types.PType.Kind import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType -import org.partiql.value.PartiQLValueType.ANY -import org.partiql.value.PartiQLValueType.BAG -import org.partiql.value.PartiQLValueType.BINARY -import org.partiql.value.PartiQLValueType.BLOB -import org.partiql.value.PartiQLValueType.BOOL -import org.partiql.value.PartiQLValueType.BYTE -import org.partiql.value.PartiQLValueType.CHAR -import org.partiql.value.PartiQLValueType.CLOB -import org.partiql.value.PartiQLValueType.DATE -import org.partiql.value.PartiQLValueType.DECIMAL -import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY -import org.partiql.value.PartiQLValueType.FLOAT32 -import org.partiql.value.PartiQLValueType.FLOAT64 -import org.partiql.value.PartiQLValueType.INT -import org.partiql.value.PartiQLValueType.INT16 -import org.partiql.value.PartiQLValueType.INT32 -import org.partiql.value.PartiQLValueType.INT64 -import org.partiql.value.PartiQLValueType.INT8 -import org.partiql.value.PartiQLValueType.INTERVAL -import org.partiql.value.PartiQLValueType.LIST -import org.partiql.value.PartiQLValueType.MISSING -import org.partiql.value.PartiQLValueType.NULL -import org.partiql.value.PartiQLValueType.SEXP -import org.partiql.value.PartiQLValueType.STRING -import org.partiql.value.PartiQLValueType.STRUCT -import org.partiql.value.PartiQLValueType.SYMBOL -import org.partiql.value.PartiQLValueType.TIME -import org.partiql.value.PartiQLValueType.TIMESTAMP /** * A place to model type relationships (for now this is to answer CAST inquiries). @@ -41,47 +16,38 @@ import org.partiql.value.PartiQLValueType.TIMESTAMP */ @OptIn(PartiQLValueExperimental::class) internal class CastTable private constructor( - private val types: Array, - private val graph: Array>, + private val types: Array, + private val graph: Array>, ) { - private fun relationships(): Sequence = sequence { - for (t1 in types) { - for (t2 in types) { - val r = graph[t1][t2] - if (r != null) { - yield(r) - } - } + fun get(operand: PType, target: PType): Cast? { + val i = operand.kind.ordinal + val j = target.kind.ordinal + return when (graph[i][j]) { + Status.YES, Status.MODIFIED -> Cast(CompilerType(operand), CompilerType(target), Ref.Cast.Safety.COERCION, isNullable = true) + Status.NO -> null } } - fun get(operand: PartiQLValueType, target: PartiQLValueType): Cast? { - val i = operand.ordinal - val j = target.ordinal - return graph[i][j] - } + private operator fun Array.get(t: PartiQLValueType): T = get(t.ordinal) /** - * Returns the CAST function if exists, else null. + * This represents the Y, M, and N in the table listed in SQL:1999 Section 6.22. */ - fun lookupCoercion(operand: PartiQLValueType, target: PartiQLValueType): Cast? { - val i = operand.ordinal - val j = target.ordinal - val cast = graph[i][j] ?: return null - return if (cast.safety == Cast.Safety.COERCION) cast else null + internal enum class Status { + YES, + NO, + MODIFIED } - private operator fun Array.get(t: PartiQLValueType): T = get(t.ordinal) - companion object { - private val N = PartiQLValueType.values().size + private val N = Kind.values().size private operator fun Array.set(t: PartiQLValueType, value: T): Unit = this.set(t.ordinal, value) - private fun PartiQLValueType.relationships(block: RelationshipBuilder.() -> Unit): Array { - return with(RelationshipBuilder(this)) { + private fun relationships(block: RelationshipBuilder.() -> Unit): Array { + return with(RelationshipBuilder()) { block() build() } @@ -94,236 +60,247 @@ internal class CastTable private constructor( */ @JvmStatic val partiql: CastTable = run { - val types = PartiQLValueType.values() - val graph = arrayOfNulls>(N) + val types = Kind.values() + val graph = arrayOfNulls>(N) for (type in types) { // initialize all with empty relationships - graph[type] = arrayOfNulls(N) - } - graph[ANY] = ANY.relationships { - coercion(ANY) - PartiQLValueType.values().filterNot { it == ANY }.forEach { - unsafe(it) - } + graph[type.ordinal] = Array(N) { Status.NO } } - graph[NULL] = NULL.relationships { - PartiQLValueType.values().filterNot { it == ANY || it == MISSING }.forEach { - coercion(it, isNullable = true) + graph[Kind.DYNAMIC.ordinal] = relationships { + cast(Kind.DYNAMIC) + Kind.values().filterNot { it == Kind.DYNAMIC }.forEach { + cast(it) } } - graph[MISSING] = MISSING.relationships { - coercion(MISSING) + graph[Kind.BOOL.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.DECIMAL) + cast(Kind.DECIMAL_ARBITRARY) + cast(Kind.REAL) + cast(Kind.DOUBLE_PRECISION) + cast(Kind.CHAR) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[BOOL] = BOOL.relationships { - coercion(BOOL) - explicit(INT8) - explicit(INT16) - explicit(INT32) - explicit(INT64) - explicit(INT) - explicit(DECIMAL) - explicit(DECIMAL_ARBITRARY) - explicit(FLOAT32) - explicit(FLOAT64) - explicit(CHAR) - explicit(STRING) - explicit(SYMBOL) + graph[Kind.TINYINT.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.DECIMAL) + cast(Kind.DECIMAL_ARBITRARY) + cast(Kind.REAL) + cast(Kind.DOUBLE_PRECISION) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[INT8] = INT8.relationships { - explicit(BOOL) - coercion(INT8) - coercion(INT16) - coercion(INT32) - coercion(INT64) - coercion(INT) - explicit(DECIMAL) - coercion(DECIMAL_ARBITRARY) - coercion(FLOAT32) - coercion(FLOAT64) - explicit(STRING) - explicit(SYMBOL) + graph[Kind.SMALLINT.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.DECIMAL) + cast(Kind.DECIMAL_ARBITRARY) + cast(Kind.REAL) + cast(Kind.DOUBLE_PRECISION) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[INT16] = INT16.relationships { - explicit(BOOL) - unsafe(INT8) - coercion(INT16) - coercion(INT32) - coercion(INT64) - coercion(INT) - explicit(DECIMAL) - coercion(DECIMAL_ARBITRARY) - coercion(FLOAT32) - coercion(FLOAT64) - explicit(STRING) - explicit(SYMBOL) + graph[Kind.INT.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.DECIMAL) + cast(Kind.DECIMAL_ARBITRARY) + cast(Kind.REAL) + cast(Kind.DOUBLE_PRECISION) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[INT32] = INT32.relationships { - explicit(BOOL) - unsafe(INT8) - unsafe(INT16) - coercion(INT32) - coercion(INT64) - coercion(INT) - explicit(DECIMAL) - coercion(DECIMAL_ARBITRARY) - coercion(FLOAT32) - coercion(FLOAT64) - explicit(STRING) - explicit(SYMBOL) + graph[Kind.BIGINT.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.DECIMAL) + cast(Kind.DECIMAL_ARBITRARY) + cast(Kind.REAL) + cast(Kind.DOUBLE_PRECISION) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[INT64] = INT64.relationships { - explicit(BOOL) - unsafe(INT8) - unsafe(INT16) - unsafe(INT32) - coercion(INT64) - coercion(INT) - explicit(DECIMAL) - coercion(DECIMAL_ARBITRARY) - coercion(FLOAT32) - coercion(FLOAT64) - explicit(STRING) - explicit(SYMBOL) + graph[Kind.INT_ARBITRARY.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.DECIMAL) + cast(Kind.DECIMAL_ARBITRARY) + cast(Kind.REAL) + cast(Kind.DOUBLE_PRECISION) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[INT] = INT.relationships { - explicit(BOOL) - unsafe(INT8) - unsafe(INT16) - unsafe(INT32) - unsafe(INT64) - coercion(INT) - explicit(DECIMAL) - coercion(DECIMAL_ARBITRARY) - coercion(FLOAT32) - coercion(FLOAT64) - explicit(STRING) - explicit(SYMBOL) + graph[Kind.DECIMAL.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.DECIMAL) + cast(Kind.DECIMAL_ARBITRARY) + cast(Kind.REAL) + cast(Kind.DOUBLE_PRECISION) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[DECIMAL] = DECIMAL.relationships { - explicit(BOOL) - unsafe(INT8) - unsafe(INT16) - unsafe(INT32) - unsafe(INT64) - unsafe(INT) - coercion(DECIMAL) - coercion(DECIMAL_ARBITRARY) - explicit(FLOAT32) - explicit(FLOAT64) - explicit(STRING) - explicit(SYMBOL) + graph[Kind.DECIMAL_ARBITRARY.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.DECIMAL) + cast(Kind.DECIMAL_ARBITRARY) + cast(Kind.REAL) + cast(Kind.DOUBLE_PRECISION) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[DECIMAL_ARBITRARY] = DECIMAL_ARBITRARY.relationships { - explicit(BOOL) - unsafe(INT8) - unsafe(INT16) - unsafe(INT32) - unsafe(INT64) - unsafe(INT) - coercion(DECIMAL) - coercion(DECIMAL_ARBITRARY) - explicit(FLOAT32) - explicit(FLOAT64) - explicit(STRING) - explicit(SYMBOL) + graph[Kind.REAL.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.DECIMAL) + cast(Kind.DECIMAL_ARBITRARY) + cast(Kind.REAL) + cast(Kind.DOUBLE_PRECISION) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[FLOAT32] = FLOAT32.relationships { - explicit(BOOL) - unsafe(INT8) - unsafe(INT16) - unsafe(INT32) - unsafe(INT64) - unsafe(INT) - unsafe(DECIMAL) - coercion(DECIMAL_ARBITRARY) - coercion(FLOAT32) - coercion(FLOAT64) - explicit(STRING) - explicit(SYMBOL) + graph[Kind.DOUBLE_PRECISION.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.DECIMAL) + cast(Kind.DECIMAL_ARBITRARY) + cast(Kind.REAL) + cast(Kind.DOUBLE_PRECISION) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[FLOAT64] = FLOAT64.relationships { - explicit(BOOL) - unsafe(INT8) - unsafe(INT16) - unsafe(INT32) - unsafe(INT64) - unsafe(INT) - unsafe(DECIMAL) - coercion(DECIMAL_ARBITRARY) - unsafe(FLOAT32) - coercion(FLOAT64) - explicit(STRING) - explicit(SYMBOL) + graph[Kind.CHAR.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.CHAR) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) } - graph[CHAR] = CHAR.relationships { - explicit(BOOL) - coercion(CHAR) - coercion(STRING) - coercion(SYMBOL) + graph[Kind.STRING.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) + cast(Kind.CLOB) } - graph[STRING] = STRING.relationships { - explicit(BOOL) - unsafe(INT8) - unsafe(INT16) - unsafe(INT32) - unsafe(INT64) - unsafe(INT) - coercion(STRING) - explicit(SYMBOL) - coercion(CLOB) + graph[Kind.VARCHAR.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.TINYINT) + cast(Kind.SMALLINT) + cast(Kind.INT) + cast(Kind.BIGINT) + cast(Kind.INT_ARBITRARY) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) + cast(Kind.CLOB) } - graph[SYMBOL] = SYMBOL.relationships { - explicit(BOOL) - coercion(STRING) - coercion(SYMBOL) - coercion(CLOB) + graph[Kind.SYMBOL.ordinal] = relationships { + cast(Kind.BOOL) + cast(Kind.STRING) + cast(Kind.VARCHAR) + cast(Kind.SYMBOL) + cast(Kind.CLOB) } - graph[CLOB] = CLOB.relationships { - coercion(CLOB) + graph[Kind.CLOB.ordinal] = relationships { + cast(Kind.CLOB) } - graph[BINARY] = arrayOfNulls(N) - graph[BYTE] = arrayOfNulls(N) - graph[BLOB] = arrayOfNulls(N) - graph[DATE] = arrayOfNulls(N) - graph[TIME] = arrayOfNulls(N) - graph[TIMESTAMP] = arrayOfNulls(N) - graph[INTERVAL] = arrayOfNulls(N) - graph[BAG] = BAG.relationships { - coercion(BAG) + graph[Kind.BLOB.ordinal] = Array(N) { Status.NO } + graph[Kind.DATE.ordinal] = Array(N) { Status.NO } + graph[Kind.TIME_WITH_TZ.ordinal] = Array(N) { Status.NO } + graph[Kind.TIME_WITHOUT_TZ.ordinal] = Array(N) { Status.NO } + graph[Kind.TIMESTAMP_WITH_TZ.ordinal] = Array(N) { Status.NO } + graph[Kind.TIMESTAMP_WITHOUT_TZ.ordinal] = Array(N) { Status.NO } + graph[Kind.BAG.ordinal] = relationships { + cast(Kind.BAG) } - graph[LIST] = LIST.relationships { - coercion(BAG) - coercion(SEXP) - coercion(LIST) + graph[Kind.LIST.ordinal] = relationships { + cast(Kind.BAG) + cast(Kind.SEXP) + cast(Kind.LIST) } - graph[SEXP] = SEXP.relationships { - coercion(BAG) - coercion(SEXP) - coercion(LIST) + graph[Kind.SEXP.ordinal] = relationships { + cast(Kind.BAG) + cast(Kind.SEXP) + cast(Kind.LIST) } - graph[STRUCT] = STRUCT.relationships { - coercion(STRUCT) + graph[Kind.STRUCT.ordinal] = relationships { + cast(Kind.STRUCT) } CastTable(types, graph.requireNoNulls()) } } - private class RelationshipBuilder(val operand: PartiQLValueType) { + /** + * TODO: Add another method to support [Status.MODIFIED]. See the cast table at SQL:1999 Section 6.22 + */ + private class RelationshipBuilder { - private val relationships = arrayOfNulls(N) + private val relationships = Array(N) { Status.NO } fun build() = relationships - fun coercion(target: PartiQLValueType, isNullable: Boolean = false) { - relationships[target] = Cast(operand, target, Ref.Cast.Safety.COERCION, isNullable) - } - - fun explicit(target: PartiQLValueType, isNullable: Boolean = false) { - relationships[target] = Cast(operand, target, Ref.Cast.Safety.EXPLICIT, isNullable) - } - - fun unsafe(target: PartiQLValueType, isNullable: Boolean = false) { - relationships[target] = Cast(operand, target, Ref.Cast.Safety.UNSAFE, isNullable) + fun cast(target: Kind) { + relationships[target.ordinal] = Status.YES } } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/casts/Coercions.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/casts/Coercions.kt new file mode 100644 index 0000000000..5c8a944e02 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/casts/Coercions.kt @@ -0,0 +1,220 @@ +package org.partiql.planner.internal.casts + +import org.partiql.planner.internal.ir.Ref +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.types.Field +import org.partiql.types.PType +import org.partiql.types.PType.Kind + +/** + * Important SQL Definitions: + * - assignable: The characteristic of a data type that permits a value of that data type to be + * assigned to a site of a specified data type. + */ +internal object Coercions { + + fun get(input: PType, target: PType): Ref.Cast? { + return getCoercion(input, target) + } + + /** + * Remaining coercions from SQL:1999: + * - Values corresponding to the binary data type are mutually assignable. + * - Values corresponding to the data types BIT and BIT VARYING are always mutually comparable and + * are mutually assignable. + * - Values of type interval are mutually assignable only if the source and target of the assignment are + * both year-month intervals or if they are both day-time intervals. + * - Values corresponding to user-defined types are discussed in Subclause 4.8.4, ‘‘User-defined type + * comparison and assignment’’. + */ + private fun getCoercion(input: PType, target: PType): Ref.Cast? { + return when { + isAssignable(input, target) -> coercion(input, target) + else -> null + } + } + + private val TYPES_NUMBER = setOf( + Kind.TINYINT, + Kind.SMALLINT, + Kind.INT, + Kind.BIGINT, + Kind.INT_ARBITRARY, + Kind.REAL, + Kind.DOUBLE_PRECISION, + Kind.DECIMAL, + Kind.DECIMAL_ARBITRARY + ) + + private val TYPES_TEXT = setOf( + Kind.CHAR, + Kind.VARCHAR, + Kind.STRING, + Kind.CLOB, + Kind.SYMBOL + ) + + private val TYPES_COLLECTION = setOf( + Kind.LIST, + Kind.SEXP, + Kind.BAG + ) + + private fun isAssignable(input: PType, target: PType): Boolean { + return areAssignableNumberTypes(input, target) || + areAssignableTextTypes(input, target) || + areAssignableBooleanTypes(input, target) || + areAssignableDateTimeTypes(input, target) || + areAssignableCollectionTypes(input, target) || + areAssignableStructuralTypes(input, target) || + areAssignableDynamicTypes(target) + } + + /** + * NOT specified by SQL:1999. We assume that we can coerce a collection of one type to another if the subtype + * of each collection is assignable. + */ + private fun areAssignableCollectionTypes(input: PType, target: PType): Boolean { + return input.kind in TYPES_COLLECTION && target.kind in TYPES_COLLECTION && isAssignable(input.typeParameter, target.typeParameter) + } + + /** + * NOT specified by SQL:1999. We assume that we can statically coerce anything to DYNAMIC. However, note that + * CAST( AS DYNAMIC) is NEVER inserted. We check for the use of DYNAMIC at function resolution. This is merely + * for the [PType.getTypeParameter] and [PType.getFields] + */ + private fun areAssignableDynamicTypes(target: PType): Boolean { + return target.kind == Kind.DYNAMIC + } + + /** + * NOT completely specified by SQL:1999. + * + * From SQL:1999: + * ``` + * Values corresponding to row types are mutually assignable if and only if both have the same degree + * and every field in one row type is mutually assignable to the field in the same ordinal position of + * the other row type. Values corresponding to row types are mutually comparable if and only if both + * have the same degree and every field in one row type is mutually comparable to the field in the + * same ordinal position of the other row type. + * ``` + */ + private fun areAssignableStructuralTypes(input: PType, target: PType): Boolean { + return when { + input.kind == Kind.ROW && target.kind == Kind.ROW -> fieldsAreAssignable(input.fields!!.toList(), target.fields!!.toList()) + input.kind == Kind.STRUCT && target.kind == Kind.ROW -> when (input.fields) { + null -> true + else -> namedFieldsAreAssignableUnordered(input.fields!!.toList(), target.fields!!.toList()) + } + input.kind == Kind.ROW && target.kind == Kind.STRUCT -> when (target.fields) { + null -> true + else -> namedFieldsAreAssignableUnordered(input.fields!!.toList(), target.fields!!.toList()) + } + input.kind == Kind.STRUCT && target.kind == Kind.STRUCT -> when { + input.fields == null || target.fields == null -> true + else -> fieldsAreAssignable(input.fields!!.toList(), target.fields!!.toList()) + } + else -> false + } + } + + private fun fieldsAreAssignable(input: List, target: List): Boolean { + if (input.size != target.size) { return false } + val iIter = input.iterator() + val tIter = target.iterator() + while (iIter.hasNext()) { + val iField = iIter.next() + val tField = tIter.next() + if (!isAssignable(iField.type, tField.type)) { + return false + } + } + return true + } + + /** + * This is a PartiQL extension. We assume that structs/rows with the same field names may be assignable + * if all names match AND types are assignable. + */ + private fun namedFieldsAreAssignableUnordered(input: List, target: List): Boolean { + if (input.size != target.size) { return false } + val inputSorted = input.sortedBy { it.name } + val targetSorted = target.sortedBy { it.name } + val iIter = inputSorted.iterator() + val tIter = targetSorted.iterator() + while (iIter.hasNext()) { + val iField = iIter.next() + val tField = tIter.next() + if (iField.name != tField.name) { + return false + } + if (!isAssignable(iField.type, tField.type)) { + return false + } + } + return true + } + + /** + * From SQL:1999: + * ``` + * Values of the data types NUMERIC, DECIMAL, INTEGER, SMALLINT, FLOAT, REAL, and + * DOUBLE PRECISION are numbers and are all mutually comparable and mutually assignable. + * ``` + */ + private fun areAssignableNumberTypes(input: PType, target: PType): Boolean { + return input.kind in TYPES_NUMBER && target.kind in TYPES_NUMBER + } + + /** + * From SQL:1999: + * ``` + * Values corresponding to the data type boolean are always mutually comparable and are mutually + * assignable. + * ``` + */ + private fun areAssignableBooleanTypes(input: PType, target: PType): Boolean { + return input.kind == Kind.BOOL && target.kind == Kind.BOOL + } + + /** + * From SQL:1999: + * ``` + * Values corresponding to the data types CHARACTER, CHARACTER VARYING, and CHARACTER + * LARGE OBJECT are mutually assignable if and only if they are taken from the same character + * repertoire. (For this implementation, we shall assume that all text types share the same + * character repertoire.) + * ``` + */ + private fun areAssignableTextTypes(input: PType, target: PType): Boolean { + return input.kind in TYPES_TEXT && target.kind in TYPES_TEXT + } + + /** + * From SQL:1999: + * ``` + * Values of type datetime are mutually assignable only if the source and target of the assignment are + * both of type DATE, or both of type TIME (regardless whether WITH TIME ZONE or WITHOUT + * TIME ZONE is specified or implicit), or both of type TIMESTAMP (regardless whether WITH TIME + * ZONE or WITHOUT TIME ZONE is specified or implicit) + * ``` + */ + private fun areAssignableDateTimeTypes(input: PType, target: PType): Boolean { + val i = input.kind + val t = target.kind + return when { + i == Kind.DATE && t == Kind.DATE -> true + (i == Kind.TIME_WITH_TZ || i == Kind.TIME_WITHOUT_TZ) && (t == Kind.TIME_WITH_TZ || t == Kind.TIME_WITHOUT_TZ) -> true + (i == Kind.TIMESTAMP_WITH_TZ || i == Kind.TIMESTAMP_WITHOUT_TZ) && (t == Kind.TIMESTAMP_WITH_TZ || t == Kind.TIMESTAMP_WITHOUT_TZ) -> true + else -> false + } + } + + private fun explicit(input: PType, target: PType): Ref.Cast { + return Ref.Cast(CompilerType(input), CompilerType(target), Ref.Cast.Safety.EXPLICIT, isNullable = true) + } + + private fun coercion(input: PType, target: PType): Ref.Cast { + return Ref.Cast(CompilerType(input), CompilerType(target), Ref.Cast.Safety.EXPLICIT, isNullable = true) + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt index 7c4b5230e0..9860a132f4 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt @@ -70,13 +70,12 @@ import org.partiql.planner.internal.ir.builder.RexOpVarLocalBuilder import org.partiql.planner.internal.ir.builder.RexOpVarUnresolvedBuilder import org.partiql.planner.internal.ir.builder.StatementQueryBuilder import org.partiql.planner.internal.ir.visitor.PlanVisitor +import org.partiql.planner.internal.typer.CompilerType import org.partiql.spi.fn.AggSignature import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnSignature -import org.partiql.types.StaticType import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType import kotlin.Boolean import kotlin.Char import kotlin.Int @@ -124,7 +123,7 @@ internal sealed class Ref : PlanNode() { internal data class Obj( @JvmField internal val catalog: String, @JvmField internal val path: List, - @JvmField internal val type: StaticType, + @JvmField internal val type: CompilerType, ) : Ref() { public override val children: List = emptyList() @@ -167,8 +166,8 @@ internal sealed class Ref : PlanNode() { } internal data class Cast( - @JvmField internal val input: PartiQLValueType, - @JvmField internal val target: PartiQLValueType, + @JvmField internal val input: CompilerType, + @JvmField internal val target: CompilerType, @JvmField internal val safety: Safety, @JvmField internal val isNullable: Boolean, ) : PlanNode() { @@ -258,7 +257,7 @@ internal sealed class Identifier : PlanNode() { } internal data class Rex( - @JvmField internal val type: StaticType, + @JvmField internal val type: CompilerType, @JvmField internal val op: Op, ) : PlanNode() { public override val children: List by lazy { @@ -442,7 +441,7 @@ internal data class Rex( } internal data class Unresolved( - @JvmField internal val target: PartiQLValueType, + @JvmField internal val target: CompilerType, @JvmField internal val arg: Rex, ) : Cast() { public override val children: List by lazy { @@ -1438,7 +1437,7 @@ internal data class Rel( internal data class Binding( @JvmField internal val name: String, - @JvmField internal val type: StaticType, + @JvmField internal val type: CompilerType, ) : PlanNode() { public override val children: List = emptyList() diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt index 5d82f4acbc..c1c62bb092 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt @@ -69,7 +69,8 @@ import org.partiql.planner.internal.ir.rexOpSelect import org.partiql.planner.internal.ir.rexOpStruct import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpVarLocal -import org.partiql.types.StaticType +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.boolValue import org.partiql.value.int32Value @@ -95,7 +96,7 @@ internal object RelConverter { is Select.Pivot -> { val key = projection.key.toRex(env) val value = projection.value.toRex(env) - val type = (StaticType.STRUCT) + val type = (STRUCT) val op = rexOpPivot(key, value, rel) rex(type, op) } @@ -105,11 +106,11 @@ internal object RelConverter { "Expected SELECT VALUE's input to have a single binding. " + "However, it contained: ${rel.type.schema.map { it.name }}." } - val constructor = rex(StaticType.ANY, rexOpVarLocal(0, 0)) + val constructor = rex(ANY, rexOpVarLocal(0, 0)) val op = rexOpSelect(constructor, rel) val type = when (rel.type.props.contains(Rel.Prop.ORDERED)) { - true -> (StaticType.LIST) - else -> (StaticType.BAG) + true -> (LIST) + else -> (BAG) } rex(type, op) } @@ -222,7 +223,7 @@ internal object RelConverter { else -> { val index = relBinding( name = i.symbol, - type = (StaticType.INT) + type = (INT) ) convertScanIndexed(rex, binding, index) } @@ -233,7 +234,7 @@ internal object RelConverter { null -> error("AST not normalized, missing AT alias on UNPIVOT $node") else -> relBinding( name = at.symbol, - type = (StaticType.STRING) + type = (STRING) ) } convertUnpivot(rex, k = atAlias, v = binding) @@ -252,7 +253,7 @@ internal object RelConverter { val rhs = visitFrom(node.rhs, nil) val schema = lhs.type.schema + rhs.type.schema // Note: This gets more specific in PlanTyper. It is only used to find binding names here. val props = emptySet() - val condition = node.condition?.let { RexConverter.apply(it, env) } ?: rex(StaticType.BOOL, rexOpLit(boolValue(true))) + val condition = node.condition?.let { RexConverter.apply(it, env) } ?: rex(BOOL, rexOpLit(boolValue(true))) val joinType = when (node.type) { From.Join.Type.LEFT_OUTER, From.Join.Type.LEFT -> Rel.Op.Join.Type.LEFT From.Join.Type.RIGHT_OUTER, From.Join.Type.RIGHT -> Rel.Op.Join.Type.RIGHT @@ -360,7 +361,7 @@ internal object RelConverter { val calls = aggregations.mapIndexed { i, expr -> val binding = relBinding( name = syntheticAgg(i), - type = (StaticType.ANY), + type = (ANY), ) schema.add(binding) val args = expr.args.map { arg -> arg.toRex(env) } @@ -388,15 +389,15 @@ internal object RelConverter { // Add GROUP_AS aggregation groupBy?.let { gb -> gb.asAlias?.let { groupAs -> - val binding = relBinding(groupAs.symbol, StaticType.ANY) + val binding = relBinding(groupAs.symbol, ANY) schema.add(binding) val fields = input.type.schema.mapIndexed { bindingIndex, currBinding -> rexOpStructField( - k = rex(StaticType.STRING, rexOpLit(stringValue(currBinding.name))), - v = rex(StaticType.ANY, rexOpVarLocal(0, bindingIndex)) + k = rex(STRING, rexOpLit(stringValue(currBinding.name))), + v = rex(ANY, rexOpVarLocal(0, bindingIndex)) ) } - val arg = listOf(rex(StaticType.ANY, rexOpStruct(fields))) + val arg = listOf(rex(ANY, rexOpStruct(fields))) calls.add(relOpAggregateCallUnresolved("group_as", Rel.Op.Aggregate.SetQuantifier.ALL, arg)) } } @@ -408,7 +409,7 @@ internal object RelConverter { } val binding = relBinding( name = it.asAlias!!.symbol, - type = (StaticType.ANY) + type = (ANY) ) schema.add(binding) it.expr.toRex(env) @@ -572,17 +573,17 @@ internal object RelConverter { // private fun convertGroupAs(name: String, from: From): Binding { // val fields = from.bindings().map { n -> // Plan.field( - // name = Plan.rexLit(ionString(n), StaticType.STRING), - // value = Plan.rexId(n, Case.SENSITIVE, Rex.Id.Qualifier.UNQUALIFIED, type = StaticType.STRUCT) + // name = Plan.rexLit(ionString(n), STRING), + // value = Plan.rexId(n, Case.SENSITIVE, Rex.Id.Qualifier.UNQUALIFIED, type = STRUCT) // ) // } // return Plan.binding( // name = name, // value = Plan.rexAgg( // id = "group_as", - // args = listOf(Plan.rexTuple(fields, StaticType.STRUCT)), + // args = listOf(Plan.rexTuple(fields, STRUCT)), // modifier = Rex.Agg.Modifier.ALL, - // type = StaticType.STRUCT + // type = STRUCT // ) // ) // } @@ -649,4 +650,12 @@ internal object RelConverter { } private fun syntheticAgg(i: Int) = "\$agg_$i" + + private val ANY: CompilerType = CompilerType(PType.typeDynamic()) + private val BOOL: CompilerType = CompilerType(PType.typeBool()) + private val STRING: CompilerType = CompilerType(PType.typeString()) + private val STRUCT: CompilerType = CompilerType(PType.typeStruct()) + private val BAG: CompilerType = CompilerType(PType.typeBag()) + private val LIST: CompilerType = CompilerType(PType.typeList()) + private val INT: CompilerType = CompilerType(PType.typeIntArbitrary()) } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index e2b6b1a24b..330397672e 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -54,10 +54,11 @@ import org.partiql.planner.internal.ir.rexOpSubquery import org.partiql.planner.internal.ir.rexOpTupleUnion import org.partiql.planner.internal.ir.rexOpVarLocal import org.partiql.planner.internal.ir.rexOpVarUnresolved -import org.partiql.planner.internal.typer.toStaticType -import org.partiql.types.StaticType +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType +import org.partiql.types.PType +import org.partiql.value.MissingValue import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType import org.partiql.value.StringValue import org.partiql.value.boolValue import org.partiql.value.int32Value @@ -83,7 +84,11 @@ internal object RexConverter { throw IllegalArgumentException("unsupported rex $node") override fun visitExprLit(node: Expr.Lit, context: Env): Rex { - val type = node.value.type.toStaticType() + val type = CompilerType( + _delegate = PType.fromPartiQLValueType(node.value.type), + isNullValue = node.value.isNull, + isMissingValue = node.value is MissingValue + ) val op = rexOpLit(node.value) return rex(type, op) } @@ -92,7 +97,7 @@ internal object RexConverter { val value = PartiQLValueIonReaderBuilder .standard().build(node.value).read() - val type = value.type.toStaticType() + val type = CompilerType(PType.fromPartiQLValueType(value.type)) return rex(type, rexOpLit(value)) } @@ -119,7 +124,7 @@ internal object RexConverter { true -> { val select = rex.op as Rex.Op.Select rex( - StaticType.ANY, + CompilerType(PType.typeDynamic()), rexOpSubquery( constructor = select.constructor, rel = select.rel, @@ -132,7 +137,7 @@ internal object RexConverter { } override fun visitExprVar(node: Expr.Var, context: Env): Rex { - val type = (StaticType.ANY) + val type = (ANY) val identifier = AstToPlan.convert(node.identifier) val scope = when (node.scope) { Expr.Var.Scope.DEFAULT -> Rex.Op.Var.Scope.DEFAULT @@ -143,7 +148,7 @@ internal object RexConverter { } override fun visitExprUnary(node: Expr.Unary, context: Env): Rex { - val type = (StaticType.ANY) + val type = (ANY) // Args val arg = visitExprCoerce(node.expr, context) val args = listOf(arg) @@ -154,7 +159,7 @@ internal object RexConverter { } override fun visitExprBinary(node: Expr.Binary, context: Env): Rex { - val type = (StaticType.ANY) + val type = (ANY) val args = when (node.op) { Expr.Binary.Op.LT, Expr.Binary.Op.GT, Expr.Binary.Op.LTE, Expr.Binary.Op.GTE, @@ -245,7 +250,7 @@ internal object RexConverter { 0 -> root to node.steps else -> { val newRoot = rex( - StaticType.ANY, + ANY, rexOpVarUnresolved(mergeIdentifiers(op.identifier, identifierSteps), op.scope) ) val newSteps = node.steps.subList(identifierSteps.size, node.steps.size) @@ -340,7 +345,7 @@ internal object RexConverter { op } } - rex(StaticType.ANY, path) + rex(ANY, path) } if (fromList.size == 0) return pathNavi @@ -348,7 +353,7 @@ internal object RexConverter { val schema = acc.type.schema + scan.type.schema val props = emptySet() val type = relType(schema, props) - rel(type, relOpJoin(acc, scan, rex(StaticType.BOOL, rexOpLit(boolValue(true))), Rel.Op.Join.Type.INNER)) + rel(type, relOpJoin(acc, scan, rex(BOOL, rexOpLit(boolValue(true))), Rel.Op.Join.Type.INNER)) } // compute the ref used by select construct @@ -363,7 +368,7 @@ internal object RexConverter { else -> throw IllegalStateException() } val op = rexOpSelect(constructor, fromNode) - return rex(StaticType.ANY, op) + return rex(ANY, op) } /** @@ -392,7 +397,7 @@ internal object RexConverter { val schema = listOf( relBinding( name = "_k$index", // fresh variable - type = StaticType.STRING + type = STRING ), relBinding( name = "_v$index", // fresh variable @@ -404,10 +409,10 @@ internal object RexConverter { return rel(relType, relOpUnpivot(path)) } - private fun rexString(str: String) = rex(StaticType.STRING, rexOpLit(stringValue(str))) + private fun rexString(str: String) = rex(STRING, rexOpLit(stringValue(str))) override fun visitExprCall(node: Expr.Call, context: Env): Rex { - val type = (StaticType.ANY) + val type = (ANY) // Fn val id = AstToPlan.convert(node.function) if (id is Identifier.Symbol && id.symbol.equals("TUPLEUNION", ignoreCase = true)) { @@ -421,14 +426,14 @@ internal object RexConverter { } private fun visitExprCallTupleUnion(node: Expr.Call, context: Env): Rex { - val type = (StaticType.STRUCT) + val type = (STRUCT) val args = node.args.map { visitExprCoerce(it, context) }.toMutableList() val op = rexOpTupleUnion(args) return rex(type, op) } override fun visitExprCase(node: Expr.Case, context: Env) = plan { - val type = (StaticType.ANY) + val type = (ANY) val rex = when (node.expr) { null -> null else -> visitExprCoerce(node.expr!!, context) // match `rex @@ -451,7 +456,7 @@ internal object RexConverter { }.toMutableList() val defaultRex = when (val default = node.default) { - null -> rex(type = StaticType.ANY, op = rexOpLit(value = nullValue())) + null -> rex(type = ANY, op = rexOpLit(value = nullValue())) else -> visitExprCoerce(default, context) } val op = rexOpCase(branches = branches, default = defaultRex) @@ -460,11 +465,11 @@ internal object RexConverter { override fun visitExprCollection(node: Expr.Collection, context: Env): Rex { val type = when (node.type) { - Expr.Collection.Type.BAG -> StaticType.BAG - Expr.Collection.Type.ARRAY -> StaticType.LIST - Expr.Collection.Type.VALUES -> StaticType.LIST - Expr.Collection.Type.LIST -> StaticType.LIST - Expr.Collection.Type.SEXP -> StaticType.SEXP + Expr.Collection.Type.BAG -> BAG + Expr.Collection.Type.ARRAY -> LIST + Expr.Collection.Type.VALUES -> LIST + Expr.Collection.Type.LIST -> LIST + Expr.Collection.Type.SEXP -> SEXP } val values = node.values.map { visitExprCoerce(it, context) } val op = rexOpCollection(values) @@ -472,7 +477,7 @@ internal object RexConverter { } override fun visitExprStruct(node: Expr.Struct, context: Env): Rex { - val type = (StaticType.STRUCT) + val type = (STRUCT) val fields = node.fields.map { val k = visitExprCoerce(it.name, context) val v = visitExprCoerce(it.value, context) @@ -488,7 +493,7 @@ internal object RexConverter { * NOT? LIKE ( ESCAPE )? */ override fun visitExprLike(node: Expr.Like, ctx: Env): Rex { - val type = StaticType.BOOL + val type = BOOL // Args val arg0 = visitExprCoerce(node.value, ctx) val arg1 = visitExprCoerce(node.pattern, ctx) @@ -509,7 +514,7 @@ internal object RexConverter { * NOT? BETWEEN AND */ override fun visitExprBetween(node: Expr.Between, ctx: Env): Rex = plan { - val type = StaticType.BOOL + val type = BOOL // Args val arg0 = visitExprCoerce(node.value, ctx) val arg1 = visitExprCoerce(node.from, ctx) @@ -536,7 +541,7 @@ internal object RexConverter { * */ override fun visitExprInCollection(node: Expr.InCollection, ctx: Env): Rex { - val type = StaticType.BOOL + val type = BOOL // Args val arg0 = visitExprCoerce(node.lhs, ctx) val arg1 = visitExpr(node.rhs, ctx) // !! don't insert scalar subquery coercions @@ -554,7 +559,7 @@ internal object RexConverter { * IS ? */ override fun visitExprIsType(node: Expr.IsType, ctx: Env): Rex { - val type = StaticType.BOOL + val type = BOOL // arg val arg0 = visitExprCoerce(node.value, ctx) @@ -608,7 +613,7 @@ internal object RexConverter { } override fun visitExprCoalesce(node: Expr.Coalesce, ctx: Env): Rex { - val type = StaticType.ANY + val type = ANY val args = node.args.map { arg -> visitExprCoerce(arg, ctx) } @@ -617,7 +622,7 @@ internal object RexConverter { } override fun visitExprNullIf(node: Expr.NullIf, ctx: Env): Rex { - val type = StaticType.ANY + val type = ANY val value = visitExprCoerce(node.value, ctx) val nullifier = visitExprCoerce(node.nullifier, ctx) val op = rexOpNullif(value, nullifier) @@ -628,10 +633,10 @@ internal object RexConverter { * SUBSTRING( (FROM (FOR )?)? ) */ override fun visitExprSubstring(node: Expr.Substring, ctx: Env): Rex { - val type = StaticType.ANY + val type = ANY // Args val arg0 = visitExprCoerce(node.value, ctx) - val arg1 = node.start?.let { visitExprCoerce(it, ctx) } ?: rex(StaticType.INT, rexOpLit(int64Value(1))) + val arg1 = node.start?.let { visitExprCoerce(it, ctx) } ?: rex(INT, rexOpLit(int64Value(1))) val arg2 = node.length?.let { visitExprCoerce(it, ctx) } // Call Variants val call = when (arg2) { @@ -645,7 +650,7 @@ internal object RexConverter { * POSITION( IN ) */ override fun visitExprPosition(node: Expr.Position, ctx: Env): Rex { - val type = StaticType.ANY + val type = ANY // Args val arg0 = visitExprCoerce(node.lhs, ctx) val arg1 = visitExprCoerce(node.rhs, ctx) @@ -658,7 +663,7 @@ internal object RexConverter { * TRIM([LEADING|TRAILING|BOTH]? ( FROM)? ) */ override fun visitExprTrim(node: Expr.Trim, ctx: Env): Rex { - val type = StaticType.TEXT + val type = STRING // Args val arg0 = visitExprCoerce(node.value, ctx) val arg1 = node.chars?.let { visitExprCoerce(it, ctx) } @@ -703,75 +708,92 @@ internal object RexConverter { val cv = visitExprCoerce(node.value, ctx) val sp = visitExprCoerce(node.start, ctx) val rs = visitExprCoerce(node.overlay, ctx) - val sl = node.length?.let { visitExprCoerce(it, ctx) } ?: rex(StaticType.ANY, call("char_length", rs)) + val sl = node.length?.let { visitExprCoerce(it, ctx) } ?: rex(ANY, call("char_length", rs)) val p1 = rex( - StaticType.ANY, + ANY, call( "substring", cv, - rex(StaticType.INT4, rexOpLit(int32Value(1))), - rex(StaticType.ANY, call("minus", sp, rex(StaticType.INT4, rexOpLit(int32Value(1))))) + rex(INT4, rexOpLit(int32Value(1))), + rex(ANY, call("minus", sp, rex(INT4, rexOpLit(int32Value(1))))) ) ) - val p2 = rex(StaticType.ANY, call("concat", p1, rs)) + val p2 = rex(ANY, call("concat", p1, rs)) return rex( - StaticType.ANY, + ANY, call( "concat", p2, - rex(StaticType.ANY, call("substring", cv, rex(StaticType.ANY, call("plus", sp, sl)))) + rex(ANY, call("substring", cv, rex(ANY, call("plus", sp, sl)))) ) ) } override fun visitExprExtract(node: Expr.Extract, ctx: Env): Rex { val call = call("extract_${node.field.name.lowercase()}", visitExprCoerce(node.source, ctx)) - return rex(StaticType.ANY, call) + return rex(ANY, call) } - // TODO: Ignoring type parameter now override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex { - val type = node.asType + val type = visitType(node.asType) val arg = visitExprCoerce(node.value, ctx) - val target = when (type) { - is Type.NullType -> error("Cannot cast any value to NULL") - is Type.Missing -> error("Cannot cast any value to MISSING") - is Type.Bool -> PartiQLValueType.BOOL - is Type.Tinyint -> PartiQLValueType.INT8 - is Type.Smallint, is Type.Int2 -> PartiQLValueType.INT16 - is Type.Int4 -> PartiQLValueType.INT32 - is Type.Bigint, is Type.Int8 -> PartiQLValueType.INT64 - is Type.Int -> PartiQLValueType.INT - is Type.Real -> PartiQLValueType.FLOAT64 - is Type.Float32 -> PartiQLValueType.FLOAT32 - is Type.Float64 -> PartiQLValueType.FLOAT64 - is Type.Decimal -> if (type.scale != null) PartiQLValueType.DECIMAL else PartiQLValueType.DECIMAL_ARBITRARY - is Type.Numeric -> if (type.scale != null) PartiQLValueType.DECIMAL else PartiQLValueType.DECIMAL_ARBITRARY - is Type.Char -> PartiQLValueType.CHAR - is Type.Varchar -> PartiQLValueType.STRING - is Type.String -> PartiQLValueType.STRING - is Type.Symbol -> PartiQLValueType.SYMBOL - is Type.Bit -> PartiQLValueType.BINARY - is Type.BitVarying -> PartiQLValueType.BINARY - is Type.ByteString -> PartiQLValueType.BINARY - is Type.Blob -> PartiQLValueType.BLOB - is Type.Clob -> PartiQLValueType.CLOB - is Type.Date -> PartiQLValueType.DATE - is Type.Time -> PartiQLValueType.TIME - is Type.TimeWithTz -> PartiQLValueType.TIME - is Type.Timestamp -> PartiQLValueType.TIMESTAMP - is Type.TimestampWithTz -> PartiQLValueType.TIMESTAMP - is Type.Interval -> PartiQLValueType.INTERVAL - is Type.Bag -> PartiQLValueType.BAG - is Type.Sexp -> PartiQLValueType.SEXP - is Type.Any -> PartiQLValueType.ANY + return rex(ANY, rexOpCastUnresolved(type, arg)) + } + + private fun visitType(type: Type): CompilerType { + return when (type) { + is Type.NullType -> error("Casting to NULL is not supported.") + is Type.Missing -> error("Casting to MISSING is not supported.") + is Type.Bool -> PType.typeBool() + is Type.Tinyint -> PType.typeTinyInt() + is Type.Smallint, is Type.Int2 -> PType.typeSmallInt() + is Type.Int4 -> PType.typeInt() + is Type.Bigint, is Type.Int8 -> PType.typeBigInt() + is Type.Int -> PType.typeIntArbitrary() + is Type.Real -> PType.typeReal() + is Type.Float32 -> PType.typeReal() + is Type.Float64 -> PType.typeDoublePrecision() + is Type.Decimal -> when { + type.precision == null && type.scale == null -> PType.typeDecimalArbitrary() + type.precision != null && type.scale != null -> PType.typeDecimal(type.precision!!, type.scale!!) + type.precision != null && type.scale == null -> PType.typeDecimal(type.precision!!, 0) + else -> error("Precision can never be null while scale is specified.") + } + + is Type.Numeric -> when { + type.precision == null && type.scale == null -> PType.typeDecimalArbitrary() + type.precision != null && type.scale != null -> PType.typeDecimal(type.precision!!, type.scale!!) + type.precision != null && type.scale == null -> PType.typeDecimal(type.precision!!, 0) + else -> error("Precision can never be null while scale is specified.") + } + + is Type.Char -> PType.typeChar(type.length ?: 255) // TODO: What is default? + is Type.Varchar -> error("VARCHAR is not supported yet.") + is Type.String -> PType.typeString() + is Type.Symbol -> PType.typeSymbol() + is Type.Bit -> error("BIT is not supported yet.") + is Type.BitVarying -> error("BIT VARYING is not supported yet.") + is Type.ByteString -> error("BINARY is not supported yet.") + is Type.Blob -> PType.typeBlob(type.length ?: Int.MAX_VALUE) + is Type.Clob -> PType.typeClob(type.length ?: Int.MAX_VALUE) + is Type.Date -> PType.typeDate() + is Type.Time -> PType.typeTimeWithoutTZ(type.precision ?: 6) + is Type.TimeWithTz -> PType.typeTimeWithTZ(type.precision ?: 6) + is Type.Timestamp -> PType.typeTimestampWithoutTZ(type.precision ?: 6) + is Type.TimestampWithTz -> PType.typeTimestampWithTZ(type.precision ?: 6) + is Type.Interval -> error("INTERVAL is not supported yet.") + is Type.Bag -> PType.typeBag() + is Type.Sexp -> PType.typeSexp() + is Type.Any -> PType.typeDynamic() is Type.Custom -> TODO("Custom type not supported ") - is Type.List -> PartiQLValueType.LIST - is Type.Tuple -> PartiQLValueType.STRUCT - is Type.Array -> PartiQLValueType.LIST - is Type.Struct -> PartiQLValueType.STRUCT - } - return rex(StaticType.ANY, rexOpCastUnresolved(target, arg)) + is Type.List -> PType.typeList() + is Type.Tuple -> PType.typeStruct() + is Type.Array -> when (type.type) { + null -> PType.typeList() + else -> PType.typeList(visitType(type.type!!)) + } + is Type.Struct -> PType.typeStruct() + }.toCType() } override fun visitExprCanCast(node: Expr.CanCast, ctx: Env): Rex { @@ -783,7 +805,7 @@ internal object RexConverter { } override fun visitExprDateAdd(node: Expr.DateAdd, ctx: Env): Rex { - val type = StaticType.TIMESTAMP + val type = TIMESTAMP // Args val arg0 = visitExprCoerce(node.lhs, ctx) val arg1 = visitExprCoerce(node.rhs, ctx) @@ -797,7 +819,7 @@ internal object RexConverter { } override fun visitExprDateDiff(node: Expr.DateDiff, ctx: Env): Rex { - val type = StaticType.TIMESTAMP + val type = TIMESTAMP // Args val arg0 = visitExprCoerce(node.lhs, ctx) val arg1 = visitExprCoerce(node.rhs, ctx) @@ -811,7 +833,7 @@ internal object RexConverter { } override fun visitExprSessionAttribute(node: Expr.SessionAttribute, ctx: Env): Rex { - val type = StaticType.ANY + val type = ANY val fn = node.attribute.name.lowercase() val call = call(fn) return rex(type, call) @@ -821,11 +843,11 @@ internal object RexConverter { override fun visitExprBagOp(node: Expr.BagOp, ctx: Env): Rex { val lhs = Rel( - type = Rel.Type(listOf(Rel.Binding("_0", StaticType.ANY)), props = emptySet()), + type = Rel.Type(listOf(Rel.Binding("_0", ANY)), props = emptySet()), op = Rel.Op.Scan(visitExpr(node.lhs, ctx)) ) val rhs = Rel( - type = Rel.Type(listOf(Rel.Binding("_1", StaticType.ANY)), props = emptySet()), + type = Rel.Type(listOf(Rel.Binding("_1", ANY)), props = emptySet()), op = Rel.Op.Scan(visitExpr(node.rhs, ctx)) ) val quantifier = when (node.type.setq) { @@ -839,14 +861,14 @@ internal object RexConverter { SetOp.Type.INTERSECT -> Rel.Op.Set.Intersect(quantifier, lhs, rhs, isOuter) } val rel = Rel( - type = Rel.Type(listOf(Rel.Binding("_0", StaticType.ANY)), props = emptySet()), + type = Rel.Type(listOf(Rel.Binding("_0", ANY)), props = emptySet()), op = op ) return Rex( - type = StaticType.ANY, + type = ANY, op = Rex.Op.Select( constructor = Rex( - StaticType.ANY, + ANY, Rex.Op.Var.Unresolved(Identifier.Symbol("_0", Identifier.CaseSensitivity.SENSITIVE), Rex.Op.Var.Scope.LOCAL) ), rel = rel @@ -860,7 +882,7 @@ internal object RexConverter { val name = Expr.Unary.Op.NOT.name val id = identifierSymbol(name.lowercase(), Identifier.CaseSensitivity.SENSITIVE) // wrap - val arg = rex(StaticType.BOOL, call) + val arg = rex(BOOL, call) // rewrite call return rexOpCallUnresolved(id, listOf(arg)) } @@ -874,6 +896,17 @@ internal object RexConverter { return rexOpCallUnresolved(id, args.toList()) } - private fun Int?.toRex() = rex(StaticType.INT4, rexOpLit(int32Value(this))) + private fun Int?.toRex() = rex(INT4, rexOpLit(int32Value(this))) + + private val ANY: CompilerType = CompilerType(PType.typeDynamic()) + private val BOOL: CompilerType = CompilerType(PType.typeBool()) + private val STRING: CompilerType = CompilerType(PType.typeString()) + private val STRUCT: CompilerType = CompilerType(PType.typeStruct()) + private val BAG: CompilerType = CompilerType(PType.typeBag()) + private val LIST: CompilerType = CompilerType(PType.typeList()) + private val SEXP: CompilerType = CompilerType(PType.typeSexp()) + private val INT: CompilerType = CompilerType(PType.typeIntArbitrary()) + private val INT4: CompilerType = CompilerType(PType.typeInt()) + private val TIMESTAMP: CompilerType = CompilerType(PType.typeTimestampWithoutTZ(6)) } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt new file mode 100644 index 0000000000..722db8d4e8 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt @@ -0,0 +1,77 @@ +package org.partiql.planner.internal.typer + +import org.partiql.types.PType +import org.partiql.types.PType.Kind + +/** + * This is largely just to show that the planner does not need to use [_delegate] ([PType]) directly. Using an + * internal representation, we can leverage the APIs of [PType] while carrying some additional information such + * as [isMissingValue]. + * + * @property isNullValue denotes that the expression will always return the null value. + * @property isMissingValue denotes that the expression will always return the missing value. + */ +internal class CompilerType( + private val _delegate: PType, + // Note: This is an experimental property. + internal val isNullValue: Boolean = false, + // Note: This is an experimental property. + internal val isMissingValue: Boolean = false +) : PType { + override fun getKind(): Kind = _delegate.kind + override fun getFields(): MutableCollection { + return _delegate.fields.map { field -> + when (field) { + is Field -> field + else -> Field(field.name, CompilerType(field.type)) + } + }.toMutableList() + } + + override fun getLength(): Int { + return _delegate.length + } + + override fun getPrecision(): Int = _delegate.precision + override fun getScale(): Int = _delegate.scale + override fun getTypeParameter(): CompilerType { + return when (val p = _delegate.typeParameter) { + is CompilerType -> p + else -> CompilerType(p) + } + } + + override fun equals(other: Any?): Boolean { + return _delegate == other + } + + override fun hashCode(): Int { + return _delegate.hashCode() + } + + override fun toString(): String { + return _delegate.toString() + } + + internal class Field( + private val _name: String, + private val _type: CompilerType + ) : org.partiql.types.Field { + override fun getName(): String = _name + override fun getType(): CompilerType = _type + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is org.partiql.types.Field) return false + val nameMatches = _name == other.name + val typeMatches = _type == other.type + return nameMatches && typeMatches + } + + override fun hashCode(): Int { + var result = _name.hashCode() + result = 31 * result + _type.hashCode() + return result + } + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/DynamicTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/DynamicTyper.kt index 2adf627763..61473b917c 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/DynamicTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/DynamicTyper.kt @@ -1,44 +1,41 @@ -@file:OptIn(PartiQLValueExperimental::class) - package org.partiql.planner.internal.typer import org.partiql.planner.internal.ir.Rex -import org.partiql.types.StaticType +import org.partiql.planner.internal.typer.PlanTyper.Companion.anyOf +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType +import org.partiql.types.PType +import org.partiql.types.PType.Kind import org.partiql.value.MissingValue -import org.partiql.value.NullValue +import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.PartiQLValueType -import org.partiql.value.PartiQLValueType.ANY -import org.partiql.value.PartiQLValueType.BAG -import org.partiql.value.PartiQLValueType.BINARY -import org.partiql.value.PartiQLValueType.BLOB -import org.partiql.value.PartiQLValueType.BOOL -import org.partiql.value.PartiQLValueType.BYTE -import org.partiql.value.PartiQLValueType.CHAR -import org.partiql.value.PartiQLValueType.CLOB -import org.partiql.value.PartiQLValueType.DATE -import org.partiql.value.PartiQLValueType.DECIMAL -import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY -import org.partiql.value.PartiQLValueType.FLOAT32 -import org.partiql.value.PartiQLValueType.FLOAT64 -import org.partiql.value.PartiQLValueType.INT -import org.partiql.value.PartiQLValueType.INT16 -import org.partiql.value.PartiQLValueType.INT32 -import org.partiql.value.PartiQLValueType.INT64 -import org.partiql.value.PartiQLValueType.INT8 -import org.partiql.value.PartiQLValueType.INTERVAL -import org.partiql.value.PartiQLValueType.LIST -import org.partiql.value.PartiQLValueType.SEXP -import org.partiql.value.PartiQLValueType.STRING -import org.partiql.value.PartiQLValueType.STRUCT -import org.partiql.value.PartiQLValueType.SYMBOL -import org.partiql.value.PartiQLValueType.TIME -import org.partiql.value.PartiQLValueType.TIMESTAMP +import org.partiql.value.bagValue +import org.partiql.value.blobValue +import org.partiql.value.boolValue +import org.partiql.value.charValue +import org.partiql.value.clobValue +import org.partiql.value.dateValue +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue +import org.partiql.value.listValue +import org.partiql.value.missingValue +import org.partiql.value.nullValue +import org.partiql.value.sexpValue +import org.partiql.value.stringValue +import org.partiql.value.structValue +import org.partiql.value.symbolValue +import org.partiql.value.timeValue +import org.partiql.value.timestampValue /** * Graph of super types for quick lookup because we don't have a tree. */ -internal typealias SuperGraph = Array> +internal typealias SuperGraph = Array> /** * For lack of a better name, this is the "dynamic typer" which implements the typing rules of SQL-99 9.3. @@ -50,30 +47,38 @@ internal typealias SuperGraph = Array> * To calculate the type of an "aggregation" create a new instance and "accumulate" each possible type. * This is a pain with StaticType... */ -@OptIn(PartiQLValueExperimental::class) internal class DynamicTyper { - private var supertype: PartiQLValueType? = null - private var args = mutableListOf() - - private val types = mutableSetOf() + private var supertype: CompilerType? = null + private var args = mutableListOf() + private val types = mutableListOf() /** * Adds the [rex]'s [Rex.type] to the typing accumulator (if the [rex] is not a literal NULL/MISSING). */ fun accumulate(rex: Rex) { - when (rex.isLiteralAbsent()) { - true -> accumulateUnknown() - false -> accumulate(rex.type) + when { + rex.isLiteralNull() || rex.isLiteralMissing() -> accumulateUnknown(rex) + else -> accumulateConcrete(rex) } } /** - * Checks for literal NULL or MISSING. + * Checks for literal NULL + */ + @OptIn(PartiQLValueExperimental::class) + private fun Rex.isLiteralNull(): Boolean { + val op = this.op + return op is Rex.Op.Lit && op.value.isNull + } + + /** + * Checks for literal MISSING */ - private fun Rex.isLiteralAbsent(): Boolean { + @OptIn(PartiQLValueExperimental::class) + private fun Rex.isLiteralMissing(): Boolean { val op = this.op - return op is Rex.Op.Lit && (op.value is MissingValue || op.value is NullValue) + return op is Rex.Op.Lit && op.value is MissingValue } /** @@ -81,69 +86,58 @@ internal class DynamicTyper { * inferred. This function ignores literal null/missing values, yet adds their indices to know how to return the * mapping. */ - private fun accumulateUnknown() { - args.add(ANY) + private fun accumulateUnknown(rex: Rex) { + args.add(rex) } /** * This adds non-absent types (aka not NULL / MISSING literals) to the typing accumulator. * @param type */ - private fun accumulate(type: StaticType) { - val flatType = type.flatten() - if (flatType == StaticType.ANY) { - types.add(flatType) - args.add(ANY) - calculate(ANY) - return - } - val allTypes = flatType.allTypes - when (allTypes.size) { - 0 -> { - error("This should not have happened.") - } - 1 -> { - // Had single type - val single = allTypes.first() - val singleRuntime = single.toRuntimeType() - types.add(single) - args.add(singleRuntime) - calculate(singleRuntime) - } - else -> { - // Had a union; use ANY runtime - types.addAll(allTypes) - args.add(ANY) - calculate(ANY) - } - } + private fun accumulateConcrete(rex: Rex) { + types.add(rex.type) + args.add(rex) + calculate(rex.type) } /** - * Returns a pair of the return StaticType and the coercion. + * Returns a pair of the return type and the coercions. * * If the list is null, then no mapping is required. * * @return */ - fun mapping(): Pair>?> { + @OptIn(PartiQLValueExperimental::class) + fun mapping(): Pair?> { + val s = supertype ?: return CompilerType(PType.typeDynamic()) to null + val superTypeBase = s.kind // If at top supertype, then return union of all accumulated types - if (supertype == ANY) { - return StaticType.unionOf(types).flatten() to null + if (superTypeBase == Kind.DYNAMIC) { + return anyOf(types)!!.toCType() to null } // If a collection, then return union of all accumulated types as these coercion rules are not defined by SQL. - if (supertype == STRUCT || supertype == BAG || supertype == LIST || supertype == SEXP) { - return StaticType.unionOf(types) to null + if (superTypeBase in setOf(Kind.ROW, Kind.STRUCT, Kind.BAG, Kind.LIST, Kind.SEXP)) { + return anyOf(types)!!.toCType() to null } // If not initialized, then return null, missing, or null|missing. - val s = supertype ?: return StaticType.ANY to null // Otherwise, return the supertype along with the coercion mapping - val type = s.toStaticType() - val mapping = args.map { it to s } - return type to mapping + val mapping = args.map { + when { + it.isLiteralNull() -> Mapping.Replacement(Rex(s, Rex.Op.Lit(nullValue(s.kind)))) + it.isLiteralMissing() -> Mapping.Replacement(Rex(s, Rex.Op.Lit(missingValue()))) + it.type == s -> Mapping.Coercion(s) + else -> null + } + } + return s to mapping + } + + internal sealed interface Mapping { + class Replacement(val replacement: Rex) : Mapping + class Coercion(val target: CompilerType) : Mapping } - private fun calculate(type: PartiQLValueType) { + private fun calculate(type: CompilerType) { val s = supertype // Initialize if (s == null) { @@ -151,17 +145,15 @@ internal class DynamicTyper { return } // Don't bother calculating the new supertype, we've already hit `dynamic`. - if (s == ANY) return + if (s.kind == Kind.DYNAMIC) return // Lookup and set the new minimum common supertype supertype = when { - type == ANY -> type + type.kind == Kind.DYNAMIC -> type s == type -> return // skip - else -> graph[s][type] ?: ANY // lookup, if missing then go to top. + else -> graph[s.kind.ordinal][type.kind.ordinal]?.toPType() ?: CompilerType(PType.typeDynamic()) // lookup, if missing then go to top. } } - private operator fun Array.get(t: PartiQLValueType): T = get(t.ordinal) - /** * !! IMPORTANT !! * @@ -170,16 +162,14 @@ internal class DynamicTyper { */ companion object { - private operator fun Array.set(t: PartiQLValueType, value: T): Unit = this.set(t.ordinal, value) - @JvmStatic - private val N = PartiQLValueType.values().size + private val N = Kind.values().size @JvmStatic - private fun edges(vararg edges: Pair): Array { - val arr = arrayOfNulls(N) + private fun edges(vararg edges: Pair): Array { + val arr = arrayOfNulls(N) for (type in edges) { - arr[type.first] = type.second + arr[type.first.ordinal] = type.second } return arr } @@ -192,180 +182,258 @@ internal class DynamicTyper { */ @JvmStatic internal val graph: SuperGraph = run { - val graph = arrayOfNulls>(N) - for (type in PartiQLValueType.values()) { + val graph = arrayOfNulls>(N) + for (type in Kind.values()) { // initialize all with empty edges - graph[type] = arrayOfNulls(N) + graph[type.ordinal] = arrayOfNulls(N) } - graph[ANY] = edges() - graph[BOOL] = edges( - BOOL to BOOL + graph[Kind.DYNAMIC.ordinal] = edges() + graph[Kind.BOOL.ordinal] = edges( + Kind.BOOL to Kind.BOOL + ) + graph[Kind.TINYINT.ordinal] = edges( + Kind.TINYINT to Kind.TINYINT, + Kind.SMALLINT to Kind.SMALLINT, + Kind.INT to Kind.INT, + Kind.BIGINT to Kind.BIGINT, + Kind.INT_ARBITRARY to Kind.INT_ARBITRARY, + Kind.DECIMAL to Kind.DECIMAL, + Kind.DECIMAL_ARBITRARY to Kind.DECIMAL_ARBITRARY, + Kind.REAL to Kind.REAL, + Kind.DOUBLE_PRECISION to Kind.DOUBLE_PRECISION, ) - graph[INT8] = edges( - INT8 to INT8, - INT16 to INT16, - INT32 to INT32, - INT64 to INT64, - INT to INT, - DECIMAL to DECIMAL, - DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, - FLOAT32 to FLOAT32, - FLOAT64 to FLOAT64, + graph[Kind.SMALLINT.ordinal] = edges( + Kind.TINYINT to Kind.SMALLINT, + Kind.SMALLINT to Kind.SMALLINT, + Kind.INT to Kind.INT, + Kind.BIGINT to Kind.BIGINT, + Kind.INT_ARBITRARY to Kind.INT_ARBITRARY, + Kind.DECIMAL to Kind.DECIMAL, + Kind.DECIMAL_ARBITRARY to Kind.DECIMAL_ARBITRARY, + Kind.REAL to Kind.REAL, + Kind.DOUBLE_PRECISION to Kind.DOUBLE_PRECISION, ) - graph[INT16] = edges( - INT8 to INT16, - INT16 to INT16, - INT32 to INT32, - INT64 to INT64, - INT to INT, - DECIMAL to DECIMAL, - DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, - FLOAT32 to FLOAT32, - FLOAT64 to FLOAT64, + graph[Kind.INT.ordinal] = edges( + Kind.TINYINT to Kind.INT, + Kind.SMALLINT to Kind.INT, + Kind.INT to Kind.INT, + Kind.BIGINT to Kind.BIGINT, + Kind.INT_ARBITRARY to Kind.INT_ARBITRARY, + Kind.DECIMAL to Kind.DECIMAL, + Kind.DECIMAL_ARBITRARY to Kind.DECIMAL_ARBITRARY, + Kind.REAL to Kind.REAL, + Kind.DOUBLE_PRECISION to Kind.DOUBLE_PRECISION, ) - graph[INT32] = edges( - INT8 to INT32, - INT16 to INT32, - INT32 to INT32, - INT64 to INT64, - INT to INT, - DECIMAL to DECIMAL, - DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, - FLOAT32 to FLOAT32, - FLOAT64 to FLOAT64, + graph[Kind.BIGINT.ordinal] = edges( + Kind.TINYINT to Kind.BIGINT, + Kind.SMALLINT to Kind.BIGINT, + Kind.INT to Kind.BIGINT, + Kind.BIGINT to Kind.BIGINT, + Kind.INT_ARBITRARY to Kind.INT_ARBITRARY, + Kind.DECIMAL to Kind.DECIMAL, + Kind.DECIMAL_ARBITRARY to Kind.DECIMAL_ARBITRARY, + Kind.REAL to Kind.REAL, + Kind.DOUBLE_PRECISION to Kind.DOUBLE_PRECISION, ) - graph[INT64] = edges( - INT8 to INT64, - INT16 to INT64, - INT32 to INT64, - INT64 to INT64, - INT to INT, - DECIMAL to DECIMAL, - DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, - FLOAT32 to FLOAT32, - FLOAT64 to FLOAT64, + graph[Kind.INT_ARBITRARY.ordinal] = edges( + Kind.TINYINT to Kind.INT_ARBITRARY, + Kind.SMALLINT to Kind.INT_ARBITRARY, + Kind.INT to Kind.INT_ARBITRARY, + Kind.BIGINT to Kind.INT_ARBITRARY, + Kind.INT_ARBITRARY to Kind.INT_ARBITRARY, + Kind.DECIMAL to Kind.DECIMAL, + Kind.DECIMAL_ARBITRARY to Kind.DECIMAL_ARBITRARY, + Kind.REAL to Kind.REAL, + Kind.DOUBLE_PRECISION to Kind.DOUBLE_PRECISION, ) - graph[INT] = edges( - INT8 to INT, - INT16 to INT, - INT32 to INT, - INT64 to INT, - INT to INT, - DECIMAL to DECIMAL, - DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, - FLOAT32 to FLOAT32, - FLOAT64 to FLOAT64, + graph[Kind.DECIMAL.ordinal] = edges( + Kind.TINYINT to Kind.DECIMAL, + Kind.SMALLINT to Kind.DECIMAL, + Kind.INT to Kind.DECIMAL, + Kind.BIGINT to Kind.DECIMAL, + Kind.INT_ARBITRARY to Kind.DECIMAL, + Kind.DECIMAL to Kind.DECIMAL, + Kind.DECIMAL_ARBITRARY to Kind.DECIMAL_ARBITRARY, + Kind.REAL to Kind.REAL, + Kind.DOUBLE_PRECISION to Kind.DOUBLE_PRECISION, ) - graph[DECIMAL] = edges( - INT8 to DECIMAL, - INT16 to DECIMAL, - INT32 to DECIMAL, - INT64 to DECIMAL, - INT to DECIMAL, - DECIMAL to DECIMAL, - DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, - FLOAT32 to FLOAT32, - FLOAT64 to FLOAT64, + graph[Kind.DECIMAL_ARBITRARY.ordinal] = edges( + Kind.TINYINT to Kind.DECIMAL_ARBITRARY, + Kind.SMALLINT to Kind.DECIMAL_ARBITRARY, + Kind.INT to Kind.DECIMAL_ARBITRARY, + Kind.BIGINT to Kind.DECIMAL_ARBITRARY, + Kind.INT_ARBITRARY to Kind.DECIMAL_ARBITRARY, + Kind.DECIMAL to Kind.DECIMAL_ARBITRARY, + Kind.DECIMAL_ARBITRARY to Kind.DECIMAL_ARBITRARY, + Kind.REAL to Kind.REAL, + Kind.DOUBLE_PRECISION to Kind.DOUBLE_PRECISION, ) - graph[DECIMAL_ARBITRARY] = edges( - INT8 to DECIMAL_ARBITRARY, - INT16 to DECIMAL_ARBITRARY, - INT32 to DECIMAL_ARBITRARY, - INT64 to DECIMAL_ARBITRARY, - INT to DECIMAL_ARBITRARY, - DECIMAL to DECIMAL_ARBITRARY, - DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, - FLOAT32 to FLOAT32, - FLOAT64 to FLOAT64, + graph[Kind.REAL.ordinal] = edges( + Kind.TINYINT to Kind.REAL, + Kind.SMALLINT to Kind.REAL, + Kind.INT to Kind.REAL, + Kind.BIGINT to Kind.REAL, + Kind.INT_ARBITRARY to Kind.REAL, + Kind.DECIMAL to Kind.REAL, + Kind.DECIMAL_ARBITRARY to Kind.REAL, + Kind.REAL to Kind.REAL, + Kind.DOUBLE_PRECISION to Kind.DOUBLE_PRECISION, ) - graph[FLOAT32] = edges( - INT8 to FLOAT32, - INT16 to FLOAT32, - INT32 to FLOAT32, - INT64 to FLOAT32, - INT to FLOAT32, - DECIMAL to FLOAT32, - DECIMAL_ARBITRARY to FLOAT32, - FLOAT32 to FLOAT32, - FLOAT64 to FLOAT64, + graph[Kind.DOUBLE_PRECISION.ordinal] = edges( + Kind.TINYINT to Kind.DOUBLE_PRECISION, + Kind.SMALLINT to Kind.DOUBLE_PRECISION, + Kind.INT to Kind.DOUBLE_PRECISION, + Kind.BIGINT to Kind.DOUBLE_PRECISION, + Kind.INT_ARBITRARY to Kind.DOUBLE_PRECISION, + Kind.DECIMAL to Kind.DOUBLE_PRECISION, + Kind.DECIMAL_ARBITRARY to Kind.DOUBLE_PRECISION, + Kind.REAL to Kind.DOUBLE_PRECISION, + Kind.DOUBLE_PRECISION to Kind.DOUBLE_PRECISION, ) - graph[FLOAT64] = edges( - INT8 to FLOAT64, - INT16 to FLOAT64, - INT32 to FLOAT64, - INT64 to FLOAT64, - INT to FLOAT64, - DECIMAL to FLOAT64, - DECIMAL_ARBITRARY to FLOAT64, - FLOAT32 to FLOAT64, - FLOAT64 to FLOAT64, + graph[Kind.CHAR.ordinal] = edges( + Kind.CHAR to Kind.CHAR, + Kind.STRING to Kind.STRING, + Kind.VARCHAR to Kind.STRING, + Kind.SYMBOL to Kind.STRING, + Kind.CLOB to Kind.CLOB, ) - graph[CHAR] = edges( - CHAR to CHAR, - STRING to STRING, - SYMBOL to STRING, - CLOB to CLOB, + graph[Kind.STRING.ordinal] = edges( + Kind.CHAR to Kind.STRING, + Kind.STRING to Kind.STRING, + Kind.VARCHAR to Kind.STRING, + Kind.SYMBOL to Kind.STRING, + Kind.CLOB to Kind.CLOB, ) - graph[STRING] = edges( - CHAR to STRING, - STRING to STRING, - SYMBOL to STRING, - CLOB to CLOB, + graph[Kind.VARCHAR.ordinal] = edges( + Kind.CHAR to Kind.VARCHAR, + Kind.STRING to Kind.STRING, + Kind.VARCHAR to Kind.VARCHAR, + Kind.SYMBOL to Kind.STRING, + Kind.CLOB to Kind.CLOB, ) - graph[SYMBOL] = edges( - CHAR to SYMBOL, - STRING to STRING, - SYMBOL to SYMBOL, - CLOB to CLOB, + graph[Kind.SYMBOL.ordinal] = edges( + Kind.CHAR to Kind.SYMBOL, + Kind.STRING to Kind.STRING, + Kind.VARCHAR to Kind.STRING, + Kind.SYMBOL to Kind.SYMBOL, + Kind.CLOB to Kind.CLOB, ) - graph[BINARY] = edges( - BINARY to BINARY, + graph[Kind.BLOB.ordinal] = edges( + Kind.BLOB to Kind.BLOB, ) - graph[BYTE] = edges( - BYTE to BYTE, - BLOB to BLOB, + graph[Kind.DATE.ordinal] = edges( + Kind.DATE to Kind.DATE, ) - graph[BLOB] = edges( - BYTE to BLOB, - BLOB to BLOB, + graph[Kind.CLOB.ordinal] = edges( + Kind.CHAR to Kind.CLOB, + Kind.STRING to Kind.CLOB, + Kind.VARCHAR to Kind.CLOB, + Kind.SYMBOL to Kind.CLOB, + Kind.CLOB to Kind.CLOB, ) - graph[DATE] = edges( - DATE to DATE, + graph[Kind.TIME_WITHOUT_TZ.ordinal] = edges( + Kind.TIME_WITHOUT_TZ to Kind.TIME_WITHOUT_TZ, ) - graph[CLOB] = edges( - CHAR to CLOB, - STRING to CLOB, - SYMBOL to CLOB, - CLOB to CLOB, + graph[Kind.TIME_WITH_TZ.ordinal] = edges( + Kind.TIME_WITH_TZ to Kind.TIME_WITH_TZ, ) - graph[TIME] = edges( - TIME to TIME, + graph[Kind.TIMESTAMP_WITHOUT_TZ.ordinal] = edges( + Kind.TIMESTAMP_WITHOUT_TZ to Kind.TIMESTAMP_WITHOUT_TZ, ) - graph[TIMESTAMP] = edges( - TIMESTAMP to TIMESTAMP, + graph[Kind.TIMESTAMP_WITH_TZ.ordinal] = edges( + Kind.TIMESTAMP_WITH_TZ to Kind.TIMESTAMP_WITH_TZ, ) - graph[INTERVAL] = edges( - INTERVAL to INTERVAL, + graph[Kind.LIST.ordinal] = edges( + Kind.LIST to Kind.LIST, + Kind.SEXP to Kind.SEXP, + Kind.BAG to Kind.BAG, ) - graph[LIST] = edges( - LIST to LIST, - SEXP to SEXP, - BAG to BAG, + graph[Kind.SEXP.ordinal] = edges( + Kind.LIST to Kind.SEXP, + Kind.SEXP to Kind.SEXP, + Kind.BAG to Kind.BAG, ) - graph[SEXP] = edges( - LIST to SEXP, - SEXP to SEXP, - BAG to BAG, + graph[Kind.BAG.ordinal] = edges( + Kind.LIST to Kind.BAG, + Kind.SEXP to Kind.BAG, + Kind.BAG to Kind.BAG, ) - graph[BAG] = edges( - LIST to BAG, - SEXP to BAG, - BAG to BAG, + graph[Kind.STRUCT.ordinal] = edges( + Kind.STRUCT to Kind.STRUCT, ) - graph[STRUCT] = edges( - STRUCT to STRUCT, + graph[Kind.ROW.ordinal] = edges( + Kind.ROW to Kind.ROW, ) graph.requireNoNulls() } + + /** + * TODO: We need to update the logic of this whole file. We are currently limited by not using parameters + * of types. + */ + private fun Kind.toPType(): CompilerType = when (this) { + Kind.BOOL -> PType.typeBool() + Kind.DYNAMIC -> PType.typeDynamic() + Kind.TINYINT -> PType.typeTinyInt() + Kind.SMALLINT -> PType.typeSmallInt() + Kind.INT -> PType.typeInt() + Kind.BIGINT -> PType.typeBigInt() + Kind.INT_ARBITRARY -> PType.typeIntArbitrary() + Kind.DECIMAL -> PType.typeDecimalArbitrary() // TODO: To be updated. + Kind.DECIMAL_ARBITRARY -> PType.typeDecimalArbitrary() + Kind.REAL -> PType.typeReal() + Kind.DOUBLE_PRECISION -> PType.typeDoublePrecision() + Kind.CHAR -> PType.typeChar(255) // TODO: To be updated + Kind.VARCHAR -> PType.typeVarChar(255) // TODO: To be updated + Kind.STRING -> PType.typeString() + Kind.SYMBOL -> PType.typeSymbol() + Kind.BLOB -> PType.typeBlob(Int.MAX_VALUE) // TODO: To be updated + Kind.CLOB -> PType.typeClob(Int.MAX_VALUE) // TODO: To be updated + Kind.DATE -> PType.typeDate() + Kind.TIME_WITH_TZ -> PType.typeTimeWithTZ(6) // TODO: To be updated + Kind.TIME_WITHOUT_TZ -> PType.typeTimeWithoutTZ(6) // TODO: To be updated + Kind.TIMESTAMP_WITH_TZ -> PType.typeTimestampWithTZ(6) // TODO: To be updated + Kind.TIMESTAMP_WITHOUT_TZ -> PType.typeTimestampWithoutTZ(6) // TODO: To be updated + Kind.BAG -> PType.typeBag() // TODO: To be updated + Kind.LIST -> PType.typeList() // TODO: To be updated + Kind.ROW -> PType.typeRow(emptyList()) // TODO: To be updated + Kind.SEXP -> PType.typeSexp() // TODO: To be updated + Kind.STRUCT -> PType.typeStruct() // TODO: To be updated + Kind.UNKNOWN -> PType.typeUnknown() // TODO: To be updated + }.toCType() + + @OptIn(PartiQLValueExperimental::class) + private fun nullValue(kind: Kind): PartiQLValue { + return when (kind) { + Kind.DYNAMIC -> nullValue() + Kind.BOOL -> boolValue(null) + Kind.TINYINT -> int8Value(null) + Kind.SMALLINT -> int16Value(null) + Kind.INT -> int32Value(null) + Kind.BIGINT -> int64Value(null) + Kind.INT_ARBITRARY -> intValue(null) + Kind.DECIMAL -> decimalValue(null) + Kind.DECIMAL_ARBITRARY -> decimalValue(null) + Kind.REAL -> float32Value(null) + Kind.DOUBLE_PRECISION -> float64Value(null) + Kind.CHAR -> charValue(null) + Kind.VARCHAR -> TODO("No implementation of VAR CHAR") + Kind.STRING -> stringValue(null) + Kind.SYMBOL -> symbolValue(null) + Kind.BLOB -> blobValue(null) + Kind.CLOB -> clobValue(null) + Kind.DATE -> dateValue(null) + Kind.TIME_WITH_TZ, + Kind.TIME_WITHOUT_TZ -> timeValue(null) + Kind.TIMESTAMP_WITH_TZ, + Kind.TIMESTAMP_WITHOUT_TZ -> timestampValue(null) + Kind.BAG -> bagValue(null) + Kind.LIST -> listValue(null) + Kind.ROW -> structValue(null) + Kind.SEXP -> sexpValue(null) + Kind.STRUCT -> structValue() + Kind.UNKNOWN -> nullValue() + } + } } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 4b77eb5a32..33383eeef3 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -24,10 +24,8 @@ import org.partiql.planner.internal.ir.PlanNode import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.Statement -import org.partiql.planner.internal.ir.identifierSymbol import org.partiql.planner.internal.ir.rel import org.partiql.planner.internal.ir.relOpAggregate -import org.partiql.planner.internal.ir.relOpAggregateCallUnresolved import org.partiql.planner.internal.ir.relOpDistinct import org.partiql.planner.internal.ir.relOpExclude import org.partiql.planner.internal.ir.relOpExcludePath @@ -42,22 +40,15 @@ import org.partiql.planner.internal.ir.relOpSort import org.partiql.planner.internal.ir.relOpUnpivot import org.partiql.planner.internal.ir.relType import org.partiql.planner.internal.ir.rex -import org.partiql.planner.internal.ir.rexOpCallDynamic -import org.partiql.planner.internal.ir.rexOpCallStatic -import org.partiql.planner.internal.ir.rexOpCaseBranch import org.partiql.planner.internal.ir.rexOpCoalesce import org.partiql.planner.internal.ir.rexOpCollection -import org.partiql.planner.internal.ir.rexOpLit import org.partiql.planner.internal.ir.rexOpNullif import org.partiql.planner.internal.ir.rexOpPathIndex import org.partiql.planner.internal.ir.rexOpPathKey -import org.partiql.planner.internal.ir.rexOpPathSymbol import org.partiql.planner.internal.ir.rexOpPivot -import org.partiql.planner.internal.ir.rexOpSelect import org.partiql.planner.internal.ir.rexOpStruct import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpSubquery -import org.partiql.planner.internal.ir.rexOpTupleUnion import org.partiql.planner.internal.ir.statementQuery import org.partiql.planner.internal.ir.util.PlanRewriter import org.partiql.planner.internal.utils.PlanUtils @@ -65,28 +56,13 @@ import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath import org.partiql.spi.fn.FnExperimental -import org.partiql.spi.fn.FnSignature -import org.partiql.types.AnyOfType -import org.partiql.types.AnyType -import org.partiql.types.BagType -import org.partiql.types.BoolType -import org.partiql.types.CollectionType -import org.partiql.types.IntType -import org.partiql.types.ListType -import org.partiql.types.SexpType -import org.partiql.types.StaticType -import org.partiql.types.StaticType.Companion.ANY -import org.partiql.types.StaticType.Companion.BOOL -import org.partiql.types.StaticType.Companion.STRING -import org.partiql.types.StaticType.Companion.unionOf -import org.partiql.types.StringType -import org.partiql.types.StructType -import org.partiql.types.TupleConstraint +import org.partiql.types.Field +import org.partiql.types.PType +import org.partiql.types.PType.Kind import org.partiql.value.BoolValue import org.partiql.value.MissingValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.TextValue -import org.partiql.value.boolValue import org.partiql.value.stringValue import kotlin.math.max @@ -109,9 +85,101 @@ internal class PlanTyper(private val env: Env) { val root = statement.root.type(emptyList(), emptyList(), Scope.GLOBAL) return statementQuery(root) } + internal companion object { + fun PType.static(): CompilerType = CompilerType(this) + + fun anyOf(types: Collection): PType? { + val unique = types.toSet() + return when (unique.size) { + 0 -> null + 1 -> unique.first() + else -> PType.typeDynamic() + } + } + + /** + * This is specifically to collapse literals. + * + * TODO: Can this be merged with [anyOf]? Should we even allow this? + */ + fun anyOfLiterals(types: Collection): PType? { + // Grab unique + var unique: Collection = types.toSet() + if (unique.size == 0) { + return null + } else if (unique.size == 1) { + return unique.first() + } + + // Filter out UNKNOWN + unique = unique.filter { it.kind != Kind.UNKNOWN } + if (unique.size == 0) { + return PType.typeUnknown() + } else if (unique.size == 1) { + return unique.first() + } + + // Collapse Collections + if (unique.all { it.kind == Kind.LIST } || + unique.all { it.kind == Kind.BAG } || + unique.all { it.kind == Kind.SEXP } + ) { + return collapseCollection(unique, unique.first().kind) + } + // Collapse Structs + if (unique.all { it.kind == Kind.ROW }) { + return collapseRows(unique) + } + return PType.typeDynamic() + } - private companion object { - private val FUNCTIONS_HANDLING_MISSING = setOf("is_null", "is_missing", "eq", "and", "or", "not") + private fun collapseCollection(collections: Iterable, type: Kind): PType { + val typeParam = anyOfLiterals(collections.map { it.typeParameter })!! + return when (type) { + Kind.LIST -> PType.typeList(typeParam) + Kind.BAG -> PType.typeList(typeParam) + Kind.SEXP -> PType.typeList(typeParam) + else -> error("This shouldn't have happened.") + } + } + + private fun collapseRows(rows: Iterable): PType { + val firstFields = rows.first().fields!! + val fieldNames = firstFields.map { it.name } + val fieldTypes = firstFields.map { mutableListOf(it.type) } + rows.map { struct -> + val fields = struct.fields!! + if (fields.map { it.name } != fieldNames) { + return PType.typeStruct() + } + fields.forEachIndexed { index, field -> fieldTypes[index].add(field.type) } + } + val newFields = fieldTypes.mapIndexed { i, types -> Field.of(fieldNames[i], anyOfLiterals(types)!!) } + return PType.typeRow(newFields) + } + + fun anyOf(vararg types: PType): PType? { + val unique = types.toSet() + return anyOf(unique) + } + + fun PType.toCType(): CompilerType = CompilerType(this) + + fun List.toCType(): List = this.map { it.toCType() } + + fun CompilerType.isNumeric(): Boolean { + return this.kind in setOf( + Kind.INT, + Kind.INT_ARBITRARY, + Kind.BIGINT, + Kind.TINYINT, + Kind.SMALLINT, + Kind.REAL, + Kind.DOUBLE_PRECISION, + Kind.DECIMAL, + Kind.DECIMAL_ARBITRARY + ) + } } /** @@ -149,8 +217,8 @@ internal class PlanTyper(private val env: Env) { val rex = node.rex.type(emptyList(), outer, Scope.GLOBAL) // compute rel type val valueT = getElementTypeForFromSource(rex.type) - val indexT = StaticType.INT8 - val type = ctx!!.copyWithSchema(listOf(valueT, indexT)) + val indexT = PType.typeBigInt() + val type = ctx!!.copyWithSchema(listOf(valueT, indexT).toCType()) // rewrite val op = relOpScanIndexed(rex) return rel(type, op) @@ -161,25 +229,24 @@ internal class PlanTyper(private val env: Env) { */ override fun visitRelOpUnpivot(node: Rel.Op.Unpivot, ctx: Rel.Type?): Rel { val rex = node.rex.type(emptyList(), outer, Scope.GLOBAL) + val op = relOpUnpivot(rex) + val kType = PType.typeString() - val kType = STRING - val vTypes = rex.type.allTypes.map { type -> - when (type) { - is StructType -> { - if ((type.contentClosed || type.constraints.contains(TupleConstraint.Open(false))) && type.fields.isNotEmpty()) { - unionOf(type.fields.map { it.value }.toSet()).flatten() - } else { - ANY - } - } - else -> type - } + // Check Root (Dynamic) + if (rex.type.kind == Kind.DYNAMIC) { + val type = ctx!!.copyWithSchema(listOf(kType, PType.typeDynamic()).toCType()) + return rel(type, op) + } + + // Check Root + val vType = when (rex.type.kind) { + Kind.ROW -> anyOf(rex.type.fields!!.map { it.type }) ?: PType.typeDynamic() + Kind.STRUCT -> PType.typeDynamic() + else -> rex.type } - val vType = unionOf(vTypes.toSet()).flatten() // rewrite - val type = ctx!!.copyWithSchema(listOf(kType, vType)) - val op = relOpUnpivot(rex) + val type = ctx!!.copyWithSchema(listOf(kType, vType).toCType()) return rel(type, op) } @@ -263,13 +330,13 @@ internal class PlanTyper(private val env: Env) { // Compute Schema val size = max(lhs.type.schema.size, rhs.type.schema.size) val schema = List(size) { - val lhsBinding = lhs.type.schema.getOrNull(it) ?: Rel.Binding("_$it", ANY) - val rhsBinding = rhs.type.schema.getOrNull(it) ?: Rel.Binding("_$it", ANY) + val lhsBinding = lhs.type.schema.getOrNull(it) ?: Rel.Binding("_$it", CompilerType(PType.typeDynamic(), isMissingValue = true)) + val rhsBinding = rhs.type.schema.getOrNull(it) ?: Rel.Binding("_$it", CompilerType(PType.typeDynamic(), isMissingValue = true)) val bindingName = when (lhsBinding.name == rhsBinding.name) { true -> lhsBinding.name false -> "_$it" } - Rel.Binding(bindingName, unionOf(lhsBinding.type, rhsBinding.type)) + Rel.Binding(bindingName, CompilerType(anyOf(lhsBinding.type, rhsBinding.type)!!)) } val type = Rel.Type(schema, props = emptySet()) return Rel(type, node.copy(lhs = lhs, rhs = rhs)) @@ -321,7 +388,7 @@ internal class PlanTyper(private val env: Env) { if (limit.type.isNumeric().not()) { val err = ProblemGenerator.missingRex( causes = listOf(limit.op), - problem = ProblemGenerator.unexpectedType(limit.type, setOf(StaticType.INT)) + problem = ProblemGenerator.unexpectedType(limit.type, setOf(PType.typeIntArbitrary())) ) return rel(input.type, relOpLimit(input, err)) } @@ -341,7 +408,7 @@ internal class PlanTyper(private val env: Env) { if (offset.type.isNumeric().not()) { val err = ProblemGenerator.missingRex( causes = listOf(offset.op), - problem = ProblemGenerator.unexpectedType(offset.type, setOf(StaticType.INT)) + problem = ProblemGenerator.unexpectedType(offset.type, setOf(PType.typeIntArbitrary())) ) return rel(input.type, relOpLimit(input, err)) } @@ -482,7 +549,7 @@ internal class PlanTyper(private val env: Env) { val groups = node.groups.map { typer.visitRex(it, null) } // Compute schema using order (calls...groups...) - val schema = mutableListOf() + val schema = mutableListOf() schema += calls.map { it.second } schema += groups.map { it.type } @@ -502,7 +569,7 @@ internal class PlanTyper(private val env: Env) { * Types a PartiQL expression tree. For now, we ignore the pre-existing type. We assume all existing types * are simply the `any`, so we keep the new type. Ideally we can programmatically calculate the most specific type. * - * We should consider making the StaticType? parameter non-nullable. + * We should consider making the PType? parameter non-nullable. * * @property locals TypeEnv in which this rex tree is evaluated. */ @@ -510,16 +577,16 @@ internal class PlanTyper(private val env: Env) { private inner class RexTyper( private val locals: TypeEnv, private val strategy: Scope, - ) : PlanRewriter() { + ) : PlanRewriter() { - override fun visitRex(node: Rex, ctx: StaticType?): Rex = visitRexOp(node.op, node.type) as Rex + override fun visitRex(node: Rex, ctx: CompilerType?): Rex = visitRexOp(node.op, node.type) as Rex - override fun visitRexOpLit(node: Rex.Op.Lit, ctx: StaticType?): Rex { + override fun visitRexOpLit(node: Rex.Op.Lit, ctx: CompilerType?): Rex { // type comes from RexConverter return rex(ctx!!, node) } - override fun visitRexOpVarLocal(node: Rex.Op.Var.Local, ctx: StaticType?): Rex { + override fun visitRexOpVarLocal(node: Rex.Op.Var.Local, ctx: CompilerType?): Rex { val scope = locals.getScope(node.depth) assert(node.ref < scope.schema.size) { "Invalid resolved variable (var ${node.ref}, stack frame ${node.depth}) in env: $locals" @@ -528,12 +595,12 @@ internal class PlanTyper(private val env: Env) { return rex(type, node) } - override fun visitRexOpMissing(node: Rex.Op.Missing, ctx: StaticType?): PlanNode { - val type = ctx ?: ANY + override fun visitRexOpMissing(node: Rex.Op.Missing, ctx: CompilerType?): PlanNode { + val type = ctx ?: CompilerType(PType.typeDynamic(), isMissingValue = true) return rex(type, node) } - override fun visitRexOpVarUnresolved(node: Rex.Op.Var.Unresolved, ctx: StaticType?): Rex { + override fun visitRexOpVarUnresolved(node: Rex.Op.Var.Unresolved, ctx: CompilerType?): Rex { val path = node.identifier.toBindingPath() val scope = when (node.scope) { Rex.Op.Var.Scope.DEFAULT -> strategy @@ -555,110 +622,102 @@ internal class PlanTyper(private val env: Env) { return visitRex(resolvedVar, null) } - override fun visitRexOpVarGlobal(node: Rex.Op.Var.Global, ctx: StaticType?): Rex = rex(node.ref.type, node) + override fun visitRexOpVarGlobal(node: Rex.Op.Var.Global, ctx: CompilerType?): Rex = rex(node.ref.type, node) - override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: StaticType?): Rex { + /** + * TODO: Create a function signature for the Rex.Op.Path.Index to get automatic coercions. + */ + override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: CompilerType?): Rex { val root = visitRex(node.root, node.root.type) val key = visitRex(node.key, node.key.type) - // Check Index Type - if (!key.type.mayBeType()) { + // Check Key Type (INT or coercible to INT). TODO: Allow coercions to INT + if (key.type.kind !in setOf(Kind.TINYINT, Kind.SMALLINT, Kind.INT, Kind.BIGINT, Kind.INT_ARBITRARY)) { return ProblemGenerator.missingRex( rexOpPathIndex(root, key), ProblemGenerator.expressionAlwaysReturnsMissing("Collections must be indexed with integers, found ${key.type}") ) } - // Get Element Type(s) - val elementTypes = root.type.allTypes.mapNotNull { type -> - when (type) { - is ListType -> type.elementType - is SexpType -> type.elementType - else -> null - } + // Check if Root is DYNAMIC + if (root.type.kind == Kind.DYNAMIC) { + return Rex(CompilerType(PType.typeDynamic()), Rex.Op.Path.Index(root, key)) } - // Check that root is not literal missing - if (root.isLiteralMissing()) { + // Check Root Type (LIST/SEXP) + if (root.type.kind != Kind.LIST && root.type.kind != Kind.SEXP) { return ProblemGenerator.missingRex( rexOpPathIndex(root, key), - ProblemGenerator.expressionAlwaysReturnsMissing() + ProblemGenerator.expressionAlwaysReturnsMissing("Path indexing must occur only on LIST/SEXP.") ) } - // Check that Root was LIST or SEXP by checking accumuated element types - if (elementTypes.isEmpty()) { + // Check that root is not literal missing + if (root.isLiteralMissing()) { return ProblemGenerator.missingRex( rexOpPathIndex(root, key), - ProblemGenerator.expressionAlwaysReturnsMissing("Only lists and s-expressions can be indexed with integers, found ${root.type}") + ProblemGenerator.expressionAlwaysReturnsMissing() ) } - return rex(unionOf(elementTypes), rexOpPathIndex(root, key)) + + return rex(root.type.typeParameter, rexOpPathIndex(root, key)) } private fun Rex.isLiteralMissing(): Boolean = this.op is Rex.Op.Lit && this.op.value is MissingValue - override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Rex { + override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: CompilerType?): Rex { val root = visitRex(node.root, node.root.type) val key = visitRex(node.key, node.key.type) - // Check Key Type - if (!key.type.mayBeType()) { + // Check Key Type (STRING). TODO: Allow coercions to STRING + if (key.type.kind != Kind.STRING) { return ProblemGenerator.missingRex( rexOpPathKey(root, key), ProblemGenerator.expressionAlwaysReturnsMissing("Expected string but found: ${key.type}.") ) } - // Check Root Type - if (!root.type.mayBeType()) { - return ProblemGenerator.missingRex( - rexOpPathKey(root, key), - ProblemGenerator.expressionAlwaysReturnsMissing("Key lookup may only occur on structs, not ${root.type}.") - ) + // Check if Root is DYNAMIC + if (root.type.kind == Kind.DYNAMIC) { + return Rex(CompilerType(PType.typeDynamic()), Rex.Op.Path.Key(root, key)) } - // Check that root is not literal missing - if (root.isLiteralMissing()) { + // Check Root Type (STRUCT) + if (root.type.kind != Kind.STRUCT && root.type.kind != Kind.ROW) { return ProblemGenerator.missingRex( rexOpPathKey(root, key), - ProblemGenerator.expressionAlwaysReturnsMissing() + ProblemGenerator.expressionAlwaysReturnsMissing("Key lookup may only occur on structs, not ${root.type}.") ) } - // Get Element Type - val elementType = root.type.inferListNotNull { type -> - val struct = type as? StructType ?: return@inferListNotNull null - if (key.op is Rex.Op.Lit) { - val lit = key.op.value - if (lit is TextValue<*> && !lit.isNull) { - val id = identifierSymbol(lit.string!!, Identifier.CaseSensitivity.SENSITIVE) - inferStructLookup(struct, id)?.first - } else { - return@inferListNotNull ANY - } - } else { - // cannot infer type of non-literal path step because we don't know its value - // we might improve upon this with some constant folding prior to typing - ANY - } - } - if (elementType.isEmpty()) { - return ProblemGenerator.missingRex( - rexOpPathKey(root, key), - ProblemGenerator.expressionAlwaysReturnsMissing("Key lookup did not result in any element types.") - ) + // Get Literal Key + val keyOp = key.op + val keyLiteral = when (keyOp is Rex.Op.Lit && keyOp.value is TextValue<*> && !keyOp.value.isNull) { + true -> keyOp.value.string!! + false -> return rex(CompilerType(PType.typeDynamic()), rexOpPathKey(root, key)) } - return rex(unionOf(elementType), rexOpPathKey(root, key)) + + // Find Type + val elementType = root.type.getField(keyLiteral, false) ?: return ProblemGenerator.missingRex( + Rex.Op.Path.Key(root, key), + ProblemGenerator.expressionAlwaysReturnsMissing("Path key does not exist.") + ) + + return rex(elementType, rexOpPathKey(root, key)) } - override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: StaticType?): Rex { + override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: CompilerType?): Rex { val root = visitRex(node.root, node.root.type) - // Check Root Type - if (!root.type.mayBeType()) { + // Check if Root is DYNAMIC + if (root.type.kind == Kind.DYNAMIC) { + return Rex(CompilerType(PType.typeDynamic()), Rex.Op.Path.Symbol(root, node.key)) + } + + // Check Root Type (STRUCT) + if (root.type.kind != Kind.STRUCT && root.type.kind != Kind.ROW) { return ProblemGenerator.missingRex( - rexOpPathSymbol(root, node.key), + Rex.Op.Path.Symbol(root, node.key), ProblemGenerator.expressionAlwaysReturnsMissing("Symbol lookup may only occur on structs, not ${root.type}.") ) } @@ -666,66 +725,82 @@ internal class PlanTyper(private val env: Env) { // Check that root is not literal missing if (root.isLiteralMissing()) { return ProblemGenerator.missingRex( - rexOpPathSymbol(root, node.key), + Rex.Op.Path.Symbol(root, node.key), ProblemGenerator.expressionAlwaysReturnsMissing() ) } - // Get Element Types - val paths = root.type.inferRexListNotNull { type -> - val struct = type as? StructType ?: return@inferRexListNotNull null - val (pathType, replacementId) = inferStructLookup( - struct, - identifierSymbol(node.key, Identifier.CaseSensitivity.INSENSITIVE) - ) ?: return@inferRexListNotNull null - when (replacementId.caseSensitivity) { - Identifier.CaseSensitivity.INSENSITIVE -> rex(pathType, rexOpPathSymbol(root, replacementId.symbol)) - Identifier.CaseSensitivity.SENSITIVE -> rex( - pathType, rexOpPathKey(root, rexString(replacementId.symbol)) + // Find Type + val field = root.type.getSymbol(node.key) ?: run { + val inScopeVariables = locals.schema.map { it.name }.toSet() + return ProblemGenerator.missingRex( + Rex.Op.Path.Symbol(root, node.key), + ProblemGenerator.undefinedVariable( + org.partiql.plan.Identifier.Symbol(node.key, org.partiql.plan.Identifier.CaseSensitivity.INSENSITIVE), + inScopeVariables ) - } + ) } - // Determine output type - val type = when (paths.size) { - // Escape early since no inference could be made - 0 -> { - val key = org.partiql.plan.Identifier.Symbol(node.key, org.partiql.plan.Identifier.CaseSensitivity.SENSITIVE) - val inScopeVariables = locals.schema.map { it.name }.toSet() - return ProblemGenerator.missingRex( - rexOpPathSymbol(root, node.key), - ProblemGenerator.undefinedVariable(key, inScopeVariables) - ) - } - else -> unionOf(paths.map { it.type }.toSet()) + return when (field.first.caseSensitivity) { + Identifier.CaseSensitivity.INSENSITIVE -> Rex(field.second, Rex.Op.Path.Symbol(root, node.key)) + Identifier.CaseSensitivity.SENSITIVE -> Rex(field.second, Rex.Op.Path.Key(root, rexString(field.first.symbol))) } + } - // replace step only if all are disambiguated - val allElementsInferred = paths.size == root.type.allTypes.size - val firstPathOp = paths.first().op - val replacementOp = when (allElementsInferred && paths.map { it.op }.all { it == firstPathOp }) { - true -> firstPathOp - false -> rexOpPathSymbol(root, node.key) + /** + * Assumes that the type is either a struct of row. + * @return null when the field definitely does not exist; dynamic when the type cannot be determined + */ + private fun CompilerType.getField(field: String, ignoreCase: Boolean): CompilerType? { + if (this.kind == Kind.STRUCT) { + return CompilerType(PType.typeDynamic()) + } + val fields = this.fields!!.filter { it.name.equals(field, ignoreCase) }.map { it.type }.toSet() + return when (fields.size) { + 0 -> return null + 1 -> fields.first() + else -> CompilerType(PType.typeDynamic()) } - return rex(type, replacementOp) } - private fun rexString(str: String) = rex(STRING, rexOpLit(stringValue(str))) + private fun rexString(str: String) = rex(CompilerType(PType.typeString()), Rex.Op.Lit(stringValue(str))) - override fun visitRexOpCastUnresolved(node: Rex.Op.Cast.Unresolved, ctx: StaticType?): Rex { + /** + * Assumes that the type is either a struct or row. + * @return null when the field definitely does not exist; dynamic when the type cannot be determined + */ + private fun CompilerType.getSymbol(field: String): Pair? { + if (this.kind == Kind.STRUCT) { + return Identifier.Symbol(field, Identifier.CaseSensitivity.INSENSITIVE) to CompilerType(PType.typeDynamic()) + } + val fields = this.fields!!.mapNotNull { + when (it.name.equals(field, true)) { + true -> it.name to it.type + false -> null + } + }.ifEmpty { return null } + val type = anyOf(fields.map { it.second }) ?: PType.typeDynamic() + val ids = fields.map { it.first }.toSet() + return when (ids.size > 1) { + true -> Identifier.Symbol(field, Identifier.CaseSensitivity.INSENSITIVE) to type.toCType() + false -> Identifier.Symbol(ids.first(), Identifier.CaseSensitivity.SENSITIVE) to type.toCType() + } + } + + override fun visitRexOpCastUnresolved(node: Rex.Op.Cast.Unresolved, ctx: CompilerType?): Rex { val arg = visitRex(node.arg, null) val cast = env.resolveCast(arg, node.target) ?: return ProblemGenerator.errorRex( node.copy(node.target, arg), - ProblemGenerator.undefinedFunction("CAST( AS ${node.target})", listOf(arg.type)) + ProblemGenerator.undefinedFunction(listOf(arg.type), "CAST( AS ${node.target})") ) return visitRexOpCastResolved(cast, null) } - override fun visitRexOpCastResolved(node: Rex.Op.Cast.Resolved, ctx: StaticType?): Rex { - val type = node.cast.target.toStaticType() - return rex(type, node) + override fun visitRexOpCastResolved(node: Rex.Op.Cast.Resolved, ctx: CompilerType?): Rex { + return rex(node.cast.target, node) } - override fun visitRexOpCallUnresolved(node: Rex.Op.Call.Unresolved, ctx: StaticType?): Rex { + override fun visitRexOpCallUnresolved(node: Rex.Op.Call.Unresolved, ctx: CompilerType?): Rex { // Type the arguments val args = node.args.map { visitRex(it, null) } // Attempt to resolve in the environment @@ -734,7 +809,7 @@ internal class PlanTyper(private val env: Env) { if (rex == null) { return ProblemGenerator.errorRex( causes = args.map { it.op }, - problem = ProblemGenerator.undefinedFunction(node.identifier, args.map { it.type }), + problem = ProblemGenerator.undefinedFunction(args.map { it.type }, node.identifier), ) } // Pass off to Rex.Op.Call.Static or Rex.Op.Call.Dynamic for typing. @@ -749,7 +824,7 @@ internal class PlanTyper(private val env: Env) { * @return */ @OptIn(FnExperimental::class) - override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: StaticType?): Rex { + override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: CompilerType?): Rex { // Apply the coercions as explicit casts val args: List = node.args.map { // Type the coercions @@ -758,11 +833,19 @@ internal class PlanTyper(private val env: Env) { else -> it } } - val type = inferFnType(node.fn.signature, args) ?: return ProblemGenerator.missingRex( - rexOpCallStatic(node.fn, args), - ProblemGenerator.expressionAlwaysReturnsMissing("Static function always receives MISSING arguments.") - ) - return rex(type, node) + + // Check if any arg is always missing + val argIsAlwaysMissing = args.any { it.type.isMissingValue } + if (node.fn.signature.isMissingCall && argIsAlwaysMissing) { + return ProblemGenerator.missingRex( + node, + ProblemGenerator.expressionAlwaysReturnsMissing("Static function always receives MISSING arguments."), + CompilerType(node.fn.signature.returns, isMissingValue = true) + ) + } + + // Infer fn return type + return rex(CompilerType(node.fn.signature.returns), Rex.Op.Call.Static(node.fn, args)) } /** @@ -773,22 +856,13 @@ internal class PlanTyper(private val env: Env) { * @return */ @OptIn(FnExperimental::class) - override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Rex { - var isMissingCall = false - val types = node.candidates.mapNotNull { candidate -> - isMissingCall = isMissingCall || candidate.fn.signature.isMissingCall - inferFnType(candidate.fn.signature, node.args) - }.toMutableSet() - if (types.isEmpty()) { - return ProblemGenerator.missingRex( - rexOpCallDynamic(node.args, node.candidates), - ProblemGenerator.expressionAlwaysReturnsMissing("Function argument is always the missing value.") - ) - } - return rex(type = unionOf(types).flatten(), op = node) + override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: CompilerType?): Rex { + val types = node.candidates.map { candidate -> candidate.fn.signature.returns }.toMutableSet() + // TODO: Should this always be DYNAMIC? + return Rex(type = CompilerType(anyOf(types) ?: PType.typeDynamic()), op = node) } - override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex { + override fun visitRexOpCase(node: Rex.Op.Case, ctx: CompilerType?): Rex { // Rewrite CASE-WHEN branches val oldBranches = node.branches.toTypedArray() val newBranches = mutableListOf() @@ -800,11 +874,10 @@ internal class PlanTyper(private val env: Env) { branch = visitRexOpCaseBranch(branch, branch.rex.type) // Emit typing error if a branch condition is never a boolean (prune) - if (!branch.condition.type.mayBeType()) { - return ProblemGenerator.missingRex( - node, - ProblemGenerator.incompatibleTypesForOp(branch.condition.type.allTypes, "CASE_WHEN"), - ) + if (!canBeBoolean(branch.condition.type)) { + // prune, always false + // TODO: Error probably + continue } // Accumulate typing information, but skip if literal NULL or MISSING @@ -826,18 +899,15 @@ internal class PlanTyper(private val env: Env) { assert(msize == bsize) { "Coercion mappings `len $msize` did not match the number of CASE-WHEN branches `len $bsize`" } // Rewrite branches for (i in newBranches.indices) { - val (operand, target) = mapping[i] - if (operand == target) continue // skip - val branch = newBranches[i] - val cast = env.resolveCast(branch.rex, target)!! - val rex = rex(type, cast) - newBranches[i] = branch.copy(rex = rex) + when (val function = mapping[i]) { + null -> continue + else -> newBranches[i] = newBranches[i].copy(rex = replaceCaseBranch(newBranches[i].rex, type, function)) + } } // Rewrite default - val (operand, target) = mapping.last() - if (operand != target) { - val cast = env.resolveCast(newDefault, target)!! - newDefault = rex(type, cast) + val function = mapping.last() + if (function != null) { + newDefault = replaceCaseBranch(newDefault, type, function) } } @@ -851,6 +921,18 @@ internal class PlanTyper(private val env: Env) { return rex(type, op) } + private fun replaceCaseBranch(originalRex: Rex, outputType: CompilerType, function: DynamicTyper.Mapping): Rex { + return when (function) { + is DynamicTyper.Mapping.Coercion -> { + val cast = env.resolveCast(originalRex, function.target)!! + Rex(outputType, cast) + } + is DynamicTyper.Mapping.Replacement -> { + function.replacement + } + } + } + // COALESCE(v1, v2,..., vN) // == // CASE @@ -860,7 +942,7 @@ internal class PlanTyper(private val env: Env) { // ELSE vN // END // --> minimal common supertype of(, , ..., ) - override fun visitRexOpCoalesce(node: Rex.Op.Coalesce, ctx: StaticType?): Rex { + override fun visitRexOpCoalesce(node: Rex.Op.Coalesce, ctx: CompilerType?): Rex { val args = node.args.map { visitRex(it, it.type) }.toMutableList() val typer = DynamicTyper() args.forEach { v -> typer.accumulate(v) } @@ -868,12 +950,9 @@ internal class PlanTyper(private val env: Env) { if (mapping != null) { assert(mapping.size == args.size) { "Coercion mappings `len ${mapping.size}` did not match the number of COALESCE arguments `len ${args.size}`" } for (i in args.indices) { - val (operand, target) = mapping[i] - if (operand == target) continue // skip; no coercion needed - val cast = env.resolveCast(args[i], target) - if (cast != null) { - val rex = rex(type, cast) - args[i] = rex + when (val function = mapping[i]) { + null -> continue + else -> args[i] = replaceCaseBranch(args[i], type, function) } } } @@ -888,7 +967,7 @@ internal class PlanTyper(private val env: Env) { // ELSE v1 // END // --> minimal common supertype of (NULL, ) - override fun visitRexOpNullif(node: Rex.Op.Nullif, ctx: StaticType?): Rex { + override fun visitRexOpNullif(node: Rex.Op.Nullif, ctx: CompilerType?): Rex { val value = visitRex(node.value, node.value.type) val nullifier = visitRex(node.nullifier, node.nullifier.type) val typer = DynamicTyper() @@ -900,6 +979,14 @@ internal class PlanTyper(private val env: Env) { return rex(type, op) } + /** + * In this context, Boolean means PartiQLValueType Bool, which can be nullable. + * Hence, we permit Static Type BOOL, Static Type NULL, Static Type Missing here. + */ + private fun canBeBoolean(type: CompilerType): Boolean { + return type.kind == Kind.DYNAMIC || type.kind == Kind.BOOL + } + /** * Returns the boolean value of the expression. For now, only handle literals. */ @@ -920,149 +1007,70 @@ internal class PlanTyper(private val env: Env) { * then when we see the top-level `a IS STRUCT`, then we can assume that the `a` on the RHS is definitely a * struct. We handle this by using [foldCaseBranch]. */ - override fun visitRexOpCaseBranch(node: Rex.Op.Case.Branch, ctx: StaticType?): Rex.Op.Case.Branch { + override fun visitRexOpCaseBranch(node: Rex.Op.Case.Branch, ctx: CompilerType?): Rex.Op.Case.Branch { val visitedCondition = visitRex(node.condition, node.condition.type) val visitedReturn = visitRex(node.rex, node.rex.type) - return foldCaseBranch(visitedCondition, visitedReturn) - } - - /** - * This takes in a branch condition and its result expression. - * - * 1. If the condition is a type check T (ie ` IS T`), then this function will be typed as T. - * 2. If a branch condition is known to be false, it will be removed. - * - * TODO: Currently, this only folds type checking for STRUCTs. We need to add support for all other types. - * - * TODO: I added a check for [Rex.Op.Var.Outer] as it seemed odd to replace a general expression like: - * `WHEN { 'a': { 'b': 1} }.a IS STRUCT THEN { 'a': { 'b': 1} }.a.b`. We can discuss this later, but I'm - * currently limiting the scope of this intentionally. - */ - @OptIn(FnExperimental::class) - private fun foldCaseBranch(condition: Rex, result: Rex): Rex.Op.Case.Branch { - return when (val call = condition.op) { - is Rex.Op.Call.Dynamic -> { - val rex = call.candidates.map { candidate -> - val fn = candidate.fn - if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { - return rexOpCaseBranch(condition, result) - } - val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") - // Replace the result's type - val type = unionOf(ref.type.allTypes.filterIsInstance().toSet()) - val replacementVal = ref.copy(type = type) - when (ref.op is Rex.Op.Var.Local) { - true -> RexReplacer.replace(result, ref, replacementVal) - false -> result - } - } - val type = rex.toUnionType().flatten() - return rexOpCaseBranch(condition, result.copy(type)) - } - is Rex.Op.Call.Static -> { - val fn = call.fn - if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { - return rexOpCaseBranch(condition, result) - } - val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") - val simplifiedCondition = when { - ref.type.allTypes.all { it is StructType } -> rex(BOOL, rexOpLit(boolValue(true))) - ref.type.allTypes.none { it is StructType } -> rex(BOOL, rexOpLit(boolValue(false))) - else -> condition - } - - // Replace the result's type - val type = unionOf(ref.type.allTypes.filterIsInstance().toSet()).flatten() - val replacementVal = ref.copy(type = type) - val rex = when (ref.op is Rex.Op.Var.Local) { - true -> RexReplacer.replace(result, ref, replacementVal) - false -> result - } - return rexOpCaseBranch(simplifiedCondition, rex) - } - else -> rexOpCaseBranch(condition, result) - } + return Rex.Op.Case.Branch(visitedCondition, visitedReturn) } - override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Rex { - // Check Type - if (ctx!! !is CollectionType) { + override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: CompilerType?): Rex { + if (ctx!!.kind !in setOf(Kind.LIST, Kind.SEXP, Kind.BAG)) { return ProblemGenerator.missingRex( node, - ProblemGenerator.unexpectedType(ctx, setOf(StaticType.LIST, StaticType.BAG, StaticType.SEXP)) + ProblemGenerator.unexpectedType(ctx, setOf(PType.typeList(), PType.typeBag(), PType.typeSexp())) ) } val values = node.values.map { visitRex(it, it.type) } val t = when (values.size) { - 0 -> ANY - else -> values.toUnionType() + 0 -> PType.typeDynamic() + else -> anyOfLiterals(values.map { it.type })!! } - val type = when (ctx as CollectionType) { - is BagType -> BagType(t) - is ListType -> ListType(t) - is SexpType -> SexpType(t) + val type = when (ctx.kind) { + Kind.BAG -> PType.typeBag(t) + Kind.LIST -> PType.typeList(t) + Kind.SEXP -> PType.typeSexp(t) + else -> error("This is impossible.") } - return rex(type, rexOpCollection(values)) + return rex(CompilerType(type), rexOpCollection(values)) } @OptIn(PartiQLValueExperimental::class) - override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Rex { - val fields = node.fields.mapNotNull { + override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: CompilerType?): Rex { + val fields = node.fields.map { val k = visitRex(it.k, it.k.type) val v = visitRex(it.v, it.v.type) rexOpStructField(k, v) } var structIsClosed = true - val structTypeFields = mutableListOf() - val structKeysSeent = mutableSetOf() + val structTypeFields = mutableListOf() for (field in fields) { - when (field.k.op) { - is Rex.Op.Lit -> { - // A field is only included in the StructType if its key is a text literal - val key = field.k.op - if (key.value is TextValue<*>) { - val name = key.value.string!! - val type = field.v.type - structKeysSeent.add(name) - structTypeFields.add(StructType.Field(name, type)) - } - } - else -> { - if (field.k.type.allTypes.any { it.isText() }) { - // If the non-literal could be text, StructType will have open content. - structIsClosed = false - } else { - // A field with a non-literal key name is not included in the StructType. - } - } + val keyOp = field.k.op + // TODO: Check key type + if (keyOp !is Rex.Op.Lit || keyOp.value !is TextValue<*>) { + structIsClosed = false + continue } + structTypeFields.add(CompilerType.Field(keyOp.value.string!!, field.v.type)) + } + val type = when (structIsClosed) { + true -> CompilerType(PType.typeRow(structTypeFields as Collection)) + false -> CompilerType(PType.typeStruct()) } - val type = StructType( - fields = structTypeFields, - contentClosed = structIsClosed, - constraints = setOf( - TupleConstraint.Open(!structIsClosed), - TupleConstraint.UniqueAttrs(structKeysSeent.size == fields.size) - ), - ) return rex(type, rexOpStruct(fields)) } - override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: StaticType?): Rex { + override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: CompilerType?): Rex { val stack = locals.outer + listOf(locals) val rel = node.rel.type(stack) val typeEnv = TypeEnv(rel.type.schema, stack) val typer = RexTyper(typeEnv, Scope.LOCAL) val key = typer.visitRex(node.key, null) val value = typer.visitRex(node.value, null) - val type = StructType( - contentClosed = false, constraints = setOf(TupleConstraint.Open(true)) - ) val op = rexOpPivot(key, value, rel) - return rex(type, op) + return rex(CompilerType(PType.typeStruct()), op) } - override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: StaticType?): Rex { + override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: CompilerType?): Rex { val rel = node.rel.type(locals.outer + listOf(locals)) val newTypeEnv = TypeEnv(schema = rel.type.schema, outer = locals.outer + listOf(locals)) val constructor = node.constructor.type(newTypeEnv) @@ -1076,8 +1084,8 @@ internal class PlanTyper(private val env: Env) { /** * Calculate output type of a row-value subquery. */ - private fun visitRexOpSubqueryRow(subquery: Rex.Op.Subquery, cons: StaticType): Rex { - if (cons !is StructType) { + private fun visitRexOpSubqueryRow(subquery: Rex.Op.Subquery, cons: CompilerType): Rex { + if (cons.kind != Kind.ROW) { error("Subquery with non-SQL SELECT cannot be coerced to a row-value expression. Found constructor type: $cons") } // Do a simple cardinality check for the moment. @@ -1088,84 +1096,120 @@ internal class PlanTyper(private val env: Env) { // return rexErr("Cannot coercion subquery with $m attributes to a row-value-expression with $n attributes") // } // If we made it this far, then we can coerce this subquery to the desired complex value - val type = StaticType.LIST - val op = subquery - return rex(type, op) + val type = CompilerType(PType.typeList()) + return rex(type, subquery) } /** * Calculate output type of a scalar subquery. */ - private fun visitRexOpSubqueryScalar(subquery: Rex.Op.Subquery, cons: StaticType): Rex { - if (cons !is StructType) { + private fun visitRexOpSubqueryScalar(subquery: Rex.Op.Subquery, cons: CompilerType): Rex { + if (cons.kind != Kind.ROW) { error("Subquery with non-SQL SELECT cannot be coerced to a scalar. Found constructor type: $cons") } - val n = cons.fields.size + val n = cons.fields!!.size if (n != 1) { error("SELECT constructor with $n attributes cannot be coerced to a scalar. Found constructor type: $cons") } // If we made it this far, then we can coerce this subquery to a scalar - val type = cons.fields.first().value - val op = subquery - return rex(type, op) + val type = cons.fields!!.first().type + return Rex(type, subquery) } - override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType?): Rex { + // TODO: Should we support the ROW type? + override fun visitRexOpSelect(node: Rex.Op.Select, ctx: CompilerType?): Rex { val rel = node.rel.type(locals.outer + listOf(locals)) val newTypeEnv = TypeEnv(schema = rel.type.schema, outer = locals.outer + listOf(locals)) - var constructor = node.constructor.type(newTypeEnv) - var constructorType = constructor.type - // add the ordered property to the constructor - if (constructorType is StructType) { - // TODO: We shouldn't need to copy the ordered constraint. - constructorType = constructorType.copy( - constraints = constructorType.constraints + setOf(TupleConstraint.Ordered) - ) - constructor = rex(constructorType, constructor.op) - } + val constructor = node.constructor.type(newTypeEnv) val type = when (rel.isOrdered()) { - true -> ListType(constructor.type) - else -> BagType(constructor.type) + true -> PType.typeList(constructor.type) + false -> PType.typeBag(constructor.type) } - return rex(type, rexOpSelect(constructor, rel)) + return Rex(CompilerType(type), Rex.Op.Select(constructor, rel)) } - override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType?): Rex { + override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: CompilerType?): Rex { val args = node.args.map { visitRex(it, ctx) } + val result = Rex.Op.TupleUnion(args) + + // Replace Generated Tuple Union if Schema Present + // This should occur before typing, however, we don't type on the AST or have an appropriate IR + replaceGeneratedTupleUnion(result)?.let { return it } + + // Calculate Type val type = when (args.size) { - 0 -> { - // empty struct - StructType( - fields = emptyMap(), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered, - ) - ) - } + 0 -> CompilerType(PType.typeRow(emptyList())) else -> { val argTypes = args.map { it.type } - val anyArgIsNotStruct = argTypes.any { argType -> !argType.mayBeType() } - if (anyArgIsNotStruct) { - return ProblemGenerator.missingRex( - rexOpTupleUnion(args), - ProblemGenerator.expressionAlwaysReturnsMissing("TUPLEUNION always receives a non-struct argumnent.") - ) - } - val potentialTypes = buildArgumentPermutations(argTypes).mapNotNull { argumentList -> - calculateTupleUnionOutputType(argumentList) + calculateTupleUnionOutputType(argTypes) ?: return ProblemGenerator.missingRex( + args.map { it.op }, + ProblemGenerator.undefinedFunction(args.map { it.type }, "TUPLEUNION"), + PType.typeStruct().toCType() + ) + } + } + return Rex(type, result) + } + + /** + * This is a hack to replace the generated "TUPLEUNION" that is the result of a SELECT *. In my + * opinion, this should actually occur prior to PlanTyper. That being said, we currently don't type on the AST, + * and we don't have an appropriate IR to type on. + * + * @return null if the [node] is NOT a generated tuple union; return the replacement if the [node] is a tuple union + * and there is sufficient schema to replace the tuple union + */ + private fun replaceGeneratedTupleUnion(node: Rex.Op.TupleUnion): Rex? { + val args = node.args.map { replaceGeneratedTupleUnionArg(it) } + if (args.any { it == null }) { + return null + } + // Infer Type + val type = PType.typeRow(args.flatMap { it!!.type.fields }) + val fields = args.flatMap { arg -> + val op = arg!!.op + when (op is Rex.Op.Struct) { + true -> op.fields + false -> { + arg.type.fields.map { + Rex.Op.Struct.Field( + rexString(it.name), + Rex(it.type, Rex.Op.Path.Key(arg, rexString(it.name))) + ) + } } - unionOf(potentialTypes.toSet()).flatten() } } - val op = rexOpTupleUnion(args) - return rex(type, op) + // Create struct + return Rex(type.toCType(), Rex.Op.Struct(fields)) + } + + @OptIn(FnExperimental::class) + private fun replaceGeneratedTupleUnionArg(node: Rex): Rex? { + if (node.op is Rex.Op.Struct && node.type.kind == Kind.ROW) { + return node + } + val case = node.op as? Rex.Op.Case ?: return null + if (case.branches.size != 1) { + return null + } + val firstBranch = case.branches.first() + val firstBranchCondition = case.branches.first().condition.op + if (firstBranchCondition !is Rex.Op.Call.Static) { + return null + } + if (!firstBranchCondition.fn.signature.name.equals("is_struct", ignoreCase = true)) { + return null + } + val firstBranchResultType = firstBranch.rex.type + if (firstBranchResultType.kind != Kind.ROW) { + return null + } + return Rex(firstBranchResultType, firstBranch.rex.op) } - override fun visitRexOpErr(node: Rex.Op.Err, ctx: StaticType?): PlanNode { - val type = ctx ?: ANY + override fun visitRexOpErr(node: Rex.Op.Err, ctx: CompilerType?): PlanNode { + val type = ctx ?: CompilerType(PType.typeDynamic()) return rex(type, node) } @@ -1190,168 +1234,32 @@ internal class PlanTyper(private val env: Env) { * If all arguments contain unique attributes AND all arguments are closed AND no fields clash, the output has * unique attributes. */ - private fun calculateTupleUnionOutputType(args: List): StaticType? { - val structFields = mutableListOf() - var structAmount = 0 - var structIsClosed = true - var structIsOrdered = true - var uniqueAttrs = true + private fun calculateTupleUnionOutputType(args: List): CompilerType? { + val fields = mutableListOf() + var structIsOpen = false + var containsDynamic = false + var containsNonStruct = false args.forEach { arg -> - when (arg) { - is StructType -> { - structAmount += 1 - structFields.addAll(arg.fields) - structIsClosed = structIsClosed && arg.constraints.contains(TupleConstraint.Open(false)) - structIsOrdered = structIsOrdered && arg.constraints.contains(TupleConstraint.Ordered) - uniqueAttrs = uniqueAttrs && arg.constraints.contains(TupleConstraint.UniqueAttrs(true)) - } - is AnyOfType -> { - error("TupleUnion wasn't normalized to exclude union types.") - } - else -> { - return null - } + if (arg.kind == Kind.UNKNOWN) { + return@forEach } - } - uniqueAttrs = when { - structIsClosed.not() && structAmount > 1 -> false - else -> uniqueAttrs - } - uniqueAttrs = uniqueAttrs && (structFields.size == structFields.distinctBy { it.key }.size) - val orderedConstraint = when (structIsOrdered) { - true -> TupleConstraint.Ordered - false -> null - } - val constraints = setOfNotNull( - TupleConstraint.Open(!structIsClosed), TupleConstraint.UniqueAttrs(uniqueAttrs), orderedConstraint - ) - return StructType( - fields = structFields.map { it }, contentClosed = structIsClosed, constraints = constraints - ) - } - - /** - * We are essentially making permutations of arguments that maintain the same initial ordering. For example, - * consider the following args: - * ``` - * [ 0 = UNION(INT, STRING), 1 = (DECIMAL, TIMESTAMP) ] - * ``` - * This function will return: - * ``` - * [ - * [ 0 = INT, 1 = DECIMAL ], - * [ 0 = INT, 1 = TIMESTAMP ], - * [ 0 = STRING, 1 = DECIMAL ], - * [ 0 = STRING, 1 = TIMESTAMP ] - * ] - * ``` - * - * Essentially, this becomes useful specifically in the case of TUPLEUNION, since we can make sure that - * the ordering of argument's attributes remains the same. For example: - * ``` - * TUPLEUNION( UNION(STRUCT(a, b), STRUCT(c)), UNION(STRUCT(d, e), STRUCT(f)) ) - * ``` - * - * Then, the output of the tupleunion will have the output types of all of the below: - * ``` - * TUPLEUNION(STRUCT(a,b), STRUCT(d,e)) --> STRUCT(a, b, d, e) - * TUPLEUNION(STRUCT(a,b), STRUCT(f)) --> STRUCT(a, b, f) - * TUPLEUNION(STRUCT(c), STRUCT(d,e)) --> STRUCT(c, d, e) - * TUPLEUNION(STRUCT(c), STRUCT(f)) --> STRUCT(c, f) - * ``` - */ - private fun buildArgumentPermutations(args: List): Sequence> { - val flattenedArgs = args.map { it.flatten().allTypes } - return buildArgumentPermutations(flattenedArgs, accumulator = emptyList()) - } - - private fun buildArgumentPermutations( - args: List>, - accumulator: List, - ): Sequence> { - if (args.isEmpty()) { - return sequenceOf(accumulator) - } - val first = args.first() - val rest = when (args.size) { - 1 -> emptyList() - else -> args.subList(1, args.size) - } - return sequence { - first.forEach { argSubType -> - yieldAll(buildArgumentPermutations(rest, accumulator + listOf(argSubType))) + when (arg.kind) { + Kind.ROW -> fields.addAll(arg.fields!!) + Kind.STRUCT -> structIsOpen = true + Kind.DYNAMIC -> containsDynamic = true + Kind.UNKNOWN -> structIsOpen = true + else -> containsNonStruct = true } } - } - - // Helpers - - /** - * Logic is as follows: - * 1. If [struct] is closed and ordered: - * - If no item is found, return null - * - Else, grab first matching item and make sensitive. - * 2. If [struct] is closed - * - AND no item is found, return null - * - AND only one item is present -> grab item and make sensitive. - * - AND more than one item is present, keep sensitivity and grab item. - * 3. If [struct] is open, return [AnyType] - * - * @return a [Pair] where the [Pair.first] represents the type of the [step] and the [Pair.second] represents - * the disambiguated [key]. - */ - private fun inferStructLookup(struct: StructType, key: Identifier.Symbol): Pair? { - val binding = key.toBindingName() - val isClosed = struct.constraints.contains(TupleConstraint.Open(false)) - val isOrdered = struct.constraints.contains(TupleConstraint.Ordered) - val (name, type) = when { - // 1. Struct is closed and ordered - isClosed && isOrdered -> { - struct.fields.firstOrNull { entry -> binding.matches(entry.key) }?.let { - (sensitive(it.key) to it.value) - } ?: return null - } - // 2. Struct is closed - isClosed -> { - val matches = struct.fields.filter { entry -> binding.matches(entry.key) } - when (matches.size) { - 0 -> { - return null - } - 1 -> matches.first().let { (sensitive(it.key) to it.value) } - else -> { - val firstKey = matches.first().key - val sharedKey = when (matches.all { it.key == firstKey }) { - true -> sensitive(firstKey) - false -> key - } - sharedKey to unionOf(matches.map { it.value }.toSet()).flatten() - } - } - } - // 3. Struct is open - else -> key to ANY + return when { + containsNonStruct -> null + containsDynamic -> CompilerType(PType.typeDynamic()) + structIsOpen -> CompilerType(PType.typeStruct()) + else -> CompilerType(PType.typeRow(fields as Collection)) } - return type to name } - private fun sensitive(str: String): Identifier.Symbol = - identifierSymbol(str, Identifier.CaseSensitivity.SENSITIVE) - - /** - * Returns NULL when the function is a missing call and always has an argument that is the missing value - */ - @OptIn(FnExperimental::class) - private fun inferFnType(fn: FnSignature, args: List): StaticType? { - val argAlwaysMissing = args.any { - val op = it.op as? Rex.Op.Lit ?: return@any false - op.value is MissingValue - } - if (fn.isMissingCall && argAlwaysMissing) { - return null - } - return fn.returns.toStaticType() - } + // Helpers /** * Resolution and typing of aggregation function calls. @@ -1368,15 +1276,14 @@ internal class PlanTyper(private val env: Env) { * to each row of T and eliminating null values <--- all NULL values are eliminated as inputs */ @OptIn(FnExperimental::class) - fun resolveAgg(node: Rel.Op.Aggregate.Call.Unresolved): Pair { + fun resolveAgg(node: Rel.Op.Aggregate.Call.Unresolved): Pair { // Type the arguments val args = node.args.map { visitRex(it, null) } - val argsResolved = relOpAggregateCallUnresolved(node.name, node.setQuantifier, args) + val argsResolved = Rel.Op.Aggregate.Call.Unresolved(node.name, node.setQuantifier, args) // Resolve the function - val call = env.resolveAgg(node.name, node.setQuantifier, args) ?: return argsResolved to ANY - val returns = call.agg.signature.returns - return call to returns.toStaticType() + val call = env.resolveAgg(node.name, node.setQuantifier, args) ?: return argsResolved to CompilerType(PType.typeDynamic()) + return call to CompilerType(call.agg.signature.returns) } } @@ -1410,7 +1317,7 @@ internal class PlanTyper(private val env: Env) { * We may be able to eliminate this issue by keeping everything internal and running the typing pass first. * This is simple enough for now. */ - private fun Rel.Type.copyWithSchema(types: List): Rel.Type { + private fun Rel.Type.copyWithSchema(types: List): Rel.Type { assert(types.size == schema.size) { "Illegal copy, types size does not matching bindings list size" } return this.copy(schema = schema.mapIndexed { i, binding -> binding.copy(type = types[i]) }) } @@ -1436,37 +1343,23 @@ internal class PlanTyper(private val env: Env) { /** * Produce a union type from all the */ - private fun List.toUnionType(): StaticType = unionOf(map { it.type }.toSet()).flatten() - - private fun getElementTypeForFromSource(fromSourceType: StaticType): StaticType = when (fromSourceType) { - is BagType -> fromSourceType.elementType - is ListType -> fromSourceType.elementType - is AnyType -> ANY - is AnyOfType -> unionOf(fromSourceType.types.map { getElementTypeForFromSource(it) }.toSet()) - // All the other types coerce into a bag of themselves (including null/missing/sexp). - else -> fromSourceType - } - - // HELPERS + private fun List.toUnionType(): PType = anyOf(map { it.type }.toSet()) ?: PType.typeDynamic() - private fun Identifier.debug(): String = when (this) { - is Identifier.Qualified -> (listOf(root.debug()) + steps.map { it.debug() }).joinToString(".") - is Identifier.Symbol -> when (caseSensitivity) { - Identifier.CaseSensitivity.SENSITIVE -> "\"$symbol\"" - Identifier.CaseSensitivity.INSENSITIVE -> symbol - } + private fun getElementTypeForFromSource(fromSourceType: CompilerType): CompilerType = when (fromSourceType.kind) { + Kind.DYNAMIC -> CompilerType(PType.typeDynamic()) + Kind.BAG, Kind.LIST, Kind.SEXP -> fromSourceType.typeParameter + // TODO: Should we emit a warning? + else -> fromSourceType } private fun excludeBindings(input: List, item: Rel.Op.Exclude.Path): List { - var matchedRoot = false val output = input.map { when (val root = item.root) { is Rex.Op.Var.Unresolved -> { when (val id = root.identifier) { is Identifier.Symbol -> { if (id.isEquivalentTo(it.name)) { - matchedRoot = true - // recompute the StaticType of this binding after applying the exclusions + // recompute the PType of this binding after applying the exclusions val type = it.type.exclude(item.steps, lastStepOptional = false) it.copy(type = type) } else { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt index c39a770306..17a7acb3c3 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt @@ -10,11 +10,9 @@ import org.partiql.planner.internal.ir.rexOpVarLocal import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath -import org.partiql.types.AnyOfType -import org.partiql.types.AnyType +import org.partiql.types.PType +import org.partiql.types.PType.Kind import org.partiql.types.StaticType -import org.partiql.types.StructType -import org.partiql.types.TupleConstraint import org.partiql.value.PartiQLValueExperimental import org.partiql.value.stringValue @@ -132,27 +130,6 @@ internal data class TypeEnv( return c } - /** - * Searches for the [BindingName] within the given [StructType]. - * - * Returns - * - true iff known to contain key - * - false iff known to NOT contain key - * - null iff NOT known to contain key - * - * @param name - * @return - */ - private fun StructType.containsKey(name: BindingName): Boolean? { - for (f in fields) { - if (name.matches(f.key)) { - return true - } - } - val closed = constraints.contains(TupleConstraint.Open(false)) - return if (closed) false else null - } - /** * Searches for the [BindingName] within the given [StaticType]. * @@ -164,20 +141,11 @@ internal data class TypeEnv( * @param name * @return */ - private fun StaticType.containsKey(name: BindingName): Boolean? { - return when (val type = this.flatten()) { - is StructType -> type.containsKey(name) - is AnyOfType -> { - val anyKnownToContainKey = type.allTypes.any { it.containsKey(name) == true } - val anyKnownToNotContainKey = type.allTypes.any { it.containsKey(name) == false } - val anyNotKnownToContainKey = type.allTypes.any { it.containsKey(name) == null } - when { - anyKnownToNotContainKey.not() && anyNotKnownToContainKey.not() -> true - anyKnownToContainKey.not() && anyNotKnownToContainKey -> false - else -> null - } - } - is AnyType -> null + private fun CompilerType.containsKey(name: BindingName): Boolean? { + return when (this.kind) { + Kind.ROW -> this.fields!!.any { name.matches(it.name) } + Kind.STRUCT -> null + Kind.DYNAMIC -> null else -> false } } @@ -197,10 +165,10 @@ internal data class TypeEnv( @OptIn(PartiQLValueExperimental::class) internal fun Rex.toPath(steps: List): Rex = steps.fold(this) { curr, step -> val op = when (step.case) { - BindingCase.SENSITIVE -> rexOpPathKey(curr, rex(StaticType.STRING, rexOpLit(stringValue(step.name)))) + BindingCase.SENSITIVE -> rexOpPathKey(curr, rex(CompilerType(PType.typeString()), rexOpLit(stringValue(step.name)))) BindingCase.INSENSITIVE -> rexOpPathSymbol(curr, step.name) } - rex(StaticType.ANY, op) + rex(CompilerType(PType.typeDynamic()), op) } } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt index 3a6ecd8a8f..f6019b3bda 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt @@ -2,13 +2,13 @@ package org.partiql.planner.internal.typer import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType import org.partiql.types.AnyOfType import org.partiql.types.AnyType import org.partiql.types.BagType import org.partiql.types.BlobType import org.partiql.types.BoolType import org.partiql.types.ClobType -import org.partiql.types.CollectionType import org.partiql.types.DateType import org.partiql.types.DecimalType import org.partiql.types.FloatType @@ -17,6 +17,8 @@ import org.partiql.types.IntType import org.partiql.types.ListType import org.partiql.types.MissingType import org.partiql.types.NullType +import org.partiql.types.PType +import org.partiql.types.PType.Kind import org.partiql.types.SexpType import org.partiql.types.StaticType import org.partiql.types.StringType @@ -165,17 +167,16 @@ private fun StaticType.asRuntimeType(): PartiQLValueType = when (this) { * @param lastStepOptional * @return */ -internal fun StaticType.exclude(steps: List, lastStepOptional: Boolean = false): StaticType { +internal fun CompilerType.exclude(steps: List, lastStepOptional: Boolean = false): CompilerType { val type = this return steps.fold(type) { acc, step -> - when (acc) { - is StructType -> acc.exclude(step, lastStepOptional) - is CollectionType -> acc.exclude(step, lastStepOptional) - is AnyOfType -> StaticType.unionOf( - acc.types.map { it.exclude(steps, lastStepOptional) }.toSet() - ) + when (acc.kind) { + Kind.DYNAMIC -> CompilerType(PType.typeDynamic()) + Kind.ROW -> acc.excludeStruct(step, lastStepOptional) + Kind.STRUCT -> acc + Kind.LIST, Kind.BAG, Kind.SEXP -> acc.excludeCollection(step, lastStepOptional) else -> acc - }.flatten() + } } } @@ -186,24 +187,24 @@ internal fun StaticType.exclude(steps: List, lastStepOption * @param lastStepOptional * @return */ -internal fun StructType.exclude(step: Rel.Op.Exclude.Step, lastStepOptional: Boolean = false): StaticType { +internal fun CompilerType.excludeStruct(step: Rel.Op.Exclude.Step, lastStepOptional: Boolean = false): CompilerType { val type = step.type val substeps = step.substeps val output = fields.mapNotNull { field -> val newField = if (substeps.isEmpty()) { if (lastStepOptional) { - StructType.Field(field.key, field.value) + CompilerType.Field(field.name, field.type) } else { null } } else { - val k = field.key - val v = field.value.exclude(substeps, lastStepOptional) - StructType.Field(k, v) + val k = field.name + val v = field.type.exclude(substeps, lastStepOptional) + CompilerType.Field(k, v) } when (type) { is Rel.Op.Exclude.Type.StructSymbol -> { - if (type.symbol.equals(field.key, ignoreCase = true)) { + if (type.symbol.equals(field.name, ignoreCase = true)) { newField } else { field @@ -211,7 +212,7 @@ internal fun StructType.exclude(step: Rel.Op.Exclude.Step, lastStepOptional: Boo } is Rel.Op.Exclude.Type.StructKey -> { - if (type.key == field.key) { + if (type.key == field.name) { newField } else { field @@ -221,7 +222,7 @@ internal fun StructType.exclude(step: Rel.Op.Exclude.Step, lastStepOptional: Boo else -> field } } - return this.copy(fields = output) + return CompilerType(PType.typeRow(output)) } /** @@ -231,8 +232,8 @@ internal fun StructType.exclude(step: Rel.Op.Exclude.Step, lastStepOptional: Boo * @param lastStepOptional * @return */ -internal fun CollectionType.exclude(step: Rel.Op.Exclude.Step, lastStepOptional: Boolean = false): StaticType { - var e = this.elementType +internal fun CompilerType.excludeCollection(step: Rel.Op.Exclude.Step, lastStepOptional: Boolean = false): CompilerType { + var e = this.typeParameter val substeps = step.substeps when (step.type) { is Rel.Op.Exclude.Type.CollIndex -> { @@ -240,6 +241,7 @@ internal fun CollectionType.exclude(step: Rel.Op.Exclude.Step, lastStepOptional: e = e.exclude(substeps, lastStepOptional = true) } } + is Rel.Op.Exclude.Type.CollWildcard -> { if (substeps.isNotEmpty()) { e = e.exclude(substeps, lastStepOptional) @@ -247,14 +249,16 @@ internal fun CollectionType.exclude(step: Rel.Op.Exclude.Step, lastStepOptional: // currently no change to elementType if collection wildcard is last element; this behavior could // change based on RFC definition } + else -> { // currently no change to elementType and no error thrown; could consider an error/warning in // the future } } - return when (this) { - is BagType -> this.copy(e) - is ListType -> this.copy(e) - is SexpType -> this.copy(e) + return when (this.kind) { + Kind.LIST -> PType.typeList(e).toCType() + Kind.BAG -> PType.typeBag(e).toCType() + Kind.SEXP -> PType.typeSexp(e).toCType() + else -> throw IllegalStateException() } } diff --git a/partiql-planner/src/main/resources/partiql_plan_internal.ion b/partiql-planner/src/main/resources/partiql_plan_internal.ion index 3c0027ca46..0c1cfdde1c 100644 --- a/partiql-planner/src/main/resources/partiql_plan_internal.ion +++ b/partiql-planner/src/main/resources/partiql_plan_internal.ion @@ -1,8 +1,8 @@ imports::{ kotlin: [ partiql_value::'org.partiql.value.PartiQLValue', - partiql_value_type::'org.partiql.value.PartiQLValueType', - static_type::'org.partiql.types.StaticType', + partiql_value_type::'org.partiql.planner.internal.typer.CompilerType', + static_type::'org.partiql.planner.internal.typer.CompilerType', fn_signature::'org.partiql.spi.fn.FnSignature', agg_signature::'org.partiql.spi.fn.AggSignature', problem::'org.partiql.errors.Problem' diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerErrorReportingTests.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerErrorReportingTests.kt index 8994aa5a8c..4d8f33a87f 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerErrorReportingTests.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerErrorReportingTests.kt @@ -7,14 +7,18 @@ import org.partiql.errors.Problem import org.partiql.errors.ProblemSeverity import org.partiql.parser.PartiQLParserBuilder import org.partiql.plan.debug.PlanPrinter +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType import org.partiql.planner.util.ProblemCollector import org.partiql.plugins.memory.MemoryCatalog import org.partiql.plugins.memory.MemoryConnector import org.partiql.spi.connector.ConnectorSession import org.partiql.types.BagType +import org.partiql.types.PType import org.partiql.types.StaticType import org.partiql.types.StructType import org.partiql.types.TupleConstraint +import kotlin.test.assertEquals internal class PlannerErrorReportingTests { val catalogName = "mode_test" @@ -83,8 +87,15 @@ internal class PlannerErrorReportingTests { val query: String, val isSignal: Boolean, val assertion: (List) -> List<() -> Boolean>, - val expectedType: StaticType = StaticType.ANY - ) + val expectedType: CompilerType + ) { + constructor( + query: String, + isSignal: Boolean, + assertion: (List) -> List<() -> Boolean>, + expectedType: StaticType = StaticType.ANY + ) : this(query, isSignal, assertion, PType.fromStaticType(expectedType).toCType()) + } companion object { fun closedStruct(vararg field: StructType.Field): StructType = @@ -115,12 +126,14 @@ internal class PlannerErrorReportingTests { TestCase( "MISSING", false, - assertOnProblemCount(0, 0) + assertOnProblemCount(0, 0), + expectedType = PType.typeUnknown().toCType() ), TestCase( "MISSING", true, - assertOnProblemCount(0, 0) + assertOnProblemCount(0, 0), + expectedType = PType.typeUnknown().toCType() ), // Unresolved variable always signals (10.1.3) TestCase( @@ -133,7 +146,8 @@ internal class PlannerErrorReportingTests { TestCase( "1 + MISSING", false, - assertOnProblemCount(1, 0) + assertOnProblemCount(1, 0), + expectedType = PType.typeInt().toCType() ), // This will be a non-resolved function error. // As plus does not contain a function that match argument type with @@ -142,7 +156,8 @@ internal class PlannerErrorReportingTests { TestCase( "1 + MISSING", true, - assertOnProblemCount(0, 1) + assertOnProblemCount(0, 1), + expectedType = PType.typeInt().toCType() ), // Attempting to do path navigation(symbol) on missing(which is not tuple) // returns missing in quite mode, and error out in signal mode @@ -263,14 +278,14 @@ internal class PlannerErrorReportingTests { TestCase( "1 + not_a_function(1)", false, - assertOnProblemCount(0, 1), - StaticType.unionOf(StaticType.INT4, StaticType.INT8, StaticType.INT, StaticType.FLOAT, StaticType.DECIMAL), + assertOnProblemCount(1, 1), + StaticType.INT4, ), TestCase( "1 + not_a_function(1)", true, - assertOnProblemCount(0, 1), - StaticType.unionOf(StaticType.INT4, StaticType.INT8, StaticType.INT, StaticType.FLOAT, StaticType.DECIMAL), + assertOnProblemCount(0, 2), + StaticType.INT4, ), TestCase( @@ -395,7 +410,7 @@ internal class PlannerErrorReportingTests { plan, problems, *tc.assertion(problems).toTypedArray() ) - tc.expectedType.assertStaticTypeEqual((plan.statement as org.partiql.plan.Statement.Query).root.type) + assertEquals(tc.expectedType, (plan.statement as org.partiql.plan.Statement.Query).root.type) } @ParameterizedTest diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/FnResolverTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/FnResolverTest.kt index 763033fabf..e271429308 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/FnResolverTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/FnResolverTest.kt @@ -4,10 +4,11 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.api.fail import org.partiql.planner.internal.FnMatch import org.partiql.planner.internal.FnResolver +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType import org.partiql.spi.fn.FnExperimental import org.partiql.spi.fn.FnParameter import org.partiql.spi.fn.FnSignature -import org.partiql.types.StaticType +import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType @@ -32,7 +33,7 @@ class FnResolverTest { ), ) ) - val args = listOf(StaticType.INT4, StaticType.FLOAT) + val args = listOf(PType.typeInt().toCType(), PType.typeDoublePrecision().toCType()) val expectedImplicitCasts = listOf(true, false) val case = Case.Success(variants, args, expectedImplicitCasts) case.assert() @@ -51,7 +52,7 @@ class FnResolverTest { isNullable = false, ) ) - val args = listOf(StaticType.STRING, StaticType.STRING) + val args = listOf(PType.typeString().toCType(), PType.typeString().toCType()) val expectedImplicitCasts = listOf(false, false) val case = Case.Success(variants, args, expectedImplicitCasts) case.assert() @@ -63,7 +64,7 @@ class FnResolverTest { class Success( private val variants: List, - private val inputs: List, + private val inputs: List, private val expectedImplicitCast: List, ) : Case() { diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt index bb08a3c107..d8cff4fc04 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt @@ -19,6 +19,8 @@ import org.partiql.spi.BindingName import org.partiql.spi.BindingPath import org.partiql.spi.connector.ConnectorMetadata import org.partiql.spi.connector.ConnectorSession +import org.partiql.types.PType +import org.partiql.types.PType.Kind import org.partiql.types.StaticType import org.partiql.value.PartiQLValueExperimental import java.util.Random @@ -26,7 +28,10 @@ import java.util.stream.Stream abstract class PartiQLTyperTestBase { sealed class TestResult { - data class Success(val expectedType: StaticType) : TestResult() { + data class Success(val expectedType: PType) : TestResult() { + + constructor(expectedType: StaticType) : this(PType.fromStaticType(expectedType)) + override fun toString(): String = "Success_$expectedType" } @@ -127,9 +132,9 @@ abstract class PartiQLTyperTestBase { val result = testingPipeline(statement, testName, metadata, pc) val root = (result.plan.statement as Statement.Query).root val actualType = root.type - assert(actualType == StaticType.ANY) { + assert(actualType.kind == Kind.DYNAMIC) { buildString { - this.appendLine(" expected Type is : ANY") + this.appendLine("expected Type is : DYNAMIC") this.appendLine("actual Type is : $actualType") PlanPrinter.append(this, result.plan) } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt index 4108f18c42..2fe976d566 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt @@ -1,10 +1,12 @@ package org.partiql.planner.internal.typer +import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.Env import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.Statement import org.partiql.planner.internal.ir.identifierSymbol import org.partiql.planner.internal.ir.refObj import org.partiql.planner.internal.ir.rex @@ -16,23 +18,15 @@ import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpVarGlobal import org.partiql.planner.internal.ir.rexOpVarUnresolved import org.partiql.planner.internal.ir.statementQuery +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType import org.partiql.planner.util.ProblemCollector import org.partiql.plugins.local.LocalConnector -import org.partiql.types.StaticType -import org.partiql.types.StaticType.Companion.ANY -import org.partiql.types.StaticType.Companion.DECIMAL -import org.partiql.types.StaticType.Companion.FLOAT -import org.partiql.types.StaticType.Companion.INT2 -import org.partiql.types.StaticType.Companion.INT4 -import org.partiql.types.StaticType.Companion.STRING -import org.partiql.types.StructType -import org.partiql.types.TupleConstraint +import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.int32Value import org.partiql.value.stringValue import java.util.Random import kotlin.io.path.toPath -import kotlin.test.assertEquals class PlanTyperTest { @@ -40,6 +34,12 @@ class PlanTyperTest { private val root = this::class.java.getResource("/catalogs/default/pql")!!.toURI().toPath() + private val ANY = PType.typeDynamic().toCType() + private val STRING = PType.typeString().toCType() + private val INT4 = PType.typeInt().toCType() + private val DOUBLE_PRECISION = PType.typeDoublePrecision().toCType() + private val DECIMAL = PType.typeDecimalArbitrary().toCType() + @OptIn(PartiQLValueExperimental::class) private val LITERAL_STRUCT_1 = rex( ANY, @@ -63,30 +63,16 @@ class PlanTyperTest { ) ) - private val LITERAL_STRUCT_1_FIRST_KEY_TYPE = StructType( - fields = mapOf( - "sEcoNd_KEY" to INT4 - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Open(false) - ) - ) + private val LITERAL_STRUCT_1_FIRST_KEY_TYPE = PType.typeRow( + listOf(CompilerType.Field("sEcoNd_KEY", INT4)), + ).toCType() @OptIn(PartiQLValueExperimental::class) private val LITERAL_STRUCT_1_TYPED: Rex get() { - val topLevelStruct = StructType( - fields = mapOf( - "FiRsT_KeY" to LITERAL_STRUCT_1_FIRST_KEY_TYPE - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Open(false) - ) - ) + val topLevelStruct = PType.typeRow( + listOf(CompilerType.Field("FiRsT_KeY", LITERAL_STRUCT_1_FIRST_KEY_TYPE)), + ).toCType() return rex( type = topLevelStruct, rexOpStruct( @@ -110,65 +96,25 @@ class PlanTyperTest { ) } - private val ORDERED_DUPLICATES_STRUCT = StructType( - fields = listOf( - StructType.Field("definition", StaticType.STRING), - StructType.Field("definition", StaticType.FLOAT), - StructType.Field("DEFINITION", StaticType.DECIMAL), + private val ORDERED_DUPLICATES_STRUCT = PType.typeRow( + listOf( + CompilerType.Field("definition", STRING), + CompilerType.Field("definition", DOUBLE_PRECISION), + CompilerType.Field("DEFINITION", DECIMAL), ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.Ordered - ) - ) + ).toCType() - private val DUPLICATES_STRUCT = StructType( - fields = listOf( - StructType.Field("definition", StaticType.STRING), - StructType.Field("definition", StaticType.FLOAT), - StructType.Field("DEFINITION", StaticType.DECIMAL), + private val DUPLICATES_STRUCT = PType.typeRow( + listOf( + CompilerType.Field("definition", STRING), + CompilerType.Field("definition", DOUBLE_PRECISION), + CompilerType.Field("DEFINITION", DECIMAL), ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false) - ) - ) + ).toCType() - private val CLOSED_UNION_DUPLICATES_STRUCT = StaticType.unionOf( - StructType( - fields = listOf( - StructType.Field("definition", StaticType.STRING), - StructType.Field("definition", StaticType.FLOAT), - StructType.Field("DEFINITION", StaticType.DECIMAL), - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false) - ) - ), - StructType( - fields = listOf( - StructType.Field("definition", StaticType.INT2), - StructType.Field("definition", StaticType.INT4), - StructType.Field("DEFINITION", StaticType.INT8), - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.Ordered - ) - ), - ) + private val CLOSED_UNION_DUPLICATES_STRUCT = ANY - private val OPEN_DUPLICATES_STRUCT = StructType( - fields = listOf( - StructType.Field("definition", StaticType.STRING), - StructType.Field("definition", StaticType.FLOAT), - StructType.Field("DEFINITION", StaticType.DECIMAL), - ), - contentClosed = false - ) + private val OPEN_DUPLICATES_STRUCT = PType.typeStruct().toCType() private fun getTyper(): PlanTyperWrapper { ProblemCollector() @@ -217,6 +163,7 @@ class PlanTyperTest { } @Test + @Disabled("PartiQL doesn't have the concept of ordered structs (yet)") fun testOrderedDuplicates() { val wrapper = getTyper() val typer = wrapper.typer @@ -235,6 +182,7 @@ class PlanTyperTest { } @Test + @Disabled("PartiQL doesn't have the concept of ordered structs (yet)") fun testOrderedDuplicatesWithSensitivity() { val wrapper = getTyper() val typer = wrapper.typer @@ -261,7 +209,7 @@ class PlanTyperTest { path = listOf("main", "closed_duplicates_struct"), ).pathSymbol( "DEFINITION", - StaticType.unionOf(STRING, FLOAT, DECIMAL) + PType.typeDynamic().toCType() ) ) @@ -296,7 +244,7 @@ class PlanTyperTest { path = listOf("main", "closed_duplicates_struct"), ).pathKey( "definition", - StaticType.unionOf(StaticType.STRING, StaticType.FLOAT) + PType.typeDynamic().toCType() ) ) @@ -331,7 +279,7 @@ class PlanTyperTest { path = listOf("main", "closed_union_duplicates_struct"), ).pathSymbol( "definition", - StaticType.unionOf(STRING, FLOAT, DECIMAL, INT2) + PType.typeDynamic().toCType() ) ) @@ -350,7 +298,7 @@ class PlanTyperTest { path = listOf("main", "closed_union_duplicates_struct"), ).pathKey( "definition", - StaticType.unionOf(STRING, FLOAT, INT2) + PType.typeDynamic().toCType() ) ) @@ -361,11 +309,11 @@ class PlanTyperTest { @OptIn(PartiQLValueExperimental::class) private fun rexString(str: String) = rex(STRING, rexOpLit(stringValue(str))) - private fun Rex.pathKey(key: String, type: StaticType = ANY): Rex = Rex(type, rexOpPathKey(this, rexString(key))) + private fun Rex.pathKey(key: String, type: CompilerType = ANY): Rex = Rex(type, rexOpPathKey(this, rexString(key))) - private fun Rex.pathSymbol(key: String, type: StaticType = ANY): Rex = Rex(type, rexOpPathSymbol(this, key)) + private fun Rex.pathSymbol(key: String, type: CompilerType = ANY): Rex = Rex(type, rexOpPathSymbol(this, key)) - private fun unresolvedSensitiveVar(name: String, type: StaticType = ANY): Rex { + private fun unresolvedSensitiveVar(name: String, type: CompilerType = ANY): Rex { return rex( type, rexOpVarUnresolved( @@ -375,10 +323,19 @@ class PlanTyperTest { ) } - private fun global(type: StaticType, path: List): Rex { + private fun global(type: CompilerType, path: List): Rex { return rex( type, rexOpVarGlobal(refObj(catalog = "pql", path = path, type)) ) } + + private fun assertEquals(expected: Statement, actual: Statement) { + return assert(expected == actual) { + buildString { + appendLine("Expected : $expected") + appendLine("Actual : $actual") + } + } + } } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index 1ba3024610..481a819cad 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -20,6 +20,7 @@ import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.ProblemGenerator +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType import org.partiql.planner.internal.typer.PlanTyperTestsPorted.TestCase.ErrorTestCase import org.partiql.planner.internal.typer.PlanTyperTestsPorted.TestCase.SuccessTestCase import org.partiql.planner.internal.typer.PlanTyperTestsPorted.TestCase.ThrowingExceptionTestCase @@ -37,12 +38,16 @@ import org.partiql.spi.connector.ConnectorMetadata import org.partiql.spi.connector.ConnectorSession import org.partiql.types.BagType import org.partiql.types.ListType +import org.partiql.types.PType import org.partiql.types.SexpType import org.partiql.types.StaticType import org.partiql.types.StaticType.Companion.ANY +import org.partiql.types.StaticType.Companion.DECIMAL import org.partiql.types.StaticType.Companion.INT import org.partiql.types.StaticType.Companion.INT4 import org.partiql.types.StaticType.Companion.INT8 +import org.partiql.types.StaticType.Companion.STRING +import org.partiql.types.StaticType.Companion.STRUCT import org.partiql.types.StaticType.Companion.unionOf import org.partiql.types.StructType import org.partiql.types.TupleConstraint @@ -53,7 +58,7 @@ import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue -class PlanTyperTestsPorted { +internal class PlanTyperTestsPorted { sealed class TestCase { class SuccessTestCase( @@ -62,9 +67,19 @@ class PlanTyperTestsPorted { val query: String? = null, val catalog: String = "pql", val catalogPath: List = emptyList(), - val expected: StaticType, + val expected: CompilerType, val warnings: ProblemHandler? = null, ) : TestCase() { + constructor( + name: String? = null, + key: PartiQLTest.Key? = null, + query: String? = null, + catalog: String = "pql", + catalogPath: List = emptyList(), + expected: StaticType, + warnings: ProblemHandler? = null, + ) : this(name, key, query, catalog, catalogPath, PType.fromStaticType(expected).toCType(), warnings) + override fun toString(): String { if (key != null) { return "${key.group} : ${key.name}" @@ -75,15 +90,26 @@ class PlanTyperTestsPorted { class ErrorTestCase( val name: String, - val key: PartiQLTest.Key? = null, - val query: String? = null, + val key: PartiQLTest.Key?, + val query: String?, val catalog: String = "pql", val catalogPath: List = emptyList(), - val note: String? = null, - val expected: StaticType? = null, - val problemHandler: ProblemHandler? = null, + val note: String?, + val expected: CompilerType?, + val problemHandler: ProblemHandler?, ) : TestCase() { + constructor( + name: String, + key: PartiQLTest.Key? = null, + query: String? = null, + catalog: String = "pql", + catalogPath: List = emptyList(), + note: String? = null, + expected: StaticType? = null, + problemHandler: ProblemHandler? = null, + ) : this(name, key, query, catalog, catalogPath, note, expected?.let { PType.fromStaticType(it).toCType() }, problemHandler) + override fun toString(): String = "$name : ${query ?: key}" } @@ -107,11 +133,18 @@ class PlanTyperTestsPorted { private val planner = PartiQLPlanner.builder().signalMode().build() private fun assertProblemExists(problem: Problem) = ProblemHandler { problems, ignoreSourceLocation -> - when (ignoreSourceLocation) { - true -> assertTrue("Expected to find $problem in $problems") { - problems.any { it.details == problem.details } + val message = buildString { + appendLine("Expected problems to include: $problem") + appendLine("Received: [") + problems.forEach { + append("\t") + appendLine(it) } - false -> assertTrue("Expected to find $problem in $problems") { problems.any { it == problem } } + appendLine("]") + } + when (ignoreSourceLocation) { + true -> assertTrue(message) { problems.any { it.details == problem.details } } + false -> assertTrue(message) { problems.any { it == problem } } } } @@ -562,11 +595,14 @@ class PlanTyperTestsPorted { catalog = "pql", expected = StaticType.BOOL, ), - SuccessTestCase( + ErrorTestCase( name = "MISSING IS NULL", key = key("is-type-04"), catalog = "pql", expected = StaticType.BOOL, + problemHandler = assertProblemExists( + ProblemGenerator.expressionAlwaysReturnsMissing("Static function always receives MISSING arguments.") + ) ), SuccessTestCase( name = "NULL IS NULL", @@ -755,14 +791,14 @@ class PlanTyperTestsPorted { SuccessTestCase( name = "BITWISE_AND_NULL_OPERAND", query = "1 & NULL", - expected = unionOf(INT4, INT8, INT), + expected = INT4, ), ErrorTestCase( name = "BITWISE_AND_MISSING_OPERAND", query = "1 & MISSING", - expected = ANY, // TODO: Is this unionOf(INT4, INT8, INT) ? + expected = INT4, problemHandler = assertProblemExists( - ProblemGenerator.expressionAlwaysReturnsMissing("Function argument is always the missing value.") + ProblemGenerator.expressionAlwaysReturnsMissing("Static function always receives MISSING arguments.") ) ), ErrorTestCase( @@ -1733,31 +1769,22 @@ class PlanTyperTestsPorted { expected = BagType( StructType( fields = mapOf( - "t" to StaticType.unionOf( - StructType( - fields = mapOf( - "a" to StructType( - fields = mapOf( - "c" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) + "t" to StructType( + fields = mapOf( + "a" to StructType( + fields = mapOf( + "c" to StaticType.STRING + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) ) - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), - StructType( - fields = mapOf( - "a" to ANY - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ) ), - ) + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), ), contentClosed = true, constraints = setOf( @@ -1774,40 +1801,22 @@ class PlanTyperTestsPorted { expected = BagType( StructType( fields = mapOf( - "t" to StaticType.unionOf( - StructType( - fields = mapOf( - "a" to StructType( - fields = mapOf( - "c" to StaticType.STRING - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ) - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) - ), - StructType( - fields = mapOf( - "a" to StructType( - fields = mapOf( - "c" to ANY - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) + "t" to StructType( + fields = mapOf( + "a" to StructType( + fields = mapOf( + "c" to StaticType.STRING + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) ) - ), - contentClosed = true, - constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ) ), - ) + contentClosed = true, + constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true)) + ), ), contentClosed = true, constraints = setOf( @@ -2033,41 +2042,17 @@ class PlanTyperTestsPorted { elementType = StructType( fields = mapOf( "a" to ListType( - elementType = StaticType.unionOf( - StructType( - fields = mapOf( - "b" to StaticType.INT4, - "c" to StaticType.INT4 - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) - ), - StructType( - fields = mapOf( - "b" to StaticType.INT4, - "c" to ANY - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) + elementType = StructType( + fields = mapOf( + "b" to StaticType.INT4, + "c" to ANY ), - StructType( - fields = mapOf( - "b" to StaticType.INT4, - "c" to StaticType.DECIMAL - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true) - ) + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) ) - ) + ), ) ), contentClosed = true, @@ -2218,19 +2203,7 @@ class PlanTyperTestsPorted { { 'a': 1 } >> AS t """, - expected = BagType( - StructType( - fields = listOf( - StructType.Field("b", INT4), - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) - ) - ), + expected = BagType(ANY), ), SuccessTestCase( name = "Tuple Union with Heterogeneous Data (2)", @@ -2651,20 +2624,7 @@ class PlanTyperTestsPorted { SuccessTestCase( key = PartiQLTest.Key("basics", "case-when-29"), catalog = "pql", - expected = unionOf( - StructType( - fields = listOf( - StructType.Field("x", StaticType.INT4), - StructType.Field("y", StaticType.INT4), - ), - ), - StructType( - fields = listOf( - StructType.Field("x", StaticType.INT8), - StructType.Field("y", StaticType.INT8), - ), - ), - ), + expected = STRUCT ), SuccessTestCase( name = "CASE-WHEN always MISSING", @@ -3053,7 +3013,7 @@ class PlanTyperTestsPorted { """, expected = ANY, problemHandler = assertProblemExists( - ProblemGenerator.expressionAlwaysReturnsMissing("Collections must be indexed with integers, found string") + ProblemGenerator.expressionAlwaysReturnsMissing("Collections must be indexed with integers, found STRING") ) ), // The reason this is ANY is because we do not have support for constant-folding. We don't know what @@ -3349,7 +3309,7 @@ class PlanTyperTestsPorted { expected = BagType( StructType( fields = mapOf( - "c" to StaticType.DECIMAL, + "c" to ANY, ), contentClosed = true, constraints = setOf( @@ -3379,10 +3339,9 @@ class PlanTyperTestsPorted { ) ) ), - ErrorTestCase( + SuccessTestCase( name = """ - unary plus on non-compatible union type -- this cannot resolve to a dynamic call since no function - will ever be invoked. + unary plus on dynamic types """.trimIndent(), query = """ SELECT VALUE +t.a @@ -3392,12 +3351,6 @@ class PlanTyperTestsPorted { >> AS t """.trimIndent(), expected = BagType(ANY), - problemHandler = assertProblemExists( - ProblemGenerator.incompatibleTypesForOp( - listOf(StaticType.unionOf(StaticType.STRING, StaticType.BAG)), - "POS", - ) - ) ), ErrorTestCase( name = """ @@ -3410,10 +3363,10 @@ class PlanTyperTestsPorted { { 'NOT_A': 1 } >> AS t """.trimIndent(), - expected = BagType(unionOf(StaticType.INT2, INT4, INT8, INT, StaticType.FLOAT, StaticType.DECIMAL)), + expected = BagType(ANY), problemHandler = assertProblemExists( ProblemGenerator.undefinedVariable( - Identifier.Symbol("a", Identifier.CaseSensitivity.SENSITIVE), + Identifier.Symbol("a", Identifier.CaseSensitivity.INSENSITIVE), setOf("t"), ) ) @@ -3426,9 +3379,9 @@ class PlanTyperTestsPorted { query = """ +MISSING """.trimIndent(), - expected = StaticType.ANY, + expected = StaticType.DECIMAL, // This is due to it being the highest precedence type problemHandler = assertProblemExists( - ProblemGenerator.expressionAlwaysReturnsMissing("Function argument is always the missing value.") + ProblemGenerator.expressionAlwaysReturnsMissing("Static function always receives MISSING arguments.") ) ), ) @@ -3516,6 +3469,33 @@ class PlanTyperTestsPorted { // Parameterized Tests // + @Test + @Disabled("We currently don't have the concept of an ordered struct.") + fun orderedTuple() { + val tc = + SuccessTestCase( + name = "Duplicate fields in ordered STRUCT. NOTE: b.b.d is an ordered struct with two attributes (e). First is INT4.", + query = """ + SELECT d.e AS e + FROM << b.b.d >> AS d + """, + expected = BagType( + StructType( + fields = listOf( + StructType.Field("e", StaticType.INT4) + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ) + runTest(tc) + } + @Test @Disabled("The planner doesn't support heterogeneous input to aggregation functions (yet?).") fun failingTest() { @@ -3569,6 +3549,38 @@ class PlanTyperTestsPorted { runTest(tc) } + @Test + fun developmentTest() { + val tc = SuccessTestCase( + name = "DEV TEST", + query = "CAST('' AS STRING) < CAST('' AS SYMBOL);", + expected = PType.typeBool().toCType() + ) + runTest(tc) + } + + @Test + fun developmentTest3() { + val tc = + SuccessTestCase( + name = "MISSING IS MISSING", + key = key("is-type-06"), + catalog = "pql", + expected = StaticType.BOOL, + ) + runTest(tc) + } + + @Test + fun developmentTest4() { + val tc = SuccessTestCase( + name = "DEV TEST 4", + query = "NULLIF([], [])", + expected = PType.typeList().toCType() + ) + runTest(tc) + } + @ParameterizedTest @ArgumentsSource(TestProvider::class) fun test(tc: TestCase) = runTest(tc) @@ -4368,26 +4380,6 @@ class PlanTyperTestsPorted { ) ) ), - SuccessTestCase( - name = "Duplicate fields in ordered STRUCT. NOTE: b.b.d is an ordered struct with two attributes (e). First is INT4.", - query = """ - SELECT d.e AS e - FROM << b.b.d >> AS d - """, - expected = BagType( - StructType( - fields = listOf( - StructType.Field("e", StaticType.INT4) - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) - ) - ) - ), SuccessTestCase( name = "Duplicate fields in struct", query = """ diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/TypeEnvTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/TypeEnvTest.kt index 9d9758b4c0..a5274f20da 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/TypeEnvTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/TypeEnvTest.kt @@ -6,13 +6,11 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.relBinding +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath -import org.partiql.types.BoolType -import org.partiql.types.StaticType -import org.partiql.types.StructType -import org.partiql.types.TupleConstraint +import org.partiql.types.PType import kotlin.test.assertEquals import kotlin.test.fail @@ -33,21 +31,21 @@ internal class TypeEnvTest { @JvmStatic val locals = TypeEnv( listOf( - relBinding("A", struct("B" to BoolType())), - relBinding("a", struct("b" to BoolType())), + relBinding("A", struct("B" to PType.typeBool().toCType())), + relBinding("a", struct("b" to PType.typeBool().toCType())), relBinding("X", struct(open = true)), - relBinding("x", struct("Y" to BoolType(), open = true)), + relBinding("x", struct("Y" to PType.typeBool().toCType(), open = false)), // We currently don't allow for partial schema structs relBinding("y", struct(open = true)), - relBinding("T", struct("x" to BoolType(), "x" to BoolType())), + relBinding("T", struct("x" to PType.typeBool().toCType(), "x" to PType.typeBool().toCType())), ), outer = emptyList() ) - private fun struct(vararg fields: Pair, open: Boolean = false): StructType { - return StructType( - fields = fields.map { StructType.Field(it.first, it.second) }, - constraints = setOf(TupleConstraint.Open(open)), - ) + private fun struct(vararg fields: Pair, open: Boolean = false): CompilerType { + return when (open) { + true -> PType.typeStruct().toCType() + false -> PType.typeRow(fields.map { CompilerType.Field(it.first, it.second) }).toCType() + } } @JvmStatic diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt index 38078da0fa..6973cf5de4 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt @@ -6,6 +6,7 @@ import org.partiql.planner.internal.typer.PartiQLTyperTestBase import org.partiql.planner.internal.typer.accumulateSuccessNullCall import org.partiql.planner.util.CastType import org.partiql.planner.util.allIntType +import org.partiql.planner.util.allNumberType import org.partiql.planner.util.allSupportedType import org.partiql.planner.util.cartesianProduct import org.partiql.planner.util.castTable @@ -20,7 +21,7 @@ class OpBitwiseAndTest : PartiQLTyperTestBase() { ).map { inputs.get("basics", it)!! } val argsMap = buildMap { - val successArgs = allIntType.let { cartesianProduct(it, it) } + val successArgs = allNumberType.let { cartesianProduct(it, it) } val failureArgs = cartesianProduct( allSupportedType, allSupportedType @@ -32,6 +33,9 @@ class OpBitwiseAndTest : PartiQLTyperTestBase() { val arg0 = args.first() val arg1 = args[1] val output = when { + arg0 !in allIntType && arg1 !in allIntType -> StaticType.INT + arg0 in allIntType && arg1 !in allIntType -> arg0 + arg0 !in allIntType && arg1 in allIntType -> arg1 arg0 == arg1 -> arg1 castTable(arg1, arg0) == CastType.COERCION -> arg0 castTable(arg0, arg1) == CastType.COERCION -> arg1 diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpComparisonTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpComparisonTest.kt index 7a27ca9593..0d19562f65 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpComparisonTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpComparisonTest.kt @@ -48,8 +48,8 @@ class OpComparisonTest : PartiQLTyperTestBase() { StaticType.NUMERIC.allTypes, StaticType.NUMERIC.allTypes ) + cartesianProduct( - StaticType.TEXT.allTypes, - StaticType.TEXT.allTypes + StaticType.TEXT.allTypes + listOf(StaticType.CLOB), + StaticType.TEXT.allTypes + listOf(StaticType.CLOB) ) + cartesianProduct( listOf(StaticType.BOOL), listOf(StaticType.BOOL) diff --git a/partiql-spi/api/partiql-spi.api b/partiql-spi/api/partiql-spi.api index a5208ea8da..c99f5ac93d 100644 --- a/partiql-spi/api/partiql-spi.api +++ b/partiql-spi/api/partiql-spi.api @@ -124,9 +124,14 @@ public abstract interface class org/partiql/spi/connector/ConnectorMetadata { } public abstract interface class org/partiql/spi/connector/ConnectorObject { + public abstract fun getPType ()Lorg/partiql/types/PType; public abstract fun getType ()Lorg/partiql/types/StaticType; } +public final class org/partiql/spi/connector/ConnectorObject$DefaultImpls { + public static fun getPType (Lorg/partiql/spi/connector/ConnectorObject;)Lorg/partiql/types/PType; +} + public final class org/partiql/spi/connector/ConnectorPath : java/lang/Iterable, kotlin/jvm/internal/markers/KMappedMarker { public static final field Companion Lorg/partiql/spi/connector/ConnectorPath$Companion; public fun (Ljava/util/List;)V @@ -498,7 +503,9 @@ public final class org/partiql/spi/fn/AggSignature { public final field isNullable Z public final field name Ljava/lang/String; public final field parameters Ljava/util/List; - public final field returns Lorg/partiql/value/PartiQLValueType; + public final field returns Lorg/partiql/types/PType; + public fun (Ljava/lang/String;Lorg/partiql/types/PType;Ljava/util/List;Ljava/lang/String;ZZ)V + public synthetic fun (Ljava/lang/String;Lorg/partiql/types/PType;Ljava/util/List;Ljava/lang/String;ZZILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun (Ljava/lang/String;Lorg/partiql/value/PartiQLValueType;Ljava/util/List;Ljava/lang/String;ZZ)V public synthetic fun (Ljava/lang/String;Lorg/partiql/value/PartiQLValueType;Ljava/util/List;Ljava/lang/String;ZZILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun equals (Ljava/lang/Object;)Z @@ -516,14 +523,15 @@ public abstract interface annotation class org/partiql/spi/fn/FnExperimental : j } public final class org/partiql/spi/fn/FnParameter { + public fun (Ljava/lang/String;Lorg/partiql/types/PType;)V public fun (Ljava/lang/String;Lorg/partiql/value/PartiQLValueType;)V public final fun component1 ()Ljava/lang/String; - public final fun component2 ()Lorg/partiql/value/PartiQLValueType; - public final fun copy (Ljava/lang/String;Lorg/partiql/value/PartiQLValueType;)Lorg/partiql/spi/fn/FnParameter; - public static synthetic fun copy$default (Lorg/partiql/spi/fn/FnParameter;Ljava/lang/String;Lorg/partiql/value/PartiQLValueType;ILjava/lang/Object;)Lorg/partiql/spi/fn/FnParameter; + public final fun component2 ()Lorg/partiql/types/PType; + public final fun copy (Ljava/lang/String;Lorg/partiql/types/PType;)Lorg/partiql/spi/fn/FnParameter; + public static synthetic fun copy$default (Lorg/partiql/spi/fn/FnParameter;Ljava/lang/String;Lorg/partiql/types/PType;ILjava/lang/Object;)Lorg/partiql/spi/fn/FnParameter; public fun equals (Ljava/lang/Object;)Z public final fun getName ()Ljava/lang/String; - public final fun getType ()Lorg/partiql/value/PartiQLValueType; + public final fun getType ()Lorg/partiql/types/PType; public fun hashCode ()I public fun toString ()Ljava/lang/String; } @@ -537,11 +545,13 @@ public final class org/partiql/spi/fn/FnSignature { public final field isNullable Z public final field name Ljava/lang/String; public final field parameters Ljava/util/List; - public final field returns Lorg/partiql/value/PartiQLValueType; + public final field returns Lorg/partiql/types/PType; + public fun (Ljava/lang/String;Lorg/partiql/types/PType;Ljava/util/List;Ljava/lang/String;ZZZZZ)V + public synthetic fun (Ljava/lang/String;Lorg/partiql/types/PType;Ljava/util/List;Ljava/lang/String;ZZZZZILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun (Ljava/lang/String;Lorg/partiql/value/PartiQLValueType;Ljava/util/List;Ljava/lang/String;ZZZZZ)V public synthetic fun (Ljava/lang/String;Lorg/partiql/value/PartiQLValueType;Ljava/util/List;Ljava/lang/String;ZZZZZILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/lang/String; - public final fun component2 ()Lorg/partiql/value/PartiQLValueType; + public final fun component2 ()Lorg/partiql/types/PType; public final fun component3 ()Ljava/util/List; public final fun component4 ()Ljava/lang/String; public final fun component5 ()Z @@ -549,8 +559,8 @@ public final class org/partiql/spi/fn/FnSignature { public final fun component7 ()Z public final fun component8 ()Z public final fun component9 ()Z - public final fun copy (Ljava/lang/String;Lorg/partiql/value/PartiQLValueType;Ljava/util/List;Ljava/lang/String;ZZZZZ)Lorg/partiql/spi/fn/FnSignature; - public static synthetic fun copy$default (Lorg/partiql/spi/fn/FnSignature;Ljava/lang/String;Lorg/partiql/value/PartiQLValueType;Ljava/util/List;Ljava/lang/String;ZZZZZILjava/lang/Object;)Lorg/partiql/spi/fn/FnSignature; + public final fun copy (Ljava/lang/String;Lorg/partiql/types/PType;Ljava/util/List;Ljava/lang/String;ZZZZZ)Lorg/partiql/spi/fn/FnSignature; + public static synthetic fun copy$default (Lorg/partiql/spi/fn/FnSignature;Ljava/lang/String;Lorg/partiql/types/PType;Ljava/util/List;Ljava/lang/String;ZZZZZILjava/lang/Object;)Lorg/partiql/spi/fn/FnSignature; public fun equals (Ljava/lang/Object;)Z public final fun getSpecific ()Ljava/lang/String; public fun hashCode ()I diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorObject.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorObject.kt index c2eb134594..e267d41d93 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorObject.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/ConnectorObject.kt @@ -14,6 +14,7 @@ package org.partiql.spi.connector +import org.partiql.types.PType import org.partiql.types.StaticType /** @@ -30,5 +31,14 @@ public interface ConnectorObject { * * @return */ + @Deprecated("This is subject to removal in a future release.") public fun getType(): StaticType + + /** + * @return the type of the object in a catalog. + */ + public fun getPType(): PType { + // TODO: Remove the default prior to release + return PType.fromStaticType(getType()) + } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt index b3e305b930..6678cb0840 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/SqlBuiltins.kt @@ -22,10 +22,7 @@ internal object SqlBuiltins { Fn_ABS__FLOAT32__FLOAT32, Fn_ABS__FLOAT64__FLOAT64, Fn_ABS__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, - Fn_AND__MISSING_BOOL__BOOL, - Fn_AND__MISSING_MISSING__BOOL, Fn_AND__BOOL_BOOL__BOOL, - Fn_AND__BOOL_MISSING__BOOL, Fn_BETWEEN__INT8_INT8_INT8__BOOL, Fn_BETWEEN__INT16_INT16_INT16__BOOL, Fn_BETWEEN__INT32_INT32_INT32__BOOL, @@ -370,10 +367,7 @@ internal object SqlBuiltins { Fn_NEG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, Fn_NOT__MISSING__BOOL, Fn_NOT__BOOL__BOOL, - Fn_OR__MISSING_BOOL__BOOL, - Fn_OR__MISSING_MISSING__BOOL, Fn_OR__BOOL_BOOL__BOOL, - Fn_OR__BOOL_MISSING__BOOL, Fn_OCTET_LENGTH__STRING__INT32, Fn_OCTET_LENGTH__CLOB__INT32, Fn_OCTET_LENGTH__SYMBOL__INT32, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnAnd.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnAnd.kt index 32ed12ea0a..f5c552cfe9 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnAnd.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnAnd.kt @@ -11,7 +11,6 @@ import org.partiql.value.BoolValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType.BOOL -import org.partiql.value.PartiQLValueType.MISSING import org.partiql.value.boolValue import org.partiql.value.check @@ -45,72 +44,3 @@ internal object Fn_AND__BOOL_BOOL__BOOL : Fn { return boolValue(toReturn) } } - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_AND__MISSING_BOOL__BOOL : Fn { - - override val signature = FnSignature( - name = "and", - returns = BOOL, - parameters = listOf( - FnParameter("lhs", MISSING), - FnParameter("rhs", BOOL), - ), - isNullable = true, - isNullCall = false, - isMissable = false, - isMissingCall = false, - ) - - override fun invoke(args: Array): PartiQLValue { - return when (args[1].check().value!!) { - false -> boolValue(false) - else -> boolValue(null) - } - } -} - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_AND__BOOL_MISSING__BOOL : Fn { - - override val signature = FnSignature( - name = "and", - returns = BOOL, - parameters = listOf( - FnParameter("lhs", BOOL), - FnParameter("rhs", MISSING), - ), - isNullable = true, - isNullCall = false, - isMissable = false, - isMissingCall = false, - ) - - override fun invoke(args: Array): PartiQLValue { - return when (args[0].check().value!!) { - false -> boolValue(false) - else -> boolValue(null) - } - } -} - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_AND__MISSING_MISSING__BOOL : Fn { - - override val signature = FnSignature( - name = "and", - returns = BOOL, - parameters = listOf( - FnParameter("lhs", MISSING), - FnParameter("rhs", MISSING), - ), - isNullable = true, - isNullCall = false, - isMissable = false, - isMissingCall = false, - ) - - override fun invoke(args: Array): PartiQLValue { - return boolValue(null) - } -} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnEq.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnEq.kt index b3cfa80e7e..dbdee96799 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnEq.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnEq.kt @@ -85,12 +85,12 @@ internal object Fn_EQ__ANY_ANY__BOOL : Fn { // TODO ANY, ANY equals not clearly defined at the moment. override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0] val rhs = args[1] - return when { - lhs.type == MISSING || rhs.type == MISSING -> boolValue(lhs == rhs) - else -> boolValue(comparator.compare(lhs, rhs) == 0) - } + return boolValue(comparator.compare(lhs, rhs) == 0) } } @@ -111,6 +111,9 @@ internal object Fn_EQ__BOOL_BOOL__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -134,6 +137,9 @@ internal object Fn_EQ__INT8_INT8__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -157,6 +163,9 @@ internal object Fn_EQ__INT16_INT16__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -180,6 +189,9 @@ internal object Fn_EQ__INT32_INT32__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -203,6 +215,9 @@ internal object Fn_EQ__INT64_INT64__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -226,6 +241,9 @@ internal object Fn_EQ__INT_INT__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -249,6 +267,9 @@ internal object Fn_EQ__DECIMAL_DECIMAL__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -272,6 +293,9 @@ internal object Fn_EQ__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -295,6 +319,9 @@ internal object Fn_EQ__FLOAT32_FLOAT32__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -318,6 +345,9 @@ internal object Fn_EQ__FLOAT64_FLOAT64__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -341,6 +371,9 @@ internal object Fn_EQ__CHAR_CHAR__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -364,6 +397,9 @@ internal object Fn_EQ__STRING_STRING__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -387,6 +423,9 @@ internal object Fn_EQ__SYMBOL_SYMBOL__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -410,6 +449,9 @@ internal object Fn_EQ__BINARY_BINARY__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -433,6 +475,9 @@ internal object Fn_EQ__BYTE_BYTE__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -456,6 +501,9 @@ internal object Fn_EQ__BLOB_BLOB__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -479,6 +527,9 @@ internal object Fn_EQ__CLOB_CLOB__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -502,6 +553,9 @@ internal object Fn_EQ__DATE_DATE__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -525,6 +579,9 @@ internal object Fn_EQ__TIME_TIME__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -548,6 +605,9 @@ internal object Fn_EQ__TIMESTAMP_TIMESTAMP__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -571,6 +631,9 @@ internal object Fn_EQ__INTERVAL_INTERVAL__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check() val rhs = args[1].check() return boolValue(lhs == rhs) @@ -594,6 +657,9 @@ internal object Fn_EQ__BAG_BAG__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check>() val rhs = args[1].check>() return boolValue(lhs == rhs) @@ -617,6 +683,9 @@ internal object Fn_EQ__LIST_LIST__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check>() val rhs = args[1].check>() return boolValue(lhs == rhs) @@ -640,6 +709,9 @@ internal object Fn_EQ__SEXP_SEXP__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check>() val rhs = args[1].check>() return boolValue(lhs == rhs) @@ -663,6 +735,9 @@ internal object Fn_EQ__STRUCT_STRUCT__BOOL : Fn { ) override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } val lhs = args[0].check>() val rhs = args[1].check>() return boolValue(lhs == rhs) @@ -687,6 +762,9 @@ internal object Fn_EQ__NULL_NULL__BOOL : Fn { // TODO how does null comparison work? ie null.null == null.null or int8.null == null.null ?? override fun invoke(args: Array): PartiQLValue { + if (args[0].type == MISSING || args[1].type == MISSING) { + return boolValue(args[0] == args[1]) + } // According to the conformance tests, NULL = NULL -> NULL return boolValue(null) } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsNull.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsNull.kt index 6a040b6fa4..6c1a8ff7cf 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsNull.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnIsNull.kt @@ -23,7 +23,7 @@ internal object Fn_IS_NULL__ANY__BOOL : Fn { isNullable = false, isNullCall = false, isMissable = false, - isMissingCall = false, + isMissingCall = true, ) override fun invoke(args: Array): PartiQLValue { diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnOr.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnOr.kt index 53f50b1854..4444d08bcf 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnOr.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/sql/builtins/FnOr.kt @@ -11,7 +11,6 @@ import org.partiql.value.BoolValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType.BOOL -import org.partiql.value.PartiQLValueType.MISSING import org.partiql.value.boolValue import org.partiql.value.check @@ -42,74 +41,3 @@ internal object Fn_OR__BOOL_BOOL__BOOL : Fn { return boolValue(toReturn) } } - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_OR__MISSING_BOOL__BOOL : Fn { - - override val signature = FnSignature( - name = "or", - returns = BOOL, - parameters = listOf( - FnParameter("lhs", MISSING), - FnParameter("rhs", BOOL), - ), - isNullable = true, - isNullCall = false, - isMissable = false, - isMissingCall = false, - ) - - override fun invoke(args: Array): PartiQLValue { - val rhs = args[1].check().value - return when (rhs) { - true -> boolValue(true) - else -> boolValue(null) - } - } -} - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_OR__BOOL_MISSING__BOOL : Fn { - - override val signature = FnSignature( - name = "or", - returns = BOOL, - parameters = listOf( - FnParameter("lhs", BOOL), - FnParameter("rhs", MISSING), - ), - isNullable = true, - isNullCall = false, - isMissable = false, - isMissingCall = false, - ) - - override fun invoke(args: Array): PartiQLValue { - val lhs = args[0].check().value - return when (lhs) { - true -> boolValue(true) - else -> boolValue(null) - } - } -} - -@OptIn(PartiQLValueExperimental::class, FnExperimental::class) -internal object Fn_OR__MISSING_MISSING__BOOL : Fn { - - override val signature = FnSignature( - name = "or", - returns = BOOL, - parameters = listOf( - FnParameter("lhs", MISSING), - FnParameter("rhs", MISSING), - ), - isNullable = true, - isNullCall = false, - isMissable = false, - isMissingCall = false, - ) - - override fun invoke(args: Array): PartiQLValue { - return boolValue(null) - } -} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/AggSignature.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/AggSignature.kt index 26a9357396..3e797a60da 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/AggSignature.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/AggSignature.kt @@ -1,5 +1,6 @@ package org.partiql.spi.fn +import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType @@ -13,22 +14,33 @@ import org.partiql.value.PartiQLValueType @OptIn(PartiQLValueExperimental::class) public class AggSignature( @JvmField public val name: String, - @JvmField public val returns: PartiQLValueType, + @JvmField public val returns: PType, @JvmField public val parameters: List, @JvmField public val description: String? = null, @JvmField public val isNullable: Boolean = true, @JvmField public val isDecomposable: Boolean = true, ) { + public constructor( + name: String, + returns: PartiQLValueType, + parameters: List, + description: String? = null, + isNullable: Boolean = true, + isDecomposable: Boolean = true, + ) : this( + name, PType.fromPartiQLValueType(returns), parameters, description, isNullable, isDecomposable + ) + /** * Symbolic name of this operator of the form NAME__INPUTS__RETURNS */ public val specific: String = buildString { append(name.uppercase()) append("__") - append(parameters.joinToString("_") { it.type.name }) + append(parameters.joinToString("_") { it.type.toString() }) append("__") - append(returns.name) + append(returns) } /** diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnParameter.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnParameter.kt index 479dd9a98f..9905cf1f00 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnParameter.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnParameter.kt @@ -1,5 +1,6 @@ package org.partiql.spi.fn +import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType @@ -13,5 +14,10 @@ import org.partiql.value.PartiQLValueType @OptIn(PartiQLValueExperimental::class) public data class FnParameter( public val name: String, - public val type: PartiQLValueType, -) + public val type: PType, +) { + public constructor( + name: String, + type: PartiQLValueType, + ) : this(name, PType.fromPartiQLValueType(type)) +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnSignature.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnSignature.kt index c31339604e..d7e4a6ef30 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnSignature.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/FnSignature.kt @@ -1,5 +1,6 @@ package org.partiql.spi.fn +import org.partiql.types.PType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType @@ -23,7 +24,7 @@ import org.partiql.value.PartiQLValueType @OptIn(PartiQLValueExperimental::class) public data class FnSignature( @JvmField public val name: String, - @JvmField public val returns: PartiQLValueType, + @JvmField public val returns: PType, @JvmField public val parameters: List, @JvmField public val description: String? = null, @JvmField public val isDeterministic: Boolean = true, @@ -33,15 +34,27 @@ public data class FnSignature( @JvmField public val isMissingCall: Boolean = true, ) { + public constructor( + name: String, + returns: PartiQLValueType, + parameters: List, + description: String? = null, + isDeterministic: Boolean = true, + isNullable: Boolean = true, + isNullCall: Boolean = false, + isMissable: Boolean = true, + isMissingCall: Boolean = true, + ) : this(name, PType.fromPartiQLValueType(returns), parameters, description, isDeterministic, isNullable, isNullCall, isMissable, isMissingCall) + /** * Symbolic name of this operator of the form NAME__INPUTS__RETURNS */ public val specific: String = buildString { append(name.uppercase()) append("__") - append(parameters.joinToString("_") { it.type.name }) + append(parameters.joinToString("_") { it.type.toString() }) append("__") - append(returns.name) + append(returns) } /** @@ -79,7 +92,7 @@ public data class FnSignature( val p = parameters[i] val ws = (extent - p.name.length) + 1 appendLine() - append(indent).append(p.name.uppercase()).append(" ".repeat(ws)).append(p.type.name) + append(indent).append(p.name.uppercase()).append(" ".repeat(ws)).append(p.type.toString()) if (i != parameters.size - 1) append(",") } } diff --git a/partiql-spi/src/test/kotlin/org/partiql/spi/connector/sql/HeaderCodeGen.kt b/partiql-spi/src/test/kotlin/org/partiql/spi/connector/sql/HeaderCodeGen.kt index 6807bf8854..8326b19289 100644 --- a/partiql-spi/src/test/kotlin/org/partiql/spi/connector/sql/HeaderCodeGen.kt +++ b/partiql-spi/src/test/kotlin/org/partiql/spi/connector/sql/HeaderCodeGen.kt @@ -168,9 +168,9 @@ class HeaderCodeGen { @OptIn(FnExperimental::class) private fun toParams(clazz: String, fn: FnSignature): Array { val snake = fn.name - val returns = fn.returns.name + val returns = fn.returns.toString() val parameters = fn.parameters.mapIndexed { i, p -> - "FnParameter(\"${p.name}\", ${p.type.name})" + "FnParameter(\"${p.name}\", ${p.type})" }.joinToString(",\n", postfix = ",") return arrayOf(clazz, snake, returns, parameters, fn.isNullCall, fn.isNullable, snake) } @@ -178,10 +178,10 @@ class HeaderCodeGen { @OptIn(FnExperimental::class) private fun toParams(clazz: String, agg: AggSignature): Array { val snake = agg.name - val returns = agg.returns.name + val returns = agg.returns.toString() var parameters = "" for (p in agg.parameters) { - parameters += "FnParameter(\"${p.name}\", ${p.type.name}),\n" + parameters += "FnParameter(\"${p.name}\", ${p.type}),\n" } return arrayOf(clazz, snake, returns, parameters, agg.isNullable, agg.isDecomposable, snake) } diff --git a/partiql-types/api/partiql-types.api b/partiql-types/api/partiql-types.api index 69e04aa709..5b4dd34315 100644 --- a/partiql-types/api/partiql-types.api +++ b/partiql-types/api/partiql-types.api @@ -448,6 +448,12 @@ public final class org/partiql/types/DecimalType$PrecisionScaleConstraint$Uncons public fun matches (Ljava/math/BigDecimal;)Z } +public abstract interface class org/partiql/types/Field { + public abstract fun getName ()Ljava/lang/String; + public abstract fun getType ()Lorg/partiql/types/PType; + public static fun of (Ljava/lang/String;Lorg/partiql/types/PType;)Lorg/partiql/types/Field; +} + public final class org/partiql/types/FloatType : org/partiql/types/SingleType { public fun ()V public fun (Ljava/util/Map;)V @@ -572,6 +578,81 @@ public final class org/partiql/types/NumberConstraint$UpTo : org/partiql/types/N public fun toString ()Ljava/lang/String; } +public abstract interface class org/partiql/types/PType { + public static fun fromPartiQLValueType (Lorg/partiql/value/PartiQLValueType;)Lorg/partiql/types/PType; + public static fun fromStaticType (Lorg/partiql/types/StaticType;)Lorg/partiql/types/PType; + public fun getFields ()Ljava/util/Collection; + public abstract fun getKind ()Lorg/partiql/types/PType$Kind; + public fun getLength ()I + public fun getPrecision ()I + public fun getScale ()I + public fun getTypeParameter ()Lorg/partiql/types/PType; + public static fun typeBag ()Lorg/partiql/types/PType; + public static fun typeBag (Lorg/partiql/types/PType;)Lorg/partiql/types/PType; + public static fun typeBigInt ()Lorg/partiql/types/PType; + public static fun typeBlob (I)Lorg/partiql/types/PType; + public static fun typeBool ()Lorg/partiql/types/PType; + public static fun typeChar (I)Lorg/partiql/types/PType; + public static fun typeClob (I)Lorg/partiql/types/PType; + public static fun typeDate ()Lorg/partiql/types/PType; + public static fun typeDecimal (II)Lorg/partiql/types/PType; + public static fun typeDecimalArbitrary ()Lorg/partiql/types/PType; + public static fun typeDoublePrecision ()Lorg/partiql/types/PType; + public static fun typeDynamic ()Lorg/partiql/types/PType; + public static fun typeInt ()Lorg/partiql/types/PType; + public static fun typeIntArbitrary ()Lorg/partiql/types/PType; + public static fun typeList ()Lorg/partiql/types/PType; + public static fun typeList (Lorg/partiql/types/PType;)Lorg/partiql/types/PType; + public static fun typeReal ()Lorg/partiql/types/PType; + public static fun typeRow (Ljava/util/Collection;)Lorg/partiql/types/PType; + public static fun typeSexp ()Lorg/partiql/types/PType; + public static fun typeSexp (Lorg/partiql/types/PType;)Lorg/partiql/types/PType; + public static fun typeSmallInt ()Lorg/partiql/types/PType; + public static fun typeString ()Lorg/partiql/types/PType; + public static fun typeStruct ()Lorg/partiql/types/PType; + public static fun typeSymbol ()Lorg/partiql/types/PType; + public static fun typeTimeWithTZ (I)Lorg/partiql/types/PType; + public static fun typeTimeWithoutTZ (I)Lorg/partiql/types/PType; + public static fun typeTimestampWithTZ (I)Lorg/partiql/types/PType; + public static fun typeTimestampWithoutTZ (I)Lorg/partiql/types/PType; + public static fun typeTinyInt ()Lorg/partiql/types/PType; + public static fun typeUnknown ()Lorg/partiql/types/PType; + public static fun typeVarChar (I)Lorg/partiql/types/PType; +} + +public final class org/partiql/types/PType$Kind : java/lang/Enum { + public static final field BAG Lorg/partiql/types/PType$Kind; + public static final field BIGINT Lorg/partiql/types/PType$Kind; + public static final field BLOB Lorg/partiql/types/PType$Kind; + public static final field BOOL Lorg/partiql/types/PType$Kind; + public static final field CHAR Lorg/partiql/types/PType$Kind; + public static final field CLOB Lorg/partiql/types/PType$Kind; + public static final field DATE Lorg/partiql/types/PType$Kind; + public static final field DECIMAL Lorg/partiql/types/PType$Kind; + public static final field DECIMAL_ARBITRARY Lorg/partiql/types/PType$Kind; + public static final field DOUBLE_PRECISION Lorg/partiql/types/PType$Kind; + public static final field DYNAMIC Lorg/partiql/types/PType$Kind; + public static final field INT Lorg/partiql/types/PType$Kind; + public static final field INT_ARBITRARY Lorg/partiql/types/PType$Kind; + public static final field LIST Lorg/partiql/types/PType$Kind; + public static final field REAL Lorg/partiql/types/PType$Kind; + public static final field ROW Lorg/partiql/types/PType$Kind; + public static final field SEXP Lorg/partiql/types/PType$Kind; + public static final field SMALLINT Lorg/partiql/types/PType$Kind; + public static final field STRING Lorg/partiql/types/PType$Kind; + public static final field STRUCT Lorg/partiql/types/PType$Kind; + public static final field SYMBOL Lorg/partiql/types/PType$Kind; + public static final field TIMESTAMP_WITHOUT_TZ Lorg/partiql/types/PType$Kind; + public static final field TIMESTAMP_WITH_TZ Lorg/partiql/types/PType$Kind; + public static final field TIME_WITHOUT_TZ Lorg/partiql/types/PType$Kind; + public static final field TIME_WITH_TZ Lorg/partiql/types/PType$Kind; + public static final field TINYINT Lorg/partiql/types/PType$Kind; + public static final field UNKNOWN Lorg/partiql/types/PType$Kind; + public static final field VARCHAR Lorg/partiql/types/PType$Kind; + public static fun valueOf (Ljava/lang/String;)Lorg/partiql/types/PType$Kind; + public static fun values ()[Lorg/partiql/types/PType$Kind; +} + public final class org/partiql/types/SexpType : org/partiql/types/CollectionType { public fun ()V public fun (Lorg/partiql/types/StaticType;Ljava/util/Map;Ljava/util/Set;)V diff --git a/partiql-types/src/main/java/org/partiql/types/Field.java b/partiql-types/src/main/java/org/partiql/types/Field.java new file mode 100644 index 0000000000..7ad6ae0b2a --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/Field.java @@ -0,0 +1,44 @@ +package org.partiql.types; + +import org.jetbrains.annotations.NotNull; + +/** + * This represents a field of a structured type. + */ +public interface Field { + @NotNull + public String getName(); + + @NotNull + public PType getType(); + + + /** + * Returns a simple implementation of {@link Field}. + * @param name the key of the struct field + * @param type the type of the struct field + * @return a field containing the name and type + */ + static Field of(@NotNull String name, @NotNull PType type) { + return new Field() { + @NotNull + @Override + public String getName() { + return name; + } + + @NotNull + @Override + public PType getType() { + return type; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof Field)) return false; + return name.equals(((Field) o).getName()) && type.equals(((Field) o).getType()); + } + }; + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/PType.java b/partiql-types/src/main/java/org/partiql/types/PType.java new file mode 100644 index 0000000000..70c15cf616 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/PType.java @@ -0,0 +1,849 @@ +package org.partiql.types; + +import org.jetbrains.annotations.NotNull; +import org.partiql.value.PartiQLValueType; + +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; + +/** + * This represents a PartiQL type, whether it be a PartiQL primitive or user-defined. + *

+ * This implementation allows for parameterization of the core type ({@link Kind}) while allowing for APIs + * to access their parameters ({@link PType#getPrecision()}, {@link PType#getTypeParameter()}, etc.) + *

+ * Before using these methods, please be careful to read each method's documentation to ensure that it applies to the current + * {@link PType#getKind()}. If one carelessly invokes the wrong method, an {@link UnsupportedOperationException} will be + * thrown. + *

+ * This representation of a PartiQL type is intentionally modeled as a "fat" interface -- holding all methods relevant + * to any of the types. The maintainers of PartiQL have seen an unintentional reliance on Java's type semantics that + * make it cumbersome (with explicit Java casts) to gain access to methods. This modeling makes it simpler for the + * PartiQL planner to have immediate access to the available type's parameters. + *

+ * Users should NOT author their own implementation. The current recommendation is to use the static methods + * (exposed by this interface) to instantiate a type. + */ +public interface PType { + + /** + * Dictates the associates {@link Kind} of this instance. This method should be called and its return should be + * analyzed before calling any other method. For example: + *

+ * {@code + * public int getPrecisionOrNull(PType type) { + * if (type.base == {@link Kind#DECIMAL}) { + * return type.getPrecision(); + * } + * return null; + * } + * } + * @return the corresponding PartiQL {@link Kind}. + */ + @NotNull + Kind getKind(); + + /** + * The fields of the type + * @throws UnsupportedOperationException if this is called on a type whose {@link Kind} is not: + * {@link Kind#ROW} + */ + @NotNull + default Collection getFields() throws UnsupportedOperationException { + throw new UnsupportedOperationException(); + } + + /** + * The decimal precision of the type + * @return decimal precision + * @throws UnsupportedOperationException if this is called on a type whose {@link Kind} is not: + * {@link Kind#DECIMAL}, {@link Kind#TIMESTAMP_WITH_TZ}, {@link Kind#TIMESTAMP_WITHOUT_TZ}, {@link Kind#TIME_WITH_TZ}, + * {@link Kind#TIME_WITHOUT_TZ}, {@link Kind#REAL}, {@link Kind#DOUBLE_PRECISION} + */ + default int getPrecision() throws UnsupportedOperationException { + throw new UnsupportedOperationException(); + } + + /** + * The max length of the type + * @return max length of a type + * @throws UnsupportedOperationException if this is called on a type whose {@link Kind} is not: + * {@link Kind#CHAR}, {@link Kind#CLOB}, {@link Kind#BLOB} + */ + default int getLength() throws UnsupportedOperationException { + throw new UnsupportedOperationException(); + } + + /** + * The scale of the type. Example: DECIMAL(<param>, <scale>) + * @return the scale of the type + * @throws UnsupportedOperationException if this is called on a type whose {@link Kind} is not: + * {@link Kind#DECIMAL} + */ + default int getScale() throws UnsupportedOperationException { + throw new UnsupportedOperationException(); + } + + /** + * The type parameter of the type. Example: BAG(<param>) + * @return type parameter of the type + * @throws UnsupportedOperationException if this is called on a type whose {@link Kind} is not: + * {@link Kind#LIST}, {@link Kind#BAG}, {@link Kind#SEXP} + */ + @NotNull + default PType getTypeParameter() throws UnsupportedOperationException { + throw new UnsupportedOperationException(); + } + + /** + * PartiQL Core Type Kinds + *

+ * Each of these types correspond with a subset of APIs established in {@link PType}. Each of these can be seen as + * a category of types, distinguished only by the APIs available to them. For instance, all instances of {@link Kind#DECIMAL} + * may utilize {@link PType#getPrecision()} (and may return different results), however, they may never return a + * valid value for {@link PType#getFields()}. Consumers of this API should be careful to read the documentation + * for each API exposed in {@link PType} before using them. + *

+ * Future additions may add enums such as INTERVAL_YEAR_MONTH, INTERVAL_DAY_TIME, and more. + * @see PType + */ + enum Kind { + + /** + * PartiQL's dynamic type. + *
+ *
+ * Type Syntax: DYNAMIC + *
+ * Applicable methods: NONE + */ + DYNAMIC, + + /** + * SQL's boolean type. + *
+ *
+ * Type Syntax: BOOL, BOOLEAN + *
+ * Applicable methods: NONE + */ + BOOL, + + /** + * PartiQL's tiny integer type. + *
+ *
+ * Type Syntax: TINYINT + *
+ * Applicable methods: {@link PType#getPrecision()}, {@link PType#getScale()} + */ + TINYINT, + + /** + * SQL's small integer type. + *
+ *
+ * Type Syntax: SMALLINT + *
+ * Applicable methods: {@link PType#getPrecision()}, {@link PType#getScale()} + */ + SMALLINT, + + /** + * SQL's integer type. + *
+ *
+ * Type Syntax: INT, INTEGER + *
+ * Applicable methods: {@link PType#getPrecision()}, {@link PType#getScale()} + */ + INT, + + /** + * PartiQL's big integer type. + *
+ *
+ * Type Syntax: BIGINT + *
+ * Applicable methods: {@link PType#getPrecision()}, {@link PType#getScale()} + */ + BIGINT, + + /** + * PartiQL's big integer type. + *
+ *
+ * Type Syntax: TO_BE_DETERMINED + *
+ * Applicable methods: NONE + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + @Deprecated + INT_ARBITRARY, + + /** + * SQL's decimal type. + *
+ *
+ * Type Syntax: DECIMAL(<precision>, <scale>), DECIMAL(<precision>) + *
+ * Applicable methods: {@link PType#getPrecision()}, {@link PType#getScale()} + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + DECIMAL, + + /** + * Ion's arbitrary precision and scale decimal type. + *
+ *
+ * Type Syntax: TO_BE_DETERMINED + *
+ * Applicable methods: NONE + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + @Deprecated + DECIMAL_ARBITRARY, + + /** + * SQL's real type. + *
+ *
+ * Type Syntax: REAL + *
+ * Applicable methods: {@link PType#getPrecision()} + */ + REAL, + + /** + * SQL's double precision type. + *
+ *
+ * Type Syntax: DOUBLE PRECISION + *
+ * Applicable methods: {@link PType#getPrecision()} + */ + DOUBLE_PRECISION, + + /** + * SQL's character type. + *
+ *
+ * Type Syntax: CHAR(<length>), CHARACTER(<length>), CHAR, CHARACTER + *
+ * Applicable methods: {@link PType#getLength()} + */ + CHAR, + + /** + * SQL's character varying type. + *
+ *
+ * Type Syntax: VARCHAR(<length>), CHAR VARYING(<length>), + * CHARACTER VARYING(<length>), + * VARCHAR, CHAR VARYING, CHARACTER VARYING + *
+ * Applicable methods: {@link PType#getLength()} + */ + VARCHAR, + + /** + * PartiQL's string type. + *
+ *
+ * Type Syntax: TO_BE_DETERMINED + *
+ * Applicable methods: NONE + */ + STRING, + + /** + * Ion's symbol type. + *
+ *
+ * Type Syntax: TO_BE_DETERMINED + *
+ * Applicable methods: NONE + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + @Deprecated + SYMBOL, + + /** + * SQL's blob type. + *
+ *
+ * Type Syntax: BLOB, BLOB(<large object length>), + * BINARY LARGE OBJECT, BINARY LARGE OBJECT(<large object length>) + *
+ * Applicable methods: {@link PType#getLength()} + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + @Deprecated + BLOB, + + /** + * SQL's clob type. + *
+ *
+ * Type Syntax: CLOB, CLOB(<large object length>), + * CHAR LARGE OBJECT, CHAR LARGE OBJECT(<large object length>) + * CHARACTER LARGE OBJECT, CHARACTER LARGE OBJECT(<large object length>) + *
+ * Applicable methods: {@link PType#getLength()} + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + @Deprecated + CLOB, + + /** + * SQL's date type. + *
+ *
+ * Type Syntax: DATE + *
+ * Applicable methods: NONE + */ + DATE, + + /** + * SQL's time with timezone type. + *
+ *
+ * Type Syntax: TIME WITH TIME ZONE, TIME(<precision>) WITH TIME ZONE + *
+ * Applicable methods: NONE + */ + TIME_WITH_TZ, + + /** + * SQL's time without timezone type. + *
+ *
+ * Type Syntax: TIME, TIME WITHOUT TIME ZONE, + * TIME(<precision>), TIME(<precision>) WITHOUT TIME ZONE + *
+ * Applicable methods: NONE + */ + TIME_WITHOUT_TZ, + + /** + * SQL's timestamp with timezone type. + *
+ *
+ * Type Syntax: TIMESTAMP WITH TIME ZONE, TIMESTAMP(<precision>) WITH TIME ZONE + *
+ * Applicable methods: NONE + */ + TIMESTAMP_WITH_TZ, + + /** + * SQL's timestamp without timezone type. + *
+ *
+ * Type Syntax: TIMESTAMP, TIMESTAMP WITHOUT TIME ZONE, + * TIMESTAMP(<precision>), TIMESTAMP(<precision>) WITHOUT TIME ZONE + *
+ * Applicable methods: NONE + */ + TIMESTAMP_WITHOUT_TZ, + + /** + * PartiQL's bag type. There is no size limit. + *
+ *
+ * Type Syntax: BAG, BAG(<type>) + *
+ * Applicable methods: + * {@link PType#getTypeParameter()} + */ + BAG, + + /** + * Ion's list type. There is no size limit. + *
+ *
+ * Type Syntax: LIST, LIST(<type>) + *
+ * Applicable methods: + * {@link PType#getTypeParameter()} + */ + LIST, + + /** + * SQL's row type. Characterized as a closed, ordered collection of fields. + *
+ *
+ * Type Syntax: ROW(<str>: <type>, ...) + *
+ * Applicable methods: + * {@link PType#getFields()} + * + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + @Deprecated + ROW, + + /** + * Ion's s-expression type. There is no size limit. + *
+ *
+ * Type Syntax: SEXP, SEXP(<type>) + *
+ * Applicable methods: + * {@link PType#getTypeParameter()} + * + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + @Deprecated + SEXP, + + /** + * Ion's struct type. Characterized as an open, unordered collection of fields (duplicates allowed). + *
+ *
+ * Type Syntax: STRUCT + *
+ * Applicable methods: NONE + */ + STRUCT, + + /** + * PartiQL's unknown type. This temporarily represents literal null and missing values. + *
+ *
+ * Type Syntax: NONE + *
+ * Applicable methods: NONE + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + @Deprecated + UNKNOWN + } + + /** + * @return a PartiQL dynamic type + */ + @NotNull + static PType typeDynamic() { + return new PTypePrimitive(Kind.DYNAMIC); + } + + /** + * @return a PartiQL list type with a component type of dynamic + */ + @NotNull + static PType typeList() { + return new PTypeCollection(Kind.LIST, PType.typeDynamic()); + } + + /** + * @return a PartiQL list type with a component type of {@code typeParam} + */ + @NotNull + static PType typeList(@NotNull PType typeParam) { + return new PTypeCollection(Kind.LIST, typeParam); + } + + /** + * @return a PartiQL bag type with a component type of dynamic + */ + @NotNull + static PType typeBag() { + return new PTypeCollection(Kind.BAG, PType.typeDynamic()); + } + + /** + * @return a PartiQL bag type with a component type of {@code typeParam} + */ + @NotNull + static PType typeBag(@NotNull PType typeParam) { + return new PTypeCollection(Kind.BAG, typeParam); + } + + /** + * @return a PartiQL sexp type containing a component type of dynamic. + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + @Deprecated + @NotNull + static PType typeSexp() { + return new PTypeCollection(Kind.SEXP, PType.typeDynamic()); + } + + /** + * + * @param typeParam the component type to be used + * @return a PartiQL sexp type containing a component type of {@code typeParam}. + * @deprecated this is an experimental API and is subject to modification/deletion without prior notice. + */ + @NotNull + static PType typeSexp(@NotNull PType typeParam) { + return new PTypeCollection(Kind.SEXP, typeParam); + } + + /** + * @return a PartiQL boolean type + */ + @NotNull + static PType typeBool() { + return new PTypePrimitive(Kind.BOOL); + } + + /** + * @return a PartiQL real type. + */ + @NotNull + static PType typeReal() { + return new PTypePrimitive(Kind.REAL); + } + + /** + * @return a PartiQL double precision type + */ + @NotNull + static PType typeDoublePrecision() { + return new PTypePrimitive(Kind.DOUBLE_PRECISION); + } + + /** + * @return a PartiQL tiny integer type + */ + @NotNull + static PType typeTinyInt() { + return new PTypePrimitive(Kind.TINYINT); + } + + /** + * @return a PartiQL small integer type + */ + @NotNull + static PType typeSmallInt() { + return new PTypePrimitive(Kind.SMALLINT); + } + + /** + * @return a PartiQL integer type + */ + @NotNull + static PType typeInt() { + return new PTypePrimitive(Kind.INT); + } + + /** + * @return a PartiQL big integer type + */ + @NotNull + static PType typeBigInt() { + return new PTypePrimitive(Kind.BIGINT); + } + + /** + * @return a PartiQL int (arbitrary precision) type + * @deprecated this API is experimental and is subject to modification/deletion without prior notice. + */ + @NotNull + @Deprecated + static PType typeIntArbitrary() { + return new PTypePrimitive(Kind.INT_ARBITRARY); + } + + /** + * @return a PartiQL decimal (arbitrary precision/scale) type + * @deprecated this API is experimental and is subject to modification/deletion without prior notice. + */ + @NotNull + static PType typeDecimalArbitrary() { + return new PTypePrimitive(Kind.DECIMAL_ARBITRARY); + } + + /** + * @return a PartiQL decimal type + */ + @NotNull + static PType typeDecimal(int precision, int scale) { + return new PTypeDecimal(precision, scale); + } + + /** + * @return a PartiQL row type + * @deprecated this API is experimental and is subject to modification/deletion without prior notice. + */ + @NotNull + static PType typeRow(@NotNull Collection fields) { + return new PTypeRow(fields); + } + + /** + * @return a PartiQL struct type + */ + @NotNull + static PType typeStruct() { + return new PTypePrimitive(Kind.STRUCT); + } + + /** + * @return a PartiQL timestamp with timezone type + */ + @NotNull + static PType typeTimestampWithTZ(int precision) { + return new PTypeWithPrecisionOnly(Kind.TIMESTAMP_WITH_TZ, precision); + } + + /** + * @return a PartiQL timestamp without timezone type + */ + @NotNull + static PType typeTimestampWithoutTZ(int precision) { + return new PTypeWithPrecisionOnly(Kind.TIMESTAMP_WITHOUT_TZ, precision); + } + + /** + * @return a PartiQL time with timezone type + */ + @NotNull + static PType typeTimeWithTZ(int precision) { + return new PTypeWithPrecisionOnly(Kind.TIME_WITH_TZ, precision); + } + + /** + * @return a PartiQL time without timezone type + */ + @NotNull + static PType typeTimeWithoutTZ(int precision) { + return new PTypeWithPrecisionOnly(Kind.TIME_WITHOUT_TZ, precision); + } + + /** + * @return a PartiQL string type + */ + @NotNull + static PType typeString() { + return new PTypePrimitive(Kind.STRING); + } + + /** + * @return a PartiQL string type + * @deprecated this API is experimental and is subject to modification/deletion without prior notice. + */ + @NotNull + @Deprecated + static PType typeSymbol() { + return new PTypePrimitive(Kind.SYMBOL); + } + + /** + * @return a PartiQL blob type + * @deprecated this API is experimental and is subject to modification/deletion without prior notice. + */ + @NotNull + static PType typeBlob(int length) { + return new PTypeWithMaxLength(Kind.BLOB, length); + } + + /** + * @return a PartiQL clob type + * @deprecated this API is experimental and is subject to modification/deletion without prior notice. + */ + @NotNull + static PType typeClob(int length) { + return new PTypeWithMaxLength(Kind.CLOB, length); + } + + /** + * @return a PartiQL char type + */ + @NotNull + static PType typeChar(int length) { + return new PTypeWithMaxLength(Kind.CHAR, length); + } + + /** + * @return a PartiQL char type + */ + @NotNull + static PType typeVarChar(int length) { + return new PTypeWithMaxLength(Kind.CHAR, length); + } + + /** + * @return a PartiQL date type + */ + @NotNull + static PType typeDate() { + return new PTypePrimitive(Kind.DATE); + } + + /** + * @return a PartiQL unknown type + * @deprecated this API is experimental and is subject to modification/deletion without prior notice. + */ + @NotNull + static PType typeUnknown() { + return new PTypePrimitive(Kind.UNKNOWN); + } + + /** + * @return a corresponding PType from a {@link PartiQLValueType} + * @deprecated this API is experimental and is subject to modification/deletion without prior notice. This is + * meant for use internally by the PartiQL library. Public consumers should not use this API. + */ + @NotNull + static PType fromPartiQLValueType(@NotNull PartiQLValueType type) { + switch (type) { + case DECIMAL: + case DECIMAL_ARBITRARY: + return PType.typeDecimalArbitrary(); + case INT8: + return PType.typeTinyInt(); + case CHAR: + return PType.typeChar(255); + case TIMESTAMP: + return PType.typeTimestampWithTZ(6); + case DATE: + return PType.typeDate(); + case BOOL: + return PType.typeBool(); + case SYMBOL: + return PType.typeSymbol(); + case STRING: + return PType.typeString(); + case STRUCT: + return PType.typeStruct(); + case SEXP: + return PType.typeSexp(); + case LIST: + return PType.typeList(); + case BAG: + return PType.typeBag(); + case FLOAT32: + return PType.typeReal(); + case INT: + return PType.typeIntArbitrary(); + case INT64: + return PType.typeBigInt(); + case INT32: + return PType.typeInt(); + case INT16: + return PType.typeSmallInt(); + case TIME: + return PType.typeTimeWithoutTZ(6); + case ANY: + return PType.typeDynamic(); + case FLOAT64: + return PType.typeDoublePrecision(); + case CLOB: + return PType.typeClob(Integer.MAX_VALUE); + case BLOB: + return PType.typeBlob(Integer.MAX_VALUE); + + // TODO: Is this allowed? This is specifically for literals + case NULL: + case MISSING: + return PType.typeUnknown(); + + // Unsupported types + case INTERVAL: + case BYTE: + case BINARY: + return PType.typeDynamic(); // TODO: REMOVE THIS + default: + throw new IllegalStateException(); + } + } + + /** + * @return a corresponding PType from a {@link StaticType} + * @deprecated this API is experimental and is subject to modification/deletion without prior notice. This is + * meant for use internally by the PartiQL library. Public consumers should not use this API. + */ + @NotNull + @Deprecated + static PType fromStaticType(@NotNull StaticType type) { + if (type instanceof AnyType) { + return PType.typeDynamic(); + } else if (type instanceof AnyOfType) { + HashSet allTypes = new HashSet<>(type.flatten().getAllTypes()); + if (allTypes.isEmpty()) { + return PType.typeDynamic(); + } else if (allTypes.size() == 1) { + return fromStaticType(allTypes.stream().findFirst().get()); + } else { + return PType.typeDynamic(); + } +// if (allTypes.stream().allMatch((subType) -> subType instanceof CollectionType)) {} + } else if (type instanceof BagType) { + PType elementType = fromStaticType(((BagType) type).getElementType()); + return PType.typeBag(elementType); + } else if (type instanceof BlobType) { + return PType.typeBlob(Integer.MAX_VALUE); // TODO: Update this + } else if (type instanceof BoolType) { + return PType.typeBool(); + } else if (type instanceof ClobType) { + return PType.typeClob(Integer.MAX_VALUE); // TODO: Update this + } else if (type instanceof DateType) { + return PType.typeDate(); + } else if (type instanceof DecimalType) { + DecimalType.PrecisionScaleConstraint precScale = ((DecimalType) type).getPrecisionScaleConstraint(); + if (precScale instanceof DecimalType.PrecisionScaleConstraint.Unconstrained) { + return PType.typeDecimalArbitrary(); + } else if (precScale instanceof DecimalType.PrecisionScaleConstraint.Constrained) { + DecimalType.PrecisionScaleConstraint.Constrained precisionScaleConstraint = (DecimalType.PrecisionScaleConstraint.Constrained) precScale; + return PType.typeDecimal(precisionScaleConstraint.getPrecision(), precisionScaleConstraint.getScale()); + } else { + throw new IllegalStateException(); + } + } else if (type instanceof FloatType) { + return PType.typeDoublePrecision(); + } else if (type instanceof IntType) { + IntType.IntRangeConstraint cons = ((IntType) type).getRangeConstraint(); + if (cons == IntType.IntRangeConstraint.INT4) { + return PType.typeInt(); + } else if (cons == IntType.IntRangeConstraint.SHORT) { + return PType.typeSmallInt(); + } else if (cons == IntType.IntRangeConstraint.LONG) { + return PType.typeBigInt(); + } else if (cons == IntType.IntRangeConstraint.UNCONSTRAINED) { + return PType.typeIntArbitrary(); + } else { + throw new IllegalStateException(); + } + } else if (type instanceof ListType) { + PType elementType = fromStaticType(((ListType) type).getElementType()); + return PType.typeList(elementType); + } else if (type instanceof SexpType) { + PType elementType = fromStaticType(((SexpType) type).getElementType()); + return PType.typeSexp(elementType); + } else if (type instanceof StringType) { + return PType.typeString(); + } else if (type instanceof StructType) { + boolean isOrdered = ((StructType) type).getConstraints().contains(TupleConstraint.Ordered.INSTANCE); + boolean isClosed = ((StructType) type).getContentClosed(); + List fields = ((StructType) type).getFields().stream().map((field) -> Field.of(field.getKey(), PType.fromStaticType(field.getValue()))).collect(Collectors.toList()); + if (isClosed && isOrdered) { + return PType.typeRow(fields); + } else if (isClosed) { + return PType.typeRow(fields); // TODO: We currently use ROW when closed. + } else { + return PType.typeStruct(); + } + } else if (type instanceof SymbolType) { + return PType.typeSymbol(); + } else if (type instanceof TimeType) { + Integer precision = ((TimeType) type).getPrecision(); + if (precision == null) { + precision = 6; + } + return PType.typeTimeWithoutTZ(precision); + } else if (type instanceof TimestampType) { + Integer precision = ((TimestampType) type).getPrecision(); + if (precision == null) { + precision = 6; + } + return PType.typeTimestampWithoutTZ(precision); + } else { + throw new IllegalStateException("Unsupported type: " + type); + } + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/PTypeCollection.java b/partiql-types/src/main/java/org/partiql/types/PTypeCollection.java new file mode 100644 index 0000000000..ce1df809d4 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/PTypeCollection.java @@ -0,0 +1,48 @@ +package org.partiql.types; + +import org.jetbrains.annotations.NotNull; + +import java.util.Objects; + +class PTypeCollection implements PType { + + @NotNull + final PType _typeParam; + + @NotNull + final Kind _kind; + + PTypeCollection(@NotNull Kind base, @NotNull PType typeParam) { + _kind = base; + _typeParam = typeParam; + } + + @NotNull + @Override + public PType getTypeParameter() { + return _typeParam; + } + + @NotNull + @Override + public Kind getKind() { + return _kind; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof PType)) return false; + return ((PType) o).getKind() == this._kind && ((PType) o).getTypeParameter().equals(_typeParam); + } + + @Override + public String toString() { + return _kind.name() + "(" + _typeParam + ")"; + } + + @Override + public int hashCode() { + return Objects.hash(_kind, _typeParam); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/PTypeDecimal.java b/partiql-types/src/main/java/org/partiql/types/PTypeDecimal.java new file mode 100644 index 0000000000..6b81cd7a01 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/PTypeDecimal.java @@ -0,0 +1,48 @@ +package org.partiql.types; + +import org.jetbrains.annotations.NotNull; + +import java.util.Objects; + +class PTypeDecimal implements PType { + final int _precision; + final int _scale; + + PTypeDecimal(int precision, int scale) { + _precision = precision; + _scale = scale; + } + + @NotNull + @Override + public Kind getKind() { + return Kind.DECIMAL; + } + + @Override + public int getPrecision() { + return _precision; + } + + @Override + public int getScale() { + return _scale; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof PType)) return false; + return ((PType) o).getKind() == Kind.DECIMAL && _precision == ((PType) o).getPrecision() && _scale == ((PType) o).getScale(); + } + + @Override + public String toString() { + return Kind.DECIMAL.name() + "(" + _precision + ", " + "_" + _scale + ")"; + } + + @Override + public int hashCode() { + return Objects.hash(Kind.DECIMAL, _precision, _scale); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/PTypePrimitive.java b/partiql-types/src/main/java/org/partiql/types/PTypePrimitive.java new file mode 100644 index 0000000000..ea92bb213b --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/PTypePrimitive.java @@ -0,0 +1,38 @@ +package org.partiql.types; + +import org.jetbrains.annotations.NotNull; + +import java.util.Objects; + +class PTypePrimitive implements PType { + + @NotNull + final Kind _kind; + + PTypePrimitive(@NotNull Kind type) { + _kind = type; + } + + @NotNull + @Override + public Kind getKind() { + return _kind; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof PType)) return false; + return _kind == ((PType) o).getKind(); + } + + @Override + public String toString() { + return _kind.name(); + } + + @Override + public int hashCode() { + return Objects.hashCode(_kind); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/PTypeRow.java b/partiql-types/src/main/java/org/partiql/types/PTypeRow.java new file mode 100644 index 0000000000..419b3b87c2 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/PTypeRow.java @@ -0,0 +1,68 @@ +package org.partiql.types; + +import org.jetbrains.annotations.NotNull; + +import java.util.Collection; +import java.util.Iterator; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * Applicable to {@link PType.Kind#ROW}. + */ +class PTypeRow implements PType { + + final Collection _fields; + + PTypeRow(@NotNull Collection fields) { + _fields = fields; + } + + @NotNull + @Override + public Kind getKind() { + return Kind.ROW; + } + + @NotNull + @Override + public Collection getFields() { + return _fields; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof PType)) return false; + if (Kind.ROW != ((PType) o).getKind()) { + return false; + } + Collection otherFields = ((PType) o).getFields(); + int size = _fields.size(); + if (size != otherFields.size()) { + return false; + } + Iterator thisIter = _fields.iterator(); + Iterator otherIter = otherFields.iterator(); + for (int i = 0; i < size; i++) { + Field thisField = thisIter.next(); + Field otherField = otherIter.next(); + if (!thisField.equals(otherField)) { + return false; + } + } + return true; + } + + @Override + public String toString() { + Collection fieldStringList = _fields.stream().map((f) -> f.getName() + ": " + f.getType()).collect(Collectors.toList()); + String fieldStrings = String.join(", ", fieldStringList); + return Kind.ROW.name() + "(" + fieldStrings + ")"; + } + + @Override + public int hashCode() { + return Objects.hash(Kind.ROW, _fields); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/PTypeWithMaxLength.java b/partiql-types/src/main/java/org/partiql/types/PTypeWithMaxLength.java new file mode 100644 index 0000000000..7c67044227 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/PTypeWithMaxLength.java @@ -0,0 +1,45 @@ +package org.partiql.types; + +import org.jetbrains.annotations.NotNull; + +import java.util.Objects; + +class PTypeWithMaxLength implements PType { + + final int _maxLength; + + final Kind _kind; + + PTypeWithMaxLength(@NotNull Kind type, int maxLength) { + _kind = type; + _maxLength = maxLength; + } + + @NotNull + @Override + public Kind getKind() { + return _kind; + } + + @Override + public int getLength() { + return _maxLength; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof PType)) return false; + return _kind == ((PType) o).getKind() && _maxLength == ((PType) o).getLength(); + } + + @Override + public String toString() { + return _kind.name() + "(" + _maxLength + ")"; + } + + @Override + public int hashCode() { + return Objects.hash(_kind, _maxLength); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/PTypeWithPrecisionOnly.java b/partiql-types/src/main/java/org/partiql/types/PTypeWithPrecisionOnly.java new file mode 100644 index 0000000000..05b94dc77a --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/PTypeWithPrecisionOnly.java @@ -0,0 +1,46 @@ +package org.partiql.types; + +import org.jetbrains.annotations.NotNull; + +import java.util.Objects; + +class PTypeWithPrecisionOnly implements PType { + + final int _precision; + + @NotNull + final Kind _kind; + + PTypeWithPrecisionOnly(@NotNull Kind base, int precision) { + _precision = precision; + _kind = base; + } + + @NotNull + @Override + public Kind getKind() { + return _kind; + } + + @Override + public int getPrecision() { + return _precision; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof PType)) return false; + return _kind == ((PType) o).getKind() && _precision == ((PType) o).getPrecision(); + } + + @Override + public String toString() { + return _kind.name() + "(" + _precision + ")"; + } + + @Override + public int hashCode() { + return Objects.hash(_kind, _precision); + } +} diff --git a/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt b/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt index 6f8d357d28..eab8096dab 100644 --- a/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt +++ b/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt @@ -25,6 +25,7 @@ import org.partiql.spi.BindingName import org.partiql.spi.BindingPath import org.partiql.spi.connector.Connector import org.partiql.spi.connector.ConnectorSession +import org.partiql.types.PType import org.partiql.types.StaticType import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental @@ -156,7 +157,7 @@ class EvalExecutor( * @return */ private fun infer(env: StructElement): Connector { - val map = mutableMapOf() + val map = mutableMapOf() env.fields.forEach { map[it.name] = inferEnv(it.value) } @@ -165,7 +166,7 @@ class EvalExecutor( return MemoryConnector(catalog) } - private fun inferEnv(env: AnyElement): StaticType { + private fun inferEnv(env: AnyElement): PType { val catalog = MemoryCatalog.PartiQL().name("conformance_test").build() val connector = MemoryConnector(catalog) val session = PartiQLPlanner.Session(