aboutsummaryrefslogtreecommitdiff
path: root/Tools/CodeGen
diff options
context:
space:
mode:
authorianhin <ianhin>2006-02-23 04:28:02 +0000
committerianhin <ianhin>2006-02-23 04:28:02 +0000
commitfc4e962ea23b96a19bc792e102975f5f275314a2 (patch)
treeaa75a4c94d103ffc4b8f692e16c20e716b3345f1 /Tools/CodeGen
parentb6f018e362bd120168ed99b69c344842bce5b32b (diff)
Added feature so that the user can loop over more points in a
calculation than the stencil_width will allow. This is useful if you know that your difference operators are actually safe to use there. No longer "clean" the calculation five times in a row. In fact, no longer clean it at all. What was this for? It looks like it was for removing unused shorthands from a calculation. But only those shorthands which are used are declared anyway, so why is this necessary? The logic used in cleancalc doesn't find the shorthands which are used inside derivative operators, and this was stopping my code from working. Only precompute those derivatives that do not contain shorthands. If a derivative contains a shorthand, the shorthands needs to be computed before the derivative, so precomputation is not really possible.
Diffstat (limited to 'Tools/CodeGen')
-rw-r--r--Tools/CodeGen/CalculationFunction.m129
1 files changed, 66 insertions, 63 deletions
diff --git a/Tools/CodeGen/CalculationFunction.m b/Tools/CodeGen/CalculationFunction.m
index e324c86..d8818de 100644
--- a/Tools/CodeGen/CalculationFunction.m
+++ b/Tools/CodeGen/CalculationFunction.m
@@ -24,7 +24,7 @@ BeginPackage["sym`"];
{GridFunctions, Shorthands, Equations, t, DeclarationIncludes,
LoopPreIncludes, GroupImplementations, PartialDerivatives, Dplus1,
-Dplus2, Dplus3, Boundary, Interior, Where}
+Dplus2, Dplus3, Boundary, Interior, Where, AddToStencilWidth}
{INV, SQR, CUB, QAD, dot, pow, exp}
@@ -62,24 +62,20 @@ removeRHS[x_] := Module[{string = ToString[x]},
(* collect and simplify terms *)
simpCollect[collectList_, eqrhs_, localvar_, debug_] :=
Module[{rhs, collectCoeff, all, localCollectList},
-
- If[debug,
- Print[];
- Print[localvar];
- ];
+ InfoMessage[Full, localvar];
rhs = eqrhs;
rhs = rhs /. Abs[MathTensor`Detg] -> MathTensor`Detg;
- If[debug, Print["ByteCount[rhs]: ", ByteCount@rhs];];
+ InfoMessage[Full, "ByteCount[rhs]: ", ByteCount@rhs];
localCollectList = collectList /. VAR :> removeRHS@localvar;
collectCoeff = Collect[rhs, localCollectList];
- If[debug, Print["ByteCount[terms collected]: ", ByteCount@collectCoeff];];
+ InfoMessage[Full, "ByteCount[terms collected]: ", ByteCount@collectCoeff];
all = Collect[rhs, localCollectList, Simplify];
- If[debug, Print["ByteCount[simplified rhs]: ", ByteCount@all];];
+ InfoMessage[Full, "ByteCount[simplified rhs]: ", ByteCount@all];
all
];
@@ -111,7 +107,6 @@ hideDerivatives[x_] :=
unhide = Map[Unique[] -> # &, derivatives];
hide = invertMap[unhide];
-(* Print["Expression with derivatives hidden: ", x /. hide]; *)
{x /. hide, unhide}];
(* Apply the map (list of rules) to the expression x, but avoid replacing
@@ -126,15 +121,7 @@ replaceWithDerivativesHidden[x_, map_] :=
replaceDerivatives[x_, derivRules_] :=
Module[{},
replaceStandard = (d_ ? (MemberQ[derivativeHeads, #] &)[f_] -> d[f,Symbol["i"],Symbol["j"],Symbol["k"]]);
-
replaceCustom = Flatten[derivRules,1];
-(* Print["derivRules == ", derivRules//FullForm];
- Print["replaceCustom == ", replaceCustom];*)
-
-(* If[Length[derivRules] != 0,
- Print["Before replace: ", x];
- Print["After replace: ",x /. replaceCustom]];*)
-
x /. replaceStandard];
(* Return a CodeGen block which assigns dest by evaluating expr *)
@@ -163,7 +150,8 @@ assignVariableFromExpression[dest_, expr_] := Module[{tSym, cleanExpr, code},
code = StringReplace[code, "==" -> " = "];
code = StringReplace[code, "BesselJ"-> "gsl_sf_bessel_Jn"];
code = StringReplace[code, ToString@tSym -> "cctk_time"];
-
+ code = StringReplace[code, "\"" -> ""];
+
{code}];
(* This flag determines whether you want to generate debugging code to
@@ -197,8 +185,6 @@ declareVariablesForCalculation[calc_] :=
(* Derivative precomputation *)
oldDerivativesUsed[x_] :=
-(* Print["Possible derivatives (new): ", Map[PDsFromDefinition, pddefs]];*)
-
Union[Cases[x, _ ? (MemberQ[derivativeHeads, #] &)[_],Infinity]];
(* Expects a list of the form {D11[h22], ...} *)
@@ -225,7 +211,7 @@ printEq[eq_] :=
rhsString = ToString@CForm[rhsSplit[[1]]] <> rhsSplit[[2]];
- Print[" " <> ToString@lhs <> " = " <> rhsString]];
+ InfoMessage[Full, " " <> ToString@lhs <> " -> " <> rhsString]];
(* Return the names of any gridfunctions used in the calculation *)
calculationUsedGFs[calc_] :=
@@ -272,7 +258,7 @@ calculationSymbolsLHS[calc_] :=
calculationSymbolsRHS[calc_] :=
Module[{allAtoms},
- allAtoms = Union[Map[Last, Flatten@lookup[calc, Equations] ]];
+ allAtoms = Union[Map[Last, Flatten@{lookup[calc, Equations], lookup[calc,PartialDerivatives]} ]];
allAtoms = Union[Level[allAtoms, {-1}]];
Cases[allAtoms, x_Symbol]];
@@ -325,7 +311,7 @@ cleanCalculation[calc_] := Module[
shorthands = calculationUsedShorthands[calc];
- Print["Deleted unused shorthands: ",
+ InfoMessage[Info, "Deleted unused shorthands: ",
Complement[lookupDefault[calc, Shorthands, {}], shorthands]];
assignedGFs = calculationUsedGFsLHS[calc];
@@ -386,23 +372,27 @@ GrepSyncGroups[x_, func_] := Module[{pick},
CreateCalculationFunction[calc_, debug_] :=
Module[{gfs, allSymbols, knownSymbols,
shorts, eqs, syncGroups, parameters,
- functionName, dsUsed, groups, pddefs, cleancalc, numeq, eqLoop, GrepSYNC, where},
+ functionName, dsUsed, groups, pddefs, cleancalc, numeq, eqLoop, GrepSYNC, where, addToStencilWidth},
- cleancalc = cleanCalculation[calc];
- cleancalc = cleanCalculation[cleancalc];
+(* cleancalc = cleanCalculation[calc];
cleancalc = cleanCalculation[cleancalc];
cleancalc = cleanCalculation[cleancalc];
cleancalc = cleanCalculation[cleancalc];
+ cleancalc = cleanCalculation[cleancalc]; *)
+
+ cleancalc = calc;
+ shorts = lookupDefault[cleancalc, Shorthands, {}];
+ eqs = lookup[cleancalc, Equations];
+ syncGroups = lookupDefault[cleancalc, SyncGroups, {}];
+ parameters = lookupDefault[cleancalc, Parameters, {}];
+ groups = lookup[cleancalc, Groups];
+ pddefs = lookupDefault[cleancalc, PartialDerivatives, {}];
+ where = lookupDefault[cleancalc, Where, Everywhere];
+ addToStencilWidth = lookupDefault[cleancalc, AddToStencilWidth, 0];
+ numeq = Length@eqs;
- shorts = lookupDefault[cleancalc, Shorthands, {}];
- eqs = lookup[cleancalc, Equations];
- syncGroups = lookupDefault[cleancalc, SyncGroups, {}];
- parameters = lookupDefault[cleancalc, Parameters, {}];
- groups = lookup[cleancalc, Groups];
- pddefs = lookupDefault[cleancalc, PartialDerivatives, {}];
- where = lookupDefault[cleancalc, Where, Everywhere];
- Print["number of equations in calculation: ", numeq = Length@eqs];
+ InfoMessage[Full, "number of equations in calculation: ", numeq];
VerifyCalculation[cleancalc];
@@ -410,20 +400,15 @@ CreateCalculationFunction[calc_, debug_] :=
functionName = ToString@lookup[cleancalc, Name];
dsUsed = oldDerivativesUsed[eqs];
- Print["Creating Calculation Function: " <> functionName];
+ InfoMessage[Terse, "Creating calculation function: " <> functionName];
- Print[" ", Length@shorts, " shorthands / ",
- Length@gfs, " grid functions / ",
- Length@groups, " groups"];
+ InfoMessage[Full, " ", Length@shorts, " shorthands"];
+ InfoMessage[Full, " ", Length@gfs, " grid functions"];
+ InfoMessage[Full, " ", Length@groups, " groups"];
- Print[];
-
- Print[" shorthands:"];
- Print[" ", shorts];
- Print[" groups:"];
- Print[" ", Map[groupName, groups]];
-
- If[debug, Print[" grid functions:", gfs]];
+ InfoMessage[Full, "Shorthands: ", shorts];
+ InfoMessage[Full, "Grid functions: ", gfs];
+ InfoMessage[Full, "Groups: ", Map[groupName, groups]];
If[Length@lookupDefault[cleancalc, CollectList, {}] > 0,
@@ -433,15 +418,13 @@ CreateCalculationFunction[calc_, debug_] :=
eqs[[i]] ], {i, 1, Length@eqs}]
];
- Print["\n\nEquations:"];
+ InfoMessage[Full, "Equations:"];
Map[Map[printEq, #]&, eqs];
- Print[];
(* Check all the function names *)
functionsPresent = functionsInCalculation[cleancalc];
-(* Print["Functions in calculation: ", functionsPresent];*)
(* FIXME: Sascha does not understand the next lines and commented
it out in order to avod problems with using the exp function in BSSN *)
@@ -467,9 +450,7 @@ CreateCalculationFunction[calc_, debug_] :=
If[unknownSymbols != {},
Module[{},
- Print["Unknown symbols in calculation: ", unknownSymbols];
- Print["Failed verification of calculation: ", cleancalc];
- Throw["Unknown symbols in calculation"]]];
+ ThrowError["Unknown symbols in calculation. Symbols are:", unknownSymbols, "Calculation is:", cleancalc]]];
DefineCCTKSubroutine[lookup[cleancalc, Name],
{ DeclareGridLoopVariables[],
@@ -495,17 +476,17 @@ CreateCalculationFunction[calc_, debug_] :=
(* Have removed ability to include external header files here.
Can be put back when we need it. *)
- eqLoop = Map[equationLoop[#, gfs, shorts, {}, groups, syncGroups, pddefs, where] &, eqs]};
+ eqLoop = Map[equationLoop[#, gfs, shorts, {}, groups, syncGroups, pddefs, where, addToStencilWidth] &, eqs]};
(* search for SYNCs *)
If[numeq <= 1,
GrepSYNC = GrepSyncGroups[eqLoop],
GrepSYNC = {};
eqLoop = UncommentSourceSync[eqLoop];
- Print["> 1 loop in thorn -> scheduling in source code, incompatible with Multipatch!"];
+ InfoMessage[Warning, "> 1 loop in thorn -> scheduling in source code, incompatible with Multipatch!"];
];
- Print["grepSync from eqLoop: ",GrepSyncGroups[eqLoop] ];
+ InfoMessage[Full, "grepSync from eqLoop: ",GrepSyncGroups[eqLoop]];
InsertSyncFuncName[eqLoop, lookup[cleancalc, Name]],
{}]}]];
@@ -545,7 +526,7 @@ checkShorthandAssignmentOrder[eqs_, shorthand_] :=
assignments = Position[lhss, shorthand];
If[Length[uses] == 0 && Length[assignments] >= 1,
- Print["WARNING: Shorthand ",shorthand," is defined but not used in this equation list."]];
+ InfoMessage[Warning, "WARNING: Shorthand ", shorthand, " is defined but not used in this equation list."]];
If[Length[uses] == 0, Return[]];
@@ -553,7 +534,7 @@ checkShorthandAssignmentOrder[eqs_, shorthand_] :=
firstUse = First[uses];
If[Length[assignments] > 1,
- Print["WARNING: Shorthand ", shorthand, " is defined more than once."]];
+ InfoMessage[Warning, "WARNING: Shorthand ", shorthand, " is defined more than once."]];
If[Length[assignments] == 0,
ThrowError["Shorthand", shorthand, "is not defined in this equation list", eqs]];
@@ -561,9 +542,27 @@ checkShorthandAssignmentOrder[eqs_, shorthand_] :=
If[assignments[[1]] >= firstUse,
ThrowError["Shorthand", shorthand, "is used before it is defined in this equation list", eqs]]];
+defContainsShorthand[def_, shorthands_] :=
+Module[{allAtoms, c},
+ allAtoms = Union[Level[def, {-1}]];
+ c = Intersection[shorthands, allAtoms];
+ c != {}];
+
+
+(* Split the list of partial derivative definitions into those
+containing shorthands, and those that do not. *)
+
+splitPDDefsWithShorthands[pddefs_, shorthands_] :=
+ Module[{defsWithShorts, defsWithoutShorts},
+ defsWithShorts = Select[pddefs, defContainsShorthand[#, shorthands] &];
+ defsWithoutShorts = Select[pddefs, ! defContainsShorthand[#, shorthands] &];
+ Return[{defsWithoutShorts, defsWithShorts}]];
-equationLoop[eqs_, gfs_, shorts_, incs_, groups_, syncGroups_, pddefs_, where_] :=
+
+
+
+equationLoop[eqs_, gfs_, shorts_, incs_, groups_, syncGroups_, pddefs_, where_, addToStencilWidth_] :=
Module[{rhss, lhss, gfsInRHS, gfsInLHS, localGFs, localMap, eqs2,
derivSwitch, actualSyncGroups, code, syncCode, loopFunction},
@@ -580,11 +579,15 @@ equationLoop[eqs_, gfs_, shorts_, incs_, groups_, syncGroups_, pddefs_, where_]
(* Replace the partial derivatives *)
+ {defsWithoutShorts, defsWithShorts} = splitPDDefsWithShorthands[pddefs, shorts];
+
+
(* This is for the custom derivative operators pddefs *)
- eqs2 = ReplaceDerivatives[pddefs, eqs];
+ eqs2 = ReplaceDerivatives[defsWithoutShorts, eqs, True];
+ eqs2 = ReplaceDerivatives[defsWithShorts, eqs2, False];
checkEquationAssignmentOrder[eqs2, shorts];
- code = {InitialiseGridLoopVariables[derivSwitch],
+ code = {InitialiseGridLoopVariables[derivSwitch, addToStencilWidth],
loopFunction = Switch[where,
Boundary, BoundaryLoop,
@@ -599,7 +602,7 @@ equationLoop[eqs_, gfs_, shorts_, incs_, groups_, syncGroups_, pddefs_, where_]
Map[IncludeFile, incs]],
CommentedBlock["Precompute derivatives (new style)",
- PrecomputeDerivatives[pddefs, eqs]],
+ PrecomputeDerivatives[defsWithoutShorts, eqs]],
CommentedBlock["Precompute derivatives (old style)",
Map[precomputeDerivative, oldDerivativesUsed[eqs]]],
@@ -628,7 +631,7 @@ equationLoop[eqs_, gfs_, shorts_, incs_, groups_, syncGroups_, pddefs_, where_]
(* If[Not@derivSwitch, actualSyncGroups = {}]; only sync when derivs are taken *)
If[Length@actualSyncGroups > 0,
- Print["Synchronizing groups: ", actualSyncGroups];
+ InfoMessage[Full, "Synchronizing groups: ", actualSyncGroups];
syncCode = Map[syncGroup, actualSyncGroups];
AppendTo[code, CommentedBlock["Synchronize the groups that have just been set", syncCode]];