Skip to content

Commit 831e6bf

Browse files
Cherrypick SCC with indirect effects
1 parent 9bc667f commit 831e6bf

File tree

3 files changed

+462
-72
lines changed

3 files changed

+462
-72
lines changed

src/passes/GlobalEffects.cpp

Lines changed: 121 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "ir/effects.h"
2323
#include "ir/module-utils.h"
24+
#include "ir/subtypes.h"
2425
#include "pass.h"
2526
#include "support/strongly_connected_components.h"
2627
#include "wasm.h"
@@ -39,6 +40,9 @@ struct FuncInfo {
3940

4041
// Directly-called functions from this function.
4142
std::unordered_set<Name> calledFunctions;
43+
44+
// Types that are targets of indirect calls.
45+
std::unordered_set<HeapType> indirectCalledTypes;
4246
};
4347

4448
std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
@@ -84,11 +88,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8488
// Note the direct call.
8589
funcInfo.calledFunctions.insert(call->target);
8690
} else if (effects.calls) {
87-
// This is an indirect call of some sort, so we must assume the
88-
// worst. To do so, clear the effects, which indicates nothing
89-
// is known (so anything is possible).
90-
// TODO: We could group effects by function type etc.
91-
funcInfo.effects = UnknownEffects;
91+
HeapType type;
92+
if (auto* callRef = curr->dynCast<CallRef>()) {
93+
type = callRef->target->type.getHeapType();
94+
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
95+
// nullability doesn't matter here
96+
// call_indirect is always inexact
97+
type = callIndirect->heapType;
98+
} else {
99+
assert(false && "Unexpected type of call");
100+
}
101+
102+
funcInfo.indirectCalledTypes.insert(type);
103+
92104
} else {
93105
// No call here, but update throwing if we see it. (Only do so,
94106
// however, if we have effects; if we cleared it - see before -
@@ -107,14 +119,31 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107119
return std::move(analysis.map);
108120
}
109121

110-
std::unordered_map<Function*, std::unordered_set<Function*>>
111-
buildCallGraph(const Module& module,
112-
const std::map<Function*, FuncInfo>& funcInfos) {
113-
std::unordered_map<Function*, std::unordered_set<Function*>> callGraph;
114-
for (const auto& [func, info] : funcInfos) {
115-
for (Name callee : info.calledFunctions) {
116-
callGraph[func].insert(module.getFunction(callee));
122+
using CallGraphNode = std::variant<Function*, HeapType>;
123+
124+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
125+
buildCallGraph(Module& module, std::map<Function*, FuncInfo> funcInfos) {
126+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
127+
callGraph;
128+
std::unordered_set<HeapType> allFunctionTypes;
129+
for (const auto& [caller, callerInfo] : funcInfos) {
130+
allFunctionTypes.insert(caller->type.getHeapType());
131+
for (Name calleeFunction : callerInfo.calledFunctions) {
132+
callGraph[caller].insert(module.getFunction(calleeFunction));
133+
}
134+
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
135+
callGraph[caller].insert(calleeType);
136+
allFunctionTypes.insert(calleeType);
117137
}
138+
callGraph[caller->type.getHeapType()].insert(caller);
139+
}
140+
141+
SubTypes subtypes(module);
142+
for (HeapType type : allFunctionTypes) {
143+
subtypes.iterSubTypes(type, [&callGraph, type](HeapType sub, auto _) {
144+
callGraph[type].insert(sub);
145+
return true;
146+
});
118147
}
119148

120149
return callGraph;
@@ -123,98 +152,127 @@ buildCallGraph(const Module& module,
123152
// Propagate effects from callees to callers transitively
124153
// e.g. if A -> B -> C (A calls B which calls C)
125154
// Then B inherits effects from C and A inherits effects from both B and C.
126-
//
127-
// Generate SCC for the call graph, then traverse it in reverse topological
128-
// order processing each callee before its callers. When traversing:
129-
// - Merge all of the effects of functions within the CC
130-
// - Also merge the (already computed) effects of each callee CC
131-
// - Add trap effects for potentially recursive call chains
132155
void propagateEffects(
133-
const Module& module,
156+
Module& module,
134157
const PassOptions& passOptions,
135158
std::map<Function*, FuncInfo>& funcInfos,
136-
const std::unordered_map<Function*, std::unordered_set<Function*>>
159+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
137160
callGraph) {
161+
std::unordered_set<HeapType> allFunctionTypes;
162+
for (const auto& [caller, callerInfo] : funcInfos) {
163+
allFunctionTypes.insert(caller->type.getHeapType());
164+
for (Name calleeFunction : callerInfo.calledFunctions) {
165+
callGraph[caller].insert(module.getFunction(calleeFunction));
166+
}
167+
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
168+
callGraph[caller].insert(calleeType);
169+
allFunctionTypes.insert(calleeType);
170+
}
171+
callGraph[caller->type.getHeapType()].insert(caller);
172+
}
173+
174+
SubTypes subtypes(module);
175+
for (HeapType type : allFunctionTypes) {
176+
subtypes.iterSubTypes(type, [&callGraph, type](HeapType sub, auto _) {
177+
callGraph[type].insert(sub);
178+
return true;
179+
});
180+
}
181+
138182
struct CallGraphSCCs
139-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs> {
183+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs> {
140184
const std::map<Function*, FuncInfo>& funcInfos;
141-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
142-
callGraph;
143185
const Module& module;
186+
const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
187+
callGraph;
144188

145189
CallGraphSCCs(
146-
const std::vector<Function*>& funcs,
190+
const std::vector<CallGraphNode>& nodes,
147191
const std::map<Function*, FuncInfo>& funcInfos,
148-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
149-
callGraph,
150-
const Module& module)
151-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs>(
152-
funcs.begin(), funcs.end()),
153-
funcInfos(funcInfos), callGraph(callGraph), module(module) {}
154-
155-
void pushChildren(Function* f) {
156-
auto callees = callGraph.find(f);
192+
Module& module,
193+
const std::unordered_map<CallGraphNode,
194+
std::unordered_set<CallGraphNode>>& callGraph)
195+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs>(
196+
nodes.begin(), nodes.end()),
197+
funcInfos(funcInfos), module(module), callGraph(callGraph) {}
198+
199+
void pushChildren(CallGraphNode node) {
200+
auto callees = callGraph.find(node);
157201
if (callees == callGraph.end()) {
158202
return;
159203
}
160-
161-
for (auto* callee : callees->second) {
204+
for (const auto& callee : callees->second) {
162205
push(callee);
163206
}
164207
}
165208
};
166209

167-
std::vector<Function*> allFuncs;
210+
std::vector<CallGraphNode> funcs;
168211
for (auto& [func, info] : funcInfos) {
169-
allFuncs.push_back(func);
212+
funcs.push_back(func);
170213
}
171-
CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module);
172214

173-
std::unordered_map<Function*, int> sccMembers;
215+
CallGraphSCCs sccs(funcs, funcInfos, module, callGraph);
216+
217+
std::unordered_map<CallGraphNode, int> sccMembers;
174218
std::unordered_map<int, std::optional<EffectAnalyzer>> componentEffects;
175219

176220
int ccIndex = 0;
177221
for (auto ccIterator : sccs) {
222+
std::vector<CallGraphNode> cc(ccIterator.begin(), ccIterator.end());
178223
ccIndex++;
179224
std::optional<EffectAnalyzer>& ccEffects = componentEffects[ccIndex];
180-
std::vector<Function*> ccFuncs(ccIterator.begin(), ccIterator.end());
181-
182225
ccEffects.emplace(passOptions, module);
183226

227+
std::vector<Function*> ccFuncs;
228+
std::vector<HeapType> ccTypes;
229+
for (auto v : cc) {
230+
if (auto** func = std::get_if<Function*>(&v)) {
231+
ccFuncs.push_back(*func);
232+
} else {
233+
ccTypes.push_back(std::get<HeapType>(v));
234+
}
235+
}
236+
184237
for (Function* f : ccFuncs) {
185238
sccMembers.emplace(f, ccIndex);
186239
}
240+
for (HeapType t : ccTypes) {
241+
sccMembers.emplace(t, ccIndex);
242+
}
187243

188244
std::unordered_set<int> calleeSccs;
189-
for (Function* caller : ccFuncs) {
245+
for (const auto& caller : cc) {
190246
auto callees = callGraph.find(caller);
191-
if (callees == callGraph.end()) {
192-
continue;
193-
}
194-
for (auto* callee : callees->second) {
195-
calleeSccs.insert(sccMembers.at(callee));
247+
if (callees != callGraph.end()) {
248+
for (const auto& callee : callees->second) {
249+
auto sccIt = sccMembers.find(callee);
250+
if (sccIt != sccMembers.end()) {
251+
calleeSccs.insert(sccIt->second);
252+
}
253+
}
196254
}
197255
}
198256

199-
// Merge in effects from callees
200257
for (int calleeScc : calleeSccs) {
201258
const auto& calleeComponentEffects = componentEffects.at(calleeScc);
202259
if (calleeComponentEffects == UnknownEffects) {
203260
ccEffects = UnknownEffects;
261+
// stop = true;
204262
break;
205263
}
206264

207-
else if (ccEffects != UnknownEffects) {
265+
else if (ccEffects) {
208266
ccEffects->mergeIn(*calleeComponentEffects);
209267
}
210268
}
211269

212-
// Add trap effects for potential cycles.
213-
if (ccFuncs.size() > 1) {
270+
if (cc.size() > 1) {
214271
if (ccEffects != UnknownEffects) {
215272
ccEffects->trap = true;
216273
}
217-
} else {
274+
// A cycle isn't possible for a CC that only contains a type
275+
} else if (ccFuncs.size() == 1) {
218276
auto* func = ccFuncs[0];
219277
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
220278
if (ccEffects != UnknownEffects) {
@@ -223,8 +281,7 @@ void propagateEffects(
223281
}
224282
}
225283

226-
// Aggregate effects within this CC
227-
if (ccEffects) {
284+
if (ccEffects)
228285
for (Function* f : ccFuncs) {
229286
const auto& effects = funcInfos.at(f).effects;
230287
if (effects == UnknownEffects) {
@@ -234,9 +291,7 @@ void propagateEffects(
234291

235292
ccEffects->mergeIn(*effects);
236293
}
237-
}
238294

239-
// Assign each function's effects to its CC effects.
240295
for (Function* f : ccFuncs) {
241296
if (!ccEffects) {
242297
funcInfos.at(f).effects = UnknownEffects;
@@ -247,17 +302,6 @@ void propagateEffects(
247302
}
248303
}
249304

250-
void copyEffectsToFunctions(const std::map<Function*, FuncInfo> funcInfos) {
251-
for (auto& [func, info] : funcInfos) {
252-
func->effects.reset();
253-
if (!info.effects) {
254-
continue;
255-
}
256-
257-
func->effects = std::make_shared<EffectAnalyzer>(*info.effects);
258-
}
259-
}
260-
261305
struct GenerateGlobalEffects : public Pass {
262306
void run(Module* module) override {
263307
std::map<Function*, FuncInfo> funcInfos =
@@ -267,7 +311,16 @@ struct GenerateGlobalEffects : public Pass {
267311

268312
propagateEffects(*module, getPassOptions(), funcInfos, callGraph);
269313

270-
copyEffectsToFunctions(funcInfos);
314+
// Generate the final data, starting from a blank slate where nothing is
315+
// known.
316+
for (auto& [func, info] : funcInfos) {
317+
func->effects.reset();
318+
if (!info.effects) {
319+
continue;
320+
}
321+
322+
func->effects = std::make_shared<EffectAnalyzer>(*info.effects);
323+
}
271324
}
272325
};
273326

src/wasm-type.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,8 @@ class Type {
353353
assert(!heapType.isBasic() || exact == Inexact);
354354
}
355355

356+
// Type& operator=(const Type&) = default;
357+
356358
// Predicates
357359
// Compound Concrete
358360
// Type Basic │ Single│
@@ -581,14 +583,18 @@ class Type {
581583
};
582584

583585
Iterator begin() const { return Iterator{{this, 0}}; }
584-
Iterator end() const { return Iterator{{this, size()}}; }
586+
Iterator end() const {
587+
return Iterator{{this, size()}};
588+
}
585589
std::reverse_iterator<Iterator> rbegin() const {
586590
return std::make_reverse_iterator(end());
587591
}
588592
std::reverse_iterator<Iterator> rend() const {
589593
return std::make_reverse_iterator(begin());
590594
}
591-
const Type& operator[](size_t i) const { return *Iterator{{this, i}}; }
595+
const Type& operator[](size_t i) const {
596+
return *Iterator{{this, i}};
597+
}
592598
};
593599

594600
Type Type::asWrittenGivenFeatures(FeatureSet feats) const {
@@ -669,8 +675,12 @@ class RecGroup {
669675
};
670676

671677
Iterator begin() const { return Iterator{{this, 0}}; }
672-
Iterator end() const { return Iterator{{this, size()}}; }
673-
HeapType operator[](size_t i) const { return *Iterator{{this, i}}; }
678+
Iterator end() const {
679+
return Iterator{{this, size()}};
680+
}
681+
HeapType operator[](size_t i) const {
682+
return *Iterator{{this, i}};
683+
}
674684
};
675685

676686
struct Signature {

0 commit comments

Comments
 (0)