aboutsummaryrefslogtreecommitdiff
path: root/Tools/CodeGen/Calculation.m
diff options
context:
space:
mode:
Diffstat (limited to 'Tools/CodeGen/Calculation.m')
-rw-r--r--Tools/CodeGen/Calculation.m35
1 files changed, 35 insertions, 0 deletions
diff --git a/Tools/CodeGen/Calculation.m b/Tools/CodeGen/Calculation.m
index c74c44d..6c929a6 100644
--- a/Tools/CodeGen/Calculation.m
+++ b/Tools/CodeGen/Calculation.m
@@ -30,6 +30,7 @@ GetCalculationParameters;
CalculationStencilSize;
CalculationOnDevice;
GetCalculationWhere;
+SplitCalculations;
Begin["`Private`"];
@@ -101,6 +102,40 @@ DefFn[
CalculationOnDevice[calc_List] :=
lookupDefault[calc, ExecuteOn, Automatic] === Device];
+partialCalculation[calc_, suffix_, updates_, evolVars_] :=
+Module[
+ {name, calc1, replaces, calc2, vars, patterns, eqs, calc3},
+ (* Add suffix to name *)
+ name = lookup[calc, Name] <> suffix;
+ calc1 = mapReplace[calc, Name, name];
+ (* Replace some entries in the calculation *)
+ replaces = updates //. (lhs_ -> rhs_) -> (mapReplace[#, lhs, rhs]&);
+ calc2 = Apply[Composition, replaces][calc1];
+ (* Remove unnecessary equations *)
+ vars = Join[evolVars, lookup[calc2, Shorthands]];
+ patterns = Replace[vars, { Tensor[n_,__] -> Tensor[n,__] ,
+ dot[Tensor[n_,__]] -> dot[Tensor[n,__]]}, 1];
+ eqs = FilterRules[lookup[calc, Equations], patterns];
+ calc3 = mapReplace[calc2, Equations, eqs];
+ calc3
+];
+
+DefFn[
+ SplitCalculations[calcs_List] :=
+ Flatten[SplitCalculation/@calcs,1]];
+
+DefFn[
+ SplitCalculation[calc_] :=
+ Module[
+ {splitBy = lookup[calc,SplitBy, {}]},
+ If[splitBy === {},
+ {calc},
+ Table[partialCalculation[calc,
+ "_"<>StringReplace[ToString[var],{"["->"","]"->"",","->""}],
+ {},
+ {var}]~Join~{CachedVariables -> {(* var[[1]] *)}}, (* This is not general *)
+ {var, splitBy}]]]];
+
End[];
EndPackage[];