Skip to content

Commit 8bf1640

Browse files
authored
[To dev/1.3] Improve state type validation in CombineRequest deserialization. (#17449) (#17502)
* Improve state type validation in CombineRequest deserialization. (#17449) * Improve state type validation in CombineRequest deserialization. Tighten request deserialization by validating state type before instantiation to reduce unexpected type usage risk, and add targeted tests for accepted and rejected state class names. Made-with: Cursor * spotless (cherry picked from commit c1d16a4) * spotless
1 parent 751c5c5 commit 8bf1640

2 files changed

Lines changed: 67 additions & 1 deletion

File tree

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/processor/twostage/exchange/payload/CombineRequest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.apache.iotdb.db.pipe.processor.twostage.exchange.payload;
2121

2222
import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.IoTDBSinkRequestVersion;
23+
import org.apache.iotdb.db.pipe.processor.twostage.state.CountState;
2324
import org.apache.iotdb.db.pipe.processor.twostage.state.State;
2425
import org.apache.iotdb.service.rpc.thrift.TPipeTransferReq;
2526

@@ -109,7 +110,7 @@ private CombineRequest translateFromTPipeTransferReq(TPipeTransferReq transferRe
109110
combineId = ReadWriteIOUtils.readString(transferReq.body);
110111

111112
final String stateClassName = ReadWriteIOUtils.readString(transferReq.body);
112-
state = (State) Class.forName(stateClassName).newInstance();
113+
state = instantiateState(stateClassName);
113114
state.deserialize(transferReq.body);
114115

115116
version = transferReq.version;
@@ -118,6 +119,13 @@ private CombineRequest translateFromTPipeTransferReq(TPipeTransferReq transferRe
118119
return this;
119120
}
120121

122+
private State instantiateState(final String stateClassName) throws Exception {
123+
if (CountState.class.getName().equals(stateClassName)) {
124+
return new CountState();
125+
}
126+
throw new IllegalArgumentException("Unexpected state class: " + stateClassName);
127+
}
128+
121129
@Override
122130
public String toString() {
123131
return "CombineRequest{"

iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import org.apache.iotdb.commons.path.PartialPath;
2323
import org.apache.iotdb.commons.pipe.sink.payload.thrift.response.PipeTransferFilePieceResp;
24+
import org.apache.iotdb.db.pipe.processor.twostage.exchange.payload.CombineRequest;
25+
import org.apache.iotdb.db.pipe.processor.twostage.state.CountState;
2426
import org.apache.iotdb.db.pipe.sink.payload.evolvable.request.PipeTransferDataNodeHandshakeV1Req;
2527
import org.apache.iotdb.db.pipe.sink.payload.evolvable.request.PipeTransferPlanNodeReq;
2628
import org.apache.iotdb.db.pipe.sink.payload.evolvable.request.PipeTransferSchemaSnapshotPieceReq;
@@ -37,6 +39,7 @@
3739
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.write.InsertRowNode;
3840
import org.apache.iotdb.db.queryengine.plan.statement.Statement;
3941
import org.apache.iotdb.rpc.RpcUtils;
42+
import org.apache.iotdb.service.rpc.thrift.TPipeTransferReq;
4043

4144
import org.apache.tsfile.common.conf.TSFileConfig;
4245
import org.apache.tsfile.enums.TSDataType;
@@ -62,6 +65,61 @@ public class PipeDataNodeThriftRequestTest {
6265

6366
private static final String TIME_PRECISION = "ms";
6467

68+
@Test
69+
public void testCombineRequest() throws Exception {
70+
final CombineRequest req =
71+
CombineRequest.toTPipeTransferReq("pipe", 1L, 2, "combine", new CountState(123L));
72+
final CombineRequest deserializeReq = CombineRequest.fromTPipeTransferReq(req);
73+
74+
Assert.assertEquals(req.getVersion(), deserializeReq.getVersion());
75+
Assert.assertEquals(req.getType(), deserializeReq.getType());
76+
Assert.assertEquals("pipe", deserializeReq.getPipeName());
77+
Assert.assertEquals(1L, deserializeReq.getCreationTime());
78+
Assert.assertEquals(2, deserializeReq.getRegionId());
79+
Assert.assertEquals("combine", deserializeReq.getCombineId());
80+
Assert.assertTrue(deserializeReq.getState() instanceof CountState);
81+
Assert.assertEquals(123L, ((CountState) deserializeReq.getState()).getCount());
82+
}
83+
84+
@Test
85+
public void testCombineRequestWithUnexpectedStateClassName() throws Exception {
86+
final CombineRequest req =
87+
CombineRequest.toTPipeTransferReq("pipe", 1L, 2, "combine", new CountState(123L));
88+
89+
final ByteBuffer bodyBuffer = req.body.duplicate();
90+
final String pipeName = ReadWriteIOUtils.readString(bodyBuffer);
91+
final long creationTime = ReadWriteIOUtils.readLong(bodyBuffer);
92+
final int regionId = ReadWriteIOUtils.readInt(bodyBuffer);
93+
final String combineId = ReadWriteIOUtils.readString(bodyBuffer);
94+
ReadWriteIOUtils.readString(bodyBuffer);
95+
final long count = ReadWriteIOUtils.readLong(bodyBuffer);
96+
97+
final ByteBuffer tamperedBody;
98+
try (final PublicBAOS byteArrayOutputStream = new PublicBAOS();
99+
final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream)) {
100+
ReadWriteIOUtils.write(pipeName, outputStream);
101+
ReadWriteIOUtils.write(creationTime, outputStream);
102+
ReadWriteIOUtils.write(regionId, outputStream);
103+
ReadWriteIOUtils.write(combineId, outputStream);
104+
ReadWriteIOUtils.write("java.lang.String", outputStream);
105+
ReadWriteIOUtils.write(count, outputStream);
106+
tamperedBody =
107+
ByteBuffer.wrap(byteArrayOutputStream.getBuf(), 0, byteArrayOutputStream.size());
108+
}
109+
110+
final TPipeTransferReq tamperedReq = new TPipeTransferReq();
111+
tamperedReq.version = req.version;
112+
tamperedReq.type = req.type;
113+
tamperedReq.body = tamperedBody;
114+
115+
try {
116+
CombineRequest.fromTPipeTransferReq(tamperedReq);
117+
Assert.fail("Expected IllegalArgumentException");
118+
} catch (final IllegalArgumentException e) {
119+
Assert.assertTrue(e.getMessage().contains("Unexpected state class"));
120+
}
121+
}
122+
65123
@Test
66124
public void testPipeTransferDataNodeHandshakeReq() throws IOException {
67125
final PipeTransferDataNodeHandshakeV1Req req =

0 commit comments

Comments
 (0)