Skip to content

Commit 89e12f3

Browse files
authored
KAFKA-10388; Fix struct conversion logic for tagged structures (#9166)
The message generator was missing conversion logic for tagged structures. This led to casting errors when either `fromStruct` or `toStruct` were invoked. This patch also adds missing null checks in the serialization of tagged byte arrays, which was found from improved test coverage. Reviewers: Colin P. McCabe <[email protected]>
1 parent 7915d5e commit 89e12f3

File tree

2 files changed

+63
-12
lines changed

2 files changed

+63
-12
lines changed

clients/src/test/java/org/apache/kafka/common/message/SimpleExampleMessageTest.java

+42-8
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.apache.kafka.common.errors.UnsupportedVersionException;
2020
import org.apache.kafka.common.protocol.ByteBufferAccessor;
2121
import org.apache.kafka.common.protocol.ObjectSerializationCache;
22+
import org.apache.kafka.common.protocol.types.Schema;
2223
import org.apache.kafka.common.protocol.types.Struct;
2324
import org.apache.kafka.common.utils.ByteUtils;
2425
import org.junit.Test;
@@ -321,6 +322,38 @@ public void testMyTaggedStruct() {
321322
message -> assertEquals("abc", message.myString()), (short) 2);
322323
}
323324

325+
private ByteBuffer serialize(SimpleExampleMessageData message, short version) {
326+
ObjectSerializationCache cache = new ObjectSerializationCache();
327+
int size = message.size(cache, version);
328+
ByteBuffer buf = ByteBuffer.allocate(size);
329+
message.write(new ByteBufferAccessor(buf), cache, version);
330+
buf.flip();
331+
assertEquals(size, buf.remaining());
332+
return buf;
333+
}
334+
335+
private SimpleExampleMessageData deserialize(ByteBuffer buf, short version) {
336+
SimpleExampleMessageData message = new SimpleExampleMessageData();
337+
message.read(new ByteBufferAccessor(buf.duplicate()), version);
338+
return message;
339+
}
340+
341+
private ByteBuffer serializeThroughStruct(SimpleExampleMessageData message, short version) {
342+
Struct struct = message.toStruct(version);
343+
int size = struct.sizeOf();
344+
ByteBuffer buf = ByteBuffer.allocate(size);
345+
struct.writeTo(buf);
346+
buf.flip();
347+
assertEquals(size, buf.remaining());
348+
return buf;
349+
}
350+
351+
private SimpleExampleMessageData deserializeThroughStruct(ByteBuffer buf, short version) {
352+
Schema schema = SimpleExampleMessageData.SCHEMAS[version];
353+
Struct struct = schema.read(buf);
354+
return new SimpleExampleMessageData(struct, version);
355+
}
356+
324357
private void testRoundTrip(SimpleExampleMessageData message,
325358
Consumer<SimpleExampleMessageData> validator) {
326359
testRoundTrip(message, validator, (short) 1);
@@ -330,17 +363,18 @@ private void testRoundTrip(SimpleExampleMessageData message,
330363
Consumer<SimpleExampleMessageData> validator,
331364
short version) {
332365
validator.accept(message);
333-
ObjectSerializationCache cache = new ObjectSerializationCache();
334-
int size = message.size(cache, version);
335-
ByteBuffer buf = ByteBuffer.allocate(size);
336-
message.write(new ByteBufferAccessor(buf), cache, version);
337-
buf.flip();
338-
assertEquals(size, buf.remaining());
366+
ByteBuffer buf = serialize(message, version);
339367

340-
SimpleExampleMessageData message2 = new SimpleExampleMessageData();
341-
message2.read(new ByteBufferAccessor(buf.duplicate()), version);
368+
SimpleExampleMessageData message2 = deserialize(buf.duplicate(), version);
342369
validator.accept(message2);
343370
assertEquals(message, message2);
344371
assertEquals(message.hashCode(), message2.hashCode());
372+
373+
// Check struct serialization as well
374+
assertEquals(buf, serializeThroughStruct(message, version));
375+
SimpleExampleMessageData messageFromStruct = deserializeThroughStruct(buf.duplicate(), version);
376+
validator.accept(messageFromStruct);
377+
assertEquals(message, messageFromStruct);
378+
assertEquals(message.hashCode(), messageFromStruct.hashCode());
345379
}
346380
}

generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java

+21-4
Original file line numberDiff line numberDiff line change
@@ -793,10 +793,25 @@ private void generateClassFromStruct(String className, StructSpec struct,
793793
generateArrayFromStruct(field, presentAndTaggedVersions);
794794
} else if (field.type().isBytes()) {
795795
headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS);
796-
headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS);
797-
buffer.printf("this.%s = MessageUtil.byteBufferToArray(" +
798-
"(ByteBuffer) _taggedFields.remove(%d));%n",
796+
buffer.printf("ByteBuffer _byteBuffer = (ByteBuffer) _taggedFields.remove(%d);%n",
797+
field.tag().get());
798+
799+
IsNullConditional.forName("_byteBuffer").
800+
nullableVersions(field.nullableVersions()).
801+
possibleVersions(field.versions()).
802+
ifNull(() -> {
803+
buffer.printf("this.%s = null;%n", field.camelCaseName());
804+
}).
805+
ifShouldNotBeNull(() -> {
806+
headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS);
807+
buffer.printf("this.%s = MessageUtil.byteBufferToArray(_byteBuffer);%n",
808+
field.camelCaseName());
809+
}).
810+
generate(buffer);
811+
} else if (field.type().isStruct()) {
812+
buffer.printf("this.%s = new %s((Struct) _taggedFields.remove(%d), _version);%n",
799813
field.camelCaseName(),
814+
getBoxedJavaType(field.type()),
800815
field.tag().get());
801816
} else {
802817
buffer.printf("this.%s = (%s) _taggedFields.remove(%d);%n",
@@ -1731,10 +1746,12 @@ private void generateTaggedFieldToMap(FieldSpec field, Versions versions) {
17311746
(field.type() instanceof FieldType.Int64FieldType) ||
17321747
(field.type() instanceof FieldType.UUIDFieldType) ||
17331748
(field.type() instanceof FieldType.Float64FieldType) ||
1734-
(field.type() instanceof FieldType.StructType) ||
17351749
(field.type() instanceof FieldType.StringFieldType)) {
17361750
buffer.printf("_taggedFields.put(%d, %s);%n",
17371751
field.tag().get(), field.camelCaseName());
1752+
} else if (field.type().isStruct()) {
1753+
buffer.printf("_taggedFields.put(%d, %s.toStruct(_version));%n",
1754+
field.tag().get(), field.camelCaseName());
17381755
} else if (field.type().isBytes()) {
17391756
headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS);
17401757
if (field.taggedVersions().intersect(field.nullableVersions()).empty()) {

0 commit comments

Comments
 (0)