diff --git a/docs/content.zh/docs/dev/table/functions/ptfs.md b/docs/content.zh/docs/dev/table/functions/ptfs.md index 6c613a115e8d3..835aeb0335009 100644 --- a/docs/content.zh/docs/dev/table/functions/ptfs.md +++ b/docs/content.zh/docs/dev/table/functions/ptfs.md @@ -708,6 +708,85 @@ class CountingFunction extends ProcessTableFunction { {{< /tab >}} {{< /tabs >}} +### Large State + +Flink's state backends provide different types of state to efficiently handle large state. + +Currently, PTFs support three types of state: + +- **Value state**: Represents a single value. +- **List state**: Represents a list of values, supporting operations like appending, removing, and iterating. +- **Map state**: Represents a map (key-value pair) for efficient lookups, modifications, and removal of individual entries. + +By default, state entries in a PTF are represented as value state. This means that every state entry is fully read from +the state backend when the evaluation method is called, and the value is written back to the state backend once the +evaluation method finishes. + +To optimize state access and avoid unnecessary (de)serialization, state entries can be declared as: +- `org.apache.flink.table.api.dataview.ListView` (for list state) +- `org.apache.flink.table.api.dataview.MapView` (for map state) + +These provide direct views to the underlying Flink state backend. + +For example, when using a `MapView`, accessing a value via `MapView#get` will only deserialize the value associated with +the specified key. This allows for efficient access to individual entries without needing to load the entire map. This +approach is particularly useful when the map does not fit entirely into memory. + +{{< hint info >}} +State TTL is applied individually to each entry in a list or map, allowing for fine-grained expiration control over state +elements. +{{< /hint >}} + +The following example demonstrates how to declare and use a `MapView`. It assumes the PTF processes a table with the +schema `(userId, eventId, ...)`, partitioned by `userId`, with a high cardinality of distinct `eventId` values. For this +use case, it is generally recommended to partition the table by both `userId` and `eventId`. For example purposes, the +large state is stored as a map state. + +{{< tabs "1837eeed-3d13-455c-8e2f-5e164da9f844" >}} +{{< tab "Java" >}} +```java +// Function that uses a map view for storing a large map for an event history per user +class LargeHistoryFunction extends ProcessTableFunction { + public void eval( + @StateHint MapView largeMemory, + @ArgumentHint(TABLE_AS_SET) Row input + ) { + String eventId = input.getFieldAs("eventId"); + Integer count = largeMemory.get(eventId); + if (count == null) { + largeMemory.put(eventId, 1); + } else { + if (count > 1000) { + collect("Anomaly detected: " + eventId); + } + largeMemory.put(eventId, count + 1); + } + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +Similar to other data types, reflection is used to extract the necessary type information. If reflection is not +feasible - such as when a `Row` object is involved - type hints can be provided. Use the `ARRAY` data type for list views +and the `MAP` data type for map views. + +{{< tabs "1937eeed-3d13-455c-8e2f-5e164da9f844" >}} +{{< tab "Java" >}} +```java +// Function that uses a list view of rows +class LargeHistoryFunction extends ProcessTableFunction { + public void eval( + @StateHint(type = @DataTypeHint("ARRAY>")) ListView largeMemory, + @ArgumentHint(TABLE_AS_SET) Row input + ) { + ... + } +} +``` +{{< /tab >}} +{{< /tabs >}} + ### Efficiency and Design Principles A stateful function also means that data layout and data retention should be well thought diff --git a/docs/content/docs/dev/table/functions/ptfs.md b/docs/content/docs/dev/table/functions/ptfs.md index 6c613a115e8d3..835aeb0335009 100644 --- a/docs/content/docs/dev/table/functions/ptfs.md +++ b/docs/content/docs/dev/table/functions/ptfs.md @@ -708,6 +708,85 @@ class CountingFunction extends ProcessTableFunction { {{< /tab >}} {{< /tabs >}} +### Large State + +Flink's state backends provide different types of state to efficiently handle large state. + +Currently, PTFs support three types of state: + +- **Value state**: Represents a single value. +- **List state**: Represents a list of values, supporting operations like appending, removing, and iterating. +- **Map state**: Represents a map (key-value pair) for efficient lookups, modifications, and removal of individual entries. + +By default, state entries in a PTF are represented as value state. This means that every state entry is fully read from +the state backend when the evaluation method is called, and the value is written back to the state backend once the +evaluation method finishes. + +To optimize state access and avoid unnecessary (de)serialization, state entries can be declared as: +- `org.apache.flink.table.api.dataview.ListView` (for list state) +- `org.apache.flink.table.api.dataview.MapView` (for map state) + +These provide direct views to the underlying Flink state backend. + +For example, when using a `MapView`, accessing a value via `MapView#get` will only deserialize the value associated with +the specified key. This allows for efficient access to individual entries without needing to load the entire map. This +approach is particularly useful when the map does not fit entirely into memory. + +{{< hint info >}} +State TTL is applied individually to each entry in a list or map, allowing for fine-grained expiration control over state +elements. +{{< /hint >}} + +The following example demonstrates how to declare and use a `MapView`. It assumes the PTF processes a table with the +schema `(userId, eventId, ...)`, partitioned by `userId`, with a high cardinality of distinct `eventId` values. For this +use case, it is generally recommended to partition the table by both `userId` and `eventId`. For example purposes, the +large state is stored as a map state. + +{{< tabs "1837eeed-3d13-455c-8e2f-5e164da9f844" >}} +{{< tab "Java" >}} +```java +// Function that uses a map view for storing a large map for an event history per user +class LargeHistoryFunction extends ProcessTableFunction { + public void eval( + @StateHint MapView largeMemory, + @ArgumentHint(TABLE_AS_SET) Row input + ) { + String eventId = input.getFieldAs("eventId"); + Integer count = largeMemory.get(eventId); + if (count == null) { + largeMemory.put(eventId, 1); + } else { + if (count > 1000) { + collect("Anomaly detected: " + eventId); + } + largeMemory.put(eventId, count + 1); + } + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +Similar to other data types, reflection is used to extract the necessary type information. If reflection is not +feasible - such as when a `Row` object is involved - type hints can be provided. Use the `ARRAY` data type for list views +and the `MAP` data type for map views. + +{{< tabs "1937eeed-3d13-455c-8e2f-5e164da9f844" >}} +{{< tab "Java" >}} +```java +// Function that uses a list view of rows +class LargeHistoryFunction extends ProcessTableFunction { + public void eval( + @StateHint(type = @DataTypeHint("ARRAY>")) ListView largeMemory, + @ArgumentHint(TABLE_AS_SET) Row input + ) { + ... + } +} +``` +{{< /tab >}} +{{< /tabs >}} + ### Efficiency and Design Principles A stateful function also means that data layout and data retention should be well thought diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/DataView.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/DataView.java index 9a49d335faafb..a3873dcfbd9c1 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/DataView.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/DataView.java @@ -19,11 +19,11 @@ package org.apache.flink.table.api.dataview; import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.table.functions.ImperativeAggregateFunction; +import org.apache.flink.table.functions.ProcessTableFunction; /** - * A {@link DataView} is a collection type that can be used in the accumulator of an {@link - * ImperativeAggregateFunction}. + * A {@link DataView} is a collection type that can be used in the accumulator of aggregating + * functions and as a state entry in {@link ProcessTableFunction}s. * *

Depending on the context in which the function is used, a {@link DataView} can be backed by a * Java heap collection or a state backend. diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/ListView.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/ListView.java index 4249becd126cf..6cd72abd1359c 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/ListView.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/ListView.java @@ -31,14 +31,21 @@ import java.util.Objects; /** - * A {@link DataView} that provides {@link List}-like functionality in the accumulator of an {@link - * AggregateFunction} or {@link TableAggregateFunction} when large amounts of data are expected. + * A {@link DataView} that provides {@link List}-like functionality in state entries. * *

A {@link ListView} can be backed by a Java {@link ArrayList} or can leverage Flink's state - * backends depending on the context in which the aggregate function is used. In many unbounded data - * scenarios, the {@link ListView} delegates all calls to a {@link ListState} instead of the {@link + * backends depending on the context. In many unbounded data scenarios, the {@link ListView} + * delegates all calls to a {@link ListState} instead of the {@link ArrayList}. + * + *

For aggregating functions, the view can be used as a field in the accumulator of an {@link + * AggregateFunction} or {@link TableAggregateFunction} when large amounts of data are expected. + * Aggregate functions might be used at various locations (pre-aggregation, combiners, merging of + * window slides, etc.) for some of these locations the data view is not backed by state but {@link * ArrayList}. * + *

For process table functions, the view can be used as a top-level state entry. Data views in + * PTFs are always backed by state. + * *

Note: Elements of a {@link ListView} must not be null. For heap-based state backends, {@code * hashCode/equals} of the original (i.e. external) class are used. However, the serialization * format will use internal data structures. @@ -57,7 +64,7 @@ * public ListView list = new ListView<>(); * * // or explicit: - * // {@literal @}DataTypeHint("ARRAY") + * // @DataTypeHint("ARRAY < STRING >") * // public ListView list = new ListView<>(); * * public long count = 0L; @@ -65,7 +72,7 @@ * * public class MyAggregateFunction extends AggregateFunction { * - * {@literal @}Override + * @Override * public MyAccumulator createAccumulator() { * return new MyAccumulator(); * } @@ -75,7 +82,7 @@ * accumulator.count++; * } * - * {@literal @}Override + * @Override * public String getValue(MyAccumulator accumulator) { * // return the count and the joined elements * return count + ": " + String.join("|", acc.list.get()); @@ -84,9 +91,6 @@ * * } * - *

{@code ListView(TypeInformation elementType)} method was deprecated and then removed. - * Please use a {@link DataTypeHint} instead. - * * @param element type */ @PublicEvolving @@ -152,7 +156,7 @@ public boolean remove(T value) throws Exception { return list.remove(value); } - /** Removes all of the elements from this list view. */ + /** Removes all elements from this list view. */ @Override public void clear() { list.clear(); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/MapView.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/MapView.java index 887d9dfe44ff9..2bfd5d8812700 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/MapView.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/dataview/MapView.java @@ -26,19 +26,27 @@ import org.apache.flink.table.functions.TableAggregateFunction; import org.apache.flink.table.types.DataType; +import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Objects; /** - * A {@link DataView} that provides {@link Map}-like functionality in the accumulator of an {@link - * AggregateFunction} or {@link TableAggregateFunction} when large amounts of data are expected. + * A {@link DataView} that provides {@link Map}-like functionality in state entries. * *

A {@link MapView} can be backed by a Java {@link HashMap} or can leverage Flink's state - * backends depending on the context in which the aggregate function is used. In many unbounded data - * scenarios, the {@link MapView} delegates all calls to a {@link MapState} instead of the {@link - * HashMap}. + * backends depending on the context. In many unbounded data scenarios, the {@link MapView} + * delegates all calls to a {@link MapState} instead of the {@link HashMap}. + * + *

For aggregating functions, the view can be used as a field in the accumulator of an {@link + * AggregateFunction} or {@link TableAggregateFunction} when large amounts of data are expected. + * Aggregate functions might be used at various locations (pre-aggregation, combiners, merging of + * window slides, etc.) for some of these locations the data view is not backed by state but {@link + * ArrayList}. + * + *

For process table functions, the view can be used as a top-level state entry. Data views in + * PTFs are always backed by state. * *

Note: Keys of a {@link MapView} must not be null. Nulls in values are supported. For * heap-based state backends, {@code hashCode/equals} of the original (i.e. external) class are @@ -58,7 +66,7 @@ * public MapView map = new MapView<>(); * * // or explicit: - * // {@literal @}DataTypeHint("MAP") + * // @DataTypeHint("MAP < STRING, INT >") * // public MapView map = new MapView<>(); * * public long count; @@ -66,7 +74,7 @@ * * public class MyAggregateFunction extends AggregateFunction { * - * {@literal @}Override + * @Override * public MyAccumulator createAccumulator() { * return new MyAccumulator(); * } @@ -78,7 +86,7 @@ * } * } * - * {@literal @}Override + * @Override * public Long getValue(MyAccumulator accumulator) { * return accumulator.count; * } @@ -86,9 +94,6 @@ * * } * - *

{@code MapView(TypeInformation keyType, TypeInformation valueType)} method was - * deprecated and removed. Please use a {@link DataTypeHint} instead. - * * @param key type * @param value type */ @@ -119,8 +124,9 @@ public void setMap(Map map) { /** * Return the value for the specified key or {@code null} if the key is not in the map view. * - * @param key The look up key. - * @return The value for the specified key. + * @param key The key whose associated value is to be returned + * @return The value to which the specified key is mapped, or {@code null} if this map contains + * no mapping for the key * @throws Exception Thrown if the system cannot get data. */ public V get(K key) throws Exception { diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java index 2bcfeb9befc35..21c92cbfa4c08 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java @@ -24,6 +24,8 @@ import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.annotation.FunctionHint; import org.apache.flink.table.annotation.StateHint; +import org.apache.flink.table.api.dataview.ListView; +import org.apache.flink.table.api.dataview.MapView; import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.types.extraction.TypeInferenceExtractor; import org.apache.flink.table.types.inference.TypeInference; @@ -290,6 +292,54 @@ * } * } * + *

Large State

+ * + *

Flink's state backends provide different types of state to efficiently handle large state. + * + *

Currently, PTFs support three types of state: + * + *

    + *
  • Value state: Represents a single value. + *
  • List state: Represents a list of values, supporting operations like appending, + * removing, and iterating. + *
  • Map state: Represents a map (key-value pair) for efficient lookups, modifications, + * and removal of individual entries. + *
+ * + *

By default, state entries in a PTF are represented as value state. This means that every state + * entry is fully read from the state backend when the evaluation method is called, and the value is + * written back to the state backend once the evaluation method finishes. + * + *

To optimize state access and avoid unnecessary (de)serialization, state entries can be + * declared as {@link ListView} or {@link MapView}. These provide direct views to the underlying + * Flink state backend. + * + *

For example, when using a {@link MapView}, accessing a value via {@link MapView#get(Object)} + * will only deserialize the value associated with the specified key. This allows for efficient + * access to individual entries without needing to load the entire map. This approach is + * particularly useful when the map does not fit entirely into memory. + * + *

State TTL is applied individually to each entry in a list or map, allowing for fine-grained + * expiration control over state elements. + * + *

{@code
+ * // Function that uses a map view for storing a large map for an event history per user
+ * class HistoryFunction extends ProcessTableFunction {
+ *   public void eval(@StateHint MapView largeMemory, @ArgumentHint(TABLE_AS_SET) Row input) {
+ *     String eventId = input.getFieldAs("eventId");
+ *     Integer count = largeMemory.get(eventId);
+ *     if (count == null) {
+ *       largeMemory.put(eventId, 1);
+ *     } else {
+ *       if (count > 1000) {
+ *         collect("Anomaly detected: " + eventId);
+ *       }
+ *       largeMemory.put(eventId, count + 1);
+ *     }
+ *   }
+ * }
+ * }
+ * *

Time and Timers

* *

A PTF supports event time natively. Time-based services are available via {@link diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java index 3eea834a3ab3f..13e13b5a80827 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java @@ -206,7 +206,8 @@ protected Transformation translateToPlanInternal( .collect(Collectors.toList()); final GeneratedHashFunction[] stateHashCode = runtimeStateInfos.stream() - .map(RuntimeStateInfo::getType) + .map(RuntimeStateInfo::getDataType) + .map(DataType::getLogicalType) .map( t -> HashCodeGenerator.generateRowHash( @@ -217,7 +218,8 @@ protected Transformation translateToPlanInternal( .toArray(GeneratedHashFunction[]::new); final GeneratedRecordEqualiser[] stateEquals = runtimeStateInfos.stream() - .map(RuntimeStateInfo::getType) + .map(RuntimeStateInfo::getDataType) + .map(DataType::getLogicalType) .map(t -> EqualiserCodeGenerator.generateRowEquals(ctx, t, "StateEquals")) .toArray(GeneratedRecordEqualiser[]::new); @@ -315,7 +317,7 @@ private static RuntimeStateInfo createRuntimeStateInfo( String name, StateInfo stateInfo, ExecNodeConfig config) { return new RuntimeStateInfo( name, - stateInfo.getDataType().getLogicalType(), + stateInfo.getDataType(), deriveStateTimeToLive( stateInfo.getTimeToLive().orElse(null), config.getStateRetentionTime())); } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala index c5064961be4c2..42914555c81cb 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala @@ -18,7 +18,9 @@ package org.apache.flink.table.planner.codegen import org.apache.flink.api.common.functions.OpenContext +import org.apache.flink.api.common.state.{ListState, MapState} import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.api.dataview.{DataView, ListView, MapView} import org.apache.flink.table.connector.ChangelogMode import org.apache.flink.table.data.RowData import org.apache.flink.table.data.conversion.RowRowConverter @@ -33,17 +35,20 @@ import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil.{ver import org.apache.flink.table.planner.delegation.PlannerBase import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction import org.apache.flink.table.planner.functions.inference.OperatorBindingCallContext -import org.apache.flink.table.planner.plan.utils.RexLiteralUtil import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala +import org.apache.flink.table.runtime.dataview.DataViewUtils +import org.apache.flink.table.runtime.dataview.StateListView.KeyedStateListView +import org.apache.flink.table.runtime.dataview.StateMapView.KeyedStateMapViewWithKeysNotNull import org.apache.flink.table.runtime.generated.{GeneratedProcessTableRunner, ProcessTableRunner} import org.apache.flink.table.types.DataType import org.apache.flink.table.types.extraction.ExtractionUtils import org.apache.flink.table.types.inference.{StaticArgument, StaticArgumentTrait, SystemTypeInference, TypeInferenceUtil} import org.apache.flink.table.types.inference.TypeInferenceUtil.StateInfo +import org.apache.flink.table.types.logical.LogicalType import org.apache.flink.table.types.logical.utils.LogicalTypeChecks import org.apache.flink.types.Row -import org.apache.calcite.rex.{RexCall, RexCallBinding, RexLiteral, RexNode, RexUtil} +import org.apache.calcite.rex.{RexCall, RexCallBinding, RexNode, RexUtil} import org.apache.calcite.sql.SqlKind import java.util @@ -116,16 +121,19 @@ object ProcessTableRunnerGenerator { val stateDataTypes = stateInfos.asScala.values.map(_.getDataType).toSeq stateDataTypes.foreach(ExtractionUtils.checkStateDataType) - val stateToFunctionTerm = "stateToFunction" + val stateHandlesTerm = "stateHandles" + val valueStateToFunctionTerm = "valueStateToFunction" val stateClearedTerm = "stateCleared" - val stateFromFunctionTerm = "stateFromFunction" - val externalStateOperands = generateStateToFunction(ctx, stateDataTypes, stateToFunctionTerm) + val valueStateFromFunctionTerm = "valueStateFromFunction" + val stateEntries = stateInfos.asScala.values.zipWithIndex.toSeq + val externalStateOperands = + generateStateToFunction(ctx, stateEntries, stateHandlesTerm, valueStateToFunctionTerm) val stateFromFunctionCode = generateStateFromFunction( ctx, - stateDataTypes, + stateEntries, externalStateOperands, stateClearedTerm, - stateFromFunctionTerm) + valueStateFromFunctionTerm) // Generate result collector val resultCollectorTerm = @@ -238,62 +246,118 @@ object ProcessTableRunnerGenerator { private def generateStateToFunction( ctx: CodeGeneratorContext, - stateDataTypes: Seq[DataType], - stateToFunctionTerm: String): Seq[GeneratedExpression] = { - stateDataTypes.zipWithIndex - .map { - case (stateDataType, pos) => - val stateEntryTerm = s"$stateToFunctionTerm[$pos]" - val externalStateTypeTerm = typeTerm(stateDataType.getConversionClass) - val externalStateTerm = newName(ctx, "externalState") - - val converterCode = genToExternalConverter(ctx, stateDataType, stateEntryTerm) - - val constructorCode = stateDataType.getConversionClass match { - case rowType if rowType == classOf[Row] => - // This allows us to retrieve the converter term that has been generated - // in genToExternalConverter(). The converter is able to created named positions - // for row fields. - val converterTerm = ctx.addReusableConverter(stateDataType) - s"((${className[RowRowConverter]}) $converterTerm).createEmptyRow()" - case rowType if rowType == classOf[RowData] => - val fieldCount = LogicalTypeChecks.getFieldCount(stateDataType.getLogicalType) - s"new $GENERIC_ROW($fieldCount)" - case structuredType @ _ => s"new ${className(structuredType)}()" - } + stateEntries: Seq[(StateInfo, Int)], + stateHandlesTerm: String, + valueStateToFunctionTerm: String): Seq[GeneratedExpression] = { + stateEntries.map { + case (stateInfo, pos) => + val stateDataType = stateInfo.getDataType + val stateType = stateDataType.getLogicalType + val externalStateTypeTerm = typeTerm(stateDataType.getConversionClass) + val externalStateTerm = newName(ctx, "externalState") + + val externalStateCode = if (DataViewUtils.isDataView(stateType, classOf[DataView])) { + generateDataViewStateToFunction( + ctx, + stateHandlesTerm, + pos, + stateType, + externalStateTypeTerm, + externalStateTerm) + NO_CODE + } else { + DataViewUtils.checkForInvalidDataViews(stateType) + generateValueStateToFunction( + ctx, + valueStateToFunctionTerm, + pos, + externalStateTypeTerm, + externalStateTerm, + stateDataType) + } - val externalStateCode = - s""" - |final $externalStateTypeTerm $externalStateTerm; - |if ($stateEntryTerm == null) { - | $externalStateTerm = $constructorCode; - |} else { - | $externalStateTerm = $converterCode; - |} - |""".stripMargin + GeneratedExpression(s"$externalStateTerm", NEVER_NULL, externalStateCode, stateType) + } + } - GeneratedExpression( - s"$externalStateTerm", - NEVER_NULL, - externalStateCode, - stateDataType.getLogicalType) + private def generateDataViewStateToFunction( + ctx: CodeGeneratorContext, + stateHandlesTerm: String, + pos: Int, + stateType: LogicalType, + externalStateTypeTerm: String, + externalStateTerm: String): Unit = { + ctx.addReusableMember(s"private $externalStateTypeTerm $externalStateTerm;") + + val (constructor, stateHandleTypeTerm) = + if (DataViewUtils.isDataView(stateType, classOf[ListView[_]])) { + (className[KeyedStateListView[_, _]], className[ListState[_]]) + } else if (DataViewUtils.isDataView(stateType, classOf[MapView[_, _]])) { + (className[KeyedStateMapViewWithKeysNotNull[_, _, _]], className[MapState[_, _]]) } + + val openCode = + s""" + |$externalStateTerm = new $constructor(($stateHandleTypeTerm) $stateHandlesTerm[$pos]); + """.stripMargin + ctx.addReusableOpenStatement(openCode) + } + + private def generateValueStateToFunction( + ctx: CodeGeneratorContext, + valueStateToFunctionTerm: String, + pos: Int, + externalStateTypeTerm: String, + externalStateTerm: String, + stateDataType: DataType): String = { + val stateEntryTerm = s"$valueStateToFunctionTerm[$pos]" + val converterCode = genToExternalConverter(ctx, stateDataType, stateEntryTerm) + + val constructorCode = stateDataType.getConversionClass match { + case rowType if rowType == classOf[Row] => + // This allows us to retrieve the converter term that has been generated + // in genToExternalConverter(). The converter is able to create named positions + // for row fields. + val converterTerm = ctx.addReusableConverter(stateDataType) + s"((${className[RowRowConverter]}) $converterTerm).createEmptyRow()" + case rowType if rowType == classOf[RowData] => + val fieldCount = LogicalTypeChecks.getFieldCount(stateDataType.getLogicalType) + s"new $GENERIC_ROW($fieldCount)" + case structuredType @ _ => s"new ${className(structuredType)}()" + } + + s""" + |final $externalStateTypeTerm $externalStateTerm; + |if ($stateEntryTerm == null) { + | $externalStateTerm = $constructorCode; + |} else { + | $externalStateTerm = $converterCode; + |} + |""".stripMargin } private def generateStateFromFunction( ctx: CodeGeneratorContext, - stateDataTypes: Seq[DataType], + stateEntries: Seq[(StateInfo, Int)], externalStateOperands: Seq[GeneratedExpression], stateClearedTerm: String, stateFromFunctionTerm: String): String = { - stateDataTypes.zipWithIndex + stateEntries .map { - case (stateDataType, pos) => - val stateEntryTerm = s"$stateFromFunctionTerm[$pos]" - val externalStateOperandTerm = externalStateOperands(pos).resultTerm - s"$stateEntryTerm = $stateClearedTerm[$pos] ? null : " + - s"${genToInternalConverter(ctx, stateDataType)(externalStateOperandTerm)};" + case (stateInfo, pos) => + val stateDataType = stateInfo.getDataType + val stateType = stateDataType.getLogicalType + + if (DataViewUtils.isDataView(stateType, classOf[DataView])) { + NO_CODE + } else { + val stateEntryTerm = s"$stateFromFunctionTerm[$pos]" + val externalStateOperandTerm = externalStateOperands(pos).resultTerm + s"$stateEntryTerm = $stateClearedTerm[$pos] ? null : " + + s"${genToInternalConverter(ctx, stateDataType)(externalStateOperandTerm)};" + } } + .filter(c => c != NO_CODE) .mkString("\n") } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index 21925ddc5d53e..8c9404de7b972 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -38,11 +38,10 @@ import org.apache.flink.table.planner.plan.`trait`.{ModifyKindSetTrait, ModifyKi import org.apache.flink.table.planner.plan.logical.{HoppingWindowSpec, WindowSpec} import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRel -import org.apache.flink.table.planner.typeutils.DataViewUtils import org.apache.flink.table.planner.typeutils.LegacyDataViewUtils.useNullSerializerForStateViewFieldsFromAccType import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory -import org.apache.flink.table.runtime.dataview.DataViewSpec +import org.apache.flink.table.runtime.dataview.{DataViewSpec, DataViewUtils} import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction import org.apache.flink.table.runtime.groupwindow._ import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/aggregation.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/aggregation.scala index 76aa5b28085ff..1001e38e645bd 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/aggregation.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/aggregation.scala @@ -18,8 +18,8 @@ package org.apache.flink.table.planner.plan.utils import org.apache.flink.table.functions.UserDefinedFunction -import org.apache.flink.table.planner.typeutils.DataViewUtils.DistinctViewSpec import org.apache.flink.table.runtime.dataview.DataViewSpec +import org.apache.flink.table.runtime.dataview.DataViewUtils.DistinctViewSpec import org.apache.flink.table.types.DataType import org.apache.calcite.rel.core.AggregateCall diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/serde/LogicalTypeJsonSerdeTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/serde/LogicalTypeJsonSerdeTest.java index d571d75388532..e2de3f8dba1f8 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/serde/LogicalTypeJsonSerdeTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/serde/LogicalTypeJsonSerdeTest.java @@ -28,7 +28,7 @@ import org.apache.flink.table.catalog.CatalogManager; import org.apache.flink.table.catalog.ObjectIdentifier; import org.apache.flink.table.planner.plan.nodes.exec.serde.DataTypeJsonSerdeTest.PojoClass; -import org.apache.flink.table.planner.typeutils.DataViewUtils; +import org.apache.flink.table.runtime.dataview.DataViewUtils; import org.apache.flink.table.runtime.typeutils.ExternalSerializer; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.ArrayType; diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionSemanticTests.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionSemanticTests.java index ef35001d0deef..8feb59a16789b 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionSemanticTests.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionSemanticTests.java @@ -72,6 +72,8 @@ public List programs() { ProcessTableFunctionTestPrograms.PROCESS_CHAINED_TIME_TABLE_API, ProcessTableFunctionTestPrograms.PROCESS_INVALID_TABLE_AS_ROW_TIMERS, ProcessTableFunctionTestPrograms.PROCESS_INVALID_PASS_THROUGH_TIMERS, - ProcessTableFunctionTestPrograms.PROCESS_INVALID_UPDATING_TIMERS); + ProcessTableFunctionTestPrograms.PROCESS_INVALID_UPDATING_TIMERS, + ProcessTableFunctionTestPrograms.PROCESS_LIST_STATE, + ProcessTableFunctionTestPrograms.PROCESS_MAP_STATE); } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestPrograms.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestPrograms.java index 4bf600228c98a..a2396a80e7694 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestPrograms.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestPrograms.java @@ -32,6 +32,8 @@ import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.InvalidTableAsRowTimersFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.InvalidUpdatingTimersFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.LateTimersFunction; +import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ListStateFunction; +import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.MapStateFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.MultiStateFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.NamedTimersFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.OptionalOnTimeFunction; @@ -960,4 +962,38 @@ public class ProcessTableFunctionTestPrograms { TableRuntimeException.class, "Timers are not supported in the current PTF declaration.") .build(); + + public static final TableTestProgram PROCESS_LIST_STATE = + TableTestProgram.of("process-list-state", "list view state entry") + .setupTemporarySystemFunction("f", ListStateFunction.class) + .setupSql(MULTI_VALUES) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addSchema(KEYED_BASE_SINK_SCHEMA) + .consumedValues( + "+I[Bob, {[], KeyedStateListView, +I[Bob, 12]}]", + "+I[Alice, {[], KeyedStateListView, +I[Alice, 42]}]", + "+I[Bob, {[0], KeyedStateListView, +I[Bob, 99]}]", + "+I[Bob, {[0, 1], KeyedStateListView, +I[Bob, 100]}]", + "+I[Alice, {[0], KeyedStateListView, +I[Alice, 400]}]") + .build()) + .runSql("INSERT INTO sink SELECT * FROM f(r => TABLE t PARTITION BY name)") + .build(); + + public static final TableTestProgram PROCESS_MAP_STATE = + TableTestProgram.of("process-map-state", "map view state entry") + .setupTemporarySystemFunction("f", MapStateFunction.class) + .setupSql(MULTI_VALUES) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addSchema(KEYED_BASE_SINK_SCHEMA) + .consumedValues( + "+I[Bob, {{}, KeyedStateMapViewWithKeysNotNull, +I[Bob, 12]}]", + "+I[Alice, {{}, KeyedStateMapViewWithKeysNotNull, +I[Alice, 42]}]", + "+I[Bob, {{Bob=2, nullValue=null, oldBob=1}, KeyedStateMapViewWithKeysNotNull, +I[Bob, 99]}]", + "+I[Bob, {{}, KeyedStateMapViewWithKeysNotNull, +I[Bob, 100]}]", + "+I[Alice, {{Alice=2, nullValue=null, oldAlice=1}, KeyedStateMapViewWithKeysNotNull, +I[Alice, 400]}]") + .build()) + .runSql("INSERT INTO sink SELECT * FROM f(r => TABLE t PARTITION BY name)") + .build(); } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestUtils.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestUtils.java index 036caa43571ea..4e526d07fd954 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestUtils.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestUtils.java @@ -22,6 +22,9 @@ import org.apache.flink.table.annotation.ArgumentTrait; import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.annotation.StateHint; +import org.apache.flink.table.api.TableRuntimeException; +import org.apache.flink.table.api.dataview.ListView; +import org.apache.flink.table.api.dataview.MapView; import org.apache.flink.table.functions.ProcessTableFunction; import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.functions.TableSemantics; @@ -32,6 +35,11 @@ import java.time.Instant; import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; import static org.apache.flink.table.annotation.ArgumentTrait.OPTIONAL_PARTITION_BY; import static org.apache.flink.table.annotation.ArgumentTrait.PASS_COLUMNS_THROUGH; @@ -39,6 +47,7 @@ import static org.apache.flink.table.annotation.ArgumentTrait.SUPPORT_UPDATES; import static org.apache.flink.table.annotation.ArgumentTrait.TABLE_AS_ROW; import static org.apache.flink.table.annotation.ArgumentTrait.TABLE_AS_SET; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Testing functions for {@link ProcessTableFunction}. */ @SuppressWarnings("unused") @@ -344,9 +353,9 @@ public void eval( collect( String.format( "s1=%s, s2=%s, s3=%s", - internalContext.getValueStateDescriptor("s1").getTtlConfig(), - internalContext.getValueStateDescriptor("s2").getTtlConfig(), - internalContext.getValueStateDescriptor("s3").getTtlConfig())); + internalContext.getStateDescriptor("s1").getTtlConfig(), + internalContext.getStateDescriptor("s2").getTtlConfig(), + internalContext.getStateDescriptor("s3").getTtlConfig())); s0.setField("emitted", true); } } @@ -605,6 +614,80 @@ public void eval( } } + /** Testing function. */ + public static class ListStateFunction extends TestProcessTableFunctionBase { + public void eval( + Context ctx, + @StateHint ListView s, + @ArgumentHint({TABLE_AS_SET, OPTIONAL_PARTITION_BY}) Row r) + throws Exception { + collectObjects(s.getList(), s.getClass().getSimpleName(), r); + + // get + int count = s.getList().size(); + + // create + s.add(String.valueOf(count)); + + // null behavior + assertThatThrownBy(() -> s.add(null)) + .isInstanceOf(TableRuntimeException.class) + .hasMessageContaining("List views don't support null values."); + assertThatThrownBy(() -> s.addAll(Arrays.asList("item0", null))) + .isInstanceOf(TableRuntimeException.class) + .hasMessageContaining("List views don't support null values."); + + // clear + if (count == 2) { + ctx.clearState("s"); + } + } + } + + /** Testing function. */ + public static class MapStateFunction extends TestProcessTableFunctionBase { + public void eval( + Context ctx, + @StateHint MapView s, + @ArgumentHint({TABLE_AS_SET, OPTIONAL_PARTITION_BY}) Row r) + throws Exception { + final String viewToString = + s.getMap().entrySet().stream() + .map(Objects::toString) + .sorted() + .collect(Collectors.joining(", ", "{", "}")); + collectObjects(viewToString, s.getClass().getSimpleName(), r); + + // get + final String name = r.getFieldAs("name"); + int count = 1; + if (s.contains(name)) { + count = s.get(name); + } + + // create + s.put("old" + name, count); + s.put(name, count + 1); + + // null behavior + assertThatThrownBy(() -> s.put(null, 42)) + .isInstanceOf(TableRuntimeException.class) + .hasMessageContaining("Map views don't support null keys."); + final Map mapWithNull = new HashMap<>(); + mapWithNull.put("key", 42); + mapWithNull.put(null, 42); + assertThatThrownBy(() -> s.putAll(mapWithNull)) + .isInstanceOf(TableRuntimeException.class) + .hasMessageContaining("Map views don't support null keys."); + s.put("nullValue", null); + + // clear + if (count == 2) { + ctx.clearState("s"); + } + } + } + // -------------------------------------------------------------------------------------------- // Helpers // -------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/DataViewUtils.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/DataViewUtils.java similarity index 81% rename from flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/DataViewUtils.java rename to flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/DataViewUtils.java index db8d29abc48d7..94fa2a1c2b50c 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/DataViewUtils.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/DataViewUtils.java @@ -16,20 +16,17 @@ * limitations under the License. */ -package org.apache.flink.table.planner.typeutils; +package org.apache.flink.table.runtime.dataview; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.api.dataview.DataView; import org.apache.flink.table.api.dataview.ListView; import org.apache.flink.table.api.dataview.MapView; import org.apache.flink.table.data.binary.LazyBinaryFormat; import org.apache.flink.table.dataview.NullSerializer; -import org.apache.flink.table.runtime.dataview.DataViewSpec; -import org.apache.flink.table.runtime.dataview.ListViewSpec; -import org.apache.flink.table.runtime.dataview.MapViewSpec; import org.apache.flink.table.runtime.typeutils.ExternalSerializer; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.inference.TypeTransformation; @@ -51,13 +48,40 @@ /** * Utilities to deal with {@link DataView}s. * - *

A {@link DataView} is either represented as a regular {@link StructuredType} or as a {@link - * RawType} that serializes to {@code null} when backed by a state backend. In the latter case, a - * {@link DataViewSpec} contains all information necessary to store and retrieve data from state. + *

For aggregating functions: A {@link DataView} is a field that is either represented as a + * regular {@link StructuredType} or as a {@link RawType} that serializes to {@code null} when + * backed by a state backend. In the latter case, a {@link DataViewSpec} contains all information + * necessary to store and retrieve data from state. + * + *

For process table functions: A {@link DataView} is a top-level instance that is always backed + * by a state backend. */ @Internal public final class DataViewUtils { + /** Returns whether the given {@link LogicalType} qualifies as a {@link DataView}. */ + public static boolean isDataView(LogicalType viewType, Class viewClass) { + final boolean isDataView = + viewType.is(STRUCTURED_TYPE) + && ((StructuredType) viewType) + .getImplementationClass() + .map(viewClass::isAssignableFrom) + .orElse(false); + if (!isDataView) { + return false; + } + viewType.getChildren().forEach(DataViewUtils::checkForInvalidDataViews); + return true; + } + + /** Checks that the given type and its children do not contain data views. */ + public static void checkForInvalidDataViews(LogicalType type) { + if (hasNested(type, t -> isDataView(t, DataView.class))) { + throw new ValidationException( + "Data views are not supported at the declared location. Given type: " + type); + } + } + /** Searches for data views in the data type of an accumulator and extracts them. */ public static List extractDataViews(int aggIndex, DataType accumulatorDataType) { final LogicalType accumulatorType = accumulatorDataType.getLogicalType(); @@ -85,11 +109,6 @@ public static List extractDataViews(int aggIndex, DataType accumul fieldDataType.getChildren().get(0), false)); } - if (fieldType.getChildren().stream() - .anyMatch(c -> hasNested(c, t -> isDataView(t, DataView.class)))) { - throw new TableException( - "Data views are only supported in the first level of a composite accumulator type."); - } } return specs; } @@ -138,14 +157,6 @@ private static String createStateId(int fieldIndex, String fieldName) { return "agg" + fieldIndex + "$" + fieldName; } - private static boolean isDataView(LogicalType t, Class viewClass) { - return t.is(STRUCTURED_TYPE) - && ((StructuredType) t) - .getImplementationClass() - .map(viewClass::isAssignableFrom) - .orElse(false); - } - // -------------------------------------------------------------------------------------------- private static class DataViewsTransformation implements TypeTransformation { diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/StateListView.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/StateListView.java index 8b2b4ac95ee0c..8aa0d0149c332 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/StateListView.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/StateListView.java @@ -21,6 +21,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.state.ListState; import org.apache.flink.runtime.state.internal.InternalListState; +import org.apache.flink.table.api.TableRuntimeException; import org.apache.flink.table.api.dataview.ListView; import java.util.ArrayList; @@ -44,7 +45,7 @@ public List getList() { try { get().forEach(list::add); } catch (Exception e) { - throw new RuntimeException(e); + throw new RuntimeException("Unable to collect list.", e); } return list; } @@ -54,8 +55,10 @@ public void setList(List list) { clear(); try { addAll(list); + } catch (TableRuntimeException e) { + throw e; } catch (Exception e) { - throw new RuntimeException(e); + throw new RuntimeException("Unable to replace list.", e); } } @@ -67,16 +70,19 @@ public Iterable get() throws Exception { @Override public void add(EE value) throws Exception { + checkValue(value); getListState().add(value); } @Override public void addAll(List list) throws Exception { + checkList(list); getListState().addAll(list); } @Override public boolean remove(EE value) throws Exception { + checkValue(value); Iterable iterable = getListState().get(); if (iterable == null) { // ListState.get() may return null according to the Javadoc. @@ -152,4 +158,16 @@ protected ListState getListState() { return listState; } } + + private static void checkValue(Object value) { + if (value == null) { + throw new TableRuntimeException("List views don't support null values."); + } + } + + private static void checkList(List list) { + if (list.contains(null)) { + throw new TableRuntimeException("List views don't support null values."); + } + } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/StateMapView.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/StateMapView.java index e7df0e53c34ab..b0ac947bd8dac 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/StateMapView.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/dataview/StateMapView.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.state.ValueState; import org.apache.flink.runtime.state.internal.InternalMapState; import org.apache.flink.runtime.state.internal.InternalValueState; +import org.apache.flink.table.api.TableRuntimeException; import org.apache.flink.table.api.dataview.MapView; import org.apache.flink.util.IterableIterator; @@ -48,7 +49,7 @@ public Map getMap() { try { entries().forEach(entry -> map.put(entry.getKey(), entry.getValue())); } catch (Exception e) { - throw new RuntimeException(e); + throw new RuntimeException("Unable to collect map.", e); } return map; } @@ -58,8 +59,10 @@ public void setMap(Map map) { clear(); try { putAll(map); + } catch (TableRuntimeException e) { + throw e; } catch (Exception e) { - throw new RuntimeException(e); + throw new RuntimeException("Unable to replace map.", e); } } @@ -81,26 +84,31 @@ private abstract static class StateMapViewWithKeysNotNull @Override public EV get(EK key) throws Exception { + checkKey(key); return getMapState().get(key); } @Override public void put(EK key, EV value) throws Exception { + checkKey(key); getMapState().put(key, value); } @Override public void putAll(Map map) throws Exception { + checkMap(map); getMapState().putAll(map); } @Override public void remove(EK key) throws Exception { + checkKey(key); getMapState().remove(key); } @Override public boolean contains(EK key) throws Exception { + checkKey(key); return getMapState().contains(key); } @@ -137,6 +145,18 @@ public boolean isEmpty() throws Exception { public void clear() { getMapState().clear(); } + + private static void checkKey(Object key) { + if (key == null) { + throw new TableRuntimeException("Map views don't support null keys."); + } + } + + private static void checkMap(Map map) { + if (map.containsKey(null)) { + throw new TableRuntimeException("Map views don't support null keys."); + } + } } /** diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/ProcessTableRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/ProcessTableRunner.java index 3136cefbcad83..f285f0c3d461a 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/ProcessTableRunner.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/ProcessTableRunner.java @@ -20,6 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.state.State; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.StringData; @@ -45,7 +46,7 @@ public abstract class ProcessTableRunner extends AbstractRichFunction { // Constant references after initialization - private ValueState[] stateHandles; + protected State[] stateHandles; private HashFunction[] stateHashCode; private RecordEqualiser[] stateEquals; private boolean emitRowtime; @@ -68,7 +69,7 @@ public abstract class ProcessTableRunner extends AbstractRichFunction { private @Nullable StringData timerName; /** State entries to be converted into external data structure; null if state is empty. */ - protected RowData[] stateToFunction; + protected RowData[] valueStateToFunction; /** * Reference to whether the state has been cleared within the function; if yes, a conversion @@ -77,10 +78,10 @@ public abstract class ProcessTableRunner extends AbstractRichFunction { protected boolean[] stateCleared; /** State ready for persistence; null if {@link #stateCleared} was true during conversion. */ - protected RowData[] stateFromFunction; + protected RowData[] valueStateFromFunction; public void initialize( - ValueState[] stateHandles, + State[] stateHandles, HashFunction[] stateHashCode, RecordEqualiser[] stateEquals, boolean emitRowtime, @@ -98,9 +99,9 @@ public void initialize( this.runnerOnTimerContext = runnerOnTimerContext; this.evalCollector = evalCollector; this.onTimerCollector = onTimerCollector; - this.stateToFunction = new RowData[stateHandles.length]; + this.valueStateToFunction = new RowData[stateHandles.length]; this.stateCleared = new boolean[stateHandles.length]; - this.stateFromFunction = new RowData[stateHandles.length]; + this.valueStateFromFunction = new RowData[stateHandles.length]; } public void ingestTableEvent(int pos, RowData row, int timeColumn) { @@ -183,34 +184,52 @@ private void processMethod(RunnableWithException method) throws Exception { } } + @SuppressWarnings("unchecked") private void moveStateToFunction() throws IOException { Arrays.fill(stateCleared, false); for (int i = 0; i < stateHandles.length; i++) { - final RowData value = stateHandles[i].value(); - stateToFunction[i] = value; + final State stateHandle = stateHandles[i]; + if (!(stateHandle instanceof ValueState)) { + continue; + } + final ValueState valueState = (ValueState) stateHandle; + final RowData value = valueState.value(); + valueStateToFunction[i] = value; } } + @SuppressWarnings("unchecked") private void moveStateFromFunction() throws IOException { for (int i = 0; i < stateHandles.length; i++) { - final RowData fromFunction = stateFromFunction[i]; - if (fromFunction == null || isEmpty(fromFunction)) { - // Reduce state size - stateHandles[i].clear(); + final State stateHandle = stateHandles[i]; + if (stateHandle instanceof ValueState) { + moveValueStateFromFunction((ValueState) stateHandle, i); } else { - final HashFunction hashCode = stateHashCode[i]; - final RecordEqualiser equals = stateEquals[i]; - final RowData toFunction = stateToFunction[i]; - // Reduce state updates by checking if something has changed - if (toFunction == null - || hashCode.hashCode(toFunction) != hashCode.hashCode(fromFunction) - || !equals.equals(toFunction, fromFunction)) { - stateHandles[i].update(fromFunction); + if (stateCleared[i]) { + stateHandle.clear(); } } } } + private void moveValueStateFromFunction(ValueState valueState, int pos) + throws IOException { + final RowData fromFunction = valueStateFromFunction[pos]; + if (fromFunction == null || isEmpty(fromFunction)) { + valueState.clear(); + } else { + final HashFunction hashCode = stateHashCode[pos]; + final RecordEqualiser equals = stateEquals[pos]; + final RowData toFunction = valueStateToFunction[pos]; + // Reduce state updates by checking if something has changed + if (toFunction == null + || hashCode.hashCode(toFunction) != hashCode.hashCode(fromFunction) + || !equals.equals(toFunction, fromFunction)) { + valueState.update(fromFunction); + } + } + } + private static boolean isEmpty(RowData row) { for (int i = 0; i < row.getArity(); i++) { if (!row.isNullAt(i)) { diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperator.java index 30ba25d82aa82..cc3feffcd1ef4 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperator.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperator.java @@ -23,10 +23,12 @@ import org.apache.flink.api.common.functions.DefaultOpenContext; import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.state.KeyedStateStore; +import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapState; import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.state.StateTtlConfig; -import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.runtime.state.VoidNamespace; @@ -40,20 +42,27 @@ import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.api.TableRuntimeException; +import org.apache.flink.table.api.dataview.ListView; +import org.apache.flink.table.api.dataview.MapView; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.StringData; import org.apache.flink.table.functions.ProcessTableFunction; import org.apache.flink.table.functions.ProcessTableFunction.TimeContext; import org.apache.flink.table.functions.TableSemantics; +import org.apache.flink.table.runtime.dataview.DataViewUtils; import org.apache.flink.table.runtime.generated.HashFunction; import org.apache.flink.table.runtime.generated.ProcessTableRunner; import org.apache.flink.table.runtime.generated.RecordEqualiser; import org.apache.flink.table.runtime.operators.process.TimeConverter.InstantTimeConverter; import org.apache.flink.table.runtime.operators.process.TimeConverter.LocalDateTimeConverter; import org.apache.flink.table.runtime.operators.process.TimeConverter.LongTimeConverter; +import org.apache.flink.table.runtime.typeutils.ExternalSerializer; import org.apache.flink.table.runtime.typeutils.InternalSerializers; import org.apache.flink.table.runtime.typeutils.StringDataSerializer; import org.apache.flink.table.runtime.util.StateConfigUtil; +import org.apache.flink.table.types.CollectionDataType; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.KeyValueDataType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.types.RowKind; @@ -79,8 +88,8 @@ public class ProcessTableOperator extends AbstractStreamOperator private transient ReadableInternalTimeContext internalTimeContext; private transient PassThroughCollectorBase evalCollector; private transient PassAllCollector onTimerCollector; - private transient ValueStateDescriptor[] stateDescriptors; - private transient ValueState[] stateHandles; + private transient StateDescriptor[] stateDescriptors; + private transient State[] stateHandles; private transient @Nullable MapState namedTimersMapState; private transient @Nullable InternalTimerService namedTimerService; @@ -241,7 +250,7 @@ public void clearAll() { } @VisibleForTesting - public ValueStateDescriptor getValueStateDescriptor(String stateName) { + public StateDescriptor getStateDescriptor(String stateName) { final Integer statePos = stateNameToPosMap.get(stateName); if (statePos == null) { throw new TableRuntimeException("Unknown state entry: " + stateName); @@ -311,16 +320,37 @@ private void setCollectors() { onTimerCollector = new PassAllCollector(output); } - @SuppressWarnings("unchecked") private void setStateDescriptors() { - final ValueStateDescriptor[] stateDescriptors = - new ValueStateDescriptor[stateInfos.size()]; + final StateDescriptor[] stateDescriptors = new StateDescriptor[stateInfos.size()]; for (int i = 0; i < stateInfos.size(); i++) { final RuntimeStateInfo stateInfo = stateInfos.get(i); - final LogicalType type = stateInfo.getType(); - final ValueStateDescriptor stateDescriptor = - new ValueStateDescriptor<>( - stateInfo.getStateName(), InternalSerializers.create(type)); + final DataType dataType = stateInfo.getDataType(); + final LogicalType type = dataType.getLogicalType(); + final String stateName = stateInfo.getStateName(); + + final StateDescriptor stateDescriptor; + if (DataViewUtils.isDataView(type, ListView.class)) { + final CollectionDataType arrayDataType = + (CollectionDataType) dataType.getChildren().get(0); + final DataType elementDataType = arrayDataType.getElementDataType(); + stateDescriptor = + new ListStateDescriptor<>( + stateName, ExternalSerializer.of(elementDataType)); + } else if (DataViewUtils.isDataView(type, MapView.class)) { + final KeyValueDataType mapDataType = + (KeyValueDataType) dataType.getChildren().get(0); + final DataType keyDataType = mapDataType.getKeyDataType(); + final DataType valueDataType = mapDataType.getValueDataType(); + stateDescriptor = + new MapStateDescriptor<>( + stateName, + ExternalSerializer.of(keyDataType), + ExternalSerializer.of(valueDataType)); + } else { + stateDescriptor = + new ValueStateDescriptor<>(stateName, InternalSerializers.create(type)); + } + final StateTtlConfig ttlConfig = StateConfigUtil.createTtlConfig(stateInfo.getTimeToLive()); if (ttlConfig.isEnabled()) { @@ -331,12 +361,24 @@ private void setStateDescriptors() { this.stateDescriptors = stateDescriptors; } - @SuppressWarnings("unchecked") private void setStateHandles() { final KeyedStateStore keyedStateStore = getKeyedStateStore(); - final ValueState[] stateHandles = new ValueState[stateDescriptors.length]; - for (int i = 0; i < stateInfos.size(); i++) { - stateHandles[i] = keyedStateStore.getState(stateDescriptors[i]); + final State[] stateHandles = new State[stateDescriptors.length]; + for (int i = 0; i < stateDescriptors.length; i++) { + final StateDescriptor stateDescriptor = stateDescriptors[i]; + final State stateHandle; + if (stateDescriptor instanceof ValueStateDescriptor) { + stateHandle = keyedStateStore.getState((ValueStateDescriptor) stateDescriptor); + } else if (stateDescriptor instanceof ListStateDescriptor) { + stateHandle = + keyedStateStore.getListState((ListStateDescriptor) stateDescriptor); + } else if (stateDescriptor instanceof MapStateDescriptor) { + stateHandle = + keyedStateStore.getMapState((MapStateDescriptor) stateDescriptor); + } else { + throw new IllegalStateException("Unknown state descriptor:" + stateDescriptor); + } + stateHandles[i] = stateHandle; } this.stateHandles = stateHandles; } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RuntimeStateInfo.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RuntimeStateInfo.java index e616ecf5dee7f..5f983e32ff96e 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RuntimeStateInfo.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RuntimeStateInfo.java @@ -18,8 +18,8 @@ package org.apache.flink.table.runtime.operators.process; +import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.inference.TypeInferenceUtil.StateInfo; -import org.apache.flink.table.types.logical.LogicalType; import java.io.Serializable; @@ -32,12 +32,12 @@ public class RuntimeStateInfo implements Serializable { private static final long serialVersionUID = 1L; private final String stateName; - private final LogicalType type; + private final DataType dataType; private final long timeToLive; - public RuntimeStateInfo(String stateName, LogicalType type, long timeToLive) { + public RuntimeStateInfo(String stateName, DataType dataType, long timeToLive) { this.stateName = stateName; - this.type = type; + this.dataType = dataType; this.timeToLive = timeToLive; } @@ -45,8 +45,8 @@ public String getStateName() { return stateName; } - public LogicalType getType() { - return type; + public DataType getDataType() { + return dataType; } public long getTimeToLive() {