Why SqlDialect's supportsCharSet and getCastSpec don't work?

28 views Asked by At

I defined my own DorisSqlDialect, which inherits from SqlDialect. In DorisSqlDialect, I rewrote the supportsCharSet and getCastSpec functions, but found that it did not work. When my SqlNode(after validate) contains cast, after to sql cast(xxx as VARCHAR CHARACTER SET UTF_8) still appears.

Here is the part of DorisSqlDialect:

public class DorisSqlDialect extends SqlDialect {
        public static final SqlDialect.Context DEFAULT_CONTEXT = SqlDialect.EMPTY_CONTEXT
            .withDatabaseProduct(SqlDialect.DatabaseProduct.MYSQL)
            .withDataTypeSystem(RelDataTypeSystem.DEFAULT)
            .withNullCollation(NullCollation.LOW)
            .withQuotedCasing(Casing.TO_LOWER)
            .withUnquotedCasing(Casing.TO_LOWER)
            .withCaseSensitive(false);

    public static final SqlDialect DEFAULT = new DorisSqlDialect(DEFAULT_CONTEXT);

    /**
     * Creates a DorisSqlDialect.
     */
    public DorisSqlDialect(Context context) {
        super(context);
    }

    @Override
    public boolean supportsCharSet() {
        return false;
    }

    @Override 
    public @Nullable SqlNode getCastSpec(RelDataType type) {
        switch (type.getSqlTypeName()) {
            case VARCHAR:
                return new SqlDataTypeSpec(
                        new SqlBasicTypeNameSpec(SqlTypeName.VARCHAR, SqlParserPos.ZERO), SqlParserPos.ZERO);
            case INTEGER:
                return new SqlDataTypeSpec(
                        new SqlAlienSystemTypeNameSpec(
                                "INT",
                                type.getSqlTypeName(),
                                SqlParserPos.ZERO),
                        SqlParserPos.ZERO);
            default:
                break;
        }
        return super.getCastSpec(type);
    }
}

public class CalciteEnvUtilsTest {

    public static final JavaTypeFactory JAVA_TYPE_FACTORY = new JavaTypeFactoryImpl();
    public static final RelDataTypeFactory SQL_TYPE_FACTORY = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
    public static final RelDataTypeSystem TYPE_SYSTEM = RelDataTypeSystem.DEFAULT;

    private static void offCache(CalciteSchema calciteSchema) {
        NavigableMap<String, CalciteSchema> subSchemaMap = calciteSchema.getSubSchemaMap();
        if (subSchemaMap == null || subSchemaMap.size() <= 0) {
            return;
        }
        subSchemaMap.forEach((name, subSchema) -> {
            subSchema.setCache(false);
            offCache(subSchema);
        });
    }

    public static CalciteCatalogReader createCatalogReader(SchemaPlus rootSchema, boolean cache){

        rootSchema = rootSchema(rootSchema);

        if (!cache) {
            // set cache=false
            rootSchema.setCacheEnabled(false);
        }

        CalciteSchema calciteSchema = CalciteSchema.from(rootSchema);
        if (!cache) {
            calciteSchema.setCache(false);
            offCache(calciteSchema);
        }

        // CalciteConnectionConfig
        Properties properties = new Properties();
        properties.put(CalciteConnectionProperty.CASE_SENSITIVE.camelName(), Boolean.FALSE.toString());
        properties.put(CalciteConnectionProperty.UNQUOTED_CASING.camelName(), Casing.TO_LOWER.toString());
        properties.put(CalciteConnectionProperty.QUOTED_CASING.camelName(), Casing.TO_LOWER.toString());
        CalciteConnectionConfig connectionConfig = new CalciteConnectionConfigImpl(properties);

        // create CatalogReader
        return new CalciteCatalogReader(
                calciteSchema,
                Collections.singletonList(calciteSchema.getName()),
                JAVA_TYPE_FACTORY,
                connectionConfig
        );
    }

    private static SchemaPlus rootSchema(SchemaPlus schema) {
        for (;;) {
            SchemaPlus parentSchema = schema.getParentSchema();
            if (parentSchema == null) {
                return schema;
            }
            schema = parentSchema;
        }
    }

    public static SqlOperatorTable createSqlOperatorTable(CalciteCatalogReader catalogReader){

        return SqlOperatorTables.chain(SqlStdOperatorTable.instance(),
                catalogReader);
    }

    public static SqlOperatorTable createSqlOperatorTable(SchemaPlus defaultSchema){

        return SqlOperatorTables.chain(SqlStdOperatorTable.instance(),
                createCatalogReader(defaultSchema, false));
    }

    public static FrameworkConfig createFrameConfig(SchemaPlus defaultSchema, SqlOperatorTable sqlOperatorTable){

        // SqlParser.Config
        SqlParser.Config sqlParserConfig = SqlParser.Config.DEFAULT
                .withConformance(SqlConformanceEnum.MYSQL_5)
                .withQuotedCasing(Casing.TO_LOWER)
                .withUnquotedCasing(Casing.TO_LOWER)
                .withCaseSensitive(false)
                .withIdentifierMaxLength(Integer.MAX_VALUE)
                .withQuoting(Quoting.BACK_TICK_BACKSLASH);

        // SqlValidator.Config
        SqlValidator.Config sqlValidatorConfig = SqlValidator.Config.DEFAULT
                .withTypeCoercionEnabled(true)
                .withIdentifierExpansion(true)
                .withDefaultNullCollation(NullCollation.LOW);

        // SqlToRelConverter.Config
        SqlToRelConverter.Config sqlToRelConverterConfig = SqlToRelConverter.config()
                .withTrimUnusedFields(false)
                .withInSubQueryThreshold(Integer.MAX_VALUE);

        return Frameworks.newConfigBuilder()
                .defaultSchema(defaultSchema)
                .traitDefs(ConventionTraitDef.INSTANCE)
                .operatorTable(sqlOperatorTable)
                .parserConfig(sqlParserConfig)
                .sqlValidatorConfig(sqlValidatorConfig)
                .sqlToRelConverterConfig(sqlToRelConverterConfig)
                .build();
    }

    public static FrameworkConfig createFrameConfig(SchemaPlus defaultSchema){
        return createFrameConfig(defaultSchema, createSqlOperatorTable(defaultSchema));
    }

    public static SqlValidator createSqlValidator(CalciteCatalogReader catalogReader,
                                                  SqlOperatorTable sqlOperatorTable,
                                                  SqlValidator.Config config){
        SqlValidator validator =  new CalciteSqlValidator(sqlOperatorTable, catalogReader, JAVA_TYPE_FACTORY, config);
        validator.setValidatedNodeType(SqlLiteral.createNull(SqlParserPos.ZERO),  JAVA_TYPE_FACTORY.createSqlType(SqlTypeName.NULL));
        return validator;
    }

    public static SqlValidator createSqlValidator(SchemaPlus schema){
        // CatalogReader
        CalciteCatalogReader catalogReader = createCatalogReader(schema, false);

        // SqlOperatorTable
        SqlOperatorTable sqlOperatorTable =  createSqlOperatorTable(catalogReader);

        // FrameworkConfig
        FrameworkConfig config = createFrameConfig(schema, sqlOperatorTable);

        // sqlValidator
        return createSqlValidator(catalogReader, sqlOperatorTable, config.getSqlValidatorConfig());
    }

    public static SchemaPlus getSchemaPlus() {

        SchemaPlus rootSchema = Frameworks.createRootSchema(true);

        // Doris tables
        SimpleTable table = SimpleTable.newBuilder("users")
                .addField("id", SqlTypeName.INTEGER)
                .addField("name", SqlTypeName.VARCHAR)
                .addField("age", SqlTypeName.INTEGER)
                .addField("bitmap_column", SqlTypeName.BINARY)
                .addField("partition_date", SqlTypeName.DATE)
                .withRowCount(60_000L)
                .build();

        SimpleTable tableA = SimpleTable.newBuilder("table_a")
                .addField("id", SqlTypeName.INTEGER)
                .addField("name1", SqlTypeName.VARCHAR)
                .withRowCount(60_000L)
                .build();

        SchemaPlus dorisSchema = Frameworks.createRootSchema(true);
        dorisSchema.add("users", table);
        dorisSchema.add("table_a", tableA);
        SchemaPlus dorisDsnSchema = Frameworks.createRootSchema(true);
        dorisDsnSchema.add("doris_db", dorisSchema);
        SchemaPlus dorisEngineSchema = Frameworks.createRootSchema(true);
        dorisEngineSchema.add("doris_dsn", dorisDsnSchema);
        rootSchema.add("doris", dorisEngineSchema);


        // Hive tables
        SimpleTable tableB = SimpleTable.newBuilder("table_b")
                .addField("id", SqlTypeName.INTEGER)
                .addField("name2", SqlTypeName.VARCHAR)
                .withRowCount(60_000L)
                .build();

        SimpleTable tableC = SimpleTable.newBuilder("table_c")
                .addField("id", SqlTypeName.INTEGER)
                .addField("name3", SqlTypeName.VARCHAR)
                .withRowCount(60_000L)
                .build();
        SchemaPlus hiveSchema = Frameworks.createRootSchema(true);
        hiveSchema.add("table_b", tableB);
        hiveSchema.add("table_c", tableC);
        SchemaPlus hiveDsnSchema = Frameworks.createRootSchema(true);
        hiveDsnSchema.add("hive_db", hiveSchema);
        SchemaPlus hiveEngineSchema = Frameworks.createRootSchema(true);
        hiveEngineSchema.add("hive_dsn", hiveDsnSchema);
        rootSchema.add("hive", hiveEngineSchema);

        return rootSchema;
    }
}

public class SimpleTable extends AbstractTable implements ScannableTable {

    private final String tableName;
    private final List<String> fieldNames;
    private final List<SqlTypeName> fieldTypes;
    private final SimpleTableStatistic statistic;

    private RelDataType rowType;

    private SimpleTable(String tableName, List<String> fieldNames, List<SqlTypeName> fieldTypes, SimpleTableStatistic statistic) {
        this.tableName = tableName;
        this.fieldNames = fieldNames;
        this.fieldTypes = fieldTypes;
        this.statistic = statistic;
    }

    public String getTableName() {
        return tableName;
    }

    @Override
    public RelDataType getRowType(RelDataTypeFactory typeFactory) {
        if (rowType == null) {
            List<RelDataTypeField> fields = new ArrayList<>(fieldNames.size());

            for (int i = 0; i < fieldNames.size(); i++) {
                RelDataType fieldType = typeFactory.createSqlType(fieldTypes.get(i));
                RelDataTypeField field = new RelDataTypeFieldImpl(fieldNames.get(i), i, fieldType);
                fields.add(field);
            }

            rowType = new RelRecordType(StructKind.PEEK_FIELDS, fields, false);
        }

        return rowType;
    }

    @Override
    public Statistic getStatistic() {
        return statistic;
    }

    @Override
    public Enumerable<Object[]> scan(DataContext root) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public static Builder newBuilder(String tableName) {
        return new Builder(tableName);
    }

    public static final class Builder {

        private final String tableName;
        private final List<String> fieldNames = new ArrayList<>();
        private final List<SqlTypeName> fieldTypes = new ArrayList<>();
        private long rowCount;

        private Builder(String tableName) {
            if (tableName == null || tableName.isEmpty()) {
                throw new IllegalArgumentException("Table name cannot be null or empty");
            }

            this.tableName = tableName;
        }

        public Builder addField(String name, SqlTypeName typeName) {
            if (name == null || name.isEmpty()) {
                throw new IllegalArgumentException("Field name cannot be null or empty");
            }

            if (fieldNames.contains(name)) {
                throw new IllegalArgumentException("Field already defined: " + name);
            }

            fieldNames.add(name);
            fieldTypes.add(typeName);

            return this;
        }

        public Builder withRowCount(long rowCount) {
            this.rowCount = rowCount;

            return this;
        }

        public SimpleTable build() {
            if (fieldNames.isEmpty()) {
                throw new IllegalStateException("Table must have at least one field");
            }

            if (rowCount == 0L) {
                throw new IllegalStateException("Table must have positive row count");
            }

            return new SimpleTable(tableName, fieldNames, fieldTypes, new SimpleTableStatistic(rowCount));
        }
    }
}

public class DorisSqlDialectTest {
    @Test
    public void testCastVarchar2(){

        // root schema
        final SchemaPlus rootSchema = CalciteEnvUtilsTest.getSchemaPlus();

        // config and validator
        FrameworkConfig config = CalciteEnvUtilsTest.createFrameConfig(rootSchema);
        SqlValidator validator = CalciteEnvUtilsTest.createSqlValidator(rootSchema);
        RelBuilder builder = RelBuilder.create(config);

        // relnode
        builder.scan("doris", "doris_dsn", "doris_db", "users");
        builder.project(builder.alias(
                builder.call(SqlStdOperatorTable.CONCAT,
                        builder.field("id"),
                        builder.field("age")),
                "concats"));
        RelNode relNode = builder.build();

        // dialect
        DorisSqlDialect dialect = new DorisSqlDialect(DorisSqlDialect.DEFAULT_CONTEXT);

        // RelNode to sql
        final RelToSqlConverter converter = new RelToSqlConverter(dialect);
        SqlNode sqlNode = converter.visitRoot(relNode).asStatement();
        sqlNode = validator.validate(sqlNode);
        String sql = sqlNode.toSqlString(dialect).getSql();

        System.out.println(sql);
        System.out.println("==================================");
    }
}

The output is

SELECT CAST(users.id AS VARCHAR CHARACTER SET UTF-8) || CAST(users.age AS VARCHAR CHARACTER SET UTF-8) AS concats
FROM doris_db.users AS users
0

There are 0 answers