2121
2222#include " ir/effects.h"
2323#include " ir/module-utils.h"
24+ #include " ir/subtypes.h"
2425#include " pass.h"
2526#include " support/strongly_connected_components.h"
2627#include " wasm.h"
@@ -39,6 +40,9 @@ struct FuncInfo {
3940
4041 // Directly-called functions from this function.
4142 std::unordered_set<Name> calledFunctions;
43+
44+ // Types that are targets of indirect calls.
45+ std::unordered_set<HeapType> indirectCalledTypes;
4246};
4347
4448std::map<Function*, FuncInfo> analyzeFuncs (Module& module ,
@@ -84,11 +88,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8488 // Note the direct call.
8589 funcInfo.calledFunctions .insert (call->target );
8690 } else if (effects.calls ) {
87- // This is an indirect call of some sort, so we must assume the
88- // worst. To do so, clear the effects, which indicates nothing
89- // is known (so anything is possible).
90- // TODO: We could group effects by function type etc.
91- funcInfo.effects = UnknownEffects;
91+ HeapType type;
92+ if (auto * callRef = curr->dynCast <CallRef>()) {
93+ type = callRef->target ->type .getHeapType ();
94+ } else if (auto * callIndirect = curr->dynCast <CallIndirect>()) {
95+ // nullability doesn't matter here
96+ // call_indirect is always inexact
97+ type = callIndirect->heapType ;
98+ } else {
99+ assert (false && " Unexpected type of call" );
100+ }
101+
102+ funcInfo.indirectCalledTypes .insert (type);
103+
92104 } else {
93105 // No call here, but update throwing if we see it. (Only do so,
94106 // however, if we have effects; if we cleared it - see before -
@@ -107,14 +119,31 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107119 return std::move (analysis.map );
108120}
109121
110- std::unordered_map<Function*, std::unordered_set<Function*>>
111- buildCallGraph (const Module& module ,
112- const std::map<Function*, FuncInfo>& funcInfos) {
113- std::unordered_map<Function*, std::unordered_set<Function*>> callGraph;
114- for (const auto & [func, info] : funcInfos) {
115- for (Name callee : info.calledFunctions ) {
116- callGraph[func].insert (module .getFunction (callee));
122+ using CallGraphNode = std::variant<Function*, HeapType>;
123+
124+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
125+ buildCallGraph (Module& module , std::map<Function*, FuncInfo> funcInfos) {
126+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
127+ callGraph;
128+ std::unordered_set<HeapType> allFunctionTypes;
129+ for (const auto & [caller, callerInfo] : funcInfos) {
130+ allFunctionTypes.insert (caller->type .getHeapType ());
131+ for (Name calleeFunction : callerInfo.calledFunctions ) {
132+ callGraph[caller].insert (module .getFunction (calleeFunction));
133+ }
134+ for (HeapType calleeType : callerInfo.indirectCalledTypes ) {
135+ callGraph[caller].insert (calleeType);
136+ allFunctionTypes.insert (calleeType);
117137 }
138+ callGraph[caller->type .getHeapType ()].insert (caller);
139+ }
140+
141+ SubTypes subtypes (module );
142+ for (HeapType type : allFunctionTypes) {
143+ subtypes.iterSubTypes (type, [&callGraph, type](HeapType sub, auto _) {
144+ callGraph[type].insert (sub);
145+ return true ;
146+ });
118147 }
119148
120149 return callGraph;
@@ -123,98 +152,127 @@ buildCallGraph(const Module& module,
123152// Propagate effects from callees to callers transitively
124153// e.g. if A -> B -> C (A calls B which calls C)
125154// Then B inherits effects from C and A inherits effects from both B and C.
126- //
127- // Generate SCC for the call graph, then traverse it in reverse topological
128- // order processing each callee before its callers. When traversing:
129- // - Merge all of the effects of functions within the CC
130- // - Also merge the (already computed) effects of each callee CC
131- // - Add trap effects for potentially recursive call chains
132155void propagateEffects (
133- const Module& module ,
156+ Module& module ,
134157 const PassOptions& passOptions,
135158 std::map<Function*, FuncInfo>& funcInfos,
136- const std::unordered_map<Function* , std::unordered_set<Function* >>
159+ std::unordered_map<CallGraphNode , std::unordered_set<CallGraphNode >>
137160 callGraph) {
161+ std::unordered_set<HeapType> allFunctionTypes;
162+ for (const auto & [caller, callerInfo] : funcInfos) {
163+ allFunctionTypes.insert (caller->type .getHeapType ());
164+ for (Name calleeFunction : callerInfo.calledFunctions ) {
165+ callGraph[caller].insert (module .getFunction (calleeFunction));
166+ }
167+ for (HeapType calleeType : callerInfo.indirectCalledTypes ) {
168+ callGraph[caller].insert (calleeType);
169+ allFunctionTypes.insert (calleeType);
170+ }
171+ callGraph[caller->type .getHeapType ()].insert (caller);
172+ }
173+
174+ SubTypes subtypes (module );
175+ for (HeapType type : allFunctionTypes) {
176+ subtypes.iterSubTypes (type, [&callGraph, type](HeapType sub, auto _) {
177+ callGraph[type].insert (sub);
178+ return true ;
179+ });
180+ }
181+
138182 struct CallGraphSCCs
139- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs> {
183+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs> {
140184 const std::map<Function*, FuncInfo>& funcInfos;
141- const std::unordered_map<Function*, std::unordered_set<Function*>>&
142- callGraph;
143185 const Module& module ;
186+ const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
187+ callGraph;
144188
145189 CallGraphSCCs (
146- const std::vector<Function* >& funcs ,
190+ const std::vector<CallGraphNode >& nodes ,
147191 const std::map<Function*, FuncInfo>& funcInfos,
148- const std::unordered_map<Function*, std::unordered_set<Function*>>&
149- callGraph ,
150- const Module& module )
151- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs>(
152- funcs .begin(), funcs .end()),
153- funcInfos (funcInfos), callGraph(callGraph ), module ( module ) {}
154-
155- void pushChildren (Function* f ) {
156- auto callees = callGraph.find (f );
192+ Module& module ,
193+ const std::unordered_map<CallGraphNode ,
194+ std::unordered_set<CallGraphNode>>& callGraph )
195+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs>(
196+ nodes .begin(), nodes .end()),
197+ funcInfos (funcInfos), module ( module ), callGraph(callGraph ) {}
198+
199+ void pushChildren (CallGraphNode node ) {
200+ auto callees = callGraph.find (node );
157201 if (callees == callGraph.end ()) {
158202 return ;
159203 }
160-
161- for (auto * callee : callees->second ) {
204+ for (const auto & callee : callees->second ) {
162205 push (callee);
163206 }
164207 }
165208 };
166209
167- std::vector<Function*> allFuncs ;
210+ std::vector<CallGraphNode> funcs ;
168211 for (auto & [func, info] : funcInfos) {
169- allFuncs .push_back (func);
212+ funcs .push_back (func);
170213 }
171- CallGraphSCCs sccs (allFuncs, funcInfos, callGraph, module );
172214
173- std::unordered_map<Function*, int > sccMembers;
215+ CallGraphSCCs sccs (funcs, funcInfos, module , callGraph);
216+
217+ std::unordered_map<CallGraphNode, int > sccMembers;
174218 std::unordered_map<int , std::optional<EffectAnalyzer>> componentEffects;
175219
176220 int ccIndex = 0 ;
177221 for (auto ccIterator : sccs) {
222+ std::vector<CallGraphNode> cc (ccIterator.begin (), ccIterator.end ());
178223 ccIndex++;
179224 std::optional<EffectAnalyzer>& ccEffects = componentEffects[ccIndex];
180- std::vector<Function*> ccFuncs (ccIterator.begin (), ccIterator.end ());
181-
182225 ccEffects.emplace (passOptions, module );
183226
227+ std::vector<Function*> ccFuncs;
228+ std::vector<HeapType> ccTypes;
229+ for (auto v : cc) {
230+ if (auto ** func = std::get_if<Function*>(&v)) {
231+ ccFuncs.push_back (*func);
232+ } else {
233+ ccTypes.push_back (std::get<HeapType>(v));
234+ }
235+ }
236+
184237 for (Function* f : ccFuncs) {
185238 sccMembers.emplace (f, ccIndex);
186239 }
240+ for (HeapType t : ccTypes) {
241+ sccMembers.emplace (t, ccIndex);
242+ }
187243
188244 std::unordered_set<int > calleeSccs;
189- for (Function* caller : ccFuncs ) {
245+ for (const auto & caller : cc ) {
190246 auto callees = callGraph.find (caller);
191- if (callees == callGraph.end ()) {
192- continue ;
193- }
194- for (auto * callee : callees->second ) {
195- calleeSccs.insert (sccMembers.at (callee));
247+ if (callees != callGraph.end ()) {
248+ for (const auto & callee : callees->second ) {
249+ auto sccIt = sccMembers.find (callee);
250+ if (sccIt != sccMembers.end ()) {
251+ calleeSccs.insert (sccIt->second );
252+ }
253+ }
196254 }
197255 }
198256
199- // Merge in effects from callees
200257 for (int calleeScc : calleeSccs) {
201258 const auto & calleeComponentEffects = componentEffects.at (calleeScc);
202259 if (calleeComponentEffects == UnknownEffects) {
203260 ccEffects = UnknownEffects;
261+ // stop = true;
204262 break ;
205263 }
206264
207- else if (ccEffects != UnknownEffects ) {
265+ else if (ccEffects) {
208266 ccEffects->mergeIn (*calleeComponentEffects);
209267 }
210268 }
211269
212- // Add trap effects for potential cycles.
213- if (ccFuncs.size () > 1 ) {
270+ if (cc.size () > 1 ) {
214271 if (ccEffects != UnknownEffects) {
215272 ccEffects->trap = true ;
216273 }
217- } else {
274+ // A cycle isn't possible for a CC that only contains a type
275+ } else if (ccFuncs.size () == 1 ) {
218276 auto * func = ccFuncs[0 ];
219277 if (funcInfos.at (func).calledFunctions .contains (func->name )) {
220278 if (ccEffects != UnknownEffects) {
@@ -223,8 +281,7 @@ void propagateEffects(
223281 }
224282 }
225283
226- // Aggregate effects within this CC
227- if (ccEffects) {
284+ if (ccEffects)
228285 for (Function* f : ccFuncs) {
229286 const auto & effects = funcInfos.at (f).effects ;
230287 if (effects == UnknownEffects) {
@@ -234,9 +291,7 @@ void propagateEffects(
234291
235292 ccEffects->mergeIn (*effects);
236293 }
237- }
238294
239- // Assign each function's effects to its CC effects.
240295 for (Function* f : ccFuncs) {
241296 if (!ccEffects) {
242297 funcInfos.at (f).effects = UnknownEffects;
@@ -247,17 +302,6 @@ void propagateEffects(
247302 }
248303}
249304
250- void copyEffectsToFunctions (const std::map<Function*, FuncInfo> funcInfos) {
251- for (auto & [func, info] : funcInfos) {
252- func->effects .reset ();
253- if (!info.effects ) {
254- continue ;
255- }
256-
257- func->effects = std::make_shared<EffectAnalyzer>(*info.effects );
258- }
259- }
260-
261305struct GenerateGlobalEffects : public Pass {
262306 void run (Module* module ) override {
263307 std::map<Function*, FuncInfo> funcInfos =
@@ -267,7 +311,16 @@ struct GenerateGlobalEffects : public Pass {
267311
268312 propagateEffects (*module , getPassOptions (), funcInfos, callGraph);
269313
270- copyEffectsToFunctions (funcInfos);
314+ // Generate the final data, starting from a blank slate where nothing is
315+ // known.
316+ for (auto & [func, info] : funcInfos) {
317+ func->effects .reset ();
318+ if (!info.effects ) {
319+ continue ;
320+ }
321+
322+ func->effects = std::make_shared<EffectAnalyzer>(*info.effects );
323+ }
271324 }
272325};
273326
0 commit comments