diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java index 5baf998345..f220c2265b 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java @@ -115,7 +115,10 @@ public static SqlTypeName convertRelDataTypeToSqlTypeName(RelDataType type) { case EXPR_DATE -> SqlTypeName.DATE; case EXPR_TIME -> SqlTypeName.TIME; case EXPR_TIMESTAMP -> SqlTypeName.TIMESTAMP; - case EXPR_IP -> SqlTypeName.VARCHAR; + // EXPR_IP is mapped to SqlTypeName.NULL since there is no + // corresponding SqlTypeName in Calcite. This is a workaround to allow + // type checking for IP types in UDFs. + case EXPR_IP -> SqlTypeName.NULL; case EXPR_BINARY -> SqlTypeName.VARBINARY; default -> type.getSqlTypeName(); }; diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index 1b3a2b7dba..f295b07022 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -71,6 +71,7 @@ import org.opensearch.sql.expression.function.udf.datetime.WeekdayFunction; import org.opensearch.sql.expression.function.udf.datetime.YearweekFunction; import org.opensearch.sql.expression.function.udf.ip.CidrMatchFunction; +import org.opensearch.sql.expression.function.udf.ip.CompareIpFunction; import org.opensearch.sql.expression.function.udf.math.CRC32Function; import org.opensearch.sql.expression.function.udf.math.ConvFunction; import org.opensearch.sql.expression.function.udf.math.DivideFunction; @@ -102,6 +103,15 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { public static final SqlOperator SHA2 = CryptographicFunction.sha2().toUDF("SHA2"); public static final SqlOperator CIDRMATCH = new CidrMatchFunction().toUDF("CIDRMATCH"); + // IP comparing functions + public static final SqlOperator NOT_EQUALS_IP = + CompareIpFunction.notEquals().toUDF("NOT_EQUALS_IP"); + public static final SqlOperator EQUALS_IP = CompareIpFunction.equals().toUDF("EQUALS_IP"); + public static final SqlOperator GREATER_IP = CompareIpFunction.greater().toUDF("GREATER_IP"); + public static final SqlOperator GTE_IP = CompareIpFunction.greaterOrEquals().toUDF("GTE_IP"); + public static final SqlOperator LESS_IP = CompareIpFunction.less().toUDF("LESS_IP"); + public static final SqlOperator LTE_IP = CompareIpFunction.lessOrEquals().toUDF("LTE_IP"); + // Condition function public static final SqlOperator EARLIEST = new EarliestFunction().toUDF("EARLIEST"); public static final SqlOperator LATEST = new LatestFunction().toUDF("LATEST"); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 0019e6819d..981f3be585 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -614,6 +614,14 @@ public PPLTypeChecker getTypeChecker() { } void populate() { + // register operators for IP comparing + registerOperator(NOTEQUAL, PPLBuiltinOperators.NOT_EQUALS_IP); + registerOperator(EQUAL, PPLBuiltinOperators.EQUALS_IP); + registerOperator(GREATER, PPLBuiltinOperators.GREATER_IP); + registerOperator(GTE, PPLBuiltinOperators.GTE_IP); + registerOperator(LESS, PPLBuiltinOperators.LESS_IP); + registerOperator(LTE, PPLBuiltinOperators.LTE_IP); + // Register std operator registerOperator(AND, SqlStdOperatorTable.AND); registerOperator(OR, SqlStdOperatorTable.OR); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java index d3443ab850..17e9e290ad 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java @@ -400,6 +400,10 @@ private static List getExprTypes(SqlTypeFamily family) { OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER)); case ANY, IGNORE -> List.of( OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.ANY)); + // We borrow SqlTypeFamily.NULL to represent EXPR_IP. This is a workaround + // since there is no corresponding IP type family in Calcite. + case NULL -> List.of( + OpenSearchTypeFactory.TYPE_FACTORY.createUDT(OpenSearchTypeFactory.ExprUDT.EXPR_IP)); default -> { RelDataType type = family.getDefaultConcreteType(OpenSearchTypeFactory.TYPE_FACTORY); if (type == null) { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CidrMatchFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CidrMatchFunction.java index d7879d881d..4c8e532cbb 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CidrMatchFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CidrMatchFunction.java @@ -12,9 +12,11 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.sql.type.CompositeOperandTypeChecker; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.data.model.ExprIpValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -44,9 +46,12 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - // EXPR_IP is mapped to SqlTypeFamily.VARCHAR in + // EXPR_IP is mapped to SqlTypeFamily.NULL in // UserDefinedFunctionUtils.convertRelDataTypeToSqlTypeName - return UDFOperandMetadata.wrap(OperandTypes.STRING_STRING); + return UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING) + .or(OperandTypes.family(SqlTypeFamily.NULL, SqlTypeFamily.STRING))); } public static class CidrMatchImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java new file mode 100644 index 0000000000..12a6a42516 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function.udf.ip; + +import inet.ipaddr.IPAddress; +import java.util.List; +import org.apache.calcite.adapter.enumerable.NotNullImplementor; +import org.apache.calcite.adapter.enumerable.NullPolicy; +import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.sql.type.CompositeOperandTypeChecker; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.opensearch.sql.data.model.ExprIpValue; +import org.opensearch.sql.expression.function.ImplementorUDF; +import org.opensearch.sql.expression.function.UDFOperandMetadata; +import org.opensearch.sql.utils.IPUtils; + +/** + * {@code compare(ip1, ip2)} compares two IP addresses using a provided op. + * + *

Signature: + * + *

+ */ +public class CompareIpFunction extends ImplementorUDF { + + private CompareIpFunction(ComparisonType comparisonType) { + super(new CompareImplementor(comparisonType), NullPolicy.ANY); + } + + public static CompareIpFunction less() { + return new CompareIpFunction(ComparisonType.LESS); + } + + public static CompareIpFunction greater() { + return new CompareIpFunction(ComparisonType.GREATER); + } + + public static CompareIpFunction lessOrEquals() { + return new CompareIpFunction(ComparisonType.LESS_OR_EQUAL); + } + + public static CompareIpFunction greaterOrEquals() { + return new CompareIpFunction(ComparisonType.GREATER_OR_EQUAL); + } + + public static CompareIpFunction equals() { + return new CompareIpFunction(ComparisonType.EQUALS); + } + + public static CompareIpFunction notEquals() { + return new CompareIpFunction(ComparisonType.NOT_EQUALS); + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return ReturnTypes.BOOLEAN_FORCE_NULLABLE; + } + + @Override + public UDFOperandMetadata getOperandMetadata() { + return UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.NULL) + .or(OperandTypes.family(SqlTypeFamily.NULL, SqlTypeFamily.STRING)) + .or(OperandTypes.family(SqlTypeFamily.NULL, SqlTypeFamily.NULL))); + } + + public static class CompareImplementor implements NotNullImplementor { + private final ComparisonType comparisonType; + + public CompareImplementor(ComparisonType comparisonType) { + this.comparisonType = comparisonType; + } + + @Override + public Expression implement( + RexToLixTranslator translator, RexCall call, List translatedOperands) { + return Expressions.call( + CompareImplementor.class, + "compare", + translatedOperands.get(0), + translatedOperands.get(1), + Expressions.constant(comparisonType)); + } + + public static boolean compare(Object obj1, Object obj2, ComparisonType comparisonType) { + try { + String ip1 = extractIpString(obj1); + String ip2 = extractIpString(obj2); + IPAddress addr1 = IPUtils.toAddress(ip1); + IPAddress addr2 = IPUtils.toAddress(ip2); + int result = IPUtils.compare(addr1, addr2); + return switch (comparisonType) { + case EQUALS -> result == 0; + case NOT_EQUALS -> result != 0; + case LESS -> result < 0; + case LESS_OR_EQUAL -> result <= 0; + case GREATER -> result > 0; + case GREATER_OR_EQUAL -> result >= 0; + }; + } catch (Exception e) { + return false; + } + } + + private static String extractIpString(Object obj) { + if (obj instanceof String) return (String) obj; + if (obj instanceof ExprIpValue) return ((ExprIpValue) obj).value(); + throw new IllegalArgumentException("Invalid IP type: " + obj); + } + } + + public enum ComparisonType { + EQUALS, + NOT_EQUALS, + LESS, + LESS_OR_EQUAL, + GREATER, + GREATER_OR_EQUAL + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/pushdown/CalciteNoPushdownIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/pushdown/CalciteNoPushdownIT.java index 14ae021e66..6dbc078acc 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/pushdown/CalciteNoPushdownIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/pushdown/CalciteNoPushdownIT.java @@ -40,9 +40,7 @@ CalciteGeoPointFormatsIT.class, CalciteHeadCommandIT.class, CalciteInformationSchemaCommandIT.class, - // TODO: Enable after implementing comparison for IP addresses with Calcite - // https://github.com/opensearch-project/sql/issues/3776 - // CalciteIPComparisonIT.class, + CalciteIPComparisonIT.class, CalciteIPFunctionsIT.class, CalciteJsonFunctionsIT.class, CalciteLegacyAPICompatibilityIT.class, diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java index 55f5635048..20447be761 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java @@ -49,7 +49,9 @@ public void testComparisonWithDifferentType() { String ppl = "source=EMP | where ENAME < 6 | fields ENAME"; Throwable t = Assert.assertThrows(ExpressionEvaluationException.class, () -> getRelNode(ppl)); verifyErrorMessageContains( - t, "LESS function expects {[COMPARABLE_TYPE,COMPARABLE_TYPE]}, but got [STRING,INTEGER]"); + t, + "LESS function expects {[STRING,IP],[IP,STRING],[IP,IP],[COMPARABLE_TYPE,COMPARABLE_TYPE]}," + + " but got [STRING,INTEGER]"); } @Test