Page MenuHomePhabricator
Paste P69440

Kafka Row Mapper for Spark
ActivePublic

Authored by pfischer on Oct 1 2024, 2:19 PM.
Tags
None
Referenced Files
F57578413: Kafka Row Mapper for Spark
Oct 1 2024, 2:19 PM
Subscribers
None
package org.wikimedia.eventutilities.spark.kafka;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.wikimedia.eventutilities.core.event.EventStream;
import org.wikimedia.eventutilities.core.event.EventStreamFactory;
import org.wikimedia.eventutilities.core.event.JsonEventGenerator;
import org.wikimedia.eventutilities.core.event.JsonEventGenerator.EventNormalizer;
import org.wikimedia.eventutilities.spark.sql.JsonSchemaSparkConverter;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.SneakyThrows;
import lombok.Value;
import scala.Tuple2;
import scala.collection.JavaConverters;
import scala.collection.Seq;
public class KafkaRowMapperFactory {
private final EventStreamFactory eventStreamFactory;
private final JsonEventGenerator jsonEventGenerator;
private final ObjectMapper objectMapper;
public KafkaRowMapperFactory(EventStreamFactory eventStreamFactory,
JsonEventGenerator jsonEventGenerator, ObjectMapper objectMapper) {
this.eventStreamFactory = eventStreamFactory;
this.jsonEventGenerator = jsonEventGenerator;
this.objectMapper = objectMapper;
}
public KafkaRowMapperFactory(EventStreamFactory eventStreamFactory, ObjectMapper objectMapper) {
this(eventStreamFactory, JsonEventGenerator.builder()
.eventStreamConfig(eventStreamFactory.getEventStreamConfig())
.jsonMapper(objectMapper).schemaLoader(eventStreamFactory.getEventSchemaLoader()).build(), objectMapper);
}
public KafkaRowMapper createMapper(String streamName, String topic) {
final EventStream stream = eventStreamFactory.createEventStream(streamName);
KafkaRowSchema kafkaRowInfo = createSchema(stream);
return new KafkaRowMapper(topic, kafkaRowInfo.keySelector(),
jsonEventGenerator.createEventStreamEventGenerator(streamName,
stream.schemaUri().toString()));
}
public KafkaRowSchema createSchema(EventStream eventStream) {
final DataType valueSchema = JsonSchemaSparkConverter.toDataType(
(ObjectNode) eventStream.schema());
return KafkaRowSchema.create(valueSchema,
objectMapper.convertValue(eventStream.messageKeyFields(), Map.class));
}
static class KafkaRowSchema {
final DataType valueSchema;
final List<KeyFieldAlias> keyFieldAliases;
final StructType keySchema;
public KafkaRowSchema(DataType valueSchema, @Nullable List<KeyFieldAlias> keyFieldAliases) {
this.valueSchema = valueSchema;
this.keyFieldAliases = keyFieldAliases;
this.keySchema = keyFieldAliases == null ? null : new StructType(
keyFieldAliases.stream().map(KeyFieldAlias::getAliasedField)
.toArray(StructField[]::new));
}
Dataset<Row> selectKey(Dataset<Row> df) {
return df.withColumn("key", functions.to_json(functions.struct(
keyFieldAliases.stream().map(KeyFieldAlias::getAliasedColumn).toArray(Column[]::new)
)));
}
Row selectKey(Row row) throws Exception {
return keySelector().call(row);
}
Function<Row, Row> keySelector() {
return new KafkaRowKeyExtractor(keySchema,
keyFieldAliases.stream().map(alias -> Arrays.asList(
alias.aliasedColumnExpression.split("\\."))).collect(Collectors.toList()));
}
private static <T> T selectAliasedValue(Row row, Iterable<String> pathSegments) {
final Iterator<String> iterator = pathSegments.iterator();
while (iterator.hasNext()) {
final String pathSegment = iterator.next();
if (iterator.hasNext()) {
row = row.getAs(pathSegment);
} else {
return row.getAs(pathSegment);
}
}
throw new IllegalStateException("Unable extract key field " + pathSegments);
}
static KafkaRowSchema create(DataType valueSchema, Map<String, String> messageKeyFields) {
if (valueSchema instanceof StructType) {
final List<KeyFieldAlias> keyFieldAliases = messageKeyFields.entrySet().stream()
.map(entry -> {
final Tuple2<Seq<String>, StructField> aliasedField = ((StructType) valueSchema).findNestedField(
JavaConverters.asScalaBuffer(
Arrays.asList(entry.getValue().split("\\."))),
false, (a, b) -> a.equals(b)).get();
final StructField aliasField = new StructField(entry.getKey(),
aliasedField._2().dataType(), false, null);
JavaConverters.asJavaCollection(aliasedField._1());
return new KeyFieldAlias(entry.getKey(), aliasField, entry.getValue());
}).collect(Collectors.toList());
return new KafkaRowSchema(valueSchema, keyFieldAliases);
}
return new KafkaRowSchema(valueSchema, null);
}
@Value
static class KeyFieldAlias {
String aliasName;
StructField aliasedField;
String aliasedColumnExpression;
Column aliasedColumn;
public KeyFieldAlias(String aliasName, StructField aliasedField,
String aliasedColumnExpression) {
this.aliasName = aliasName;
this.aliasedField = aliasedField;
this.aliasedColumnExpression = aliasedColumnExpression;
this.aliasedColumn = functions.col(aliasedColumnExpression).as(aliasName);
}
}
private static class KafkaRowKeyExtractor implements Function<Row, Row> {
final StructType keySchema;
final List<Iterable<String>> pathSegments;
private KafkaRowKeyExtractor(StructType keySchema, List<Iterable<String>> pathSegments) {
this.keySchema = keySchema;
this.pathSegments = pathSegments;
}
@Override
public Row call(Row row) {
return new GenericRowWithSchema(
pathSegments.stream().map(segments -> selectAliasedValue(row, segments)).toArray(),
keySchema);
}
}
}
static class KafkaRowMapper implements MapFunction<Row, Row> {
public static final String COL_KEY = "key";
public static final String COL_VALUE = "value";
public static final String COL_TOPIC = "topic";
public static final StructType SCHEMA = new StructType(
new StructField[]{
new StructField(COL_KEY, DataTypes.StringType, true, null),
new StructField(COL_VALUE, DataTypes.StringType, false, null),
new StructField(COL_TOPIC, DataTypes.StringType, true, null),
}
);
public static final Encoder<Row> ENCODER = RowEncoder.apply(SCHEMA);
final EventNormalizer valueNormalizer;
final Function<Row, Row> keyExtractor;
private final String topic;
public KafkaRowMapper(String topic, Function<Row, Row> keyExtractor,
EventNormalizer valueNormalizer) {
this.valueNormalizer = valueNormalizer;
this.keyExtractor = keyExtractor;
this.topic = topic;
}
@Override
public Row call(Row row) throws Exception {
final ObjectNode normalizedEventData = valueNormalizer.apply(eventData -> {
mapEventData(row, eventData);
});
return new GenericRowWithSchema(
new Object[]{
topic,
keyExtractor.call(row).json(),
valueNormalizer.getObjectMapper().writeValueAsString(normalizedEventData)
},
SCHEMA);
}
@SneakyThrows
private void mapEventData(Row row, ObjectNode consumer) {
final JsonNode eventData = valueNormalizer.getObjectMapper().readTree(row.json());
if (!(eventData instanceof ObjectNode)) {
throw new IllegalArgumentException("row does not map to a JSON object");
}
consumer.setAll((ObjectNode) eventData);
}
}
}
/// TEST ///
package org.wikimedia.eventutilities.spark.kafka;
import java.net.URISyntaxException;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Objects;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.wikimedia.eventutilities.core.event.EventStreamFactory;
import org.wikimedia.eventutilities.spark.kafka.KafkaRowMapperFactory.KafkaRowMapper;
import com.fasterxml.jackson.databind.ObjectMapper;
class KafkaRowMapperTest {
private static SparkSession spark;
private final ObjectMapper objectMapper = new ObjectMapper();
@BeforeAll
static void setUp() {
// Spark Session initialisieren
spark = SparkSession.builder()
.appName("Unit Test")
.master("local")
.getOrCreate();
}
@Test
void mapping() throws URISyntaxException {
final EventStreamFactory eventStreamFactory = EventStreamFactory.from(
Arrays.asList(KafkaRowMapperTest.class.getResource("/schema_repo").toString()),
KafkaRowMapperTest.class.getResource("/event-stream-config.json").toString());
final KafkaRowMapperFactory kafkaRowMapperFactory = new KafkaRowMapperFactory(eventStreamFactory,
objectMapper);
final KafkaRowMapper kafkaRowMapper = kafkaRowMapperFactory.createMapper(
"mediawiki.cirrussearch.page_weighted_tags_change.rc0", "topic");
final Dataset<Row> source = spark.read().option("multiline", true)
.json(Objects.requireNonNull(KafkaRowMapperTest.class.getResource("/df.json")).toString());
source.printSchema();
final ZonedDateTime now = LocalDateTime.now().truncatedTo(ChronoUnit.MINUTES)
.atZone(ZoneOffset.UTC);
final Dataset<Row> kafkaValueFrame = source
.withColumn("dt", functions.lit(DateTimeFormatter.ISO_INSTANT.format(now)))
.withColumn("meta", functions.struct(
functions.lit(DateTimeFormatter.ISO_INSTANT.format(now.minusMinutes(5))).as("dt"),
functions.lit("TBD").as("id")
))
.withColumn("page", functions.struct(
functions.col("page_id"),
functions.lit("arbitrary_page_title").as("page_title"),
functions.lit(0).as("namespace_id")
)).withColumn("weighted_tags", functions.struct(
functions.map_from_arrays(functions.array(functions.col("prefix")),
functions.array(functions.array())).as("set")
))
.drop("prefix", "page_id");
kafkaValueFrame.printSchema();
final Dataset<Row> mappedDataset = kafkaValueFrame.map(kafkaRowMapper, KafkaRowMapper.ENCODER);
mappedDataset.printSchema();
mappedDataset.show(100, false);
}
}