diff options
Diffstat (limited to 'src/MPIO.fast.cc')
-rw-r--r-- | src/MPIO.fast.cc | 257 |
1 files changed, 257 insertions, 0 deletions
diff --git a/src/MPIO.fast.cc b/src/MPIO.fast.cc new file mode 100644 index 0000000..d9d6bf5 --- /dev/null +++ b/src/MPIO.fast.cc @@ -0,0 +1,257 @@ +#include <stdio.h> +#include "MPIutils.hh" +#include "MPIO.hh" + +void MPIO::requestSlice(int z,float *slicebuffer,MPI_Request *req){ + // skip to z + //loop (foreach p of pe's) + if(!isRoot()) return; + //printf("requestSlice %u",z); + for(int reqindex=0,offset=0,idx=slicemap[z],lastchunk=slicemap[z]+chunkindexcount[z]; + idx<lastchunk;idx++,reqindex++){ + int p = procmap[idx]; + int *idims = gdims+chunkindex[idx]; + int *iorigin = gorigins + chunkindex[idx]; + comm.iRecv(p,z,idims[0]*idims[1],slicebuffer+offset,req[reqindex]); + //printf("reqslice(%u):MPI_Request=%u\n",z,req[reqindex]); + offset+=idims[0]*idims[1]; + } + // break if outside of z-range (dangerous if non-canonical proc layout) +} + +void MPIO::waitForSlice(int z,float *slicebuffer,float *destbuffer,MPI_Request *req){ + if(!isRoot()) return; + // printf("waitforslice %u\n",z); + // could do this 8k at a time + // loop for each p of pe's + // if within z-range, compute nelem based on dims at p + // wait for current request. + // re-copy based on offset & origin + sliceallwait.start(); + //printf("z(%u) slicemap(%u) len(%u)\n",z,slicemap[z],chunkindexcount[z]); + for(int reqindex=0,chunkidx=slicemap[z],lastchunk=slicemap[z]+chunkindexcount[z]; + chunkidx<lastchunk;chunkidx++,reqindex++){ + //int debug=0; + //if(z>=38) debug=1; + int p = procmap[chunkidx]; + //if(debug) printf("\tp=%u\n",p); + int *idims = gdims + chunkindex[chunkidx]; + //if(debug) printf("\tdims(%u)=%u:%u:%u\n",chunkidx,idims[0],idims[1],idims[2]); + int *iorigin = gorigins + chunkindex[chunkidx]; + //if(debug) printf("\torigin=%u:%u:%u\n",iorigin[0],iorigin[1],iorigin[2]); + MPI_Status stat; + //printf("waitfor(%u):MPI_Request[%u]=%u\n",z,reqindex,req[reqindex]); + // slicewait.start(); + comm.wait(req[reqindex],stat); // frees request object too + //slicewait.stop(); + // get information out of the status object!!!! + //slicecollect.start(); + int offset = iorigin[1]*globaldims[0] + iorigin[0]; + for(int idx=0,pos=offset,j=0;j<idims[1];j++,pos+=(globaldims[0]-idims[0])) + for(int i=0;i<idims[0];i++) + destbuffer[pos++]=slicebuffer[idx++]; + // slicecollect.stop(); + } + sliceallwait.stop(); + // proclist array +} + +MPIO::MPIO(IObase *io,MPIcomm &c):file(io),comm(c),slicemap(0),procmap(0),chunkindex(0),chunkindexcount(0){ + for(int i=0;i<3;i++) globaldims[i]=localdims[i]=localorigin[i]=0; + int *rootarray; + myid = comm.rank(); + //printf("my rank = %u\n",myid); + if(!myid) rootarray = new int[comm.numProcs()]; + int rootscale = file?1:0; + //printf("rootscale[pid=%u] = %u\n",myid,rootscale); + comm.gather(0,1,&rootscale,rootarray); // gather to 0 + globaldims[0]=globaldims[1]=globaldims[2]=0; + if(!myid){ + for(int count=0,p=0;p<comm.numProcs();p++){ + //printf("rootarray[%u]=%u\n",p,rootarray[p]); + if(rootarray[p]) count++; + } + if(count<=0 || count>1){ + fprintf(stderr,"MPIO only accepts a single IO node right now.\n"); + fprintf(stderr,"You chose %u io nodes\n",count); + fprintf(stderr,"I will select only the first one\n"); + } + for(root=-1,p=0;p<comm.numProcs() && root<0;p++) + if(rootarray[p]) root=p; + delete rootarray; + } + comm.bcast(0,root); // now everyone knows root + //printf("broadcasting root node = %u\n",root); +} + +MPIO::~MPIO(){ + char buffer[32]; + sprintf(buffer,"PE(%3u)",comm.rank()); + if(myid==root) { + sprintf(buffer+7,"slicewait"); slicewait.print(buffer); + sprintf(buffer+7,"slicewrite"); slicewrite.print(buffer); + sprintf(buffer+7,"slicecollect"); slicecollect.print(buffer); + //sprintf(buffer+7,"sliceselect"); sliceselect.print(buffer); + sprintf(buffer+7,"sliceallwait"); sliceallwait.print(buffer); + } +} + +void MPIO::setLocalDims(int rank,int *origin, int *dims){ + if(rank!=3) perror("MPIO is only for 3D IO for now (sorry)!"); + if(isRoot()){ + gorigins = new int[3*comm.numProcs()]; + gdims = new int[3*comm.numProcs()]; + } + else gdims=gorigins=0; // null for everyone else + for(int i=0;i<3;i++) { + //printf("localdims-n-origin[%u] = %u:%u\n",i,dims[i],origin[i]); + localdims[i]=dims[i]; + localorigin[i]=origin[i]; + } + comm.gather(root,3,localdims,gdims); + comm.gather(root,3,localorigin,gorigins); + //if(isRoot()) for(i=0;i<3*comm.numProcs();i++) + //printf("gdims-n-origin[%u] = %u:%u\n",i,gdims[i],gorigins[i]); + globaldims[0]=globaldims[1]=globaldims[2]=0; + if(isRoot()) + for(int p=0,last=3*comm.numProcs();p<last;p++){ + //printf("globaldims[%u]: gdims[%u]+gorigins[%u] = %u\n", + // p%3,p,p,gdims[p],gorigins[p]); + if(globaldims[p%3] < gdims[p]+gorigins[p]) + globaldims[p%3] = gdims[p]+gorigins[p]; + } + // rebroadcast globaldims to all PE's + comm.bcast(root,3,globaldims); + // do we bcast the layout? + // Now organize the info hierarchially by Z (index into) + if(isRoot()){ + if(slicemap){ + delete slicemap; + delete procmap; + delete chunkindex; + delete chunkindexcount; + } + //puts("(1)"); + slicemap = new int[globaldims[2]+1]; + chunkindexcount = new int[globaldims[2]+1]; + { + for(int thickness=0,chunkcount=0,z=0;z<globaldims[2];z+=thickness){ + for(int p=0;p<comm.numProcs();p++){ + // count slices + int dindex = 3*p; + int *idims = gdims + dindex; + int *iorigin = gorigins + dindex; + int iz = z-iorigin[2]; + if(iz<0 || iz>=idims[2]) continue; + thickness = gdims[2]; // assumes uniform domain decomp in Z + chunkcount++; + } + } + // procmap = new int[chunkcount]; + //chunkindex = new int[chunkcount]; // points into iorigin/idims + // chunkindexcount = chunkcount; + } + procmap = new int[comm.numProcs()]; + chunkindex = new int [comm.numProcs()]; + // printf("(2): chunkindexcount = %u\n",chunkindexcount); + for(int thickness=0,mapcount=0,chunkcount=0,z=0;z<globaldims[2];z+=thickness){ + for(int p=0;p<comm.numProcs();p++){ + // count slices + int dindex = 3*p; + int *idims = gdims + dindex; + int *iorigin = gorigins + dindex; + int iz = z-iorigin[2]; + if(iz<0 || iz>=idims[2]) continue; + thickness = gdims[2]; + procmap[chunkcount] = p; + chunkindex[chunkcount] = dindex; // point into iorogin/idims + //printf("chunkcount=%u chunkindex=%u\n",chunkcount,chunkindex[chunkcount]); + chunkcount++; + } + } + //puts("(3)"); + for(chunkcount=0,z=0;z<globaldims[2];){ + int thickness=0; + for(int levelcount=0,p=0;p<comm.numProcs();p++){ + // count slices + int dindex = 3*p; + int *idims = gdims + dindex; + int *iorigin = gorigins + dindex; + int iz = z-iorigin[2]; + if(iz<0 || iz>=idims[2]) continue; + thickness = gdims[2]; + levelcount++; chunkcount++; + } + for(int pz=0;pz<thickness && z<=globaldims[2];pz++,z++){ + chunkindexcount[z]=levelcount; + slicemap[z] = chunkcount - levelcount; // selects chunkindex from slicenum + //printf("slicemap[%u]= %u\n",z,slicemap[z]); + } + } + //printf("last map count = %u\n",z); + //slicemap[z] = chunkcount; + } +} + +int MPIO::write(IObase::DataType type,int rank,int *dims,void *data){ + int recalc_layout=0; + if(rank!=3) recalc_layout=1; + for(int i=0;i<rank;i++) if(localdims[i]!=dims[i]) recalc_layout=1; + if(recalc_layout) setLocalDims(rank,localorigin,dims); + switch(type){ + case IObase::Float32: + write((float*)data); + break; + case IObase::Float64: + default: + break; + } + return 1; // for now, no error checking +} + +void MPIO::write(float *data){ + MPI_Request *sendreq = new MPI_Request[localdims[2]]; + typedef float *floatP; + int sliceindex=0; + floatP slicebuffer[3]; // double-buffer the slices as they arrive + typedef MPI_Request *MPI_RequestP; + MPI_RequestP recvreq[2]; + // the third one is a scratch buffer for reorganizing the slice + if(isRoot()){ + for(int i=0;i<3;i++) slicebuffer[i] = new float[globaldims[0]*globaldims[1]]; + for(i=0;i<2;i++) recvreq[i] = new MPI_Request[comm.numProcs()]; + } + if(isRoot()){ + // puts("request initial slice"); + requestSlice(0,slicebuffer[0],recvreq[0]); + } + // not good! Everyone is sending a slice here! + sendSlice(0,data,sendreq); + // hide latency behind the costly io->reserveChunk operation + if(isRoot()) file->reserveStream(IObase::Float32,3,globaldims); + int writecount=0; + for(int i=1;i<globaldims[2];i++){ + comm.barrier(); // clearly not! + if(isRoot()) + requestSlice(i,slicebuffer[i%2],recvreq[i%2]); + sendSlice(i,data,sendreq); // jms + if(isRoot()) { + waitForSlice(i-1,slicebuffer[(i-1)%2],slicebuffer[2],recvreq[(i-1)%2]); + writecount+=globaldims[0]*globaldims[1]; + slicewrite.start(); + file->writeStream(slicebuffer[2],globaldims[0]*globaldims[1]); + slicewrite.stop(); + } + } + if(isRoot()){ + waitForSlice(i-1,slicebuffer[(i-1)%2],slicebuffer[2],recvreq[(i-1)%2]); + file->writeStream(slicebuffer[2],globaldims[0]*globaldims[1]); + writecount+=globaldims[0]*globaldims[1]; + // do requestfree (double-buffered request free-s) + for(i=0;i<3;i++) delete slicebuffer[i]; // free everything up + for(i=0;i<2;i++) delete recvreq[i]; + } + MPI_Status stat; + for(i=0;i<localdims[2];i++) comm.wait(sendreq[i],stat); + delete sendreq; +} |