aboutsummaryrefslogtreecommitdiff
path: root/Tools/CodeGen/CodeGen.m
diff options
context:
space:
mode:
authorErik Schnetter <schnetter@cct.lsu.edu>2009-04-27 11:35:12 -0500
committerIan Hinder <ian.hinder@aei.mpg.de>2009-04-27 21:57:44 +0200
commita7763387a3e8a4cbbdf91146886ed6b43c56e074 (patch)
treec2a0ee20105de5d5a8dc59528ef5f8f42c47b01f /Tools/CodeGen/CodeGen.m
parent7caeabef4d1b9b1fe391ec1b5985f4f9fd21aa8f (diff)
Add CSE (Common Subexpression Elimination)
Diffstat (limited to 'Tools/CodeGen/CodeGen.m')
-rw-r--r--Tools/CodeGen/CodeGen.m222
1 files changed, 220 insertions, 2 deletions
diff --git a/Tools/CodeGen/CodeGen.m b/Tools/CodeGen/CodeGen.m
index 9418c98..4c0f9e8 100644
--- a/Tools/CodeGen/CodeGen.m
+++ b/Tools/CodeGen/CodeGen.m
@@ -112,6 +112,7 @@ CommaNewlineSeparated::usage = ""; (* This should not really be in CodeGen *)
CommaSeparated::usage = "";
ReplacePowers::usage = "";
CFormHideStrings::usage = "";
+CSE::usage = "";
BoundaryLoop::usage = "";
BoundaryWithGhostsLoop::usage = "";
GenericGridLoop::usage = "";
@@ -270,6 +271,50 @@ AssignVariableInLoop[dest_, src_] :=
TestForNaN[dest]};
*)
+(* TODO: move these into OpenMP loop *)
+DeclareVariablesInLoopVectorised[dests_, temps_, srcs_] :=
+ {
+ {"#undef LC_PRELOOP_STATEMENTS", "\n"},
+ {"#define LC_PRELOOP_STATEMENTS", " \\\n"},
+ {"int const GFD_imin = lc_imin + ((lc_imin + cctk_lsh[0] * (j + cctk_lsh[1] * k)) & (CCTK_REAL_VEC_SIZE-1))", "; \\\n"},
+ {"int const GFD_imax = lc_imax + ((lc_imax + cctk_lsh[0] * (j + cctk_lsh[1] * k)) & (CCTK_REAL_VEC_SIZE-1)) - CCTK_REAL_VEC_SIZE", "; \\\n"},
+ Map[Function[x, Module[{dest, temp, src},
+ {dest, temp, src} = x;
+ {"CCTK_REAL_VEC ", temp, "; \\\n"}]],
+ Transpose[{dests, temps, srcs}]],
+ {"\n"}
+ };
+
+AssignVariablesInLoopVectorised[dests_, temps_, srcs_] :=
+ {
+ {"{\n"},
+ {" if (i < GFD_imin || i >= GFD_imax) {\n"},
+ Map[Function[x, Module[{dest, temp, src},
+ {dest, temp, src} = x;
+ {" ", dest, "[index] = ", src, EOL[]}]],
+ Transpose[{dests, temps, srcs}]],
+ {" } else {\n"},
+ {" size_t const index0 = index & (CCTK_REAL_VEC_SIZE-1)", EOL[]},
+ Map[Function[x, Module[{dest, temp, src},
+ {dest, temp, src} = x;
+ {" ((CCTK_REAL*)&", temp, ")[index0] = ",
+ src, EOL[]}]],
+ Transpose[{dests, temps, srcs}]],
+ {" if (index0 == CCTK_REAL_VEC_SIZE-1) {\n"},
+ {" size_t const index1 = index - (CCTK_REAL_VEC_SIZE-1)", EOL[]},
+ Map[Function[x, Module[{dest, temp, src},
+ {dest, temp, src} = x;
+ {" _mm_stream_pd (&", dest, "[index1], ",
+ temp, ")", EOL[]}]],
+ Transpose[{dests, temps, srcs}]],
+ {" }\n"},
+ {" }\n"},
+ {"}\n"}
+ };
+
+AssignVariableInLoopsVectorised[dest_, temp_, src_] :=
+ {"GFD_save_and_store(", dest, ",", "index", ",", "&", temp, ",", src, ")", EOL[]};
+
TestForNaN[expr_] :=
{"if (isnan(", expr, ")) {\n",
" CCTK_VInfo(CCTK_THORNSTRING, \"NaN found\");\n",
@@ -771,9 +816,182 @@ ReplacePowers[x_] :=
strings present *)
CFormHideStrings[x_, opts___] := StringReplace[ToString[CForm[x,opts]], "\"" -> ""];
-End[];
-EndPackage[];
+(* Eliminate common subexpressions in a code sequence *)
+CSE[code_] := Module[
+ {expr, optexpr,
+ decomposed, locals, block,
+ block1, block2, temps1, stmts1, stmts2, stmts3,
+ replacevar,
+ stmts4,
+ stmts5, stmts6, stmts7},
+ (* Print["code\n", code, "\nendcode\n"]; *)
+
+ (* The code is passed in as list of {lhs,rhs} tuples. Turn this
+ list into a single expression, so that it can be optimised. *)
+ expr = code //. {a_, b__} -> CSequence[a, {b}]
+ //. {a_} -> a
+ //. (a_ -> b_) -> CAssign[a, b];
+ (* Print["expr\n", expr, "\nendexpr\n"]; *)
+
+ (* Optimise this expression *)
+ optexpr = Experimental`OptimizeExpression[expr];
+ (* Print["optexpr\n", optexpr, "\nendoptexpr\n"]; *)
+
+ (* This expression is a Mathematica expression. Decompose it into
+ the set of newly introduced local variables and the optimised
+ expression itself. *)
+ decomposed =
+ ReleaseHold[(Hold @@ optexpr)
+ /. Verbatim[Block][vars_, seq_] :> {vars, Hold[seq]}];
+
+ If[decomposed[[0]] =!= List,
+ (* If the optimiser didn't create a Block expression, we assume it
+ didn't do anything useful and return the original. *)
+ code,
+
+ {locals, block} = decomposed;
+ (* Print["locals\n", locals, "\nendlocals\n"]; *)
+ (* Print["block\n", block, "\nendblock\n"]; *)
+
+ block1 = block /. Hold[CompoundExpression[seq__]] :> Hold[{seq}];
+ (* Print["block1\n", block1, "\nendblock1\n"]; *)
+ block2 = First[block1 //. Hold[{a___Hold, b_, c___}]
+ /; Head[Unevaluated[b]] =!= Hold
+ :> Hold[{a, Hold[b], c}]];
+ (* Print["block2\n", block2, "\nendblock2\n"]; *)
+
+ (* Temporaries, including a fake declaration for them *)
+ temps1 = Most[block2] //. Hold[lhs_ = rhs_] -> CAssign[CDeclare[lhs], rhs];
+ (* Print["temps1\n", temps1, "\nendtemps1\n"]; *)
+
+ (* Expression *)
+ stmts1 = ReleaseHold[Last[block2]];
+ (* Print["stmts1\n", stmts1, "\nendstmts1\n"]; *)
+
+ (* Turn CSequence back into a list *)
+ stmts2 = Flatten[{stmts1} //. CSequence[a_,b_] -> {a,b}];
+ (* Print["stmts2\n", stmts2, "\nendstmts2\n"]; *)
+
+ (* Combine temporaries and expression *)
+ stmts3 = Join[temps1, stmts2];
+ (* Print["stmts3\n", stmts3, "\nendstmts3\n"]; *)
+
+ (* Replace the internal names of the newly generated temporaries
+ with legal C names *)
+ replacevar =
+ Rule @@@ Transpose[{(*ToString[CForm[#]] & /@*) locals,
+ Symbol[
+ StringReplace[StringReplace[ToString[#], {__ ~~ "`" ~~ a_ :> a}],
+ "$" -> "T"]] & /@ locals}];
+ (* Print["replacevar\n", replacevar, "\nendreplacevar\n"]; *)
+
+ stmts4 = stmts3 //. replacevar;
+ (* Print["stmts4\n", stmts4, "\nendstmts4\n"]; *)
+
+ (* Sort statements topologically *)
+(*
+ stmts5 = stmts4;
+*)
+ (* Print["A\n"]; *)
+ stmts5 =
+ Module[{debug,
+ tmpVars, newVars, i,
+ stmtsLeft, stmtsDone,
+ lhs, rhs, any, contains, containsAny,
+ canDoStmts, cannotDoStmts,
+ selfStmts, selfVars, allVars, nonSelfVars},
+ debug = False;
+ stmtsLeft = stmts4;
+ (* Print["B\n"]; *)
+ stmtsDone = {};
+ (* Print["C\n"]; *)
+ (* lhs[x_] := x[[1]]; *)
+ lhs[x_] := x /. (CAssign[lhs_, rhs_] -> lhs);
+ (* Print["D\n"]; *)
+ (* rhs[x_] := x[[2]]; *)
+ rhs[x_] := x /. (CAssign[lhs_, rhs_] -> rhs);
+ (* Print["E\n"]; *)
+ (* any[xs_] := Fold[Or, False, xs]; *)
+ any[xs_] := MemberQ[xs, True];
+ (* Print["F\n"]; *)
+ (* contains[e_, x_] := (e /. x -> {}) =!= e; *)
+ (* contains[e_, x_] := Count[{e}, x, Infinity] > 0; *)
+ contains[e_, x_] := MemberQ[{e}, x, Infinity];
+ (* Print["G\n"]; *)
+ containsAny[e_, xs_] := any[Map[contains[e,#]&, xs]];
+ (* Print["H\n"]; *)
+ getVars[stmts_] := Map[lhs, stmts] //. (CDeclare[lhs_] -> lhs);
+
+ (* Rename temporary variables deterministically *)
+ tmpVars = Select[getVars[stmtsLeft],
+ StringMatchQ[ToString[#], "TT"~~__]&];
+ newVars = Table[Symbol["T"<>ToString[1000000+i]],
+ {i, 1, Length[tmpVars]}];
+ stmtsLeft = stmtsLeft /. MapThread[(#1->#2)&, {tmpVars, newVars}];
+
+ allVars = getVars[stmtsLeft];
+ While[stmtsLeft =!= {},
+ If[debug, Print["stmtsLeft = \n", stmtsLeft]];
+ If[debug, Print["stmtsDone = \n", stmtsDone]];
+ allVars = getVars[stmtsLeft];
+ If[debug, Print["allVars = \n", allVars]];
+ canDoStmts =
+ Select[stmtsLeft, Not[containsAny[rhs[#], allVars]] &];
+ cannotDoStmts =
+ Select[stmtsLeft, containsAny[rhs[#], allVars] &];
+ If[debug, Print["canDoStmts = \n", canDoStmts]];
+ If[debug, Print["cannotDoStmts = \n", cannotDoStmts]];
+ If[False && canDoStmts == {},
+ (* Handle assignment where LHS and RHS access the same variables
+ (hopefully without taking derivatives!) *)
+ selfStmts = Select[stmtsLeft, contains[rhs[#], lhs[#]]];
+ selfVars = getVars[selfStmts];
+ nonSelfVars = Select[allVars, Not[contains[selfVars, #]] &];
+ canDoStmts =
+ Select[stmtsLeft, Not[containsAny[rhs[#], nonSelfVars]] &];
+ cannotDoStmts =
+ Select[stmtsLeft, containsAny[rhs[#], nonSelfVars] &];
+ If[debug, Print["nonself/canDoStmts = \n", canDoStmts]];
+ If[debug, Print["nonself/cannotDoStmts = \n", cannotDoStmts]];
+ ];
+ If[canDoStmts == {},
+ (* Accept the first statement *)
+ canDoStmts = {First[stmtsLeft]};
+ cannotDoStmts = Rest[stmtsLeft];
+ If[debug, Print["takeone/canDoStmts = \n", canDoStmts]];
+ If[debug, Print["takeone/cannotDoStmts = \n", cannotDoStmts]];
+ ];
+ If[canDoStmts == {}, ThrowError["canDoStmts == {}"]];
+ stmtsDone = Join[stmtsDone, canDoStmts];
+ (* Print["I\n"]; *)
+ stmtsLeft = cannotDoStmts;
+ (* Print["J\n"]; *)
+ ];
+ If[debug, Print["stmtsLeft\n", stmtsLeft]];
+ If[debug, Print["stmtsDone\n", stmtsDone]];
+ stmtsDone];
+ (* Print["Z\n"]; *)
+
+ (* Turn CAssign statements back into (->) tuples *)
+ stmts6 = stmts5 //. CAssign[lhs_,rhs_] -> (lhs -> rhs);
+ (* Print["stmts6\n", stmts6, "\nendstmts6\n"]; *)
+
+ (* Turn CDeclare statements into "faked" declarations *)
+ stmts7 = stmts6
+ //. CDeclare[var_]
+ :> "CCTK_REAL const " <>
+ StringReplace[ToString[var], __ ~~ "`" -> ""];
+ (* Print["stmts7\n", stmts7, "\nendstmts7\n"]; *)
+
+ stmts7
+ ]
+];
+
+
+End[];
+
+EndPackage[];