Skip to content

Commit 900a03f

Browse files
Improve SimplifyLocals compile-time performance
1 parent 2fa35d6 commit 900a03f

File tree

1 file changed

+179
-36
lines changed

1 file changed

+179
-36
lines changed

src/passes/SimplifyLocals.cpp

Lines changed: 179 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,96 @@ struct SimplifyLocals
8686
};
8787

8888
// a list of sinkables in a linear execution trace
89-
using Sinkables = std::map<Index, SinkableInfo>;
89+
using Sinkables = std::unordered_map<Index, SinkableInfo>;
9090

9191
// locals in current linear execution trace, which we try to sink
9292
Sinkables sinkables;
9393

94+
// Reverse index: for each local L, tracks which sinkable keys have effects
95+
// that read L. This allows checkInvalidations to find conflicting sinkables
96+
// in O(|current effects|) instead of O(|all sinkables|).
97+
std::unordered_map<Index, std::unordered_set<Index>> localReadBySinkable_;
98+
99+
// Reverse index: for each local L, tracks which sinkable keys have effects
100+
// that write L. A sinkable at key K always writes K, but may also write
101+
// other locals if its value contains nested local.sets.
102+
std::unordered_map<Index, std::unordered_set<Index>> localWrittenBySinkable_;
103+
104+
// Sinkable keys that have non-local ordering-relevant effects (calls,
105+
// memory, control flow, etc.). These need full orderedAfter checks when
106+
// the current expression also has non-local effects.
107+
std::unordered_set<Index> heavySinkables_;
108+
109+
void registerSinkable(Index key) {
110+
auto& effects = sinkables.at(key).effects;
111+
for (auto L : effects.localsRead) {
112+
localReadBySinkable_[L].insert(key);
113+
}
114+
for (auto L : effects.localsWritten) {
115+
localWrittenBySinkable_[L].insert(key);
116+
}
117+
if (effects.hasNonLocalOrderingEffects()) {
118+
heavySinkables_.insert(key);
119+
}
120+
}
121+
122+
void unregisterSinkable(Index key) {
123+
auto it = sinkables.find(key);
124+
if (it == sinkables.end()) {
125+
return;
126+
}
127+
auto& effects = it->second.effects;
128+
for (auto L : effects.localsRead) {
129+
auto mapIt = localReadBySinkable_.find(L);
130+
if (mapIt != localReadBySinkable_.end()) {
131+
mapIt->second.erase(key);
132+
if (mapIt->second.empty()) {
133+
localReadBySinkable_.erase(mapIt);
134+
}
135+
}
136+
}
137+
for (auto L : effects.localsWritten) {
138+
auto mapIt = localWrittenBySinkable_.find(L);
139+
if (mapIt != localWrittenBySinkable_.end()) {
140+
mapIt->second.erase(key);
141+
if (mapIt->second.empty()) {
142+
localWrittenBySinkable_.erase(mapIt);
143+
}
144+
}
145+
}
146+
heavySinkables_.erase(key);
147+
}
148+
149+
void clearSinkables() {
150+
sinkables.clear();
151+
localReadBySinkable_.clear();
152+
localWrittenBySinkable_.clear();
153+
heavySinkables_.clear();
154+
}
155+
156+
Sinkables takeSinkables() {
157+
localReadBySinkable_.clear();
158+
localWrittenBySinkable_.clear();
159+
heavySinkables_.clear();
160+
return std::move(sinkables);
161+
}
162+
163+
void eraseSinkable(typename Sinkables::iterator it) {
164+
unregisterSinkable(it->first);
165+
sinkables.erase(it);
166+
}
167+
168+
void eraseSinkable(Index key) {
169+
unregisterSinkable(key);
170+
sinkables.erase(key);
171+
}
172+
173+
void addSinkable(Index key, Expression** currp) {
174+
sinkables.emplace(std::pair{
175+
key, SinkableInfo(currp, this->getPassOptions(), *this->getModule())});
176+
registerSinkable(key);
177+
}
178+
94179
// Information about an exit from a block: the break, and the
95180
// sinkables. For the final exit from a block (falling off)
96181
// exitter is null.
@@ -135,8 +220,7 @@ struct SimplifyLocals
135220
// value means the block already has a return value
136221
self->unoptimizableBlocks.insert(br->name);
137222
} else {
138-
self->blockBreaks[br->name].push_back(
139-
{currp, std::move(self->sinkables)});
223+
self->blockBreaks[br->name].push_back({currp, self->takeSinkables()});
140224
}
141225
} else if (curr->is<Block>()) {
142226
return; // handled in visitBlock
@@ -153,15 +237,15 @@ struct SimplifyLocals
153237
}
154238
// TODO: we could use this info to stop gathering data on these blocks
155239
}
156-
self->sinkables.clear();
240+
self->clearSinkables();
157241
}
158242

159243
static void doNoteIfCondition(
160244
SimplifyLocals<allowTee, allowStructure, allowNesting>* self,
161245
Expression** currp) {
162246
// we processed the condition of this if-else, and now control flow branches
163247
// into either the true or the false sides
164-
self->sinkables.clear();
248+
self->clearSinkables();
165249
}
166250

167251
static void
@@ -170,13 +254,13 @@ struct SimplifyLocals
170254
auto* iff = (*currp)->cast<If>();
171255
if (iff->ifFalse) {
172256
// We processed the ifTrue side of this if-else, save it on the stack.
173-
self->ifStack.push_back(std::move(self->sinkables));
257+
self->ifStack.push_back(self->takeSinkables());
174258
} else {
175259
// This is an if without an else.
176260
if (allowStructure) {
177261
self->optimizeIfReturn(iff, currp);
178262
}
179-
self->sinkables.clear();
263+
self->clearSinkables();
180264
}
181265
}
182266

@@ -191,7 +275,7 @@ struct SimplifyLocals
191275
self->optimizeIfElseReturn(iff, currp, self->ifStack.back());
192276
}
193277
self->ifStack.pop_back();
194-
self->sinkables.clear();
278+
self->clearSinkables();
195279
}
196280

197281
void visitBlock(Block* curr) {
@@ -204,13 +288,13 @@ struct SimplifyLocals
204288
// post-block cleanups
205289
if (curr->name.is()) {
206290
if (unoptimizableBlocks.contains(curr->name)) {
207-
sinkables.clear();
291+
clearSinkables();
208292
unoptimizableBlocks.erase(curr->name);
209293
}
210294

211295
if (hasBreaks) {
212296
// more than one path to here, so nonlinear
213-
sinkables.clear();
297+
clearSinkables();
214298
blockBreaks.erase(curr->name);
215299
}
216300
}
@@ -284,7 +368,7 @@ struct SimplifyLocals
284368
// reuse the local.get that is dying
285369
*found->second.item = curr;
286370
ExpressionManipulator::nop(curr);
287-
sinkables.erase(found);
371+
eraseSinkable(found);
288372
anotherCycle = true;
289373
}
290374
}
@@ -300,15 +384,56 @@ struct SimplifyLocals
300384
}
301385

302386
void checkInvalidations(EffectAnalyzer& effects) {
303-
// TODO: this is O(bad)
387+
// Use targeted lookups instead of iterating all sinkables.
388+
// We collect candidate sinkable keys that *might* conflict, then verify.
389+
std::unordered_set<Index> candidates;
390+
391+
// Local conflicts via reverse indices.
392+
// When the current expression reads local L, any sinkable that writes L
393+
// has a write-read conflict.
394+
for (auto L : effects.localsRead) {
395+
auto it = localWrittenBySinkable_.find(L);
396+
if (it != localWrittenBySinkable_.end()) {
397+
candidates.insert(it->second.begin(), it->second.end());
398+
}
399+
}
400+
// When the current expression writes local L, any sinkable that reads L
401+
// (read-write conflict) or writes L (write-write conflict) is a candidate.
402+
for (auto L : effects.localsWritten) {
403+
auto it = localReadBySinkable_.find(L);
404+
if (it != localReadBySinkable_.end()) {
405+
candidates.insert(it->second.begin(), it->second.end());
406+
}
407+
auto it2 = localWrittenBySinkable_.find(L);
408+
if (it2 != localWrittenBySinkable_.end()) {
409+
candidates.insert(it2->second.begin(), it2->second.end());
410+
}
411+
}
412+
413+
// Non-local conflicts: if the current expression has non-local effects,
414+
// check sinkables that also have non-local effects.
415+
if (effects.hasNonLocalOrderingEffects()) {
416+
candidates.insert(heavySinkables_.begin(), heavySinkables_.end());
417+
}
418+
// If current transfers control flow, all sinkables with any side effects
419+
// (including local access) are invalidated. Since all sinkables access
420+
// locals, this means all of them.
421+
if (effects.transfersControlFlow()) {
422+
for (auto& [key, _] : sinkables) {
423+
candidates.insert(key);
424+
}
425+
}
426+
427+
// Verify candidates with the full ordering check and invalidate.
304428
std::vector<Index> invalidated;
305-
for (auto& [index, info] : sinkables) {
306-
if (effects.orderedAfter(info.effects)) {
307-
invalidated.push_back(index);
429+
for (auto key : candidates) {
430+
auto it = sinkables.find(key);
431+
if (it != sinkables.end() && effects.orderedAfter(it->second.effects)) {
432+
invalidated.push_back(key);
308433
}
309434
}
310-
for (auto index : invalidated) {
311-
sinkables.erase(index);
435+
for (auto key : invalidated) {
436+
eraseSinkable(key);
312437
}
313438
}
314439

@@ -334,7 +459,7 @@ struct SimplifyLocals
334459
}
335460
}
336461
for (auto index : invalidated) {
337-
self->sinkables.erase(index);
462+
self->eraseSinkable(index);
338463
}
339464
}
340465

@@ -419,7 +544,7 @@ struct SimplifyLocals
419544
Drop* drop = ExpressionManipulator::convert<LocalSet, Drop>(previous);
420545
drop->value = previousValue;
421546
drop->finalize();
422-
self->sinkables.erase(found);
547+
self->eraseSinkable(found);
423548
self->anotherCycle = true;
424549
}
425550
}
@@ -432,9 +557,7 @@ struct SimplifyLocals
432557
if (set && self->canSink(set)) {
433558
Index index = set->index;
434559
assert(!self->sinkables.contains(index));
435-
self->sinkables.emplace(std::pair{
436-
index,
437-
SinkableInfo(currp, self->getPassOptions(), *self->getModule())});
560+
self->addSinkable(index, currp);
438561
}
439562

440563
if (!allowNesting) {
@@ -476,7 +599,13 @@ struct SimplifyLocals
476599
if (sinkables.empty()) {
477600
return;
478601
}
479-
Index goodIndex = sinkables.begin()->first;
602+
// Pick the lowest-index sinkable for deterministic output.
603+
Index goodIndex = std::min_element(sinkables.begin(),
604+
sinkables.end(),
605+
[](const auto& a, const auto& b) {
606+
return a.first < b.first;
607+
})
608+
->first;
480609
// Ensure we have a place to write the return values for, if not, we
481610
// need another cycle.
482611
auto* block = loop->body->dynCast<Block>();
@@ -498,7 +627,7 @@ struct SimplifyLocals
498627
this->replaceCurrent(set);
499628
// We moved things around, clear all tracking; we'll do another cycle
500629
// anyhow.
501-
sinkables.clear();
630+
clearSinkables();
502631
anotherCycle = true;
503632
}
504633

@@ -515,7 +644,8 @@ struct SimplifyLocals
515644
// block does not already have a return value (if one break has one, they
516645
// all do)
517646
assert(!(*breaks[0].brp)->template cast<Break>()->value);
518-
// look for a local.set that is present in them all
647+
// look for a local.set that is present in them all.
648+
// Pick the lowest index for deterministic output.
519649
bool found = false;
520650
Index sharedIndex = -1;
521651
for (auto& [index, _] : sinkables) {
@@ -526,10 +656,9 @@ struct SimplifyLocals
526656
break;
527657
}
528658
}
529-
if (inAll) {
659+
if (inAll && (!found || index < sharedIndex)) {
530660
sharedIndex = index;
531661
found = true;
532-
break;
533662
}
534663
}
535664
if (!found) {
@@ -624,7 +753,7 @@ struct SimplifyLocals
624753
auto* newLocalSet =
625754
Builder(*this->getModule()).makeLocalSet(sharedIndex, block);
626755
this->replaceCurrent(newLocalSet);
627-
sinkables.clear();
756+
clearSinkables();
628757
anotherCycle = true;
629758
block->finalize();
630759
}
@@ -656,27 +785,35 @@ struct SimplifyLocals
656785
Sinkables& ifFalse = sinkables;
657786
Index goodIndex = -1;
658787
bool found = false;
788+
auto pickLowest = [](Sinkables& s) {
789+
return std::min_element(
790+
s.begin(),
791+
s.end(),
792+
[](const auto& a, const auto& b) { return a.first < b.first; })
793+
->first;
794+
};
659795
if (iff->ifTrue->type == Type::unreachable) {
660796
// since the if type is none
661797
assert(iff->ifFalse->type != Type::unreachable);
662798
if (!ifFalse.empty()) {
663-
goodIndex = ifFalse.begin()->first;
799+
goodIndex = pickLowest(ifFalse);
664800
found = true;
665801
}
666802
} else if (iff->ifFalse->type == Type::unreachable) {
667803
// since the if type is none
668804
assert(iff->ifTrue->type != Type::unreachable);
669805
if (!ifTrue.empty()) {
670-
goodIndex = ifTrue.begin()->first;
806+
goodIndex = pickLowest(ifTrue);
671807
found = true;
672808
}
673809
} else {
674-
// Look for a shared index.
810+
// Look for a shared index (pick the lowest for determinism).
675811
for (auto& [index, _] : ifTrue) {
676812
if (ifFalse.contains(index)) {
677-
goodIndex = index;
678-
found = true;
679-
break;
813+
if (!found || index < goodIndex) {
814+
goodIndex = index;
815+
found = true;
816+
}
680817
}
681818
}
682819
}
@@ -799,7 +936,13 @@ struct SimplifyLocals
799936
// element).
800937
//
801938
// TODO investigate more
802-
Index goodIndex = sinkables.begin()->first;
939+
// Pick the lowest-index sinkable for deterministic output.
940+
Index goodIndex = std::min_element(sinkables.begin(),
941+
sinkables.end(),
942+
[](const auto& a, const auto& b) {
943+
return a.first < b.first;
944+
})
945+
->first;
803946
auto localType = this->getFunction()->getLocalType(goodIndex);
804947
if (!localType.isDefaultable()) {
805948
return;
@@ -973,7 +1116,7 @@ struct SimplifyLocals
9731116
anotherCycle = true;
9741117
}
9751118
// clean up
976-
sinkables.clear();
1119+
clearSinkables();
9771120
blockBreaks.clear();
9781121
unoptimizableBlocks.clear();
9791122
return anotherCycle;

0 commit comments

Comments
 (0)