/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.calcite;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.function.PPLBuiltinOperators;
import shaded.com.google.common.collect.ImmutableList;

public class ExtendedRexBuilder
extends RexBuilder {
    public ExtendedRexBuilder(RexBuilder rexBuilder) {
        super(rexBuilder.getTypeFactory());
    }

    public RexNode coalesce(RexNode ... nodes) {
        return this.makeCall((SqlOperator)SqlStdOperatorTable.COALESCE, nodes);
    }

    public RexNode equals(RexNode n1, RexNode n2) {
        return this.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, n1, n2);
    }

    public RexNode and(RexNode left, RexNode right) {
        RelDataType booleanType = this.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN);
        return this.makeCall(booleanType, (SqlOperator)SqlStdOperatorTable.AND, List.of(left, right));
    }

    public RelDataType commonType(RexNode ... nodes) {
        return this.getTypeFactory().leastRestrictive(Arrays.stream(nodes).map(RexNode::getType).toList());
    }

    public SqlIntervalQualifier createIntervalUntil(SpanUnit unit) {
        TimeUnit timeUnit = switch (unit) {
            case SpanUnit.MILLISECOND, SpanUnit.MS -> TimeUnit.MILLISECOND;
            case SpanUnit.SECOND, SpanUnit.S -> TimeUnit.SECOND;
            case SpanUnit.MINUTE, SpanUnit.m -> TimeUnit.MINUTE;
            case SpanUnit.HOUR, SpanUnit.H -> TimeUnit.HOUR;
            case SpanUnit.DAY, SpanUnit.D -> TimeUnit.DAY;
            case SpanUnit.WEEK, SpanUnit.W -> TimeUnit.WEEK;
            case SpanUnit.MONTH, SpanUnit.M -> TimeUnit.MONTH;
            case SpanUnit.QUARTER, SpanUnit.Q -> TimeUnit.QUARTER;
            case SpanUnit.YEAR, SpanUnit.Y -> TimeUnit.YEAR;
            default -> TimeUnit.EPOCH;
        };
        return new SqlIntervalQualifier(timeUnit, timeUnit, SqlParserPos.ZERO);
    }

    @Override
    public RexNode makeCast(SqlParserPos pos, RelDataType type2, RexNode exp, boolean matchNullability, boolean safe, RexLiteral format) {
        SqlTypeName sqlType = type2.getSqlTypeName();
        if (exp instanceof RexLiteral && sqlType == SqlTypeName.BOOLEAN) {
            if (exp.equals(this.makeLiteral("1", this.typeFactory.createSqlType(SqlTypeName.CHAR, 1)))) {
                return this.makeLiteral(true, type2);
            }
            if (exp.equals(this.makeLiteral("0", this.typeFactory.createSqlType(SqlTypeName.CHAR, 1)))) {
                return this.makeLiteral(false, type2);
            }
            if (SqlTypeUtil.isExactNumeric(exp.getType())) {
                return this.makeCall(type2, (SqlOperator)SqlStdOperatorTable.NOT_EQUALS, ImmutableList.of(exp, this.makeZeroLiteral(exp.getType())));
            }
        } else if (OpenSearchTypeFactory.isUserDefinedType(type2)) {
            OpenSearchTypeFactory.ExprUDT udt = ((AbstractExprRelDataType)type2).getUdt();
            ExprType argExprType = OpenSearchTypeFactory.convertRelDataTypeToExprType(exp.getType());
            return switch (udt) {
                case OpenSearchTypeFactory.ExprUDT.EXPR_DATE -> this.makeCall(type2, PPLBuiltinOperators.DATE, List.of(exp));
                case OpenSearchTypeFactory.ExprUDT.EXPR_TIME -> this.makeCall(type2, PPLBuiltinOperators.TIME, List.of(exp));
                case OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP -> this.makeCall(type2, PPLBuiltinOperators.TIMESTAMP, List.of(exp));
                case OpenSearchTypeFactory.ExprUDT.EXPR_IP -> {
                    if (argExprType == ExprCoreType.IP) {
                        yield exp;
                    }
                    if (argExprType == ExprCoreType.STRING) {
                        yield this.makeCall(type2, PPLBuiltinOperators.IP, List.of(exp));
                    }
                    throw new ExpressionEvaluationException(String.format(Locale.ROOT, "Cannot convert %s to IP, only STRING and IP types are supported", argExprType));
                }
                default -> throw new SemanticCheckException(String.format(Locale.ROOT, "Cannot cast from %s to %s", argExprType, udt.name()));
            };
        }
        return super.makeCast(pos, type2, exp, matchNullability, safe, format);
    }
}

