diff options
author | Erik Schnetter <schnetter@cct.lsu.edu> | 2009-04-27 11:35:12 -0500 |
---|---|---|
committer | Ian Hinder <ian.hinder@aei.mpg.de> | 2009-04-27 21:57:44 +0200 |
commit | a7763387a3e8a4cbbdf91146886ed6b43c56e074 (patch) | |
tree | c2a0ee20105de5d5a8dc59528ef5f8f42c47b01f /Tools/CodeGen/CodeGen.m | |
parent | 7caeabef4d1b9b1fe391ec1b5985f4f9fd21aa8f (diff) |
Add CSE (Common Subexpression Elimination)
Diffstat (limited to 'Tools/CodeGen/CodeGen.m')
-rw-r--r-- | Tools/CodeGen/CodeGen.m | 222 |
1 files changed, 220 insertions, 2 deletions
diff --git a/Tools/CodeGen/CodeGen.m b/Tools/CodeGen/CodeGen.m index 9418c98..4c0f9e8 100644 --- a/Tools/CodeGen/CodeGen.m +++ b/Tools/CodeGen/CodeGen.m @@ -112,6 +112,7 @@ CommaNewlineSeparated::usage = ""; (* This should not really be in CodeGen *) CommaSeparated::usage = ""; ReplacePowers::usage = ""; CFormHideStrings::usage = ""; +CSE::usage = ""; BoundaryLoop::usage = ""; BoundaryWithGhostsLoop::usage = ""; GenericGridLoop::usage = ""; @@ -270,6 +271,50 @@ AssignVariableInLoop[dest_, src_] := TestForNaN[dest]}; *) +(* TODO: move these into OpenMP loop *) +DeclareVariablesInLoopVectorised[dests_, temps_, srcs_] := + { + {"#undef LC_PRELOOP_STATEMENTS", "\n"}, + {"#define LC_PRELOOP_STATEMENTS", " \\\n"}, + {"int const GFD_imin = lc_imin + ((lc_imin + cctk_lsh[0] * (j + cctk_lsh[1] * k)) & (CCTK_REAL_VEC_SIZE-1))", "; \\\n"}, + {"int const GFD_imax = lc_imax + ((lc_imax + cctk_lsh[0] * (j + cctk_lsh[1] * k)) & (CCTK_REAL_VEC_SIZE-1)) - CCTK_REAL_VEC_SIZE", "; \\\n"}, + Map[Function[x, Module[{dest, temp, src}, + {dest, temp, src} = x; + {"CCTK_REAL_VEC ", temp, "; \\\n"}]], + Transpose[{dests, temps, srcs}]], + {"\n"} + }; + +AssignVariablesInLoopVectorised[dests_, temps_, srcs_] := + { + {"{\n"}, + {" if (i < GFD_imin || i >= GFD_imax) {\n"}, + Map[Function[x, Module[{dest, temp, src}, + {dest, temp, src} = x; + {" ", dest, "[index] = ", src, EOL[]}]], + Transpose[{dests, temps, srcs}]], + {" } else {\n"}, + {" size_t const index0 = index & (CCTK_REAL_VEC_SIZE-1)", EOL[]}, + Map[Function[x, Module[{dest, temp, src}, + {dest, temp, src} = x; + {" ((CCTK_REAL*)&", temp, ")[index0] = ", + src, EOL[]}]], + Transpose[{dests, temps, srcs}]], + {" if (index0 == CCTK_REAL_VEC_SIZE-1) {\n"}, + {" size_t const index1 = index - (CCTK_REAL_VEC_SIZE-1)", EOL[]}, + Map[Function[x, Module[{dest, temp, src}, + {dest, temp, src} = x; + {" _mm_stream_pd (&", dest, "[index1], ", + temp, ")", EOL[]}]], + Transpose[{dests, temps, srcs}]], + {" }\n"}, + {" }\n"}, + {"}\n"} + }; + +AssignVariableInLoopsVectorised[dest_, temp_, src_] := + {"GFD_save_and_store(", dest, ",", "index", ",", "&", temp, ",", src, ")", EOL[]}; + TestForNaN[expr_] := {"if (isnan(", expr, ")) {\n", " CCTK_VInfo(CCTK_THORNSTRING, \"NaN found\");\n", @@ -771,9 +816,182 @@ ReplacePowers[x_] := strings present *) CFormHideStrings[x_, opts___] := StringReplace[ToString[CForm[x,opts]], "\"" -> ""]; -End[]; -EndPackage[]; +(* Eliminate common subexpressions in a code sequence *) +CSE[code_] := Module[ + {expr, optexpr, + decomposed, locals, block, + block1, block2, temps1, stmts1, stmts2, stmts3, + replacevar, + stmts4, + stmts5, stmts6, stmts7}, + (* Print["code\n", code, "\nendcode\n"]; *) + + (* The code is passed in as list of {lhs,rhs} tuples. Turn this + list into a single expression, so that it can be optimised. *) + expr = code //. {a_, b__} -> CSequence[a, {b}] + //. {a_} -> a + //. (a_ -> b_) -> CAssign[a, b]; + (* Print["expr\n", expr, "\nendexpr\n"]; *) + + (* Optimise this expression *) + optexpr = Experimental`OptimizeExpression[expr]; + (* Print["optexpr\n", optexpr, "\nendoptexpr\n"]; *) + + (* This expression is a Mathematica expression. Decompose it into + the set of newly introduced local variables and the optimised + expression itself. *) + decomposed = + ReleaseHold[(Hold @@ optexpr) + /. Verbatim[Block][vars_, seq_] :> {vars, Hold[seq]}]; + + If[decomposed[[0]] =!= List, + (* If the optimiser didn't create a Block expression, we assume it + didn't do anything useful and return the original. *) + code, + + {locals, block} = decomposed; + (* Print["locals\n", locals, "\nendlocals\n"]; *) + (* Print["block\n", block, "\nendblock\n"]; *) + + block1 = block /. Hold[CompoundExpression[seq__]] :> Hold[{seq}]; + (* Print["block1\n", block1, "\nendblock1\n"]; *) + block2 = First[block1 //. Hold[{a___Hold, b_, c___}] + /; Head[Unevaluated[b]] =!= Hold + :> Hold[{a, Hold[b], c}]]; + (* Print["block2\n", block2, "\nendblock2\n"]; *) + + (* Temporaries, including a fake declaration for them *) + temps1 = Most[block2] //. Hold[lhs_ = rhs_] -> CAssign[CDeclare[lhs], rhs]; + (* Print["temps1\n", temps1, "\nendtemps1\n"]; *) + + (* Expression *) + stmts1 = ReleaseHold[Last[block2]]; + (* Print["stmts1\n", stmts1, "\nendstmts1\n"]; *) + + (* Turn CSequence back into a list *) + stmts2 = Flatten[{stmts1} //. CSequence[a_,b_] -> {a,b}]; + (* Print["stmts2\n", stmts2, "\nendstmts2\n"]; *) + + (* Combine temporaries and expression *) + stmts3 = Join[temps1, stmts2]; + (* Print["stmts3\n", stmts3, "\nendstmts3\n"]; *) + + (* Replace the internal names of the newly generated temporaries + with legal C names *) + replacevar = + Rule @@@ Transpose[{(*ToString[CForm[#]] & /@*) locals, + Symbol[ + StringReplace[StringReplace[ToString[#], {__ ~~ "`" ~~ a_ :> a}], + "$" -> "T"]] & /@ locals}]; + (* Print["replacevar\n", replacevar, "\nendreplacevar\n"]; *) + + stmts4 = stmts3 //. replacevar; + (* Print["stmts4\n", stmts4, "\nendstmts4\n"]; *) + + (* Sort statements topologically *) +(* + stmts5 = stmts4; +*) + (* Print["A\n"]; *) + stmts5 = + Module[{debug, + tmpVars, newVars, i, + stmtsLeft, stmtsDone, + lhs, rhs, any, contains, containsAny, + canDoStmts, cannotDoStmts, + selfStmts, selfVars, allVars, nonSelfVars}, + debug = False; + stmtsLeft = stmts4; + (* Print["B\n"]; *) + stmtsDone = {}; + (* Print["C\n"]; *) + (* lhs[x_] := x[[1]]; *) + lhs[x_] := x /. (CAssign[lhs_, rhs_] -> lhs); + (* Print["D\n"]; *) + (* rhs[x_] := x[[2]]; *) + rhs[x_] := x /. (CAssign[lhs_, rhs_] -> rhs); + (* Print["E\n"]; *) + (* any[xs_] := Fold[Or, False, xs]; *) + any[xs_] := MemberQ[xs, True]; + (* Print["F\n"]; *) + (* contains[e_, x_] := (e /. x -> {}) =!= e; *) + (* contains[e_, x_] := Count[{e}, x, Infinity] > 0; *) + contains[e_, x_] := MemberQ[{e}, x, Infinity]; + (* Print["G\n"]; *) + containsAny[e_, xs_] := any[Map[contains[e,#]&, xs]]; + (* Print["H\n"]; *) + getVars[stmts_] := Map[lhs, stmts] //. (CDeclare[lhs_] -> lhs); + + (* Rename temporary variables deterministically *) + tmpVars = Select[getVars[stmtsLeft], + StringMatchQ[ToString[#], "TT"~~__]&]; + newVars = Table[Symbol["T"<>ToString[1000000+i]], + {i, 1, Length[tmpVars]}]; + stmtsLeft = stmtsLeft /. MapThread[(#1->#2)&, {tmpVars, newVars}]; + + allVars = getVars[stmtsLeft]; + While[stmtsLeft =!= {}, + If[debug, Print["stmtsLeft = \n", stmtsLeft]]; + If[debug, Print["stmtsDone = \n", stmtsDone]]; + allVars = getVars[stmtsLeft]; + If[debug, Print["allVars = \n", allVars]]; + canDoStmts = + Select[stmtsLeft, Not[containsAny[rhs[#], allVars]] &]; + cannotDoStmts = + Select[stmtsLeft, containsAny[rhs[#], allVars] &]; + If[debug, Print["canDoStmts = \n", canDoStmts]]; + If[debug, Print["cannotDoStmts = \n", cannotDoStmts]]; + If[False && canDoStmts == {}, + (* Handle assignment where LHS and RHS access the same variables + (hopefully without taking derivatives!) *) + selfStmts = Select[stmtsLeft, contains[rhs[#], lhs[#]]]; + selfVars = getVars[selfStmts]; + nonSelfVars = Select[allVars, Not[contains[selfVars, #]] &]; + canDoStmts = + Select[stmtsLeft, Not[containsAny[rhs[#], nonSelfVars]] &]; + cannotDoStmts = + Select[stmtsLeft, containsAny[rhs[#], nonSelfVars] &]; + If[debug, Print["nonself/canDoStmts = \n", canDoStmts]]; + If[debug, Print["nonself/cannotDoStmts = \n", cannotDoStmts]]; + ]; + If[canDoStmts == {}, + (* Accept the first statement *) + canDoStmts = {First[stmtsLeft]}; + cannotDoStmts = Rest[stmtsLeft]; + If[debug, Print["takeone/canDoStmts = \n", canDoStmts]]; + If[debug, Print["takeone/cannotDoStmts = \n", cannotDoStmts]]; + ]; + If[canDoStmts == {}, ThrowError["canDoStmts == {}"]]; + stmtsDone = Join[stmtsDone, canDoStmts]; + (* Print["I\n"]; *) + stmtsLeft = cannotDoStmts; + (* Print["J\n"]; *) + ]; + If[debug, Print["stmtsLeft\n", stmtsLeft]]; + If[debug, Print["stmtsDone\n", stmtsDone]]; + stmtsDone]; + (* Print["Z\n"]; *) + + (* Turn CAssign statements back into (->) tuples *) + stmts6 = stmts5 //. CAssign[lhs_,rhs_] -> (lhs -> rhs); + (* Print["stmts6\n", stmts6, "\nendstmts6\n"]; *) + + (* Turn CDeclare statements into "faked" declarations *) + stmts7 = stmts6 + //. CDeclare[var_] + :> "CCTK_REAL const " <> + StringReplace[ToString[var], __ ~~ "`" -> ""]; + (* Print["stmts7\n", stmts7, "\nendstmts7\n"]; *) + + stmts7 + ] +]; + + +End[]; + +EndPackage[]; |