aboutsummaryrefslogtreecommitdiff
path: root/src/MPIO.fast.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/MPIO.fast.cc')
-rw-r--r--src/MPIO.fast.cc257
1 files changed, 257 insertions, 0 deletions
diff --git a/src/MPIO.fast.cc b/src/MPIO.fast.cc
new file mode 100644
index 0000000..d9d6bf5
--- /dev/null
+++ b/src/MPIO.fast.cc
@@ -0,0 +1,257 @@
+#include <stdio.h>
+#include "MPIutils.hh"
+#include "MPIO.hh"
+
+void MPIO::requestSlice(int z,float *slicebuffer,MPI_Request *req){
+ // skip to z
+ //loop (foreach p of pe's)
+ if(!isRoot()) return;
+ //printf("requestSlice %u",z);
+ for(int reqindex=0,offset=0,idx=slicemap[z],lastchunk=slicemap[z]+chunkindexcount[z];
+ idx<lastchunk;idx++,reqindex++){
+ int p = procmap[idx];
+ int *idims = gdims+chunkindex[idx];
+ int *iorigin = gorigins + chunkindex[idx];
+ comm.iRecv(p,z,idims[0]*idims[1],slicebuffer+offset,req[reqindex]);
+ //printf("reqslice(%u):MPI_Request=%u\n",z,req[reqindex]);
+ offset+=idims[0]*idims[1];
+ }
+ // break if outside of z-range (dangerous if non-canonical proc layout)
+}
+
+void MPIO::waitForSlice(int z,float *slicebuffer,float *destbuffer,MPI_Request *req){
+ if(!isRoot()) return;
+ // printf("waitforslice %u\n",z);
+ // could do this 8k at a time
+ // loop for each p of pe's
+ // if within z-range, compute nelem based on dims at p
+ // wait for current request.
+ // re-copy based on offset & origin
+ sliceallwait.start();
+ //printf("z(%u) slicemap(%u) len(%u)\n",z,slicemap[z],chunkindexcount[z]);
+ for(int reqindex=0,chunkidx=slicemap[z],lastchunk=slicemap[z]+chunkindexcount[z];
+ chunkidx<lastchunk;chunkidx++,reqindex++){
+ //int debug=0;
+ //if(z>=38) debug=1;
+ int p = procmap[chunkidx];
+ //if(debug) printf("\tp=%u\n",p);
+ int *idims = gdims + chunkindex[chunkidx];
+ //if(debug) printf("\tdims(%u)=%u:%u:%u\n",chunkidx,idims[0],idims[1],idims[2]);
+ int *iorigin = gorigins + chunkindex[chunkidx];
+ //if(debug) printf("\torigin=%u:%u:%u\n",iorigin[0],iorigin[1],iorigin[2]);
+ MPI_Status stat;
+ //printf("waitfor(%u):MPI_Request[%u]=%u\n",z,reqindex,req[reqindex]);
+ // slicewait.start();
+ comm.wait(req[reqindex],stat); // frees request object too
+ //slicewait.stop();
+ // get information out of the status object!!!!
+ //slicecollect.start();
+ int offset = iorigin[1]*globaldims[0] + iorigin[0];
+ for(int idx=0,pos=offset,j=0;j<idims[1];j++,pos+=(globaldims[0]-idims[0]))
+ for(int i=0;i<idims[0];i++)
+ destbuffer[pos++]=slicebuffer[idx++];
+ // slicecollect.stop();
+ }
+ sliceallwait.stop();
+ // proclist array
+}
+
+MPIO::MPIO(IObase *io,MPIcomm &c):file(io),comm(c),slicemap(0),procmap(0),chunkindex(0),chunkindexcount(0){
+ for(int i=0;i<3;i++) globaldims[i]=localdims[i]=localorigin[i]=0;
+ int *rootarray;
+ myid = comm.rank();
+ //printf("my rank = %u\n",myid);
+ if(!myid) rootarray = new int[comm.numProcs()];
+ int rootscale = file?1:0;
+ //printf("rootscale[pid=%u] = %u\n",myid,rootscale);
+ comm.gather(0,1,&rootscale,rootarray); // gather to 0
+ globaldims[0]=globaldims[1]=globaldims[2]=0;
+ if(!myid){
+ for(int count=0,p=0;p<comm.numProcs();p++){
+ //printf("rootarray[%u]=%u\n",p,rootarray[p]);
+ if(rootarray[p]) count++;
+ }
+ if(count<=0 || count>1){
+ fprintf(stderr,"MPIO only accepts a single IO node right now.\n");
+ fprintf(stderr,"You chose %u io nodes\n",count);
+ fprintf(stderr,"I will select only the first one\n");
+ }
+ for(root=-1,p=0;p<comm.numProcs() && root<0;p++)
+ if(rootarray[p]) root=p;
+ delete rootarray;
+ }
+ comm.bcast(0,root); // now everyone knows root
+ //printf("broadcasting root node = %u\n",root);
+}
+
+MPIO::~MPIO(){
+ char buffer[32];
+ sprintf(buffer,"PE(%3u)",comm.rank());
+ if(myid==root) {
+ sprintf(buffer+7,"slicewait"); slicewait.print(buffer);
+ sprintf(buffer+7,"slicewrite"); slicewrite.print(buffer);
+ sprintf(buffer+7,"slicecollect"); slicecollect.print(buffer);
+ //sprintf(buffer+7,"sliceselect"); sliceselect.print(buffer);
+ sprintf(buffer+7,"sliceallwait"); sliceallwait.print(buffer);
+ }
+}
+
+void MPIO::setLocalDims(int rank,int *origin, int *dims){
+ if(rank!=3) perror("MPIO is only for 3D IO for now (sorry)!");
+ if(isRoot()){
+ gorigins = new int[3*comm.numProcs()];
+ gdims = new int[3*comm.numProcs()];
+ }
+ else gdims=gorigins=0; // null for everyone else
+ for(int i=0;i<3;i++) {
+ //printf("localdims-n-origin[%u] = %u:%u\n",i,dims[i],origin[i]);
+ localdims[i]=dims[i];
+ localorigin[i]=origin[i];
+ }
+ comm.gather(root,3,localdims,gdims);
+ comm.gather(root,3,localorigin,gorigins);
+ //if(isRoot()) for(i=0;i<3*comm.numProcs();i++)
+ //printf("gdims-n-origin[%u] = %u:%u\n",i,gdims[i],gorigins[i]);
+ globaldims[0]=globaldims[1]=globaldims[2]=0;
+ if(isRoot())
+ for(int p=0,last=3*comm.numProcs();p<last;p++){
+ //printf("globaldims[%u]: gdims[%u]+gorigins[%u] = %u\n",
+ // p%3,p,p,gdims[p],gorigins[p]);
+ if(globaldims[p%3] < gdims[p]+gorigins[p])
+ globaldims[p%3] = gdims[p]+gorigins[p];
+ }
+ // rebroadcast globaldims to all PE's
+ comm.bcast(root,3,globaldims);
+ // do we bcast the layout?
+ // Now organize the info hierarchially by Z (index into)
+ if(isRoot()){
+ if(slicemap){
+ delete slicemap;
+ delete procmap;
+ delete chunkindex;
+ delete chunkindexcount;
+ }
+ //puts("(1)");
+ slicemap = new int[globaldims[2]+1];
+ chunkindexcount = new int[globaldims[2]+1];
+ {
+ for(int thickness=0,chunkcount=0,z=0;z<globaldims[2];z+=thickness){
+ for(int p=0;p<comm.numProcs();p++){
+ // count slices
+ int dindex = 3*p;
+ int *idims = gdims + dindex;
+ int *iorigin = gorigins + dindex;
+ int iz = z-iorigin[2];
+ if(iz<0 || iz>=idims[2]) continue;
+ thickness = gdims[2]; // assumes uniform domain decomp in Z
+ chunkcount++;
+ }
+ }
+ // procmap = new int[chunkcount];
+ //chunkindex = new int[chunkcount]; // points into iorigin/idims
+ // chunkindexcount = chunkcount;
+ }
+ procmap = new int[comm.numProcs()];
+ chunkindex = new int [comm.numProcs()];
+ // printf("(2): chunkindexcount = %u\n",chunkindexcount);
+ for(int thickness=0,mapcount=0,chunkcount=0,z=0;z<globaldims[2];z+=thickness){
+ for(int p=0;p<comm.numProcs();p++){
+ // count slices
+ int dindex = 3*p;
+ int *idims = gdims + dindex;
+ int *iorigin = gorigins + dindex;
+ int iz = z-iorigin[2];
+ if(iz<0 || iz>=idims[2]) continue;
+ thickness = gdims[2];
+ procmap[chunkcount] = p;
+ chunkindex[chunkcount] = dindex; // point into iorogin/idims
+ //printf("chunkcount=%u chunkindex=%u\n",chunkcount,chunkindex[chunkcount]);
+ chunkcount++;
+ }
+ }
+ //puts("(3)");
+ for(chunkcount=0,z=0;z<globaldims[2];){
+ int thickness=0;
+ for(int levelcount=0,p=0;p<comm.numProcs();p++){
+ // count slices
+ int dindex = 3*p;
+ int *idims = gdims + dindex;
+ int *iorigin = gorigins + dindex;
+ int iz = z-iorigin[2];
+ if(iz<0 || iz>=idims[2]) continue;
+ thickness = gdims[2];
+ levelcount++; chunkcount++;
+ }
+ for(int pz=0;pz<thickness && z<=globaldims[2];pz++,z++){
+ chunkindexcount[z]=levelcount;
+ slicemap[z] = chunkcount - levelcount; // selects chunkindex from slicenum
+ //printf("slicemap[%u]= %u\n",z,slicemap[z]);
+ }
+ }
+ //printf("last map count = %u\n",z);
+ //slicemap[z] = chunkcount;
+ }
+}
+
+int MPIO::write(IObase::DataType type,int rank,int *dims,void *data){
+ int recalc_layout=0;
+ if(rank!=3) recalc_layout=1;
+ for(int i=0;i<rank;i++) if(localdims[i]!=dims[i]) recalc_layout=1;
+ if(recalc_layout) setLocalDims(rank,localorigin,dims);
+ switch(type){
+ case IObase::Float32:
+ write((float*)data);
+ break;
+ case IObase::Float64:
+ default:
+ break;
+ }
+ return 1; // for now, no error checking
+}
+
+void MPIO::write(float *data){
+ MPI_Request *sendreq = new MPI_Request[localdims[2]];
+ typedef float *floatP;
+ int sliceindex=0;
+ floatP slicebuffer[3]; // double-buffer the slices as they arrive
+ typedef MPI_Request *MPI_RequestP;
+ MPI_RequestP recvreq[2];
+ // the third one is a scratch buffer for reorganizing the slice
+ if(isRoot()){
+ for(int i=0;i<3;i++) slicebuffer[i] = new float[globaldims[0]*globaldims[1]];
+ for(i=0;i<2;i++) recvreq[i] = new MPI_Request[comm.numProcs()];
+ }
+ if(isRoot()){
+ // puts("request initial slice");
+ requestSlice(0,slicebuffer[0],recvreq[0]);
+ }
+ // not good! Everyone is sending a slice here!
+ sendSlice(0,data,sendreq);
+ // hide latency behind the costly io->reserveChunk operation
+ if(isRoot()) file->reserveStream(IObase::Float32,3,globaldims);
+ int writecount=0;
+ for(int i=1;i<globaldims[2];i++){
+ comm.barrier(); // clearly not!
+ if(isRoot())
+ requestSlice(i,slicebuffer[i%2],recvreq[i%2]);
+ sendSlice(i,data,sendreq); // jms
+ if(isRoot()) {
+ waitForSlice(i-1,slicebuffer[(i-1)%2],slicebuffer[2],recvreq[(i-1)%2]);
+ writecount+=globaldims[0]*globaldims[1];
+ slicewrite.start();
+ file->writeStream(slicebuffer[2],globaldims[0]*globaldims[1]);
+ slicewrite.stop();
+ }
+ }
+ if(isRoot()){
+ waitForSlice(i-1,slicebuffer[(i-1)%2],slicebuffer[2],recvreq[(i-1)%2]);
+ file->writeStream(slicebuffer[2],globaldims[0]*globaldims[1]);
+ writecount+=globaldims[0]*globaldims[1];
+ // do requestfree (double-buffered request free-s)
+ for(i=0;i<3;i++) delete slicebuffer[i]; // free everything up
+ for(i=0;i<2;i++) delete recvreq[i];
+ }
+ MPI_Status stat;
+ for(i=0;i<localdims[2];i++) comm.wait(sendreq[i],stat);
+ delete sendreq;
+}