aboutsummaryrefslogtreecommitdiff
path: root/src/maple/codegen2.maple
blob: eaa881b6e968dec4937e722b513b256dea6af15d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
# codegen2.maple -- generate C code from Maple expressions
# $Header$
#
# codegen2 - generate C code from Maple expressions
#
# cvt_to_eqnlist - convert an expression to an equation list
# fix_Diff - convert Diff() calls to PARTIAL_*() calls
#   fix_Diff/remap_table - remapping table for fix_Diff()
# temps_in_eqnlist - find temporaries used by an equation list
# is_result - is a name equal to a "result" or an array/table component of it?
# deindex_names - remove indices from indexed names
# unindex_names - convert indexed names to scalars
# fix_rationals - convert numbers to RATIONAL() calls
# print_name_list_dcl - print a C declaration for a list of names
#

################################################################################
################################################################################
################################################################################

#
# This function is a high-level driver to generate C code for a Maple
# expression.  It does the following:
# - convert the input into a list of equations of the form  name = expression
# - rewrite Diff(...) calls to DIFF_RHO_RHO(...) etc calls
# - generate an "optimized" form of the computation sequence with
#   codegen[optimize, tryhard]
# - determine which temporary variables have been introduced by the
#   codegen[optimize] library
# - "unindex" all array accesses, eg R_dd[2,3] becomes R_dd_23.
# - convert rational numbers to RATIONAL(p,q) calls
# - print C declarations for the temporary variables
# - print C code for the the optimized computation sequence
#
# Note that "codegen" is a Maple package; this function is called "codegen2".
#
# Sample usage:
#	codegen2([R_dd__fnd, R__fnd], ['R_dd', 'R'], "Ricci.c");
#
# Arguments:
# expr = (in) The expression or list of expressions for which code is to
#	      be generated.  This will typically be the name of a gridfn
#	      array functional-dependence form, or a list of such names,
#	      but this isn't required.
# lhs_name = (in) The name or list of names to be used for the results in
#		  the generated C code, eg R_dd.
# output_file_name = (in) The name of the file to which the generated
#			  code is to be written.
#
# Arguments (as global variables)
# `saveit/level`
#	= (in) (optional)
#	       If this global variable is assigned, intermediate results
#	       are saved for debugging purposes.  If it's assigned an
#	       integral value, making this value larger may increase the
#	       level of debugging output.  (10 is a good number for typical
#	       debugging.)
#
codegen2 :=
proc(expr_in::{algebraic, list(algebraic)},
     lhs_name::{name, list(name)},
     output_file_name::string)
global
  @include "../maple/coords.minc",
  @include "../maple/gfa.minc";
local expr, expr_temps, input_set, output_set, expr_cost;

printf("codegen2(%a) --> \"%s\"\n", lhs_name, output_file_name);

expr := expr_in;
saveit(10, procname, "input", expr);

printf("   convert --> equation list\n");
expr := cvt_to_eqnlist(expr, lhs_name);
saveit(10, procname, "eqnlist", expr);

printf("   optimizing computation sequence\n");
expr := [codegen[optimize](expr)];
##expr := [codegen[optimize](expr, tryhard)];
saveit(10, procname, "optimize", expr);

printf("   find temporary variables\n");
expr_temps := temps_in_eqnlist(expr, lhs_name);
saveit(10, procname, "temps", expr_temps);

printf("   convert Diff(expr,rho,sigma) --> PARTIAL_RHO_SIGMA(expr) etc\n");
expr := fix_Diff(expr);
saveit(10, procname, "fix_Diff", expr);

input_set := deindex_names( indets(map(rhs,expr),name)
			    minus {op(expr_temps)}
			    minus xy_all_set );
output_set := deindex_names( {op(map(lhs,expr))} minus {op(expr_temps)} );
printf("   convert R_dd[2,3] --> R_dd_23 etc\n");
expr := unindex_names(expr);
saveit(10, procname, "unindex", expr);

expr_cost := codegen[cost](expr);

printf("   convert p/q --> RATIONAL(p/q)\n");
expr := fix_rationals(expr);
saveit(10, procname, "fix_rationals", expr);

#
# write the C code
#
printf("   writing C code\n");
ftruncate(output_file_name);
fprintf(output_file_name, "/*\n");
fprintf(output_file_name, " * inputs = %a\n", input_set);
fprintf(output_file_name, " * outputs = %a\n", output_set);
fprintf(output_file_name, " * cost = %a\n", expr_cost);
fprintf(output_file_name, " */\n");
print_name_list_dcl(expr_temps, "fp", output_file_name);
codegen[C](expr, filename=output_file_name);

NULL;
end proc;

################################################################################
################################################################################
################################################################################

#
# This function converts an expression or list of expressions into an
# equation list.  That is, given an expression  expr , this function
# computes an equation list of the form
#
# if type(expr, algebraic)
#	[ lhs_name = expr ]
#
# if type(expr, array)			# illustrated here for a rank-1 array
#	[
#	lhs_name[1] = expr[1],		# note equations are in lexicographic
#	lhs_name[2] = expr[2],		# order of the indices, and include
#	lhs_name[3] = expr[3],		# only those array elements that are
#	    ...				# explicitly stored (as reported by
#	lhs_name[N] = expr[N]		#  indices() )
#	]
#
# if type(expr, list)
#	then concatenate the equations lists from expr's elements
#
# Arguments:
# expr = (in) The expression to be converted.
# lhs_name = (in) The unevaluated name or list of names to use for the
#		  left hand side(s) in the equation list.
#
# Results:
# The equation list is returned as the function result.
#
cvt_to_eqnlist :=
proc(expr::{algebraic, array, list({algebraic, array})},
     lhs_name::{name, list(name)})

# ... test for array first since otherwise expr itself is a "name",
#     which would match type "algebraic" as well
if (type(expr, array) and type(lhs_name, name))
   then return map(
		      proc(ii)
		      return lhs_name[op(ii)] = expr[op(ii)];
		      end
		    ,
		      indices_in_order(expr)
		  );
fi;

if (type(expr, algebraic) and type(lhs_name, name))
   then return [lhs_name = expr];
fi;

if (type(expr, list({algebraic, array})) and type(lhs_name, list(name)))
   then return zip(op @ cvt_to_eqnlist, expr, lhs_name);
fi;

error "unknown type for expression!\n"
      "   expr=%1\n"
      "   whattype(expr)=%2\n"
      ,
      expr, whattype(expr);
end;

################################################################################

#
# This function converts  Diff()  calls into  PARTIAL_*()  calls, eg
# Diff(src, rho, sigma) --> PARTIAL_RHO_SIGMA(src).
#
fix_Diff :=
proc(expr::{algebraic, name = algebraic, list({algebraic, name = algebraic})})
local nn, k, base, power, fn, fn_args_list, Darg, Dvars;
global `fix_Diff/remap_table`;

# recurse over lists
if (type(expr, list))
   then return map(fix_Diff, expr);
fi;

# recurse over equation right hand sides
if (type(expr, name = algebraic))
   then return lhs(expr) = fix_Diff(rhs(expr));
fi;

nn := nops(expr);

# recurse over sums
if (type(expr, `+`))
   then return sum('fix_Diff(op(k,expr))', 'k'=1..nn);
fi;

# recurse over products
if (type(expr, `*`))
   then return product('fix_Diff(op(k,expr))', 'k'=1..nn);
fi;

# recurse over powers
if (type(expr, `^`))
   then 
	base := op(1, expr);
	power := op(2, expr);

	return fix_Diff(base) ^ power;
	fi;

# recurse over non-Diff functions
if type(expr, function) and (op(0, expr) <> 'Diff')
   then 
	fn := op(0, expr);
	fn_args_list := [op(expr)];
	
	fn; return '%'( op(map(fix_Diff, fn_args_list)) );
fi;

# remap derivatives
if type(expr, function) and (op(0, expr) = 'Diff')
   then
	Darg := op(1, expr);
	Dvars := [op(2..nn, expr)];
	if (assigned(`fix_Diff/remap_table`[op(Dvars)]))
	   then `fix_Diff/remap_table`[op(Dvars)]; return '%'(Darg);
	   else error "don't know how to remap Diff() call!\n"
		      "   Darg = %1\n"
		      "   Dvars = %2\n"
		      ,
		      Darg, Dvars;
	fi;
fi;

# otherwise, the identity function
return expr;
end;

########################################

#
# this table defines the remapping of Diff() calls for  fix_Diff()  (above)
# n.b. Diff() should already have canonicalized the order of variables
#
`fix_Diff/remap_table`[rho  ] := 'PARTIAL_RHO';
`fix_Diff/remap_table`[sigma] := 'PARTIAL_SIGMA';
`fix_Diff/remap_table`[rho  , rho  ] := 'PARTIAL_RHO_RHO';
`fix_Diff/remap_table`[rho  , sigma] := 'PARTIAL_RHO_SIGMA';
`fix_Diff/remap_table`[sigma, sigma] := 'PARTIAL_SIGMA_SIGMA';

`fix_Diff/remap_table`[xx] := 'PARTIAL_X';
`fix_Diff/remap_table`[yy] := 'PARTIAL_Y';
`fix_Diff/remap_table`[zz] := 'PARTIAL_Z';
`fix_Diff/remap_table`[xx,xx] := 'PARTIAL_XX';
`fix_Diff/remap_table`[xx,yy] := 'PARTIAL_XY';
`fix_Diff/remap_table`[xx,zz] := 'PARTIAL_XZ';
`fix_Diff/remap_table`[yy,yy] := 'PARTIAL_YY';
`fix_Diff/remap_table`[yy,zz] := 'PARTIAL_YZ';
`fix_Diff/remap_table`[zz,zz] := 'PARTIAL_ZZ';

################################################################################

#
# Given an equation list, this function finds all the temporaries
# assigned by it.  A "temporary" is defined here to be a name on the
# left hand side of an equation, which isn't the result or a component
# of it.
#
# Arguments:
# expr = (in) The equation list to operate on.
# result_name = (in) The result name or list/set of result names.
#
# Results:
# The function returns the list of temporaries assigned.
#
temps_in_eqnlist :=
proc(expr::list(name = algebraic),
     result_name::{name, list(name), set(name)})

# "temporary" = lhs name which isn't a result
return remove(is_result, map(lhs,expr), result_name);
end;

################################################################################

#
# This function tests whether or not a name is a "result" name or
# an array/table component of it.  Either a single result name, or a
# list/set of these, may be specified; in the latter case the function
# tests whether or not a name matches *any* of the result names.
#
# Arguments:
# try_name = (in) The name to test.
# result_name = (in) The name or list/set of names of the result to test
#		     against.
#
# Results:
# The function returns  true  if the name is equal to the result or an
# array/table component of it,  false  otherwise.
#
is_result :=
proc(try_name::name,
     result_name_in::{name, list(name), set(name)})
local result_name, rn;

if type(result_name_in, name)
   then result_name := { result_name_in };
   else result_name := result_name_in;
fi;

	for rn in result_name
	do
	if (try_name = rn)
	   then return true;
	elif (type(try_name, indexed) and (op(0,try_name) = rn))
	   then return true;
	fi;
	end do;

return false;
end;

################################################################################

#
# This function removes all indices from indexed names, eg
#	A[1,2,3] --> A .
#
# Arguments:
# expr = (in) The expression to be converted.
#
# Results:
# The converted expression is returned as the function result.
#
deindex_names :=
proc(expr::{name, function, list({name,function}), set({name,function})})
local fn, fn_args_list;

# recurse over lists and sets
if (type(expr, {list, set}))
   then return map(deindex_names, expr);
fi;

# recurse over function calls
if (type(expr, function))
   then
	fn := op(0, expr);
	fn_args_list := [op(expr)];
	fn; return '%'(op(map(deindex_names, fn_args_list)));
fi;

# convert indexed names
if (type(expr, indexed))
   then return op(0, expr);
fi;

# return non-indexed names and numbers unchanged
if (type(expr, {name, numeric}))
   then return expr;
fi;

# unknown type
error "expr has unknown type!\n"
      "whattype(expr)=%1\n"
      ,
      whattype(expr);
end;

################################################################################

#
# This function converts all occurence of indexed names in an expression
# to new non-indexed "scalar names" of the form
#	A[1,2,3] --> A_123 .
#
# Arguments:
# expr = (in) The expression to be converted.
#
# Results:
# The converted expression is returned as the function result.
#
unindex_names :=
proc(expr::{
	   algebraic, name = algebraic,
	   list({algebraic, name = algebraic}),
	   set({algebraic, name = algebraic})
	   })
local nn, k,
      base, power,
      fn, fn_args_list,
      base_name, index_seq;

# recurse over lists and sets
if (type(expr, {list, set}))
   then return map(unindex_names, expr);
fi;

# recurse over equations (both lhs and rhs)
if (type(expr, `=`))
   then return unindex_names(lhs(expr))  =  unindex_names(rhs(expr));
fi;

nn := nops(expr);

# recurse over sums
if (type(expr, `+`))
   then return sum('unindex_names(op(k,expr))', 'k'=1..nn);
fi;

# recurse over products
if (type(expr, `*`))
   then return product('unindex_names(op(k,expr))', 'k'=1..nn);
fi;

# recurse over powers
if (type(expr, `^`))
   then
	base := op(1, expr);
	power := op(2, expr);
	return unindex_names(base) ^ power;
fi;

# recurse over function calls
if (type(expr, function))
   then
	fn := op(0, expr);
	fn_args_list := [op(expr)];
	fn; return '%'(op(map(unindex_names, fn_args_list)));
fi;

# convert indexed names
if (type(expr, indexed))
   then
	base_name := op(0, expr);
	index_seq := op(expr);
	return cat(base_name,"_",index_seq);
fi;

# return numbers and non-indexed names
if (type(expr, {numeric, name}))
   then return expr;
fi;

# unknown type
error "expr has unknown type!\n"
      "whattype(expr)=%1\n"
      ,
      whattype(expr);
end;

################################################################################

#
# This function converts all integer or rational subexpressions of its
# input except integer exponents and integer factors in products, into
# function calls  RATIONAL(num,den)  with  num  and  den  integers.
#
# This is useful in conjunction with the  C() library function, since
#
#	C( (1/3) * foo * bar )
#		t0 = foo*bar/3;
#
# generates a (slow) division (and runs the risk of mixed-mode-arithmetic
# problems).  In contrast, with this function
#
#	fix_rationals((1/3) * foo * bar);
#	     RATIONAL(1,3) foo bar
#	codegen[C](%);
#	     t0 = RATIONAL(1.0,3.0)*foo*bar;
#
# which a C preprocessor macro can easily convert to the desired
#
#	     t0 = (1.0/3.0)*foo*bar;
#
# Arguments:
# expr = (in) The expression to be converted.
#
fix_rationals :=
proc(expr::{algebraic, name = algebraic, list({algebraic, name = algebraic})})
local nn, k,
      expr_sign, expr_abs,
      base, power, fbase, fpower,
      fn, fn_args_list,
      int_factors, nonint_factors,
      num, den, mult;

# recurse over lists
if (type(expr, list))
   then return map(fix_rationals, expr);
fi;

# recurse over equation right hand sides
if (type(expr, name = algebraic))
   then return lhs(expr) = fix_rationals(rhs(expr));
fi;

# recurse over functions other than  RATIONAL()
if (type(expr, function))
   then
	fn := op(0, expr);
	if (fn <> 'RATIONAL')
	   then
		fn_args_list := [op(expr)];
		fn; return '%'(op(map(fix_rationals, fn_args_list)));
	fi;
fi;

nn := nops(expr);

# recurse over sums
if (type(expr, `+`))
   then return sum('fix_rationals(op(k,expr))', 'k'=1..nn);
fi;

# recurse over products
# ... leaving integer factors intact
if (type(expr, `*`))
   then
	# compute lists of all integer/non-integer factors
	int_factors,nonint_factors := selectremove(type, expr, integer);

	if (nops(int_factors) > 0)
	   then return op(1,int_factors)
		       * product('fix_rationals(op(k,nonint_factors))',
				 'k'=1..nops(nonint_factors));
	   else return product('fix_rationals(op(k,expr))', 'k'=1..nn);
	fi;
fi;

# recurse over powers
# ... leaving integer exponents intact
if (type(expr, `^`))
   then
	base := op(1, expr);
	power := op(2, expr);

	fbase := fix_rationals(base);
	if (type(power, integer))
	   then fpower := power;
	   else fpower := fix_rationals(power);
	fi;
	return fbase ^ fpower;;
fi;

# fix integers and fractions
if (type(expr, integer))
   then return 'RATIONAL'(expr, 1);
fi;
if (type(expr, fraction))
   then
	num := op(1, expr);
	den := op(2, expr);
	return 'RATIONAL'(num, den);
fi;

# turn Maple floating-point into integer fraction, then recursively fix that
if (type(expr, float))
   then
	mult := op(1, expr);
	power := op(2, expr);
	return fix_rationals(mult * 10^power);
fi;

# identity op on names
if (type(expr, name))
   then return expr;
fi;

# unknown type
error "expr has unknown type!\n"
      "whattype(expr)=%1\n"
      "expr=%2\n"
      ,
      whattype(expr), expr;
end;

################################################################################

#
# This function prints C declarations for a list of names.
#
# Argument:
# name_list = A list of the names.
# name_type = The C type of the names, eg. "double".
# file_name = The file name to write (append) the declaration to.
#
print_name_list_dcl :=
proc( name_list::list({name,string}),
      name_type::string,
      file_name::string )
local nn;

nn := nops(name_list);

# print up to 10 declarations on one line
if (nn <= 10)
   then
	map(convert, name_list, string);
	ListTools[Join](%, ", ");
	cat(op(%));
	fprintf(file_name,
		"%s %s;\n",
		name_type, %);
	NULL;
	return;
fi;

# recurse for larger numbers of declarations
print_name_list_dcl([op(1..10, name_list)], name_type, file_name);
print_name_list_dcl([op(11..nn, name_list)], name_type, file_name);
end proc;