# codegen2.maple -- generate C code from Maple expressions # $Header$ # # codegen2 - generate C code from Maple expressions # # cvt_to_eqnlist - convert an expression to an equation list # fix_Diff - convert Diff() calls to PARTIAL_*() calls # fix_Diff/remap_table - remapping table for fix_Diff() # temps_in_eqnlist - find temporaries used by an equation list # is_result - is a name equal to a "result" or an array/table component of it? # deindex_names - remove indices from indexed names # unindex_names - convert indexed names to scalars # fix_rationals - convert numbers to RATIONAL() calls # print_name_list_dcl - print a C declaration for a list of names # ################################################################################ ################################################################################ ################################################################################ # # This function is a high-level driver to generate C code for a Maple # expression. It does the following: # - convert the input into a list of equations of the form name = expression # - rewrite Diff(...) calls to DIFF_RHO_RHO(...) etc calls # - generate an "optimized" form of the computation sequence with # codegen[optimize, tryhard] # - determine which temporary variables have been introduced by the # codegen[optimize] library # - "unindex" all array accesses, eg R_dd[2,3] becomes R_dd_23. # - convert rational numbers to RATIONAL(p,q) calls # - print C declarations for the temporary variables # - print C code for the the optimized computation sequence # # Note that "codegen" is a Maple package; this function is called "codegen2". # # Sample usage: # codegen2([R_dd__fnd, R__fnd], ['R_dd', 'R'], "Ricci.c"); # # Arguments: # expr = (in) The expression or list of expressions for which code is to # be generated. This will typically be the name of a gridfn # array functional-dependence form, or a list of such names, # but this isn't required. # lhs_name = (in) The name or list of names to be used for the results in # the generated C code, eg R_dd. # output_file_name = (in) The name of the file to which the generated # code is to be written. # # Arguments (as global variables) # `saveit/level` # = (in) (optional) # If this global variable is assigned, intermediate results # are saved for debugging purposes. If it's assigned an # integral value, making this value larger may increase the # level of debugging output. (10 is a good number for typical # debugging.) # codegen2 := proc(expr_in::{algebraic, list(algebraic)}, lhs_name::{name, list(name)}, output_file_name::string) global @include "../maple/coords.minc", @include "../maple/gfa.minc"; local expr, expr_temps, input_set, output_set, expr_cost; printf("codegen2(%a) --> \"%s\"\n", lhs_name, output_file_name); expr := expr_in; saveit(10, procname, "input", expr); printf(" convert --> equation list\n"); expr := cvt_to_eqnlist(expr, lhs_name); saveit(10, procname, "eqnlist", expr); printf(" optimizing computation sequence\n"); expr := [codegen[optimize](expr)]; ##expr := [codegen[optimize](expr, tryhard)]; saveit(10, procname, "optimize", expr); printf(" find temporary variables\n"); expr_temps := temps_in_eqnlist(expr, lhs_name); saveit(10, procname, "temps", expr_temps); printf(" convert Diff(expr,rho,sigma) --> PARTIAL_RHO_SIGMA(expr) etc\n"); expr := fix_Diff(expr); saveit(10, procname, "fix_Diff", expr); input_set := deindex_names( indets(map(rhs,expr),name) minus {op(expr_temps)} minus xy_all_set ); output_set := deindex_names( {op(map(lhs,expr))} minus {op(expr_temps)} ); printf(" convert R_dd[2,3] --> R_dd_23 etc\n"); expr := unindex_names(expr); saveit(10, procname, "unindex", expr); expr_cost := codegen[cost](expr); printf(" convert p/q --> RATIONAL(p/q)\n"); expr := fix_rationals(expr); saveit(10, procname, "fix_rationals", expr); # # write the C code # printf(" writing C code\n"); ftruncate(output_file_name); fprintf(output_file_name, "/*\n"); fprintf(output_file_name, " * inputs = %a\n", input_set); fprintf(output_file_name, " * outputs = %a\n", output_set); fprintf(output_file_name, " * cost = %a\n", expr_cost); fprintf(output_file_name, " */\n"); print_name_list_dcl(expr_temps, "fp", output_file_name); codegen[C](expr, filename=output_file_name); NULL; end proc; ################################################################################ ################################################################################ ################################################################################ # # This function converts an expression or list of expressions into an # equation list. That is, given an expression expr , this function # computes an equation list of the form # # if type(expr, algebraic) # [ lhs_name = expr ] # # if type(expr, array) # illustrated here for a rank-1 array # [ # lhs_name[1] = expr[1], # note equations are in lexicographic # lhs_name[2] = expr[2], # order of the indices, and include # lhs_name[3] = expr[3], # only those array elements that are # ... # explicitly stored (as reported by # lhs_name[N] = expr[N] # indices() ) # ] # # if type(expr, list) # then concatenate the equations lists from expr's elements # # Arguments: # expr = (in) The expression to be converted. # lhs_name = (in) The unevaluated name or list of names to use for the # left hand side(s) in the equation list. # # Results: # The equation list is returned as the function result. # cvt_to_eqnlist := proc(expr::{algebraic, array, list({algebraic, array})}, lhs_name::{name, list(name)}) # ... test for array first since otherwise expr itself is a "name", # which would match type "algebraic" as well if (type(expr, array) and type(lhs_name, name)) then return map( proc(ii) return lhs_name[op(ii)] = expr[op(ii)]; end , indices_in_order(expr) ); fi; if (type(expr, algebraic) and type(lhs_name, name)) then return [lhs_name = expr]; fi; if (type(expr, list({algebraic, array})) and type(lhs_name, list(name))) then return zip(op @ cvt_to_eqnlist, expr, lhs_name); fi; error "unknown type for expression!\n" " expr=%1\n" " whattype(expr)=%2\n" , expr, whattype(expr); end; ################################################################################ # # This function converts Diff() calls into PARTIAL_*() calls, eg # Diff(src, rho, sigma) --> PARTIAL_RHO_SIGMA(src). # fix_Diff := proc(expr::{algebraic, name = algebraic, list({algebraic, name = algebraic})}) local nn, k, base, power, fn, fn_args_list, Darg, Dvars; global `fix_Diff/remap_table`; # recurse over lists if (type(expr, list)) then return map(fix_Diff, expr); fi; # recurse over equation right hand sides if (type(expr, name = algebraic)) then return lhs(expr) = fix_Diff(rhs(expr)); fi; nn := nops(expr); # recurse over sums if (type(expr, `+`)) then return sum('fix_Diff(op(k,expr))', 'k'=1..nn); fi; # recurse over products if (type(expr, `*`)) then return product('fix_Diff(op(k,expr))', 'k'=1..nn); fi; # recurse over powers if (type(expr, `^`)) then base := op(1, expr); power := op(2, expr); return fix_Diff(base) ^ power; fi; # recurse over non-Diff functions if type(expr, function) and (op(0, expr) <> 'Diff') then fn := op(0, expr); fn_args_list := [op(expr)]; fn; return '%'( op(map(fix_Diff, fn_args_list)) ); fi; # remap derivatives if type(expr, function) and (op(0, expr) = 'Diff') then Darg := op(1, expr); Dvars := [op(2..nn, expr)]; if (assigned(`fix_Diff/remap_table`[op(Dvars)])) then `fix_Diff/remap_table`[op(Dvars)]; return '%'(Darg); else error "don't know how to remap Diff() call!\n" " Darg = %1\n" " Dvars = %2\n" , Darg, Dvars; fi; fi; # otherwise, the identity function return expr; end; ######################################## # # this table defines the remapping of Diff() calls for fix_Diff() (above) # n.b. Diff() should already have canonicalized the order of variables # `fix_Diff/remap_table`[rho ] := 'PARTIAL_RHO'; `fix_Diff/remap_table`[sigma] := 'PARTIAL_SIGMA'; `fix_Diff/remap_table`[rho , rho ] := 'PARTIAL_RHO_RHO'; `fix_Diff/remap_table`[rho , sigma] := 'PARTIAL_RHO_SIGMA'; `fix_Diff/remap_table`[sigma, sigma] := 'PARTIAL_SIGMA_SIGMA'; `fix_Diff/remap_table`[xx] := 'PARTIAL_X'; `fix_Diff/remap_table`[yy] := 'PARTIAL_Y'; `fix_Diff/remap_table`[zz] := 'PARTIAL_Z'; `fix_Diff/remap_table`[xx,xx] := 'PARTIAL_XX'; `fix_Diff/remap_table`[xx,yy] := 'PARTIAL_XY'; `fix_Diff/remap_table`[xx,zz] := 'PARTIAL_XZ'; `fix_Diff/remap_table`[yy,yy] := 'PARTIAL_YY'; `fix_Diff/remap_table`[yy,zz] := 'PARTIAL_YZ'; `fix_Diff/remap_table`[zz,zz] := 'PARTIAL_ZZ'; ################################################################################ # # Given an equation list, this function finds all the temporaries # assigned by it. A "temporary" is defined here to be a name on the # left hand side of an equation, which isn't the result or a component # of it. # # Arguments: # expr = (in) The equation list to operate on. # result_name = (in) The result name or list/set of result names. # # Results: # The function returns the list of temporaries assigned. # temps_in_eqnlist := proc(expr::list(name = algebraic), result_name::{name, list(name), set(name)}) # "temporary" = lhs name which isn't a result return remove(is_result, map(lhs,expr), result_name); end; ################################################################################ # # This function tests whether or not a name is a "result" name or # an array/table component of it. Either a single result name, or a # list/set of these, may be specified; in the latter case the function # tests whether or not a name matches *any* of the result names. # # Arguments: # try_name = (in) The name to test. # result_name = (in) The name or list/set of names of the result to test # against. # # Results: # The function returns true if the name is equal to the result or an # array/table component of it, false otherwise. # is_result := proc(try_name::name, result_name_in::{name, list(name), set(name)}) local result_name, rn; if type(result_name_in, name) then result_name := { result_name_in }; else result_name := result_name_in; fi; for rn in result_name do if (try_name = rn) then return true; elif (type(try_name, indexed) and (op(0,try_name) = rn)) then return true; fi; end do; return false; end; ################################################################################ # # This function removes all indices from indexed names, eg # A[1,2,3] --> A . # # Arguments: # expr = (in) The expression to be converted. # # Results: # The converted expression is returned as the function result. # deindex_names := proc(expr::{name, function, list({name,function}), set({name,function})}) local fn, fn_args_list; # recurse over lists and sets if (type(expr, {list, set})) then return map(deindex_names, expr); fi; # recurse over function calls if (type(expr, function)) then fn := op(0, expr); fn_args_list := [op(expr)]; fn; return '%'(op(map(deindex_names, fn_args_list))); fi; # convert indexed names if (type(expr, indexed)) then return op(0, expr); fi; # return non-indexed names and numbers unchanged if (type(expr, {name, numeric})) then return expr; fi; # unknown type error "expr has unknown type!\n" "whattype(expr)=%1\n" , whattype(expr); end; ################################################################################ # # This function converts all occurence of indexed names in an expression # to new non-indexed "scalar names" of the form # A[1,2,3] --> A_123 . # # Arguments: # expr = (in) The expression to be converted. # # Results: # The converted expression is returned as the function result. # unindex_names := proc(expr::{ algebraic, name = algebraic, list({algebraic, name = algebraic}), set({algebraic, name = algebraic}) }) local nn, k, base, power, fn, fn_args_list, base_name, index_seq; # recurse over lists and sets if (type(expr, {list, set})) then return map(unindex_names, expr); fi; # recurse over equations (both lhs and rhs) if (type(expr, `=`)) then return unindex_names(lhs(expr)) = unindex_names(rhs(expr)); fi; nn := nops(expr); # recurse over sums if (type(expr, `+`)) then return sum('unindex_names(op(k,expr))', 'k'=1..nn); fi; # recurse over products if (type(expr, `*`)) then return product('unindex_names(op(k,expr))', 'k'=1..nn); fi; # recurse over powers if (type(expr, `^`)) then base := op(1, expr); power := op(2, expr); return unindex_names(base) ^ power; fi; # recurse over function calls if (type(expr, function)) then fn := op(0, expr); fn_args_list := [op(expr)]; fn; return '%'(op(map(unindex_names, fn_args_list))); fi; # convert indexed names if (type(expr, indexed)) then base_name := op(0, expr); index_seq := op(expr); return cat(base_name,"_",index_seq); fi; # return numbers and non-indexed names if (type(expr, {numeric, name})) then return expr; fi; # unknown type error "expr has unknown type!\n" "whattype(expr)=%1\n" , whattype(expr); end; ################################################################################ # # This function converts all integer or rational subexpressions of its # input except integer exponents and integer factors in products, into # function calls RATIONAL(num,den) with num and den integers. # # This is useful in conjunction with the C() library function, since # # C( (1/3) * foo * bar ) # t0 = foo*bar/3; # # generates a (slow) division (and runs the risk of mixed-mode-arithmetic # problems). In contrast, with this function # # fix_rationals((1/3) * foo * bar); # RATIONAL(1,3) foo bar # codegen[C](%); # t0 = RATIONAL(1.0,3.0)*foo*bar; # # which a C preprocessor macro can easily convert to the desired # # t0 = (1.0/3.0)*foo*bar; # # Arguments: # expr = (in) The expression to be converted. # fix_rationals := proc(expr::{algebraic, name = algebraic, list({algebraic, name = algebraic})}) local nn, k, expr_sign, expr_abs, base, power, fbase, fpower, fn, fn_args_list, int_factors, nonint_factors, num, den, mult; # recurse over lists if (type(expr, list)) then return map(fix_rationals, expr); fi; # recurse over equation right hand sides if (type(expr, name = algebraic)) then return lhs(expr) = fix_rationals(rhs(expr)); fi; # recurse over functions other than RATIONAL() if (type(expr, function)) then fn := op(0, expr); if (fn <> 'RATIONAL') then fn_args_list := [op(expr)]; fn; return '%'(op(map(fix_rationals, fn_args_list))); fi; fi; nn := nops(expr); # recurse over sums if (type(expr, `+`)) then return sum('fix_rationals(op(k,expr))', 'k'=1..nn); fi; # recurse over products # ... leaving integer factors intact if (type(expr, `*`)) then # compute lists of all integer/non-integer factors int_factors,nonint_factors := selectremove(type, expr, integer); if (nops(int_factors) > 0) then return op(1,int_factors) * product('fix_rationals(op(k,nonint_factors))', 'k'=1..nops(nonint_factors)); else return product('fix_rationals(op(k,expr))', 'k'=1..nn); fi; fi; # recurse over powers # ... leaving integer exponents intact if (type(expr, `^`)) then base := op(1, expr); power := op(2, expr); fbase := fix_rationals(base); if (type(power, integer)) then fpower := power; else fpower := fix_rationals(power); fi; return fbase ^ fpower;; fi; # fix integers and fractions if (type(expr, integer)) then return 'RATIONAL'(expr, 1); fi; if (type(expr, fraction)) then num := op(1, expr); den := op(2, expr); return 'RATIONAL'(num, den); fi; # turn Maple floating-point into integer fraction, then recursively fix that if (type(expr, float)) then mult := op(1, expr); power := op(2, expr); return fix_rationals(mult * 10^power); fi; # identity op on names if (type(expr, name)) then return expr; fi; # unknown type error "expr has unknown type!\n" "whattype(expr)=%1\n" "expr=%2\n" , whattype(expr), expr; end; ################################################################################ # # This function prints C declarations for a list of names. # # Argument: # name_list = A list of the names. # name_type = The C type of the names, eg. "double". # file_name = The file name to write (append) the declaration to. # print_name_list_dcl := proc( name_list::list({name,string}), name_type::string, file_name::string ) local nn; nn := nops(name_list); # print up to 10 declarations on one line if (nn <= 10) then map(convert, name_list, string); ListTools[Join](%, ", "); cat(op(%)); fprintf(file_name, "%s %s;\n", name_type, %); NULL; return; fi; # recurse for larger numbers of declarations print_name_list_dcl([op(1..10, name_list)], name_type, file_name); print_name_list_dcl([op(11..nn, name_list)], name_type, file_name); end proc;