[BEAM-14430] Adding a logical type support for Python callables to Ro… · tomstepp/apache-beam@5a5e51e

@@ -18,6 +18,7 @@

1818

package org.apache.beam.sdk.extensions.python;

19192020

import java.util.Arrays;

21+

import java.util.HashMap;

2122

import java.util.Map;

2223

import java.util.Set;

2324

import java.util.SortedMap;

@@ -64,6 +65,7 @@ public class PythonExternalTransform<InputT extends PInput, OutputT extends POut

6465

// We preseve the order here since Schema's care about order of fields but the order will not

6566

// matter when applying kwargs at the Python side.

6667

private SortedMap<String, Object> kwargsMap;

68+

private Map<java.lang.Class<?>, Schema.FieldType> typeHints;

67696870

private @Nullable Object @NonNull [] argsArray;

6971

private @Nullable Row providedKwargsRow;

@@ -72,6 +74,7 @@ private PythonExternalTransform(String fullyQualifiedName, String expansionServi

7274

this.fullyQualifiedName = fullyQualifiedName;

7375

this.expansionService = expansionService;

7476

this.kwargsMap = new TreeMap<>();

77+

this.typeHints = new HashMap<>();

7578

argsArray = new Object[] {};

7679

}

7780

@@ -162,6 +165,26 @@ public PythonExternalTransform<InputT, OutputT> withKwargs(Row kwargs) {

162165

return this;

163166

}

164167168+

/**

169+

* Specifies the field type of arguments.

170+

*

171+

* <p>Type hints are especially useful for logical types since type inference does not work well

172+

* for logical types.

173+

*

174+

* @param argType A class object for the argument type.

175+

* @param fieldType A schema field type for the argument.

176+

* @return updated wrapper for the cross-language transform.

177+

*/

178+

public PythonExternalTransform<InputT, OutputT> withTypeHint(

179+

java.lang.Class<?> argType, Schema.FieldType fieldType) {

180+

if (typeHints.containsKey(argType)) {

181+

throw new IllegalArgumentException(

182+

String.format("typehint for arg type %s already exists", argType));

183+

}

184+

typeHints.put(argType, fieldType);

185+

return this;

186+

}

187+165188

@VisibleForTesting

166189

Row buildOrGetKwargsRow() {

167190

if (providedKwargsRow != null) {

@@ -180,15 +203,17 @@ Row buildOrGetKwargsRow() {

180203

// * Java primitives

181204

// * Type String

182205

// * Type Row

183-

private static boolean isCustomType(java.lang.Class<?> type) {

206+

// * Any Type explicitly annotated by withTypeHint()

207+

private boolean isCustomType(java.lang.Class<?> type) {

184208

boolean val =

185209

!(ClassUtils.isPrimitiveOrWrapper(type)

186210

|| type == String.class

211+

|| typeHints.containsKey(type)

187212

|| Row.class.isAssignableFrom(type));

188213

return val;

189214

}

190215191-

// If the custom type has a registered schema, we use that. OTherwise we try to register it using

216+

// If the custom type has a registered schema, we use that. Otherwise, we try to register it using

192217

// 'JavaFieldSchema'.

193218

private Row convertCustomValue(Object value) {

194219

SerializableFunction<Object, Row> toRowFunc;

@@ -239,6 +264,8 @@ private Schema generateSchemaDirectly(

239264

if (field instanceof Row) {

240265

// Rows are used as is but other types are converted to proper field types.

241266

builder.addRowField(fieldName, ((Row) field).getSchema());

267+

} else if (typeHints.containsKey(field.getClass())) {

268+

builder.addField(fieldName, typeHints.get(field.getClass()));

242269

} else {

243270

builder.addField(

244271

fieldName,