Skip to content

Commit 837b01e

Browse files
authored
[FLINK-37599][table] Support column expansion for PTF on_time columns
This closes #27942.
1 parent d49eb62 commit 837b01e

File tree

3 files changed

+216
-70
lines changed

3 files changed

+216
-70
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/FlinkCalciteSqlValidator.java

Lines changed: 122 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import org.apache.flink.table.catalog.Column;
2828
import org.apache.flink.table.catalog.ResolvedSchema;
2929
import org.apache.flink.table.data.TimestampData;
30+
import org.apache.flink.table.functions.FunctionKind;
3031
import org.apache.flink.table.planner.catalog.CatalogSchemaModel;
3132
import org.apache.flink.table.planner.catalog.CatalogSchemaTable;
32-
import org.apache.flink.table.planner.functions.sql.ml.SqlMLTableFunction;
3333
import org.apache.flink.table.planner.plan.FlinkCalciteCatalogReader;
3434
import org.apache.flink.table.planner.plan.utils.FlinkRexUtil;
3535
import org.apache.flink.table.planner.utils.ShortcutUtils;
@@ -65,6 +65,8 @@
6565
import org.apache.calcite.sql.SqlUtil;
6666
import org.apache.calcite.sql.SqlWindowTableFunction;
6767
import org.apache.calcite.sql.parser.SqlParserPos;
68+
import org.apache.calcite.sql.type.SqlOperandMetadata;
69+
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
6870
import org.apache.calcite.sql.type.SqlTypeUtil;
6971
import org.apache.calcite.sql.validate.DelegatingScope;
7072
import org.apache.calcite.sql.validate.IdentifierNamespace;
@@ -92,7 +94,6 @@
9294
import java.util.Optional;
9395
import java.util.Set;
9496
import java.util.stream.Collectors;
95-
import java.util.stream.Stream;
9697

9798
import static org.apache.calcite.sql.type.SqlTypeName.DECIMAL;
9899
import static org.apache.flink.table.expressions.resolver.lookups.FieldReferenceLookup.includeExpandedColumn;
@@ -343,7 +344,7 @@ protected void addToSelectList(
343344
final Column column = resolvedSchema.getColumn(columnName).orElse(null);
344345
if (qualified.suffix().size() == 1 && column != null) {
345346
if (includeExpandedColumn(column, columnExpansionStrategies)
346-
|| declaredDescriptorColumn(scope, column)) {
347+
|| isDeclaredOnTimeColumn(scope, column)) {
347348
super.addToSelectList(
348349
list, aliases, fieldList, exp, scope, includeSystemVars);
349350
}
@@ -360,71 +361,71 @@ protected void addToSelectList(
360361
protected @PolyNull SqlNode performUnconditionalRewrites(
361362
@PolyNull SqlNode node, boolean underFrom) {
362363

363-
// Special case for window TVFs like:
364-
// TUMBLE(TABLE t, DESCRIPTOR(metadata_virtual), INTERVAL '1' MINUTE)) or
365-
// SESSION(TABLE t PARTITION BY a, DESCRIPTOR(metadata_virtual), INTERVAL '1' MINUTE))
364+
// Capture table arguments early:
365+
// TUMBLE(TABLE t, DESCRIPTOR(metadata_virtual), INTERVAL '1' MINUTE) or
366+
// SESSION(TABLE t PARTITION BY a, DESCRIPTOR(metadata_virtual), INTERVAL '1' MINUTE)
367+
// MyPtf(in => TABLE t PARTITION BY a, on_time => DESCRIPTOR(metadata_virtual))
366368
//
367369
// "TABLE t" is translated into an implicit "SELECT * FROM t". This would ignore columns
368-
// that are not expanded by default. However, the descriptor explicitly states the need
369-
// for this column. Therefore, explicit table expressions (for window TVFs at most one)
370-
// are captured before rewriting and replaced with a "marker" SqlSelect that contains the
371-
// descriptor information. The "marker" SqlSelect is considered during column expansion.
370+
// that are not expanded by default. However, the on_time descriptor explicitly states the
371+
// need for time columns. Therefore, explicit table expressions are captured before
372+
// rewriting and replaced with a "marker" SqlSelect that contains the descriptor
373+
// information. The "marker" SqlSelect is considered during column expansion.
372374
final List<SqlIdentifier> tableArgs = getTableOperands(node);
373375

374376
final SqlNode rewritten = super.performUnconditionalRewrites(node, underFrom);
375377

376378
if (!(node instanceof SqlBasicCall)) {
377379
return rewritten;
378380
}
381+
379382
final SqlBasicCall call = (SqlBasicCall) node;
380-
final SqlOperator operator = call.getOperator();
381383

384+
// Special case for MODEL
382385
if (node instanceof SqlExplicitModelCall) {
383386
// Convert it so that model can be accessed in planner. SqlExplicitModelCall
384387
// from parser can't access model.
385-
SqlExplicitModelCall modelCall = (SqlExplicitModelCall) node;
386-
SqlIdentifier modelIdentifier = modelCall.getModelIdentifier();
387-
FlinkCalciteCatalogReader catalogReader =
388+
final SqlExplicitModelCall modelCall = (SqlExplicitModelCall) node;
389+
final SqlIdentifier modelIdentifier = modelCall.getModelIdentifier();
390+
final FlinkCalciteCatalogReader catalogReader =
388391
(FlinkCalciteCatalogReader) getCatalogReader();
389-
CatalogSchemaModel model = catalogReader.getModel(modelIdentifier.names);
392+
final CatalogSchemaModel model = catalogReader.getModel(modelIdentifier.names);
390393
if (model != null) {
391394
return new SqlModelCall(modelCall, model);
392395
}
393396
}
394397

395-
// TODO (FLINK-37819): add test for SqlMLTableFunction
396-
if (operator instanceof SqlWindowTableFunction || operator instanceof SqlMLTableFunction) {
397-
if (tableArgs.stream().allMatch(Objects::isNull)) {
398-
return rewritten;
399-
}
400-
401-
final List<SqlIdentifier> descriptors =
402-
call.getOperandList().stream()
403-
.flatMap(FlinkCalciteSqlValidator::extractDescriptors)
404-
.collect(Collectors.toList());
405-
398+
// Mark rewritten "TABLE t" with on_time columns
399+
if (tableArgs == null || tableArgs.stream().allMatch(Objects::isNull)) {
400+
return rewritten;
401+
}
402+
final List<SqlIdentifier> onTimeColumns = extractOnTime(call);
403+
if (onTimeColumns != null) {
406404
for (int i = 0; i < call.operandCount(); i++) {
407405
final SqlIdentifier tableArg = tableArgs.get(i);
408406
if (tableArg != null) {
409-
final SqlNode opReplacement = new ExplicitTableSqlSelect(tableArg, descriptors);
407+
final SqlNode opReplacement =
408+
new ExplicitTableSqlSelect(tableArg, onTimeColumns);
409+
// for f(TABLE t PARTITION BY c, ...)
410410
if (call.operand(i).getKind() == SqlKind.SET_SEMANTICS_TABLE) {
411411
final SqlCall setSemanticsTable = call.operand(i);
412412
setSemanticsTable.setOperand(0, opReplacement);
413413
} else if (call.operand(i).getKind() == SqlKind.ARGUMENT_ASSIGNMENT) {
414-
// for TUMBLE(DATA => TABLE t3, ...)
415414
final SqlCall assignment = call.operand(i);
415+
// for f(in => TABLE t PARTITION BY c, ...)
416416
if (assignment.operand(0).getKind() == SqlKind.SET_SEMANTICS_TABLE) {
417-
final SqlCall setSemanticsTable = assignment.operand(i);
417+
final SqlCall setSemanticsTable = assignment.operand(0);
418418
setSemanticsTable.setOperand(0, opReplacement);
419419
} else {
420+
// for f(in => TABLE t, ...)
420421
assignment.setOperand(0, opReplacement);
421422
}
422423
} else {
423-
// for TUMBLE(TABLE t3, ...)
424+
// for f(TABLE t, ...)
424425
call.setOperand(i, opReplacement);
425426
}
426427
}
427-
// for TUMBLE([DATA =>] SELECT ..., ...)
428+
// for f([in =>] SELECT ..., ...)
428429
}
429430
}
430431

@@ -446,9 +447,9 @@ public SqlNode maybeCast(SqlNode node, RelDataType currentType, RelDataType desi
446447
*/
447448
static class ExplicitTableSqlSelect extends SqlSelect {
448449

449-
private final List<SqlIdentifier> descriptors;
450+
private final List<SqlIdentifier> onTimeColumns;
450451

451-
public ExplicitTableSqlSelect(SqlIdentifier table, List<SqlIdentifier> descriptors) {
452+
public ExplicitTableSqlSelect(SqlIdentifier table, List<SqlIdentifier> onTimeColumns) {
452453
super(
453454
SqlParserPos.ZERO,
454455
null,
@@ -462,91 +463,150 @@ public ExplicitTableSqlSelect(SqlIdentifier table, List<SqlIdentifier> descripto
462463
null,
463464
null,
464465
null);
465-
this.descriptors = descriptors;
466+
this.onTimeColumns = onTimeColumns;
466467
}
467468
}
468469

469470
/**
470471
* Returns whether the given column has been declared in a {@link SqlKind#DESCRIPTOR} next to a
471472
* {@link SqlKind#EXPLICIT_TABLE} within TVF operands.
472473
*/
473-
private static boolean declaredDescriptorColumn(SelectScope scope, Column column) {
474+
private static boolean isDeclaredOnTimeColumn(SelectScope scope, Column column) {
474475
if (!(scope.getNode() instanceof ExplicitTableSqlSelect)) {
475476
return false;
476477
}
477478
final ExplicitTableSqlSelect select = (ExplicitTableSqlSelect) scope.getNode();
478-
return select.descriptors.stream()
479+
return select.onTimeColumns.stream()
479480
.map(SqlIdentifier::getSimple)
480481
.anyMatch(id -> id.equals(column.getName()));
481482
}
482483

483484
/**
484485
* Returns all {@link SqlKind#EXPLICIT_TABLE} and {@link SqlKind#SET_SEMANTICS_TABLE} operands
485-
* within TVF operands. A list entry is {@code null} if the operand is not an {@link
486+
* within PTF operands. A list entry is {@code null} if the operand is not an {@link
486487
* SqlKind#EXPLICIT_TABLE} or {@link SqlKind#SET_SEMANTICS_TABLE}.
487488
*/
488489
private static List<SqlIdentifier> getTableOperands(SqlNode node) {
489490
if (!(node instanceof SqlBasicCall)) {
490491
return null;
491492
}
493+
492494
final SqlBasicCall call = (SqlBasicCall) node;
493495

494496
if (!(call.getOperator() instanceof SqlFunction)) {
495497
return null;
496498
}
499+
497500
final SqlFunction function = (SqlFunction) call.getOperator();
498501

499502
if (!isTableFunction(function)) {
500503
return null;
501504
}
502505

503506
return call.getOperandList().stream()
504-
.map(FlinkCalciteSqlValidator::extractTableOperand)
507+
.map(FlinkCalciteSqlValidator::extractExplicitTables)
505508
.collect(Collectors.toList());
506509
}
507510

508-
private static @Nullable SqlIdentifier extractTableOperand(SqlNode op) {
511+
/** Extracts "TABLE t" nodes before they get rewritten into "SELECT * FROM t". */
512+
private static @Nullable SqlIdentifier extractExplicitTables(SqlNode op) {
509513
if (op.getKind() == SqlKind.EXPLICIT_TABLE) {
510514
final SqlBasicCall opCall = (SqlBasicCall) op;
511515
if (opCall.operandCount() == 1 && opCall.operand(0) instanceof SqlIdentifier) {
512-
// for TUMBLE(TABLE t3, ...)
516+
// for f(TABLE t, ...)
513517
return opCall.operand(0);
514518
}
515519
} else if (op.getKind() == SqlKind.SET_SEMANTICS_TABLE) {
516-
// for SESSION windows
520+
// for f(TABLE t PARTITION BY x)
517521
final SqlBasicCall opCall = (SqlBasicCall) op;
518-
final SqlCall setSemanticsTable = opCall.operand(0);
519-
if (setSemanticsTable.operand(0) instanceof SqlIdentifier) {
520-
return setSemanticsTable.operand(0);
521-
}
522+
return extractExplicitTables(opCall.operand(0));
522523
} else if (op.getKind() == SqlKind.ARGUMENT_ASSIGNMENT) {
523-
// for TUMBLE(DATA => TABLE t3, ...)
524+
// for f(in => TABLE t, ...)
524525
final SqlBasicCall opCall = (SqlBasicCall) op;
525-
return extractTableOperand(opCall.operand(0));
526+
return extractExplicitTables(opCall.operand(0));
526527
}
527528
return null;
528529
}
529530

530-
private static Stream<SqlIdentifier> extractDescriptors(SqlNode op) {
531+
/** Extracts the on_time argument of a PTF (or TIMECOL for window PTFs for legacy reasons). */
532+
private static @Nullable List<SqlIdentifier> extractOnTime(SqlBasicCall call) {
533+
// Extract from operand from PTF
534+
final SqlNode onTimeOperand;
535+
if (call.getOperator() instanceof SqlWindowTableFunction) {
536+
onTimeOperand = extractOperandByArgName(call, "TIMECOL");
537+
} else if (ShortcutUtils.isFunctionKind(call.getOperator(), FunctionKind.PROCESS_TABLE)) {
538+
onTimeOperand = extractOperandByArgName(call, "on_time");
539+
} else {
540+
onTimeOperand = null;
541+
}
542+
543+
// No operand found
544+
if (onTimeOperand == null) {
545+
return null;
546+
}
547+
548+
return extractDescriptors(onTimeOperand);
549+
}
550+
551+
private static List<SqlIdentifier> extractDescriptors(SqlNode op) {
531552
if (op.getKind() == SqlKind.DESCRIPTOR) {
532-
// for TUMBLE(..., DESCRIPTOR(col), ...)
533553
final SqlBasicCall opCall = (SqlBasicCall) op;
534554
return opCall.getOperandList().stream()
535555
.filter(SqlIdentifier.class::isInstance)
536-
.map(SqlIdentifier.class::cast);
537-
} else if (op.getKind() == SqlKind.SET_SEMANTICS_TABLE) {
538-
// for SESSION windows
539-
final SqlBasicCall opCall = (SqlBasicCall) op;
540-
return ((SqlNodeList) opCall.operand(1))
541-
.stream()
542-
.filter(SqlIdentifier.class::isInstance)
543-
.map(SqlIdentifier.class::cast);
544-
} else if (op.getKind() == SqlKind.ARGUMENT_ASSIGNMENT) {
545-
// for TUMBLE(..., TIMECOL => DESCRIPTOR(col), ...)
546-
final SqlBasicCall opCall = (SqlBasicCall) op;
547-
return extractDescriptors(opCall.operand(0));
556+
.map(SqlIdentifier.class::cast)
557+
.collect(Collectors.toList());
558+
}
559+
return List.of();
560+
}
561+
562+
/**
563+
* Returns the operand for a given argument name from a BasicSqlCall. Supports both positional
564+
* and named arguments. If at least one ARGUMENT_ASSIGNMENT is used, named lookup is performed.
565+
* Otherwise, positional lookup using SqlOperandMetadata is used.
566+
*
567+
* @param call the SQL call to extract the operand from
568+
* @param argumentName the name of the argument to retrieve
569+
* @return the SqlNode for the operand, or null if not found or not supported
570+
*/
571+
private static @Nullable SqlNode extractOperandByArgName(
572+
SqlBasicCall call, String argumentName) {
573+
// Check if operator supports SqlOperandMetadata
574+
final SqlOperator operator = call.getOperator();
575+
final SqlOperandTypeChecker typeChecker = operator.getOperandTypeChecker();
576+
if (!(typeChecker instanceof SqlOperandMetadata)) {
577+
return null;
578+
}
579+
580+
final SqlOperandMetadata operandMetadata = (SqlOperandMetadata) typeChecker;
581+
582+
// Detect if named arguments are used by checking for ARGUMENT_ASSIGNMENT
583+
final List<SqlNode> operands = call.getOperandList();
584+
final boolean hasNamedArguments =
585+
operands.stream().anyMatch(op -> op.getKind() == SqlKind.ARGUMENT_ASSIGNMENT);
586+
587+
if (hasNamedArguments) {
588+
// Named mode: search through ARGUMENT_ASSIGNMENT nodes
589+
for (SqlNode operand : operands) {
590+
if (operand.getKind() == SqlKind.ARGUMENT_ASSIGNMENT) {
591+
final SqlBasicCall assignment = (SqlBasicCall) operand;
592+
// operand(1) contains the parameter name as SqlIdentifier
593+
final SqlIdentifier paramName = assignment.operand(1);
594+
if (paramName.getSimple().equals(argumentName)) {
595+
// operand(0) contains the actual value
596+
return assignment.operand(0);
597+
}
598+
}
599+
}
600+
return null;
601+
} else {
602+
// Positional mode: use SqlOperandMetadata to map name to position
603+
final List<String> paramNames = operandMetadata.paramNames();
604+
final int index = paramNames.indexOf(argumentName);
605+
if (index == -1 || index >= call.operandCount()) {
606+
return null;
607+
}
608+
return call.operand(index);
548609
}
549-
return Stream.empty();
550610
}
551611

552612
private static boolean isTableFunction(SqlFunction function) {

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.flink.table.expressions.CallExpression;
2626
import org.apache.flink.table.expressions.ResolvedExpression;
2727
import org.apache.flink.table.functions.FunctionDefinition;
28+
import org.apache.flink.table.functions.FunctionKind;
2829
import org.apache.flink.table.planner.calcite.FlinkContext;
2930
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
3031
import org.apache.flink.table.planner.delegation.PlannerBase;
@@ -156,6 +157,18 @@ public static DataTypeFactory unwrapDataTypeFactory(RelBuilder relBuilder) {
156157
return ((BridgingSqlFunction) call.getOperator()).getDefinition();
157158
}
158159

160+
public static @Nullable FunctionDefinition unwrapFunctionDefinition(SqlOperator operator) {
161+
if (!(operator instanceof BridgingSqlFunction)) {
162+
return null;
163+
}
164+
return ((BridgingSqlFunction) operator).getDefinition();
165+
}
166+
167+
public static boolean isFunctionKind(SqlOperator operator, FunctionKind kind) {
168+
final FunctionDefinition functionDefinition = unwrapFunctionDefinition(operator);
169+
return functionDefinition != null && functionDefinition.getKind() == kind;
170+
}
171+
159172
public static @Nullable BridgingSqlFunction unwrapBridgingSqlFunction(RexCall call) {
160173
final SqlOperator operator = call.getOperator();
161174
if (operator instanceof BridgingSqlFunction) {

0 commit comments

Comments
 (0)