diff --git a/test/testpaint/generate.cpp b/test/testpaint/generate.cpp index 655241c0ba..496272aadc 100644 --- a/test/testpaint/generate.cpp +++ b/test/testpaint/generate.cpp @@ -317,39 +317,44 @@ private: void GenerateCalls(int tabs, std::vector calls[4], int height) { + std::vector commonCalls = TrimCommonCallsEnd(calls); + int totalCalls = 0; for (int direction = 0; direction < 4; direction++) { totalCalls += calls[direction].size(); } - if (totalCalls == 0) + if (totalCalls != 0) { - return; - } - - WriteLine(tabs, "switch (direction) {"); - for (int direction = 0; direction < 4; direction++) - { - if (calls[direction].size() == 0) continue; - - WriteLine(tabs, "case %d:", direction); - for (int d2 = direction + 1; d2 < 4; d2++) + WriteLine(tabs, "switch (direction) {"); + for (int direction = 0; direction < 4; direction++) { - if (CompareFunctionCalls(calls[direction], calls[d2])) + if (calls[direction].size() == 0) continue; + + WriteLine(tabs, "case %d:", direction); + for (int d2 = direction + 1; d2 < 4; d2++) { - // Clear identical other direction calls and add case for it - calls[d2].clear(); - WriteLine(tabs, "case %d:", d2); + if (CompareFunctionCalls(calls[direction], calls[d2])) + { + // Clear identical other direction calls and add case for it + calls[d2].clear(); + WriteLine(tabs, "case %d:", d2); + } } - } - for (auto call : calls[direction]) - { - GenerateCalls(tabs + 1, call, height, direction); + for (auto call : calls[direction]) + { + GenerateCalls(tabs + 1, call, height, direction); + } + WriteLine(tabs + 1, "break;"); } - WriteLine(tabs + 1, "break;"); + WriteLine(tabs, "}"); + } + + for (auto call : commonCalls) + { + GenerateCalls(tabs, call, height, 0); } - WriteLine(tabs, "}"); } void GenerateCalls(int tabs, const function_call &call, int height, int direction) @@ -427,6 +432,31 @@ private: } } + std::vector TrimCommonCallsEnd(std::vector calls[4]) + { + std::vector commonCalls; + + while (calls[0].size() != 0) + { + function_call lastCall = calls[0].back(); + for (int i = 0; i < 4; i++) + { + if (calls[i].size() == 0 || !CompareFunctionCall(calls[i].back(), lastCall)) + { + goto finished; + } + } + for (int i = 0; i < 4; i++) + { + calls[i].pop_back(); + } + commonCalls.push_back(lastCall); + } + + finished: + return commonCalls; + } + bool CompareFunctionCalls(const std::vector &a, const std::vector &b) { if (a.size() != b.size()) return false;