Merge pull request #17608 from ihji/BEAM-14430 · tomstepp/apache-beam@2d57753

@@ -18,6 +18,7 @@

1818

package org.apache.beam.sdk.extensions.python;

19192020

import java.util.Arrays;

21+

import java.util.HashMap;

2122

import java.util.Map;

2223

import java.util.Set;

2324

import java.util.SortedMap;

@@ -33,10 +34,12 @@

3334

import org.apache.beam.sdk.schemas.Schema;

3435

import org.apache.beam.sdk.schemas.SchemaRegistry;

3536

import org.apache.beam.sdk.schemas.SchemaTranslation;

37+

import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;

3638

import org.apache.beam.sdk.schemas.utils.StaticSchemaInference;

3739

import org.apache.beam.sdk.transforms.PTransform;

3840

import org.apache.beam.sdk.transforms.SerializableFunction;

3941

import org.apache.beam.sdk.util.CoderUtils;

42+

import org.apache.beam.sdk.util.PythonCallableSource;

4043

import org.apache.beam.sdk.values.PBegin;

4144

import org.apache.beam.sdk.values.PCollection;

4245

import org.apache.beam.sdk.values.PCollectionTuple;

@@ -64,6 +67,7 @@ public class PythonExternalTransform<InputT extends PInput, OutputT extends POut

6467

// We preseve the order here since Schema's care about order of fields but the order will not

6568

// matter when applying kwargs at the Python side.

6669

private SortedMap<String, Object> kwargsMap;

70+

private Map<java.lang.Class<?>, Schema.FieldType> typeHints;

67716872

private @Nullable Object @NonNull [] argsArray;

6973

private @Nullable Row providedKwargsRow;

@@ -72,6 +76,11 @@ private PythonExternalTransform(String fullyQualifiedName, String expansionServi

7276

this.fullyQualifiedName = fullyQualifiedName;

7377

this.expansionService = expansionService;

7478

this.kwargsMap = new TreeMap<>();

79+

this.typeHints = new HashMap<>();

80+

// TODO(BEAM-14458): remove a default type hint for PythonCallableSource when BEAM-14458 is

81+

// resolved

82+

this.typeHints.put(

83+

PythonCallableSource.class, Schema.FieldType.logicalType(new PythonCallable()));

7584

argsArray = new Object[] {};

7685

}

7786

@@ -162,6 +171,26 @@ public PythonExternalTransform<InputT, OutputT> withKwargs(Row kwargs) {

162171

return this;

163172

}

164173174+

/**

175+

* Specifies the field type of arguments.

176+

*

177+

* <p>Type hints are especially useful for logical types since type inference does not work well

178+

* for logical types.

179+

*

180+

* @param argType A class object for the argument type.

181+

* @param fieldType A schema field type for the argument.

182+

* @return updated wrapper for the cross-language transform.

183+

*/

184+

public PythonExternalTransform<InputT, OutputT> withTypeHint(

185+

java.lang.Class<?> argType, Schema.FieldType fieldType) {

186+

if (typeHints.containsKey(argType)) {

187+

throw new IllegalArgumentException(

188+

String.format("typehint for arg type %s already exists", argType));

189+

}

190+

typeHints.put(argType, fieldType);

191+

return this;

192+

}

193+165194

@VisibleForTesting

166195

Row buildOrGetKwargsRow() {

167196

if (providedKwargsRow != null) {

@@ -179,16 +208,18 @@ Row buildOrGetKwargsRow() {

179208

// Types that are not one of following are considered custom types.

180209

// * Java primitives

181210

// * Type String

211+

// * Any Type explicitly annotated by withTypeHint()

182212

// * Type Row

183-

private static boolean isCustomType(java.lang.Class<?> type) {

213+

private boolean isCustomType(java.lang.Class<?> type) {

184214

boolean val =

185215

!(ClassUtils.isPrimitiveOrWrapper(type)

186216

|| type == String.class

217+

|| typeHints.containsKey(type)

187218

|| Row.class.isAssignableFrom(type));

188219

return val;

189220

}

190221191-

// If the custom type has a registered schema, we use that. OTherwise we try to register it using

222+

// If the custom type has a registered schema, we use that. Otherwise, we try to register it using

192223

// 'JavaFieldSchema'.

193224

private Row convertCustomValue(Object value) {

194225

SerializableFunction<Object, Row> toRowFunc;

@@ -239,6 +270,8 @@ private Schema generateSchemaDirectly(

239270

if (field instanceof Row) {

240271

// Rows are used as is but other types are converted to proper field types.

241272

builder.addRowField(fieldName, ((Row) field).getSchema());

273+

} else if (typeHints.containsKey(field.getClass())) {

274+

builder.addField(fieldName, typeHints.get(field.getClass()));

242275

} else {

243276

builder.addField(

244277

fieldName,