aboutsummaryrefslogtreecommitdiff
path: root/Tools/CodeGen/Optimize.m
diff options
context:
space:
mode:
Diffstat (limited to 'Tools/CodeGen/Optimize.m')
-rw-r--r--Tools/CodeGen/Optimize.m190
1 files changed, 190 insertions, 0 deletions
diff --git a/Tools/CodeGen/Optimize.m b/Tools/CodeGen/Optimize.m
new file mode 100644
index 0000000..2f791da
--- /dev/null
+++ b/Tools/CodeGen/Optimize.m
@@ -0,0 +1,190 @@
+(* ::Package:: *)
+
+(* Copyright 2011 Barry Wardell
+
+ This file is part of Kranc.
+
+ Kranc is free software; you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation; either version 2 of the License, or
+ (at your option) any later version.
+
+ Kranc is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Kranc; if not, write to the Free Software
+ Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+*)
+
+BeginPackage["Optimize`", {"Kranc`", "Errors`"}];
+
+EliminateCommonSubexpressions::usage = "EliminateCommonSubexpressions[calc] identifies common subexpressions in calc and introduces new shorthands for them.";
+TopologicallySortEquations::usage = "TopologicallySortEquations[eqs, v] sorts eqs topologically, including Head v as a possible vertex."
+
+Begin["`Private`"];
+
+CSEPrint[___] = null;
+(* CSEPrint = Print; *)
+
+Options[EliminateCommonSubexpressions] = ThornOptions;
+EliminateCommonSubexpressions[calc_List, OptionsPattern[]] :=
+ Module[{eqs, shorts, name, pdDefs, derivs, newShorts, newEqs, allShorts, newCalc},
+ name = (Name /. calc);
+
+ InfoMessage[InfoFull, "Doing common subexpression elimination for "<>name];
+
+ eqs = (Equations /. calc) /. Equations -> {};
+ shorts = (Shorthands /. calc) /. Shorthands -> {};
+
+ (* Get a list of symbols used for derivatives. We will not eliminate these as subexpressions. *)
+ pdDefs = OptionValue[PartialDerivatives];
+ derivs = DeleteDuplicates[Head/@(First/@pdDefs)];
+
+ (* Generate new equations with subexpressions eliminated. *)
+ {newShorts, newEqs} = cse[eqs, Symbol["csetemp"], derivs];
+
+ If[Length[newShorts]>0,
+ InfoMessage[Info, "Extracted "<>ToString[Length[newShorts]]<>" common subexpressions from "<>name];
+ ];
+
+ allShorts = Join[shorts, newShorts];
+ newCalc = Join[calc /. {(Shorthands->_) -> Sequence[], (Equations->_) -> Sequence[]},
+ {Shorthands->allShorts}, {Equations->newEqs}];
+
+ newCalc
+];
+
+cse[eqs_, v_, exceptions_, minSaving_:0] :=
+ Module[{subexprs, replacements, replace, newEqs, defs, newDefs, i, relabelVars, allEqs, sortedEqs, newVars},
+ (* Find all possible subexpressions and how many times they occur *)
+ CSEPrint["CSE"];
+ CSEPrint["CSE: eqs=", eqs];
+ subexprs = Reap[Scan[If[! AtomQ[#], Sow[#]] &, eqs[[All,2]], Infinity]];
+ CSEPrint["CSE: subexprs=", subexprs];
+ If[subexprs[[2]]=={}, Return[{{}, eqs}]];
+ subexprs = Tally[subexprs[[2, 1]]];
+
+ (* Discard subexpressions which only appear once *)
+ subexprs = Select[subexprs, #[[2]] >= 2 &];
+
+ (* Sort subexpressions in ascending order by size (LeafCount) *)
+ subexprs = Sort[subexprs, LeafCount[#1] < LeafCount[#2] &];
+
+ (* Ony keep subexpressions larger than minSaving=(numoccurances-1)*size *)
+ subexprs = Select[subexprs, (#[[2]]-1) LeafCount[#[[1]]] >= minSaving &][[All,1]];
+
+ (* Discard some specific cases *)
+ subexprs = Cases[subexprs, Except[_?AtomQ]];
+ subexprs = Cases[subexprs, Except[Times[-1, _?AtomQ]]]; (* -x *)
+ subexprs = Cases[subexprs, Except[Alternatives@@(Blank/@exceptions)]]; (* specified exceptions *)
+ subexprs = Cases[subexprs, Except[Times[-1, Alternatives@@(Blank/@exceptions)]]]; (* -exceptions *)
+
+ (* Get the list of replacements for our original expression *)
+ replacements = Thread[subexprs -> Table[v[i], {i, Length[subexprs]}]];
+
+ (* Replace common subexpressions with new variables *)
+ (* Do not replace certain terms, e.g. the first argument of IfThen. *)
+ (* newEqs = eqs //. replacements; *)
+ CSEPrint["CSE: eqs=", eqs];
+ CSEPrint["CSE: replacements=", replacements];
+ replace[expr_] := Replace[Switch[expr,
+ IfThen[_,_,_], IfThen[expr[[1]], replace[expr[[2]]], replace[expr[[3]]]],
+ (* ToReal[_], ToReal[expr[[1]]], *)
+ _?AtomQ, expr,
+ _, Map[replace, expr]],
+ replacements];
+ newEqs = FixedPoint[replace, eqs];
+ CSEPrint["CSE: newEqs=", newEqs];
+
+ (* Build up definitions for the new variables *)
+ defs = Reverse/@replacements;
+ CSEPrint["CSE: defs=", defs];
+ For[i = 2, i <= Length[subexprs], i++,
+ defs[[i,2]] = defs[[i,2]] /. replacements[[1;;i-1]];
+ ];
+ CSEPrint["CSE: defs=", defs];
+
+ (* Select only the definitions which are needed for the new expressions.
+ This accounts for cases where a subexpression appears multiple times,
+ but always as part of the same larger subexpression. For example, in
+ expr = Sqrt[(a+b)(a-b)c]+(a+b)(a-b)c+(a+b)d+Sqrt[(a+b)d+(a+b)c];
+ we would identify the subexpressions
+ {v[1]->a+b,v[2]->d v[1],v[3]->a-b,v[4]->c v[1] v[3]};
+ whereas all we really want it to identify is
+ {v[1]->a+b,v[2]->d v[1],v[4]->(a-b) c v[1]};
+ and the introduction of v[3] is unnecessary. To achieve this, we only
+ keep temporary variables which appear in the expression after substitution
+ or which appear more than once in the definition of the temporary variables.
+ *)
+ newDefs = Select[defs, (Count[newEqs, #[[1]], Infinity] > 0) ||
+ (Count[defs[[All,2]], #[[1]], Infinity] > 1) &];
+ CSEPrint["CSE: newDefs=", newDefs];
+
+ (* Replace any temporaries eliminated by the previous procedure with their definition *)
+ newDefs = newDefs //. Complement[defs, newDefs];
+ CSEPrint["CSE: newDefs2=", newDefs];
+
+ (* Check we actually have subexpressions to eliminate. Otherwise just return the original expression *)
+ If[Length[newDefs]==0, Return[{{}, eqs}]];
+
+ (* This is our new system of equations *)
+ allEqs = Join[newDefs, newEqs];
+ CSEPrint["CSE: allEqs=", allEqs];
+
+ sortedEqs = Fold[InsertNewEquation, newEqs, Reverse[newDefs]];
+ CSEPrint["CSE: sortedEqs=", sortedEqs];
+
+ (* Relabel new temporary variables so that they are sequential and C friendly *)
+ newVars = Select[sortedEqs[[All,1]], MemberQ[newDefs[[All, 1]],#]&];
+ CSEPrint["CSE: newVars=", newVars];
+ i = 0;
+ relabelVars = (# -> Symbol[ToString[v] <> ToString[i++]]) & /@ newVars;
+
+ (* Return the list of new variables and the new equations *)
+ {newVars, sortedEqs} /. relabelVars
+];
+
+TopologicallySortEquations[eqs_] := Module[{lhs, rhs, lhsInrhs, dag, sortedVars, indVars, allVars, sortedEqs},
+ lhs = eqs[[All,1]];
+ rhs = eqs[[All,2]];
+
+ (* Generate an directed acyclic graph for the system of equations *)
+ lhsInrhs = DeleteDuplicates[Cases[{#}, _?(MemberQ[lhs, #] &), Infinity]] & /@ rhs;
+ dag = Graph[ Flatten[MapThread[Thread[Rule[#1, #2]] &, {lhsInrhs, lhs}]] ];
+
+ (* Topologically sort the DAG *)
+ sortedVars = Quiet[TopologicalSort[dag], TopologicalSort::argx];
+
+ (* Check if the topological sorting failed. This can happen if the graph for
+ the equations is cyclic. For example, we could have a->a+1 or {b->a*a, a->b} *)
+ If[SameQ[Head[sortedVars], TopologicalSort],
+ InfoMessage[Info, "Failed to topologically sort equations."];
+ Return[$Failed]
+ ];
+
+ (* Some variables might be independent. Add them back in. *)
+ indVars = Complement[lhs, sortedVars];
+ allVars = Join[indVars, sortedVars];
+
+ sortedEqs = Thread[allVars -> (allVars/.eqs)];
+ sortedEqs
+];
+
+InsertNewEquation[oldEqs_, newEq_] := Module[{before},
+ CSEPrint["InsertNewEquation oldEqs=", oldEqs, " newEq=", newEq];
+ (* For some reason, we can be asked to insert an equation that is
+ not actually needed. This should not be the case. However, handle
+ it gracefully for now. *)
+ (* before = Position[oldEqs[[All,2]], newEq[[1]]][[1,1]]; *)
+ before = Position[oldEqs[[All,2]], newEq[[1]]];
+ If[before=={},
+ oldEqs,
+ Insert[oldEqs, newEq, before[[1,1]]]]
+];
+
+End[];
+
+EndPackage[];