Merge pull request #17608 from ihji/BEAM-14430 · tomstepp/apache-beam@2d57753
@@ -18,6 +18,7 @@
1818package org.apache.beam.sdk.extensions.python;
19192020import java.util.Arrays;
21+import java.util.HashMap;
2122import java.util.Map;
2223import java.util.Set;
2324import java.util.SortedMap;
@@ -33,10 +34,12 @@
3334import org.apache.beam.sdk.schemas.Schema;
3435import org.apache.beam.sdk.schemas.SchemaRegistry;
3536import org.apache.beam.sdk.schemas.SchemaTranslation;
37+import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
3638import org.apache.beam.sdk.schemas.utils.StaticSchemaInference;
3739import org.apache.beam.sdk.transforms.PTransform;
3840import org.apache.beam.sdk.transforms.SerializableFunction;
3941import org.apache.beam.sdk.util.CoderUtils;
42+import org.apache.beam.sdk.util.PythonCallableSource;
4043import org.apache.beam.sdk.values.PBegin;
4144import org.apache.beam.sdk.values.PCollection;
4245import org.apache.beam.sdk.values.PCollectionTuple;
@@ -64,6 +67,7 @@ public class PythonExternalTransform<InputT extends PInput, OutputT extends POut
6467// We preseve the order here since Schema's care about order of fields but the order will not
6568// matter when applying kwargs at the Python side.
6669private SortedMap<String, Object> kwargsMap;
70+private Map<java.lang.Class<?>, Schema.FieldType> typeHints;
67716872private @Nullable Object @NonNull [] argsArray;
6973private @Nullable Row providedKwargsRow;
@@ -72,6 +76,11 @@ private PythonExternalTransform(String fullyQualifiedName, String expansionServi
7276this.fullyQualifiedName = fullyQualifiedName;
7377this.expansionService = expansionService;
7478this.kwargsMap = new TreeMap<>();
79+this.typeHints = new HashMap<>();
80+// TODO(BEAM-14458): remove a default type hint for PythonCallableSource when BEAM-14458 is
81+// resolved
82+this.typeHints.put(
83+PythonCallableSource.class, Schema.FieldType.logicalType(new PythonCallable()));
7584argsArray = new Object[] {};
7685 }
7786@@ -162,6 +171,26 @@ public PythonExternalTransform<InputT, OutputT> withKwargs(Row kwargs) {
162171return this;
163172 }
164173174+/**
175+ * Specifies the field type of arguments.
176+ *
177+ * <p>Type hints are especially useful for logical types since type inference does not work well
178+ * for logical types.
179+ *
180+ * @param argType A class object for the argument type.
181+ * @param fieldType A schema field type for the argument.
182+ * @return updated wrapper for the cross-language transform.
183+ */
184+public PythonExternalTransform<InputT, OutputT> withTypeHint(
185+java.lang.Class<?> argType, Schema.FieldType fieldType) {
186+if (typeHints.containsKey(argType)) {
187+throw new IllegalArgumentException(
188+String.format("typehint for arg type %s already exists", argType));
189+ }
190+typeHints.put(argType, fieldType);
191+return this;
192+ }
193+165194@VisibleForTesting
166195Row buildOrGetKwargsRow() {
167196if (providedKwargsRow != null) {
@@ -179,16 +208,18 @@ Row buildOrGetKwargsRow() {
179208// Types that are not one of following are considered custom types.
180209// * Java primitives
181210// * Type String
211+// * Any Type explicitly annotated by withTypeHint()
182212// * Type Row
183-private static boolean isCustomType(java.lang.Class<?> type) {
213+private boolean isCustomType(java.lang.Class<?> type) {
184214boolean val =
185215 !(ClassUtils.isPrimitiveOrWrapper(type)
186216 || type == String.class
217+ || typeHints.containsKey(type)
187218 || Row.class.isAssignableFrom(type));
188219return val;
189220 }
190221191-// If the custom type has a registered schema, we use that. OTherwise we try to register it using
222+// If the custom type has a registered schema, we use that. Otherwise, we try to register it using
192223// 'JavaFieldSchema'.
193224private Row convertCustomValue(Object value) {
194225SerializableFunction<Object, Row> toRowFunc;
@@ -239,6 +270,8 @@ private Schema generateSchemaDirectly(
239270if (field instanceof Row) {
240271// Rows are used as is but other types are converted to proper field types.
241272builder.addRowField(fieldName, ((Row) field).getSchema());
273+ } else if (typeHints.containsKey(field.getClass())) {
274+builder.addField(fieldName, typeHints.get(field.getClass()));
242275 } else {
243276builder.addField(
244277fieldName,