Skip to content

Commit 48388f3

Browse files
authored
fix: pattern_match TVF crashes and parameter validation (#17471)
1 parent e6c4775 commit 48388f3

4 files changed

Lines changed: 52 additions & 29 deletions

File tree

integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBWindowTVFIT.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,5 +859,27 @@ public void testPatternMatchFunction() {
859859
expectedHeader,
860860
retArray,
861861
DATABASE_NAME);
862+
863+
// test flat pattern with smooth=0.0 should not crash (was IndexOutOfBoundsException due to
864+
// NaN from 0/0)
865+
// pattern '1,1,1,1,1,2,3,4,3' is flat->up->down, no data segment matches this sign sequence
866+
retArray = new String[] {};
867+
tableResultSetEqualByDataTypeTest(
868+
"select * from pattern_match(data => t1 ORDER BY time, time_col => 'time', data_col => 'value', pattern => '1.0,1.0,1.0,1.0,1.0,2.0,3.0,4.0,3.0', smooth => 0.0, threshold => 1.0, smooth_on_pattern => false)",
869+
expectedHeader,
870+
retArray,
871+
DATABASE_NAME);
872+
873+
// test negative smooth should be rejected
874+
tableAssertTestFail(
875+
"select * from pattern_match(data => t1 ORDER BY time, time_col => 'time', data_col => 'value', pattern => '1.0,2.0,1.0', smooth => -0.5, threshold => 10.0, width => 1000.0, height => 500.0, smooth_on_pattern => false)",
876+
"smooth must be a non-negative number",
877+
DATABASE_NAME);
878+
879+
// test negative threshold should be rejected
880+
tableAssertTestFail(
881+
"select * from pattern_match(data => t1 ORDER BY time, time_col => 'time', data_col => 'value', pattern => '1.0,2.0,1.0', smooth => 0.5, threshold => -1.1, width => 1000.0, height => 500.0, smooth_on_pattern => false)",
882+
"threshold must be a non-negative number",
883+
DATABASE_NAME);
862884
}
863885
}

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/PatternMatchTableFunction.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) throws UDF
105105
expectedDataName,
106106
ImmutableSet.of(Type.INT32, Type.INT64, Type.FLOAT, Type.DOUBLE));
107107

108+
Double smoothValue = (Double) ((ScalarArgument) arguments.get(SMOOTH_PARAM)).getValue();
109+
Double thresholdValue = (Double) ((ScalarArgument) arguments.get(THRESHOLD_PARAM)).getValue();
110+
if (smoothValue < 0) {
111+
throw new UDFException("smooth must be a non-negative number, but got: " + smoothValue);
112+
}
113+
if (thresholdValue < 0) {
114+
throw new UDFException("threshold must be a non-negative number, but got: " + thresholdValue);
115+
}
116+
108117
// outputColumnSchema description
109118
DescribedSchema properColumnSchema =
110119
new DescribedSchema.Builder()

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/match/model/MatchState.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ public boolean calcOneSectionMatchValue(Section section, double smoothValue, dou
165165
double localHeightUp = Math.max(section.getHeightBound(), smoothValue);
166166
double localHeightDown =
167167
Math.max(patternSectionNow.getHeightBound() * globalHeightRadio, smoothValue);
168-
double localHeightRadio = localHeightUp / localHeightDown;
168+
// When both are 0 (both sections are flat with smooth=0), ratio is 1.0 (perfect match)
169+
double localHeightRadio =
170+
(localHeightUp == 0 && localHeightDown == 0) ? 1.0 : localHeightUp / localHeightDown;
169171

170172
double led = Math.pow(Math.log(localWidthRadio), 2) + Math.pow(Math.log(localHeightRadio), 2);
171173

@@ -195,12 +197,10 @@ public boolean calcOneSectionMatchValue(Section section, double smoothValue, dou
195197
- section.getPoints().get(i).y);
196198
}
197199

198-
shapeError =
199-
shapeError
200-
/ (((dataMaxHeight - dataMinHeight) == 0
201-
? smoothValue
202-
: (dataMaxHeight - dataMinHeight))
203-
* (section.getPoints().size() - 1));
200+
double heightNorm =
201+
(dataMaxHeight - dataMinHeight) == 0 ? smoothValue : (dataMaxHeight - dataMinHeight);
202+
double seDenominator = heightNorm * (section.getPoints().size() - 1);
203+
shapeError = seDenominator == 0 ? 0 : shapeError / seDenominator;
204204

205205
// calc the match value for a section
206206
matchValue = matchValue + led + shapeError;
@@ -212,7 +212,7 @@ public boolean calcOneSectionMatchValue(Section section, double smoothValue, dou
212212
patternSectionNow = null;
213213
}
214214

215-
if (isFinish || matchValue > threshold) {
215+
if (isFinish || matchValue > threshold || Double.isNaN(matchValue)) {
216216
return true;
217217
} else {
218218
patternSectionNow = patternSectionNow.getNextSectionList().get(0);

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/match/model/RegexMatchState.java

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -320,19 +320,22 @@ public Boolean calcOneSectionMatchValue(
320320
double localHeightUp = Math.max(dataSection.getHeightBound(), smoothValue);
321321
double localHeightDown =
322322
Math.max(patternSection.getHeightBound() * globalHeightRadio, smoothValue);
323-
double localHeightRadio = localHeightUp / localHeightDown;
323+
// When both are 0 (both sections are flat with smooth=0), ratio is 1.0 (perfect match)
324+
double localHeightRadio =
325+
(localHeightUp == 0 && localHeightDown == 0) ? 1.0 : localHeightUp / localHeightDown;
324326

325327
double led = Math.pow(Math.log(localWidthRadio), 2) + Math.pow(Math.log(localHeightRadio), 2);
326328

327329
// different way
328330
double shapeError = 0.0;
331+
double heightNorm =
332+
(dataMaxHeight - dataMinHeight) == 0 ? smoothValue : (dataMaxHeight - dataMinHeight);
329333
if (CALC_SE_USING_MORE_MEMORY
330334
&& dataSection.getCalcResult().get(patternSection.getId()) != null) {
331335
shapeError =
332-
dataSection.getCalcResult().get(patternSection.getId())
333-
/ ((dataMaxHeight - dataMinHeight) == 0
334-
? smoothValue
335-
: (dataMaxHeight - dataMinHeight));
336+
heightNorm == 0
337+
? 0
338+
: dataSection.getCalcResult().get(patternSection.getId()) / heightNorm;
336339
} else {
337340
// calc the SE
338341
// align the first point or the centroid, it's same because the calculation is just an avg
@@ -348,7 +351,7 @@ public Boolean calcOneSectionMatchValue(
348351

349352
double numRadio = ((double) patternPointNum) / ((double) dataPointNum);
350353

351-
for (int i = 1; i < dataPointNum; i++) {
354+
for (int i = 1; i <= dataPointNum; i++) {
352355
double patternIndex = i * numRadio;
353356
int leftIndex = (int) patternIndex;
354357
double leftRadio = patternIndex - leftIndex;
@@ -366,27 +369,16 @@ public Boolean calcOneSectionMatchValue(
366369
- dataSection.getPoints().get(i).y);
367370
}
368371

369-
shapeError =
370-
shapeError
371-
/ (((dataMaxHeight - dataMinHeight) == 0
372-
? smoothValue
373-
: (dataMaxHeight - dataMinHeight))
374-
* (dataSection.getPoints().size() - 1));
372+
double seDenominator = heightNorm * (dataSection.getPoints().size() - 1);
373+
shapeError = seDenominator == 0 ? 0 : shapeError / seDenominator;
375374

376375
if (CALC_SE_USING_MORE_MEMORY) {
377-
dataSection
378-
.getCalcResult()
379-
.put(
380-
patternSection.getId(),
381-
shapeError
382-
* (((dataMaxHeight - dataMinHeight) == 0
383-
? smoothValue
384-
: (dataMaxHeight - dataMinHeight))));
376+
dataSection.getCalcResult().put(patternSection.getId(), shapeError * heightNorm);
385377
}
386378
}
387379

388380
matchValue = matchValue + led + shapeError;
389-
return matchValue > threshold;
381+
return matchValue > threshold || Double.isNaN(matchValue);
390382
}
391383

392384
public double getMatchValue() {

0 commit comments

Comments
 (0)