[BEAM-14430] Adding a logical type support for Python callables to Ro… · tomstepp/apache-beam@5a5e51e
@@ -18,6 +18,7 @@
1818package org.apache.beam.sdk.extensions.python;
19192020import java.util.Arrays;
21+import java.util.HashMap;
2122import java.util.Map;
2223import java.util.Set;
2324import java.util.SortedMap;
@@ -64,6 +65,7 @@ public class PythonExternalTransform<InputT extends PInput, OutputT extends POut
6465// We preseve the order here since Schema's care about order of fields but the order will not
6566// matter when applying kwargs at the Python side.
6667private SortedMap<String, Object> kwargsMap;
68+private Map<java.lang.Class<?>, Schema.FieldType> typeHints;
67696870private @Nullable Object @NonNull [] argsArray;
6971private @Nullable Row providedKwargsRow;
@@ -72,6 +74,7 @@ private PythonExternalTransform(String fullyQualifiedName, String expansionServi
7274this.fullyQualifiedName = fullyQualifiedName;
7375this.expansionService = expansionService;
7476this.kwargsMap = new TreeMap<>();
77+this.typeHints = new HashMap<>();
7578argsArray = new Object[] {};
7679 }
7780@@ -162,6 +165,26 @@ public PythonExternalTransform<InputT, OutputT> withKwargs(Row kwargs) {
162165return this;
163166 }
164167168+/**
169+ * Specifies the field type of arguments.
170+ *
171+ * <p>Type hints are especially useful for logical types since type inference does not work well
172+ * for logical types.
173+ *
174+ * @param argType A class object for the argument type.
175+ * @param fieldType A schema field type for the argument.
176+ * @return updated wrapper for the cross-language transform.
177+ */
178+public PythonExternalTransform<InputT, OutputT> withTypeHint(
179+java.lang.Class<?> argType, Schema.FieldType fieldType) {
180+if (typeHints.containsKey(argType)) {
181+throw new IllegalArgumentException(
182+String.format("typehint for arg type %s already exists", argType));
183+ }
184+typeHints.put(argType, fieldType);
185+return this;
186+ }
187+165188@VisibleForTesting
166189Row buildOrGetKwargsRow() {
167190if (providedKwargsRow != null) {
@@ -180,15 +203,17 @@ Row buildOrGetKwargsRow() {
180203// * Java primitives
181204// * Type String
182205// * Type Row
183-private static boolean isCustomType(java.lang.Class<?> type) {
206+// * Any Type explicitly annotated by withTypeHint()
207+private boolean isCustomType(java.lang.Class<?> type) {
184208boolean val =
185209 !(ClassUtils.isPrimitiveOrWrapper(type)
186210 || type == String.class
211+ || typeHints.containsKey(type)
187212 || Row.class.isAssignableFrom(type));
188213return val;
189214 }
190215191-// If the custom type has a registered schema, we use that. OTherwise we try to register it using
216+// If the custom type has a registered schema, we use that. Otherwise, we try to register it using
192217// 'JavaFieldSchema'.
193218private Row convertCustomValue(Object value) {
194219SerializableFunction<Object, Row> toRowFunc;
@@ -239,6 +264,8 @@ private Schema generateSchemaDirectly(
239264if (field instanceof Row) {
240265// Rows are used as is but other types are converted to proper field types.
241266builder.addRowField(fieldName, ((Row) field).getSchema());
267+ } else if (typeHints.containsKey(field.getClass())) {
268+builder.addField(fieldName, typeHints.get(field.getClass()));
242269 } else {
243270builder.addField(
244271fieldName,