From edcc07ec3cbeded2f09e300784bd900daf8b9032 Mon Sep 17 00:00:00 2001 From: rzinke Date: Tue, 22 Mar 2022 17:16:38 -0700 Subject: [PATCH 1/4] Start changes --- environment.yml | 2 +- tools/ARIAtools.egg-info/PKG-INFO | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index 5fab4fcb..0fd7e635 100644 --- a/environment.yml +++ b/environment.yml @@ -3,7 +3,7 @@ # remove environment : conda env remove -n ARIA-tools # enter environment : conda activate ARIA-tools # exit environment : conda deactivate -name: ARIA-tools +name: ariax channels: - conda-forge - defaults diff --git a/tools/ARIAtools.egg-info/PKG-INFO b/tools/ARIAtools.egg-info/PKG-INFO index 30ba7ee9..a8fa4693 100644 --- a/tools/ARIAtools.egg-info/PKG-INFO +++ b/tools/ARIAtools.egg-info/PKG-INFO @@ -2,7 +2,6 @@ Metadata-Version: 2.1 Name: ARIAtools Version: 1.1.0 Summary: This is the ARIA tools package without RelaxIV support -Home-page: UNKNOWN License: UNKNOWN Platform: UNKNOWN License-File: LICENSE From f5c3492a3edac10e4907408c32151f20a8723551 Mon Sep 17 00:00:00 2001 From: rzinke Date: Tue, 22 Mar 2022 17:22:48 -0700 Subject: [PATCH 2/4] Replaced gdalWarp functions in UnwrapOverlap with gdalTranslate and mask layer --- tools/ARIAtools/unwrapStitching.py | 626 +++++++++++++++++------------ 1 file changed, 361 insertions(+), 265 deletions(-) diff --git a/tools/ARIAtools/unwrapStitching.py b/tools/ARIAtools/unwrapStitching.py index d6006899..b1e704e3 100755 --- a/tools/ARIAtools/unwrapStitching.py +++ b/tools/ARIAtools/unwrapStitching.py @@ -35,127 +35,129 @@ solverTypes = ['pulp', 'glpk', 'gurobi'] redarcsTypes = {'MCF':-1, 'REDARC0':0, 'REDARC1':1, 'REDARC2':2} -stitchMethodTypes = ['overlap','2stage'] +stitchMethodTypes = ['overlap', '2stage'] class Stitching: - ''' - This is the parent class of all stiching codes. + """This is the parent class of all stiching codes. It is whatever is shared between the different variants of stitching methods e.g. - setting of the main input arguments, - functions to verify GDAL compatibility, - function to write new connected component of merged product based on a mapping table, - fucntion to write unwrapped phase of merged product based on a mapping table - ''' + """ def __init__(self): - ''' - Setting the default arguments needed by the class. + """Setting the default arguments needed by the class. Parse the filenames and bbox as None as they need to be set by the user, - which will be caught when running the child classes of the respective stitch method - ''' + which will be caught when running the child classes of the + respective stitch method. + """ self.inpFile = None self.ccFile = None self.prodbboxFile = None self.prodbbox = None - self.solver ='pulp' - self.redArcs =-1 + self.solver = 'pulp' + self.redArcs = -1 self.mask = None self.outFileUnw = './unwMerged' self.outFileConnComp = './connCompMerged' - self.outputFormat='ENVI' + self.outputFormat = 'ENVI' # stitching methods, by default leverage product overlap method # other options would be to leverage connected component self.setStitchMethod("overlap") - def setInpFile(self, input): - """ Set the input Filename for stitching/unwrapping """ + """Set the input Filename for stitching/unwrapping.""" # Convert a string (i.e. user gave single file) to a list if isinstance(input, np.str): input = [input] self.inpFile = input - # the number of files that needs to be merged/ unwrapped + # The number of files that needs to be merged/ unwrapped self.nfiles = np.shape(self.inpFile)[0] def setOutFile(self, output): - """ Set the output File name """ + """Set the output File name.""" self.outFile = output def setConnCompFile(self, connCompFile): - """ Set the connected Component file """ + """Set the connected Component file.""" # Convert a string (i.e. user gave single file) to a list if isinstance(connCompFile, np.str): connCompFile = [connCompFile] self.ccFile = connCompFile def setProdBBoxFile(self, ProdBBoxFile): - """ Set the product bounding box file(s) """ + """Set the product bounding box file(s).""" # Convert a string (i.e. user gave single file) to a list if isinstance(ProdBBoxFile, np.str): ProdBBoxFile = [ProdBBoxFile] self.prodbboxFile = ProdBBoxFile def setBBoxFile(self, BBoxFile): - """ Set bounds of bbox """ + """Set bounds of bbox.""" self.bbox_file = BBoxFile def setTotProdBBoxFile(self, prods_TOTbbox): - """ Set common track bbox file""" + """Set common track bbox file.""" self.setTotProdBBoxFile = prods_TOTbbox def setStitchMethod(self,stitchMethodType): - """ Set the stitch method to be used to handle parant class internals """ + """Set the stitch method to be used to handle parent class internals.""" if stitchMethodType not in stitchMethodTypes: raise ValueError(stitchMethodType + ' must be in ' + str(stitchMethodTypes)) else: - self.stitchMethodType =stitchMethodType + self.stitchMethodType = stitchMethodType - if self.stitchMethodType=='2stage': + if self.stitchMethodType == '2stage': self.description="Two-stage corrected/stiched Unwrapped Phase" - elif self.stitchMethodType=='overlap': + elif self.stitchMethodType == 'overlap': self.description = "Overlap-based stiched Unwrapped Phase" def setRedArcs(self, redArcs): - """ Set the Redundant Arcs to use for LP unwrapping """ + """Set the Redundant Arcs to use for LP unwrapping.""" self.redArcs = redArcs def setSolver(self, solver): - """ Set the solver to use for unwrapping """ + """Set the solver to use for unwrapping.""" self.solver = solver def setMask(self, mask): - """ Set the mask file """ + """Set the mask file.""" self.mask = mask def setOutputFormat(self,outputFormat): - """ Set the output format of the files to be generated """ - # File must be physically extracted, cannot proceed with VRT format. Defaulting to ENVI format. + """Set the output format of the files to be generated.""" + # File must be physically extracted, cannot proceed with VRT format. + # Defaulting to ENVI format. self.outputFormat = outputFormat if self.outputFormat=='VRT': self.outputFormat='ENVI' def setOutFileUnw(self,outFileUnw): - """ Set the output file name for the unwrapped stiched file to be generated""" + """Set the output file name for the unwrapped stitched file to + be generated. + """ self.outFileUnw = outFileUnw def setOutFileConnComp(self,outFileConnComp): - """ Set the output file name for the connected component stiched file to be generated""" + """Set the output file name for the connected component stitched + file to be generated. + """ self.outFileConnComp = outFileConnComp def setVerboseMode(self,verbose): - """ Set verbose output mode""" + """Set verbose output mode.""" logger.setLevel(logging.DEBUG) def __verifyInputs__(self): - ''' - Verify if the unwrapped and connected component inputs are gdal compatible. - That the provided shape files are well-formed. - If not remove them from the list to be stiched. - If a vrt exist and gdalcompatible update the file to be a vrt - ''' - # track a list of files to keep + """Verify if the unwrapped and connected component inputs are + GDAL-compatible and that the provided shape files are well-formed. + If not remove them from the list to be stitched. + If a VRT exists and is GDAL-compatible update the file to be a VRT. + """ + # Track a list of files to keep inpFile_keep = [] ccFile_keep = [] prodbboxFile_keep = [] @@ -164,7 +166,7 @@ def __verifyInputs__(self): # unw and corresponding conncomponent file inFile = self.inpFile[k_file] ccFile = self.ccFile[k_file] - # shape file in case passed through + # Shape file in case passed through if self.stitchMethodType == "overlap": prodbboxFile = self.prodbboxFile[k_file] @@ -197,46 +199,52 @@ def __verifyInputs__(self): bbox_keep.append(bbox_temp) prodbboxFile_keep.append(prodbboxFile_temp) - # update the input files and only keep those being GDAL compatible + # Update the input files and only keep those that are GDAL-compatible self.inpFile=inpFile_keep self.ccFile=ccFile_keep - # shape file in case passed through + + # Shape file in case passed through if self.stitchMethodType == "overlap": self.prodbbox=bbox_keep self.prodbboxFile=prodbboxFile_keep - # update the number of file in case some got removed + # Update the number of files in case some got removed self.nfiles = np.shape(self.inpFile)[0] - if self.nfiles==0: + if self.nfiles == 0: log.info('No files left after GDAL compatibility check') sys.exit(0) + def __createImages__(self): - ''' - This function will write the final merged unw and conencted component file. As intermediate step tiff files are generated with integer values which will represent the shift to be applied to connected componenet and the moduli shift to be applied to the unwrapped phase. - ''' + """This function will write the final merged unw and conencted + component file. As intermediate step tiff files are generated + with integer values which will represent the shift to be applied + to connected componenet and the moduli shift to be applied to + the unwrapped phase. + """ ## Will first make intermediate files in a temp folder. # For each product there will be a connComp file and 3 files related unw files. # The connected component file will show the unique wrt to all merged files # For the unwrapped related files, there will be an integer offset tif file, a vrt file which scale this integer map by 2pi, and a vrt which combines the orginal unw phase file with the scaled map. The latter will be used for merging of the unwrapped phase. - tempdir = tempfile.mkdtemp(prefix='IntermediateFiles_',dir='.') + tempdir = tempfile.mkdtemp(prefix='IntermediateFiles_', dir='.') - # will try multi-core version and default to for loop in case of failure + # Will try multi-core version and default to for loop in case of failure try: - # need to combine all inputs together as single argument tuple + # Need to combine all inputs together as single argument tuple all_inputs = () for counter in range(len(self.fileMappingDict)): fileMappingDict = self.fileMappingDict[counter] fileMappingDict['saveDir'] = tempdir fileMappingDict['saveNameID'] = "Product_" + str(counter) fileMappingDict['description'] = self.description - # parse inputs as a tuple + # Parse inputs as a tuple inputs = (fileMappingDict) - # append all tuples in a single tuple + # Append all tuples in a single tuple all_inputs = all_inputs + (inputs,) - # compute the phase value using multi-thread functionality - intermediateFiles = Parallel(n_jobs=-1,max_nbytes=1e6)(delayed(createConnComp_Int)(ii) for ii in all_inputs) + # Compute the phase value using multi-thread functionality + intermediateFiles = Parallel(n_jobs=-1, max_nbytes=1e6)\ + (delayed(createConnComp_Int)(ii) for ii in all_inputs) except: log.info('Multi-core version failed, will try single for loop') @@ -246,20 +254,20 @@ def __createImages__(self): fileMappingDict['saveDir'] = tempdir fileMappingDict['saveNameID'] = "Product_n" + str(counter) fileMappingDict['description'] = self.description - # parse inputs as a tuple + # Parse inputs as a tuple inputs = (fileMappingDict) - # compute the phase value + # Compute the phase value intermediateFiles_temp = createConnComp_Int(inputs) intermediateFiles.append(intermediateFiles_temp) - # combining all conComp and unw files that need to be blended - conCompFiles = [] + # Combine all connComp and unw files that need to be blended + connCompFiles = [] unwFiles = [] for intermediateFile in intermediateFiles: - conCompFiles.append(intermediateFile[0]) + connCompFiles.append(intermediateFile[0]) unwFiles.append(intermediateFile[1]) - # check if the folder exist to which files are being generated. + # Check if the folder to which files are being generated exists outPathUnw = os.path.dirname(os.path.abspath(self.outFileUnw)) outPathConnComp = os.path.dirname(os.path.abspath(self.outFileConnComp)) if not os.path.isdir(outPathUnw): @@ -268,23 +276,31 @@ def __createImages__(self): os.makedirs(outPathConnComp) ## Will now merge the unwrapped and connected component files - # remove existing output file(s) + # Remove existing unwrapped output file(s) for file in glob.glob(self.outFileUnw + "*"): os.remove(file) - gdal.BuildVRT(self.outFileUnw+'.vrt', unwFiles, options=gdal.BuildVRTOptions(srcNodata=0)) - gdal.Warp(self.outFileUnw, self.outFileUnw+'.vrt', options=gdal.WarpOptions(format=self.outputFormat, cutlineDSName=self.setTotProdBBoxFile, outputBounds=self.bbox_file)) + + # Create new output file(s) + vrtName = '{:s}.vrt'.format(self.outFileUnw) + gdal.BuildVRT(vrtName, unwFiles, + options=gdal.BuildVRTOptions(srcNodata=0)) + gdal.Warp(self.outFileUnw, vrtName, + options=gdal.WarpOptions(format=self.outputFormat, + cutlineDSName=self.setTotProdBBoxFile, + outputBounds=self.bbox_file)) # Update VRT - gdal.Translate(self.outFileUnw+'.vrt', self.outFileUnw, options=gdal.TranslateOptions(format="VRT")) - # Apply mask (if specified). + gdal.Translate(vrtName, self.outFileUnw, + options=gdal.TranslateOptions(format="VRT")) + # Apply mask (if specified) if self.mask is not None: - update_file=gdal.Open(self.outFileUnw,gdal.GA_Update) - update_file=update_file.GetRasterBand(1).WriteArray(self.mask.ReadAsArray()*gdal.Open(self.outFileUnw+'.vrt').ReadAsArray()) - update_file=None + update_file = gdal.Open(self.outFileUnw,gdal.GA_Update) + update_file = update_file.GetRasterBand(1).WriteArray(self.mask.ReadAsArray()*gdal.Open(self.outFileUnw+'.vrt').ReadAsArray()) + update_file = None - # remove existing output file(s) + # Remove existing connected component output file(s) for file in glob.glob(self.outFileConnComp + "*"): os.remove(file) - gdal.BuildVRT(self.outFileConnComp+'.vrt', conCompFiles, options=gdal.BuildVRTOptions(srcNodata=-1)) + gdal.BuildVRT(self.outFileConnComp+'.vrt', connCompFiles, options=gdal.BuildVRTOptions(srcNodata=-1)) gdal.Warp(self.outFileConnComp, self.outFileConnComp+'.vrt', options=gdal.WarpOptions(format=self.outputFormat, cutlineDSName=self.setTotProdBBoxFile, outputBounds=self.bbox_file)) # Update VRT gdal.Translate(self.outFileConnComp+'.vrt', self.outFileConnComp, options=gdal.TranslateOptions(format="VRT")) @@ -305,23 +321,21 @@ def __createImages__(self): class UnwrapOverlap(Stitching): - ''' - Stiching/unwrapping using product overlap minimization - ''' + """Stiching/unwrapping using product overlap minimization.""" def __init__(self): - ''' - Inheret properties from the parent class - Parse the filenames and bbox as None as they need to be set by the user, which will be caught when running the class - ''' + """Inherit properties from the parent class. + + Parse the filenames and bbox as None as they need to be set by + the user, which will be caught when running the class. + """ Stitching.__init__(self) def UnwrapOverlap(self): - - ## setting the method + ## Set the method self.setStitchMethod("overlap") - ## check if required inputs are set + ## Check if required inputs are set if self.inpFile is None: log.error("Input unwrapped file(s) is (are) not set.") raise Exception @@ -333,161 +347,231 @@ def UnwrapOverlap(self): raise Exception ## Verify if all the inputs are well-formed/GDAL compatible - # Update files to be vrt if they exist and remove files which failed the gdal compatibility + # Update files to be vrt if they exist and remove files which + # failed the gdal compatibility self.__verifyInputs__() - ## Calculating the number of phase cycles needed to miminize the residual between products + ## Calculating the number of phase cycles needed to miminize + # the residual between products self.__calculateCyclesOverlap__() ## Write out merged phase and connected component files self.__createImages__() - return def __calculateCyclesOverlap__(self): - '''Function that will calculate the number of cycles each component needs to be shifted in order to minimize the two-pi modulu residual between a neighboring component. Outputs a fileMappingDict with as key a file number. Within fileMappingDict with a integer phase shift value for each unique connected component. - ''' - - # only need to comptue the minimize the phase offset if the number of files is larger than 2 - if self.nfiles>1: - - # initiate the residuals and design matrix - residualcycles = np.zeros((self.nfiles-1,1)) - residualrange = np.zeros((self.nfiles-1,1)) - A = np.zeros((self.nfiles-1,self.nfiles)) - - # the files are already sorted in the ARIAproduct class, will make consecutive overlaps between these sorted products + """Function that will calculate the number of cycles each + component needs to be shifted in order to minimize the two-pi + modulo residual between a neighboring component. + + Outputs a fileMappingDict with a file number as a key. Within + fileMappingDict with a integer phase shift value for each unique + connected component. + """ + + # Only need to comptue the minimize the phase offset if the number + # of files is larger than 2 + if self.nfiles > 1: + # Initiate the residuals and design matrix + residualcycles = np.zeros((self.nfiles-1, 1)) + residualrange = np.zeros((self.nfiles-1, 1)) + A = np.zeros((self.nfiles-1, self.nfiles)) + + # The files are already sorted in the ARIAproduct class, will + # make consecutive overlaps between these sorted products for counter in range(self.nfiles-1): - # getting the two neighboring frames + # Getting the two neighboring frames bbox_frame1 = self.prodbbox[counter] bbox_frame2 = self.prodbbox[counter+1] - # determining the intersection between the two frames + # Determine the intersection between the two frames if not bbox_frame1.intersects(bbox_frame2): - log.error("Products do not overlap or were not provided in a contigious sorted list.") + log.error("Products do not overlap or were not " \ + "provided in a contigious sorted list.") raise Exception polyOverlap = bbox_frame1.intersection(bbox_frame2) - # will save the geojson under a temp local filename - tmfile = tempfile.NamedTemporaryFile(mode='w+b',suffix='.json', prefix='Overlap_', dir='.') + # Will save the geojson under a temp local filename + # Do this just to get the file outname + tmfile = tempfile.NamedTemporaryFile(mode='w+b', suffix='.json', + prefix='Overlap_', dir='.') outname = tmfile.name - # will remove it as GDAL polygonize function cannot overwrite files + + # Remove it as GDAL polygonize function cannot overwrite files tmfile.close() tmfile = None - # saving the temp geojson - save_shapefile(outname, polyOverlap, 'GeoJSON') - - # calculate the mean of the phase for each product in the overlap region alone - # will first attempt to mask out connected component 0, and default to complete overlap if this fails. - # Cropping the unwrapped phase and connected component to the overlap region alone, inhereting the no-data. - # connected component - out_data,connCompNoData1,geoTrans,proj = GDALread(self.ccFile[counter],data_band=1,loadData=False) - out_data,connCompNoData2,geoTrans,proj = GDALread(self.ccFile[counter+1],data_band=1,loadData=False) - connCompFile1 = gdal.Warp('', self.ccFile[counter], options=gdal.WarpOptions(format="MEM", cutlineDSName=outname, outputBounds=polyOverlap.bounds, dstNodata=connCompNoData1)) - connCompFile2 = gdal.Warp('', self.ccFile[counter+1], options=gdal.WarpOptions(format="MEM", cutlineDSName=outname, outputBounds=polyOverlap.bounds, dstNodata=connCompNoData2)) - - - # unwrapped phase - out_data,unwNoData1,geoTrans,proj = GDALread(self.inpFile[counter],data_band=1,loadData=False) - out_data,unwNoData2,geoTrans,proj = GDALread(self.inpFile[counter+1],data_band=1,loadData=False) - unwFile1 = gdal.Warp('', self.inpFile[counter], options=gdal.WarpOptions(format="MEM", cutlineDSName=outname, outputBounds=polyOverlap.bounds, dstNodata=unwNoData1)) - unwFile2 = gdal.Warp('', self.inpFile[counter+1], options=gdal.WarpOptions(format="MEM", cutlineDSName=outname, outputBounds=polyOverlap.bounds, dstNodata=unwNoData2)) + # Save the temp geojson + save_shapefile(outname, polyOverlap, 'GeoJSON') - # finding the component with the largest overlap - connCompData1 =connCompFile1.GetRasterBand(1).ReadAsArray() - connCompData1[(connCompData1==connCompNoData1) | (connCompData1==0)]=np.nan - connCompData2 =connCompFile2.GetRasterBand(1).ReadAsArray() - connCompData2[(connCompData2==connCompNoData2) | (connCompData2==0)]=np.nan + # Calculate the mean of the phase for each product in + # the overlap region alone. + # Will first attempt to mask out connected component 0, + # and default to complete overlap if this fails. + # Cropping the unwrapped phase and connected component + # to the overlap region alone, inhereting the no-data. + + # Connected component + out_data, connCompNoData1, geoTrans, proj = GDALread( + self.ccFile[counter], data_band=1, loadData=False) + out_data, connCompNoData2, geoTrans, proj = GDALread( + self.ccFile[counter+1], data_band=1, loadData=False) + + connCompFile1 = gdal.Warp('', self.ccFile[counter], + options=gdal.WarpOptions(format="MEM", + cutlineDSName=outname, + dstAlpha=True, + outputBounds=polyOverlap.bounds, + dstNodata=connCompNoData1)) + connCompFile2 = gdal.Warp('', self.ccFile[counter+1], + options=gdal.WarpOptions(format="MEM", + cutlineDSName=outname, + dstAlpha=True, + outputBounds=polyOverlap.bounds, + dstNodata=connCompNoData2)) + + # Reformat output bounds for GDAL translate + ulx, lry, lrx, uly = polyOverlap.bounds + projWin = (ulx, uly, lrx, lry) + + # Unwrapped phase + out_data, unwNoData1, geoTrans, proj = GDALread( + self.inpFile[counter], data_band=1, loadData=False) + out_data, unwNoData2, geoTrans, proj = GDALread( + self.inpFile[counter+1], data_band=1, loadData=False) + + unwFile1 = gdal.Translate('', self.inpFile[counter], + options=gdal.TranslateOptions(format="MEM", + projWin=projWin, + noData=unwNoData1)) + unwFile2 = gdal.Translate('', self.inpFile[counter+1], + options=gdal.TranslateOptions(format="MEM", + projWin=projWin, + noData=unwNoData2)) + + # Find the component with the largest overlap + connCompData1 = connCompFile1.GetRasterBand(1).ReadAsArray() + connCompData1[((connCompData1==connCompNoData1) + | (connCompData1==0))] = np.nan + connCompData2 = connCompFile2.GetRasterBand(1).ReadAsArray() + connCompData2[((connCompData2==connCompNoData2) + | (connCompData2==0))] = np.nan connCompData2_temp = (connCompData2*100) - temp = connCompData2_temp.astype(np.int)-connCompData1.astype(np.int) - temp[(temp<0) | (temp>2000)]=0 - temp_count = collections.Counter(temp.flatten()) + connCompDiff = (connCompData2_temp.astype(np.int) + - connCompData1.astype(np.int)) + connCompDiff[(connCompDiff<0) | (connCompDiff>2000)] = 0 + temp_count = collections.Counter(connCompDiff.flatten()) maxKey = 0 maxCount = 0 for key, keyCount in temp_count.items(): - if key!=0: - if keyCount>maxCount: - maxKey =key - maxCount=keyCount - - # if the max key count is 0, this means there is no good overlap region between products. + if key != 0: + if keyCount > maxCount: + maxKey = key + maxCount = keyCount + print('Max key, count:', maxKey, maxCount) + + # If the max key count is 0, this means there is no good + # overlap region between products. # In that scenario default to different stitching approach. - if maxKey!=0 and maxCount>75: - # masking the unwrapped phase and only use the largest overlapping connected component + if maxKey != 0 and maxCount > 75: + # Masking the unwrapped phase and only use the + # largest overlapping connected component unwData1 = unwFile1.GetRasterBand(1).ReadAsArray() - unwData1[(unwData1==unwNoData1) | (temp!=maxKey)]=np.nan - unwData2 =unwFile2.GetRasterBand(1).ReadAsArray() - unwData2[(unwData2==unwNoData2) | (temp!=maxKey)]=np.nan + unwData2 = unwFile2.GetRasterBand(1).ReadAsArray() + + cutlineMask1 = connCompFile1.GetRasterBand(2).ReadAsArray() + cutlineMask2 = connCompFile2.GetRasterBand(2).ReadAsArray() + + unwData1[((unwData1==unwNoData1) + | (connCompDiff!=maxKey) + | (cutlineMask1==0))] = np.nan + unwData2[((unwData2==unwNoData2) + | (connCompDiff!=maxKey) + | (cutlineMask2==0))] = np.nan # Calculation of the range correction - unwData1_wrapped = unwData1-np.round(unwData1/(2*np.pi))*(2*np.pi) - unwData2_wrapped =unwData2-np.round(unwData2/(2*np.pi))*(2*np.pi) - arr =unwData1_wrapped-unwData2_wrapped + unwData1_wrapped = (unwData1 + - np.round(unwData1/(2*np.pi))*(2*np.pi)) + unwData2_wrapped = (unwData2 + - np.round(unwData2/(2*np.pi))*(2*np.pi)) + arr = unwData1_wrapped - unwData2_wrapped - # data is not fully decorrelated + # Data is not fully decorrelated arr = arr - np.round(arr/(2*np.pi))*2*np.pi - range_temp = np.angle(np.nanmean(np.exp(1j*arr))) + range_temp = np.angle(np.nanmean(np.exp(1j*arr))) - # calculation of the number of 2 pi cycles accounting for range correction - cycles_temp = np.round((np.nanmean(unwData1-(unwData2+range_temp)))/(2*np.pi)) + # Calculation of the number of 2PI cycles accounting + # for range correction + corrected_range = unwData1 - (unwData2+range_temp) + cycles_temp = np.round((np.nanmean(corrected_range))/(2*np.pi)) + print(cycles_temp) else: - # account for the case that no-data was left, e.g. fully decorrelated - # in that scenario use all data and estimate from wrapped, histogram will be broader... + # Account for the case that no data was left, e.g. + # fully decorrelated + # In that scenario use all data and estimate from + # wrapped, histogram will be broader ... unwData1 = unwFile1.GetRasterBand(1).ReadAsArray() - unwData1[(unwData1==unwNoData1)] - unwData2 =unwFile2.GetRasterBand(1).ReadAsArray() - unwData2[(unwData2==unwNoData2)] + unwData2 = unwFile2.GetRasterBand(1).ReadAsArray() + + cutlineMask1 = connCompFile1.GetRasterBand(2).ReadAsArray() + cutlineMask2 = connCompFile2.GetRasterBand(2).ReadAsArray() + + unwData1[((unwData1==unwNoData1) + | (cutlineMask1==0))] = np.nan + unwData2[((unwData2==unwNoData2) + | (cutlineMask2==0))] = np.nan + # Calculation of the range correction - unwData1_wrapped = unwData1-np.round(unwData1/(2*np.pi))*(2*np.pi) - unwData2_wrapped =unwData2-np.round(unwData2/(2*np.pi))*(2*np.pi) - arr =unwData1_wrapped-unwData2_wrapped + unwData1_wrapped = (unwData1 + - np.round(unwData1/(2*np.pi))*(2*np.pi)) + unwData2_wrapped = (unwData2 + - np.round(unwData2/(2*np.pi))*(2*np.pi)) + arr = unwData1_wrapped - unwData2_wrapped arr = arr - np.round(arr/(2*np.pi))*2*np.pi - range_temp = np.angle(np.nanmean(np.exp(1j*arr))) + range_temp = np.angle(np.nanmean(np.exp(1j*arr))) - # data is decorelated assume no 2-pi cycles + # Data is decorelated assume no 2-pi cycles cycles_temp = 0 - - # closing the files + # Closing the files unwFile1 = None unwFile2 = None connCompFile1 = None connCompFile2 = None - # remove the tempfile + # Remove the tempfile shutil.os.remove(outname) - # store the residual and populate the design matrix - residualcycles[counter]=cycles_temp - residualrange[counter]=range_temp - A[counter,counter]=1 - A[counter,counter+1]=-1 - - - # invert the offsets with respect to the first product - cycles = np.round(np.linalg.lstsq(A[:,1:], residualcycles,rcond=None)[0]) - rangesoffset = np.linalg.lstsq(A[:,1:], residualrange,rcond=None)[0] + # Store the residual and populate the design matrix + residualcycles[counter] = cycles_temp + residualrange[counter] = range_temp + A[counter,counter] = 1 + A[counter,counter+1] = -1 + + # Invert the offsets with respect to the first product + cycles = np.round(np.linalg.lstsq(A[:,1:], + residualcycles,rcond=None)[0]) + rangesoffset = np.linalg.lstsq(A[:,1:], + residualrange,rcond=None)[0] #pdb.set_trace() - # force first product to have 0 as offset + # Force first product to have 0 as offset cycles = -1*np.concatenate((np.zeros((1,1)), cycles), axis=0) rangesoffset = -1*np.concatenate((np.zeros((1,1)), rangesoffset), axis=0) else: - # nothing to be done, i.e. no phase cycles to be added + # Nothing to be done, i.e. no phase cycles to be added cycles = np.zeros((1,1)) rangesoffset = np.zeros((1,1)) - # build the mapping dictionary + # Build the mapping dictionary fileMappingDict = {} - connCompOffset =0 + connCompOffset = 0 for fileCounter in range(self.nfiles): - - # get the number of connected components + # Get the number of connected components n_comp = 20 # The original connected components @@ -501,42 +585,43 @@ def __calculateCyclesOverlap__(self): # Generate the mapping of connectedComponents # Increment based on unique components of the merged product, 0 is grouped for all products connCompMapping = connComp - connCompMapping[1:]=connComp[1:]+connCompOffset + connCompMapping[1:] = connComp[1:]+connCompOffset # concatenate and generate a connComponent based mapping matrix # [original comp, new component, 2pi unw offset] - connCompMapping = np.concatenate((connComp, connCompMapping,cycleMapping,rangeOffsetMapping), axis=1) + connCompMapping = np.concatenate( + (connComp, connCompMapping, cycleMapping, rangeOffsetMapping), + axis=1) # Increment the count of total number of unique components - connCompOffset = connCompOffset+n_comp + connCompOffset = connCompOffset + n_comp - # populate mapping dictionary for each product + # Populate mapping dictionary for each product fileMappingDict_temp = {} fileMappingDict_temp['connCompMapping'] = connCompMapping - fileMappingDict_temp['connFile'] = self.ccFile[fileCounter] + fileMappingDict_temp['connFile'] = self.ccFile[fileCounter] fileMappingDict_temp['unwFile'] = self.inpFile[fileCounter] - # store it in the general mapping dictionary - fileMappingDict[fileCounter]=fileMappingDict_temp - # pass the fileMapping back into self + # Store it in the general mapping dictionary + fileMappingDict[fileCounter] = fileMappingDict_temp + + # Pass the fileMapping back into self self.fileMappingDict = fileMappingDict + class UnwrapComponents(Stitching): - ''' - Stiching/unwrapping using 2-Stage Phase Unwrapping - ''' + """Stiching/unwrapping using 2-Stage Phase Unwrapping.""" def __init__(self): - ''' - Inheret properties from the parent class - Parse the filenames and bbox as None as they need to be set by the user, which will be caught when running the class - ''' + """Inheret properties from the parent class. + Parse the filenames and bbox as None as they need to be set by + the user, which will be caught when running the class. + """ Stitching.__init__(self) def unwrapComponents(self): - - ## setting the method + ## Setting the method self.setStitchMethod("2stage") self.region=5 @@ -579,7 +664,6 @@ def unwrapComponents(self): ## Write out merged phase and connected component files self.__createImages__() - return def __populatePolyTable__(self): @@ -1112,13 +1196,16 @@ def GDALread(filename,data_band=1,loadData=True): def createConnComp_Int(inputs): - ''' - Function to generate intermediate connected component files and unwrapped VRT files that have with interger 2pi pixel shift applied. - Will parse inputs in a single argument as it allows for parallel processing. - Return a list of files in a unqiue temp folder - ''' + """Function to generate intermediate connected component files and + unwrapped VRT files that have with interger 2pi pixel shift applied. - # parsing the inputs to variables + Will parse inputs in a single argument as it allows for parallel + processing. + + Return a list of files in a unqiue temp folder. + """ + + # Parse the inputs to variables saveDir = inputs['saveDir'] saveNameID = inputs['saveNameID'] connFile = inputs['connFile'] @@ -1127,51 +1214,62 @@ def createConnComp_Int(inputs): ## Generating the intermediate files ## STEP 1: set-up the mapping functions - # loading the connected component - connData,connNoData,connGeoTrans,connProj = GDALread(connFile) + # Load the connected component + connData, connNoData, connGeoTrans, connProj = GDALread(connFile) - ## Defining the mapping tables - # setting up the connected component unique ID mapping + ## Define the mapping tables + # Set up the connected component unique ID mapping connIDMapping = connCompMapping[:,1] - # Will add the no-data to the mapping as well (such we can handle no-data region) - # i.e. num comp + 1 = Nodata connected component which gets mapped to no-data value again + + # Add the no-data to the mapping as well (such that we can handle a + # no-data region) + # I.e. num comp + 1 = Nodata connected component which gets mapped + # to no-data value again NoDataMapping = len(connIDMapping) connIDMapping = np.append(connIDMapping, [connNoData]) - # setting up the connected component integer 2PI shift mapping + # Set up the connected component integer 2PI shift mapping intMapping = connCompMapping[:,2] - # Will add the no-data to the mapping as well (such we can handle no-data region) - # i.e. max comp ID + 1 = Nodata connected component which gets mapped to 0 integer shift such no-data region remains unaffected + + # Add the no-data to the mapping as well (such that we can handle a + # no-data region) + # I.e. max comp ID + 1 = Nodata connected component which gets + # mapped to 0 integer shift such that the no-data region remains + # unaffected intMapping = np.append(intMapping, [0]) - # update the connected component with the new no-data value used in the mapping - connData[connData==connNoData]=NoDataMapping + # Update the connected component with the new no-data value used + # in the mapping + connData[connData==connNoData] = NoDataMapping ## STEP 2: apply the mapping functions - # interger 2PI scaling mapping for unw phase - + # Interger 2PI scaling mapping for unw phase intShift = intMapping[connData.astype('int')] - # connected component mapping to unique ID - connData = connIDMapping[connData.astype('int')] - ## STEP 3: writing out the datasets - # writing out the unqiue ID connected component file - connDataName = os.path.abspath(os.path.join(saveDir,'connComp', saveNameID + '_connComp.tif')) - write_ambiguity(connData,connDataName,connProj,connGeoTrans,connNoData) + # Connected component mapping to unique ID + connData = connIDMapping[connData.astype('int')] - # writing out the integer map as tiff file - intShiftName = os.path.abspath(os.path.join(saveDir,'unw',saveNameID + '_intShift.tif')) - write_ambiguity(intShift,intShiftName,connProj,connGeoTrans) + ## STEP 3: write out the datasets + # Write out the unqiue ID connected component file + connDataName = os.path.abspath(os.path.join(saveDir, 'connComp', + saveNameID+'_connComp.tif')) + write_ambiguity(connData, connDataName, connProj, connGeoTrans, connNoData) + # Write out the integer map as tiff file + intShiftName = os.path.abspath(os.path.join(saveDir, 'unw', + saveNameID+'_intShift.tif')) + write_ambiguity(intShift, intShiftName, connProj, connGeoTrans) - # writing out the scalled vrt => 2PI * integer map + # Write out the scalled vrt => 2PI * integer map length = intShift.shape[0] width = intShift.shape[1] - scaleVRTName = os.path.abspath(os.path.join(saveDir,'unw',saveNameID + '_scale.vrt')) - build2PiScaleVRT(scaleVRTName,intShiftName,length=length,width=width) + scaleVRTName = os.path.abspath( + os.path.join(saveDir, 'unw', saveNameID+'_scale.vrt')) + build2PiScaleVRT(scaleVRTName, intShiftName, length=length, width=width) - # Offseting the vrt for the range offset correctiom - unwRangeOffsetVRTName = os.path.abspath(os.path.join(saveDir,'unw',saveNameID + '_rangeOffset.vrt')) + # Offset the VRT for the range offset correction + unwRangeOffsetVRTName = os.path.abspath( + os.path.join(saveDir, 'unw', saveNameID + '_rangeOffset.vrt')) buildScaleOffsetVRT(unwRangeOffsetVRTName,unwFile,connProj,connGeoTrans,File1_offset=connCompMapping[1,3],length=length,width=width) # writing out the corrected unw phase vrt => phase + 2PI * integer map @@ -1180,49 +1278,45 @@ def createConnComp_Int(inputs): return [connDataName, unwVRTName] -def write_ambiguity(data, outName,proj, geoTrans,noData=False): - ''' - Write out an integer mapping in the Int16/Byte data range of values - ''' - +def write_ambiguity(data, outName,proj, geoTrans, noData=False): + """Write out an integer mapping in the Int16/Byte data range of values.""" # GDAL precision support in tif Byte = gdal.GDT_Byte Int16 = gdal.GDT_Int16 - # check if the path to the file needs to be created + # Check if the path to the file needs to be created dirname = os.path.dirname(outName) if not os.path.isdir(dirname): os.makedirs(dirname) - # Getting the GEOTIFF driver + # Get the GEOTIFF driver driver = gdal.GetDriverByName('GTIFF') - # leverage the compression option to ensure small file size + # Leverage the compression option to ensure small file size dst_options = ['COMPRESS=LZW'] - # create the dataset - ds = driver.Create(outName , data.shape[1], data.shape[0], 1, Int16, dst_options) - # setting the proj and transformation + # Create the dataset + ds = driver.Create(outName , data.shape[1], data.shape[0], 1, Int16, + dst_options) + # Set the proj and transformation ds.SetGeoTransform(geoTrans) ds.SetProjection(proj) - # populate the first band with data + # Populate the first band with data bnd = ds.GetRasterBand(1) bnd.WriteArray(data) - # setting the no-data value + # Set the no-data value if noData is not None: bnd.SetNoDataValue(noData) bnd.FlushCache() - # close the file + # Close the file ds = None -def build2PiScaleVRT(output,File,width=False,length=False): - ''' - Building a VRT file which scales a GDAL byte file with 2PI - ''' +def build2PiScaleVRT(output, File, width=False, length=False): + """Build a VRT file which scales a GDAL byte file with 2PI.""" # DBTODO: The datatype should be loaded by default from the source raster to be applied. # should be ok for now, but could be an issue for large connected com - # the vrt template with 2-pi scaling functionality + # The VRT template with 2PI scaling functionality vrttmpl = ''' @@ -1235,28 +1329,28 @@ def build2PiScaleVRT(output,File,width=False,length=False): ''' - # the inputs needed to build the vrt - # load the width and length from the GDAL file in case not specified + # The inputs needed to build the VRT + # Load the width and length from the GDAL file in case not specified if not width or not length: ds = gdal.Open(File, gdal.GA_ReadOnly) width = ds.RasterXSize ysize = ds.RasterYSize ds = None - # check if the path to the file needs to be created + # Check if the path to the file needs to be created dirname = os.path.dirname(output) if not os.path.isdir(dirname): os.makedirs(dirname) - # write out the VRT file + # Write out the VRT file with open(output, 'w') as fid: fid.write(vrttmpl.format(width = width, length = length, File = File)) def buildScaleOffsetVRT(output,File1,proj,geoTrans,File1_offset=0, File1_scale = 1, width=False,length=False,description='Scalled and offsetted VRT'): - ''' - Building a VRT file which sums two files together using pixel functionality - ''' + """Building a VRT file which sums two files together using pixel + functionality. + """ # the vrt template with sum pixel functionality vrttmpl = ''' @@ -1275,30 +1369,29 @@ def buildScaleOffsetVRT(output,File1,proj,geoTrans,File1_offset=0, File1_scale = ''' # - - # the inputs needed to build the vrt - # load the width and length from the GDAL file in case not specified + # The inputs needed to build the VRT + # Load the width and length from the GDAL file in case not specified if not width or not length: ds = gdal.Open(File1, gdal.GA_ReadOnly) width = ds.RasterXSize ysize = ds.RasterYSize ds = None - # check if the path to the file needs to be created + # Check if the path to the file needs to be created dirname = os.path.dirname(output) if not os.path.isdir(dirname): os.makedirs(dirname) - # write out the VRT file + # Write out the VRT file with open('{0}'.format(output) , 'w') as fid: fid.write( vrttmpl.format(width=width,length=length,File1=File1,File1_offset=File1_offset, File1_scale = File1_scale, proj=proj,geoTrans=str(geoTrans)[1:-1],description=description)) def buildSumVRT(output,File1,File2,proj,geoTrans,length=False, width=False,description='Unwrapped Phase'): - ''' - Building a VRT file which sums two files together using pixel functionality - ''' + """Building a VRT file which sums two files together using pixel + functionality. + """ - # the vrt template with sum pixel functionality + # The VRT template with sum pixel functionality vrttmpl = ''' {proj} {geoTrans} @@ -1316,20 +1409,20 @@ def buildSumVRT(output,File1,File2,proj,geoTrans,length=False, width=False,descr ''' - # the inputs needed to build the vrt - # load the width and length from the GDAL file in case not specified + # The inputs needed to build the vrt + # Load the width and length from the GDAL file in case not specified if not width or not length: ds = gdal.Open(File1, gdal.GA_ReadOnly) width = ds.RasterXSize ysize = ds.RasterYSize ds = None - # check if the path to the file needs to be created + # Check if the path to the file needs to be created dirname = os.path.dirname(output) if not os.path.isdir(dirname): os.makedirs(dirname) - # write out the VRT file + # Write out the VRT file with open('{0}'.format(output) , 'w') as fid: fid.write( vrttmpl.format(width=width,length=length,File1=File1,File2=File2, proj=proj,geoTrans=str(geoTrans)[1:-1],description=description)) @@ -1432,9 +1525,12 @@ def gdalTest(file, verbose=False): -def product_stitch_overlap(unw_files, conn_files, prod_bbox_files, bbox_file, prods_TOTbbox, outFileUnw = './unwMerged', outFileConnComp = './connCompMerged', outputFormat='ENVI', mask=None, verbose=False): +def product_stitch_overlap(unw_files, conn_files, prod_bbox_files, bbox_file, + prods_TOTbbox, outFileUnw = './unwMerged', + outFileConnComp = './connCompMerged', outputFormat='ENVI', + mask=None, verbose=False): ''' - Stitching of products minimizing overlap betnween products + Stitching of products minimizing overlap betnween products ''' # report method to user From 3b0f30d91f94cb013781b14c725676786e2de5cd Mon Sep 17 00:00:00 2001 From: rzinke Date: Wed, 30 Mar 2022 17:02:15 -0700 Subject: [PATCH 3/4] Replaced gdalWarp with gdalTranslate in unwrapStitching --- tools/ARIAtools/extractProduct.py | 288 +++++++++++++++++++---------- tools/ARIAtools/unwrapStitching.py | 2 +- 2 files changed, 187 insertions(+), 103 deletions(-) diff --git a/tools/ARIAtools/extractProduct.py b/tools/ARIAtools/extractProduct.py index 392d547d..c323cde2 100755 --- a/tools/ARIAtools/extractProduct.py +++ b/tools/ARIAtools/extractProduct.py @@ -12,14 +12,15 @@ from osgeo import gdal, osr import logging import requests -from ARIAtools.logger import logger +from ARIAtools.logger import logger from ARIAtools.shapefile_util import open_shapefile, chunk_area from ARIAtools.mask_util import prep_mask from ARIAtools.unwrapStitching import product_stitch_overlap, product_stitch_2stage gdal.UseExceptions() -#Suppress warnings + +# Suppress warnings gdal.PushErrorHandler('CPLQuietErrorHandler') log = logging.getLogger(__name__) @@ -107,7 +108,7 @@ def __call__(self, line, pix, h): class metadata_qualitycheck: """Metadata quality control function. - + Artifacts recognized based off of covariance of cross-profiles. Bug-fix varies based off of layer of interest. Verbose mode generates a series of quality control plots with @@ -133,7 +134,7 @@ def __init__(self, data_array, prod_key, outname, verbose=None): def __truncateArray__(self, data_array_band, Xmask, Ymask): # Mask columns/rows which are entirely made up of 0s - #first must crop all columns with no valid values + # First must crop all columns with no valid values nancols=np.all(data_array_band.mask == True, axis=0) data_array_band=data_array_band[:,~nancols] Xmask=Xmask[:,~nancols] @@ -155,10 +156,10 @@ def __getCovar__(self, prof_direc, profprefix=''): self.data_array_band, Xmask, Ymask = self.__truncateArray__( self.data_array_band, Xmask, Ymask) - #append prefix for plot names + # Append prefix for plot names prof_direc = profprefix + prof_direc - #iterate through transpose of matrix if looking in azimuth + # Iterate through transpose of matrix if looking in azimuth arrT='' if 'azimuth' in prof_direc: arrT='.T' @@ -168,35 +169,36 @@ def __getCovar__(self, prof_direc, profprefix=''): for i in enumerate(eval('self.data_array_band%s'%(arrT))): mid_line=i[1] xarr=np.array(range(len(mid_line))) - #remove masked values from slice + # Remove masked values from slice if mid_line.mask.size!=1: if True in mid_line.mask: xarr=xarr[~mid_line.mask] mid_line=mid_line[~mid_line.mask] - #chunk array to better isolate artifacts + # Chunk array to better isolate artifacts chunk_size= 4 for j in range(0, len(mid_line.tolist()), chunk_size): chunk = mid_line.tolist()[j:j+chunk_size] xarr_chunk = xarr[j:j+chunk_size] - # make sure each iteration contains at least minimum number of elements + # Make sure each iteration contains at least minimum number of elements if j==range(0, len(mid_line.tolist()), chunk_size)[-2] and \ len(mid_line.tolist()) % chunk_size != 0: chunk = mid_line.tolist()[j:] xarr_chunk = xarr[j:] - #linear regression and get covariance + # Linear regression and get covariance slope, bias, rsquared, p_value, std_err = linregress(xarr_chunk,chunk) rsquaredarr.append(abs(rsquared)**2) std_errarr.append(std_err) - #terminate early if last iteration would have small chunk size + # Terminate early if last iteration would have small chunk size if len(chunk)>chunk_size: break - #exit loop/make plots in verbose mode if R^2 and standard error anomalous, or if on last iteration + # Exit loop/make plots in verbose mode if R^2 and standard error + # anomalous, or if on last iteration if (min(rsquaredarr) < 0.9 and max(std_errarr) > 0.01) or \ (i[0]==(len(eval('self.data_array_band%s'%(arrT)))-1)): if self.verbose: - #Make quality-control plots + # Make quality-control plots import matplotlib.pyplot as plt ax0=plt.figure().add_subplot(111) ax0.scatter(xarr, mid_line, c='k', s=7) @@ -284,13 +286,13 @@ def __run__(self): self.data_array_band, Xmask, Ymask) # truncated grid covering the domain of the data - Xmask=Xmask[~self.data_array_band.mask] - Ymask=Ymask[~self.data_array_band.mask] + Xmask = Xmask[~self.data_array_band.mask] + Ymask = Ymask[~self.data_array_band.mask] self.data_array_band = self.data_array_band[~self.data_array_band.mask] XX = Xmask.flatten() YY = Ymask.flatten() A = np.c_[XX, YY, np.ones(len(XX))] - C,_,_,_ = lstsq(A, self.data_array_band.data.flatten()) + C, _, _, _ = lstsq(A, self.data_array_band.data.flatten()) # evaluate it on grid self.data_array_band = C[0]*X + C[1]*Y + C[2] #mask by nodata value @@ -303,7 +305,7 @@ def __run__(self): #update band self.data_array.GetRasterBand(i).WriteArray(self.data_array_band.filled()) # Pass warning and get R^2/standard error across range/azimuth (only do for first band) - if i==1: + if i == 1: # make sure appropriate unit is passed to print statement lyrunit = "\N{DEGREE SIGN}" if self.prod_key=='bPerpendicular' or self.prod_key=='bParallel': @@ -441,9 +443,9 @@ def merged_productbbox(metadata_dict, product_dict, workdir='./', bbox_file=None # If specified, check if user's bounding box meets minimum threshold area if bbox_file is not None: - user_bbox=open_shapefile(bbox_file, 0, 0) - overlap_area=shapefile_area(user_bbox) - if overlap_area Date: Wed, 6 Apr 2022 12:39:13 -0700 Subject: [PATCH 4/4] Overhaul of computeMisclosure to fix verbosity and query point bugs, and update style --- tools/ARIAtools/computeMisclosure.py | 1257 ++++++++++++++++---------- tools/ARIAtools/extractProduct.py | 304 +++---- tools/ARIAtools/unwrapStitching.py | 626 ++++++------- 3 files changed, 1138 insertions(+), 1049 deletions(-) mode change 100755 => 100644 tools/ARIAtools/extractProduct.py mode change 100755 => 100644 tools/ARIAtools/unwrapStitching.py diff --git a/tools/ARIAtools/computeMisclosure.py b/tools/ARIAtools/computeMisclosure.py index c8c2a15d..34865f89 100644 --- a/tools/ARIAtools/computeMisclosure.py +++ b/tools/ARIAtools/computeMisclosure.py @@ -70,12 +70,14 @@ def createParser(): formatter_class=argparse.RawTextHelpFormatter, epilog=Examples) # Input data - parser.add_argument('-f', '--file', dest='imgfile', type=str, required=True, + parser.add_argument('-f', '--file', dest='unwFile', type=str, required=True, help='ARIA files. Specify the stack/unwrapStack.vrt file, or a wildcard operator in the unwrappedPhase folder (see EXAMPLES)') parser.add_argument('-w', '--workdir', dest='workdir', type=str, default='./', help='Specify directory to deposit all outputs. Default is local directory where script is launched.') + parser.add_argument('--coherence', dest='cohFile', type=str, default=None, + help='Coherence stack for use in automatic reference point selection.') - parser.add_argument('--startdate', dest='startDate', type=str, default='20140615', + parser.add_argument('--startdate', dest='startDate', type=str, default=None, help='Start date for data series') parser.add_argument('--enddate', dest='endDate', type=str, default=None, help='End date for data series') @@ -112,10 +114,6 @@ def createParser(): parser.add_argument('-v','--verbose', dest='verbose', action='store_true', help='Verbose mode') # Misclosure map formatting - parser.add_argument('--pctmin', dest='pctMinClip', type=float, default=1, - help='Minimum percent clip value for cumulative misclosure plot') - parser.add_argument('--pctmax', dest='pctMaxClip', type=float, default=99, - help='Maximum percent clip value for cumulative misclosure plot') parser.add_argument('--plot-time-intervals', dest='plotTimeIntervals', action='store_true', help='Plot triplet intervals in misclosure analysis figure.') @@ -129,202 +127,319 @@ def cmdLineParse(iargs = None): ### STACK OBJECT --- -class stack: +class MisclosureStack: + '''Class for loading and storing stack data for phase triplet + misclosure analysis. ''' - Class for loading and storing stack data. - ''' - ## Load data - def __init__(self,imgfile,workdir='./', - startDate='20140615',endDate=None,excludePairs=None, - verbose=False): - ''' - Initialize object. Store essential info for posterity. - loadStackData() - Load data from unwrapStack.vrt using gdal. - formatDates() - Determine the IFG pairs and list of unique dates from the data set. - formatExcludePairs() - Load and format pairs to exclude, if provided. + def __init__(self, + unwFile, + cohFile=None, + workdir='./', + startDate=None, endDate=None, + excludePairs=None, + verbose=False): + '''Initialize object. Store essential info for posterity. + loadStackData() - Load data from unwrapStack.vrt using gdal + formatDates() - Determine the IFG pairs and list of unique dates + from the data set + formatExcludePairs() - Load and format pairs to exclude, if provided ''' + if verbose == True: + print('Initializing stack for misclosure analysis') + + # Verbosity + self.verbose = verbose + if self.verbose: + logger.setLevel(logging.DEBUG) + # Files and directories - self.imgfile = os.path.abspath(imgfile) - self.basename = os.path.basename(self.imgfile) - self.imgdir = os.path.dirname(self.imgfile) + self.unwFile = os.path.abspath(unwFile) self.workdir = os.path.abspath(workdir) # Check if output directory exists if not os.path.exists(self.workdir): os.mkdir(self.workdir) - # Dates and pairs - self.startDate = datetime.strptime(startDate,'%Y%m%d') - if not endDate: - self.endDate = datetime.now() - else: - self.endDate = datetime.strptime(endDate,'%Y%m%d') + # Read stack data and retrieve list of dates + self.__loadUnwStack__() - self.excludePairs = excludePairs + # Load coherence file if specified + if cohFile is not None: + self.__loadCohStack__(cohFile) - # Other - if self.verbose: logger.setLevel(logging.DEBUG) + # Format dates + self.__formatDates__(startDate, endDate) + # Format pairs to exclude, if provided + self.__formatExcludePairs__(excludePairs) - # Read stack data and retrieve list of dates - self.__loadStackData__() + def __loadUnwStack__(self): + '''Load data from unwrapStack.vrt file.''' + # Open data set + self.unwStack = gdal.Open(self.unwFile, gdal.GA_ReadOnly) - # Format dates - self.__formatDates__() + # Format extent + self.__formatGeoInfo__() - # Format pairs to exclude, if provided - self.__formatExcludePairs__() + # Report if requested + log.debug('%s bands detected', self.unwStack.RasterCount) + def __loadCohStack__(self, cohFile): + '''Load data from cohStack.vrt file.''' + cohFile = os.path.abspath(cohFile) - # Load data from unwrapStack.vrt - def __loadStackData__(self): - ''' - Load data from unwrapStack.vrt file. - ''' - # Open dataset - self.IFGs = gdal.Open(self.imgfile,gdal.GA_ReadOnly) + # Load GDAL data set + cohStack = gdal.Open(cohFile, gdal.GA_ReadOnly) - # Format extent - N = self.IFGs.RasterXSize; M = self.IFGs.RasterYSize - tnsf = self.IFGs.GetGeoTransform() - left = tnsf[0]; xstep = tnsf[1]; right = left+N*xstep - top = tnsf[3]; ystep = tnsf[5]; bottom = top+M*ystep + # Image bands + imgs = np.array([cohStack.GetRasterBand(i+1).ReadAsArray() for i in \ + range(cohStack.RasterCount)]) + + # Mean coherence + self.meanCoh = np.mean(imgs, axis=0) + + log.debug('Coherence stack loaded and averaged') + + + ## Geographic information + def __formatGeoInfo__(self): + '''Parse the spatial metadata associated with the GDAL data set.''' + # Image sizes + self.N = self.unwStack.RasterXSize + self.M = self.unwStack.RasterYSize + + # Get projection for later + self.proj = self.unwStack.GetProjection() + + # Get geographic transform + self.tnsf = self.unwStack.GetGeoTransform() + + # Parse geotransform + left, xstep, xskew, top, yskew, ystep = self.tnsf + + # Re-format geotransform as matrix + self.tnsfMatrix = np.array([[xstep, yskew], + [xskew, ystep]]) + + # Origin coordinates as vector + self.tnsfOrigin = np.array([[left, top]]).T + + # Plot extent + right = left + xstep*self.N + bottom = top + ystep*self.M self.extent = (left, right, bottom, top) - # Report if requested - log.debug('%s bands detected', self.IFGs.RasterCount) + def xy2lola(self, x, y): + '''Convert X/Y to lon/lat.''' + # Reshape points as vector + p = np.array([[x, y]]).T + + # Calculate geographic position + lon, lat = self.tnsfMatrix.dot(p) + self.tnsfOrigin + + # Flatten arrays + lon = lon.flatten() + lat = lat.flatten() + + return lon, lat + + def lola2xy(self, lon, lat): + '''Convert lon/lat coordinates to XY.''' + # Reshape points as vector + L = np.array([[lon, lat]]).T + + # Calculate point coordinates + px, py = np.linalg.inv(self.tnsfMatrix).dot(L - self.tnsfOrigin) + + # Convert pixel coordinates to integers + px = int(px.flatten()) + py = int(py.flatten()) + + return px, py + + + ## Date and date pair formatting + def __formatDates__(self, startDate, endDate): + '''Retrieve list of date pairs and unique dates (epochs). + The "pairs" attribute is a formatted list of interferogram date + pairs **in the order in which they are stored in the unwrapStack.vrt** + file. This list should not be modified. - # Convert date pair to string - def __datePair2strPair__(self,datePair): + Constrain the list of epochs available for triplet determination + using the "startDate" and "endDate" provided in the __init__ + function. ''' - Convert pair in format [master, slave] to date in format 'master_slave' + log.debug('Formatting dates') + + # PairNames - list of pairs composing the data set in the order + # they are written + pairNames = [os.path.basename(fname) for fname in \ + self.unwStack.GetFileList()] + + # Remove extra file name + pairNames.remove('unwrapStack.vrt') + + # Remove extensions + pairNames = [pairName.strip('.vrt') for pairName in pairNames] + + # Convert pair name strings to datetime objects + self.datePairs = self.__pairNames2datePairs__(pairNames) + + # Get unique dates from date pairs + self.dates = self.__uniqueDatesFromPairs__(self.datePairs) + + # Format start and end dates + self.__applyStartEndDates__(startDate, endDate) + + # Number of dates within start-end range + self.nDates = len(self.dates) + + log.debug('%s unique dates detected', self.nDates) + + def __pairNames2datePairs__(self, pairNames): + '''Convert list of pairs in format ['master_slave','master_slave',...] + to dates in format [[master, slave], [master, slave], ...] ''' - masterStr = datePair[0].strftime('%Y%m%d') - secondaryStr = datePair[1].strftime('%Y%m%d') - strPair = '{}_{}'.format(masterStr, secondaryStr) - return strPair + datePairs = [self.__strPair2datePair__(pairName) for pairName \ + in pairNames] + + return datePairs - # Convert pair string to date list def __strPair2datePair__(self,pair): - ''' - Convert pair in format 'master_slave' to date in format [master, slave] + '''Convert pair in format 'master_slave' to date in format + [master, slave] ''' pair = pair.split('_') - masterDate = datetime.strptime(pair[0],'%Y%m%d') - secondaryDate = datetime.strptime(pair[1],'%Y%m%d') + masterDate = datetime.strptime(pair[0], '%Y%m%d') + secondaryDate = datetime.strptime(pair[1], '%Y%m%d') datePair = [masterDate, secondaryDate] return datePair - # Convert pair list to date lists - def __pairList2dateList__(self,pairList): + def __datePair2strPair__(self, datePair): + '''Convert pair in format [master, slave] to date in format + [master_slave] ''' - Convert list of pairs in format ['master_slave','master_slave',...] to dates in format - [[master, slave], [master, slave], ...] - ''' - pairDates = [] - for pair in pairList: - datePair = self.__strPair2datePair__(pair) - pairDates.append(datePair) - return pairDates - - # Date pairs and unique dates - def __formatDates__(self): - ''' - Retrieve list of date pairs and unique dates (epochs). - The "pairs" attribute is a formatted list of interferogram date pairs **in the order in - which they are stored in the unwrapStack.vrt** file. This list should not be modified. - Constrain the list of epochs available for triplet determination using the "startDate" - and "endDate" provided in the __init__ function. - ''' - # Pairs - list of pairs composing the data set, in the order they are written - pairs = [os.path.basename(fname) for fname in self.IFGs.GetFileList()] - pairs = [pair.split('.')[0] for pair in pairs] # remove extensions - pairs.remove('unwrapStack') # remove extra file name - self.pairs = self.__pairList2dateList__(pairs) + masterStr = datePair[0].strftime('%Y%m%d') + secondaryStr = datePair[1].strftime('%Y%m%d') + strPair = '{:s}_{:s}'.format(masterStr, secondaryStr) + return strPair - # Get unique dates from date pairs - self.epochs = [] - [self.epochs.extend(pair) for pair in self.pairs] - self.epochs = list(set(self.epochs)) # unique dates only - self.epochs.sort() # sort oldest-youngest - - self.nEpochs = len(self.epochs) - log.debug('%s unique dates detected', self.nEpochs) - - # Limit dates available for triplet formulation by start and end date - self.tripletEpochs = [epoch for epoch in self.epochs if epoch >= self.startDate] - self.tripletEpochs = [epoch for epoch in self.epochs if epoch <= self.endDate] - self.nTripletEpochs = len(self.tripletEpochs) - - # Update start and end dates - self.startDate = self.tripletEpochs[0] - self.endDate = self.tripletEpochs[-1] - - # Ticks for plots - self.dateTicks = pd.date_range(self.startDate-timedelta(days=30), - self.endDate+timedelta(days=30),freq='MS') - self.dateTickLabels = [date.strftime('%Y-%m') for date in self.dateTicks] - - # Format dates in list to exclude - def __formatExcludePairs__(self): + def __uniqueDatesFromPairs__(self, datePairs): + '''Get a list of unique datetimes representing epochs of + acqusition. ''' - Check that exclude dates are in one of two formats: - 1. a string containing the pairs in YOUNGER_OLDER format, space-separated - 2. a .txt file with lines of the same formatting - Formatting should match "pair" formatting: [[master,slave]] + # List of all dates + dates = [] + [dates.extend(pair) for pair in datePairs] + + # Filter for unique dates + dates = list(set(dates)) + + # Sort oldest-youngest + dates.sort() + + return dates + + def __applyStartEndDates__(self, startDate, endDate): + '''Format the start and end dates as datetime objects. + Trim list of dates to start and end date limits.''' + + # Start date + if startDate: + # Format start date as datetime + startDate = datetime.strptime(startDate, '%Y%m%d') + else: + # Use first date + startDate = self.dates[0] + + # End date + if endDate: + # Format end date as datetime + endDate = datetime.strptime(endDate, '%Y%m%d') + else: + # Use last date + endDate = self.dates[-1] + + log.debug('Start date: %s; end date %s', startDate, endDate) + + # Crop list of dates to start and end + self.dates = [date for date in self.dates if date >= startDate] + self.dates = [date for date in self.dates if date <= endDate] + + # Create array of dates for plotting + self.__createDateAxis__(startDate, endDate) + + def __createDateAxis__(self, startDate, endDate): + '''Create an array of dates for plotting misclosure values.''' + # Plot x-ticks + self.dateTicks = pd.date_range( + startDate - timedelta(days=30), + endDate + timedelta(days=30), + freq='MS') + + # Plot x-tick labels + self.dateTickLabels = [date.strftime('%Y-%m') for date in + self.dateTicks] + + def __formatExcludePairs__(self, excludePairs): + '''Check that exclude dates are in one of two formats: + 1. a string containing the pairs in YOUNGER_OLDER format, + space-separated + 2. a .txt file with lines of the same formatting + + Formatting should match "pair" formatting: [[master,slave]] ''' - if self.excludePairs is not None: + if excludePairs is not None: # Determine whether list or text file - if self.excludePairs[-4:] == '.txt': + if excludePairs.endswith('.txt'): # Treat as text file with list - with open(self.excludePairs,'r') as exclFile: + with open(excludePairs, 'r') as exclFile: excludePairs = exclFile.readlines() excludePairs = [pair.strip('\n') for pair in excludePairs] - self.excludePairs = self.__pairList2dateList__(excludePairs) + self.excludePairs = self.__pairNames2datePairs__(excludePairs) else: # Treat as list - split at spaces - excludePairs = self.excludePairs.split(' ') - self.excludePairs = self.__pairList2dateList__(excludePairs) + excludePairs = excludePairs.split() + self.excludePairs = self.__pairNames2datePairs__(excludePairs) else: # Include as empty list self.excludePairs = [] - - ## Plot pairs def plotPairs(self): - ''' - Plot the timespans of interferogram pairs. - ''' + '''Plot the timespans of interferogram pairs.''' # Copy the list of pairs and sort them in time order - pairs = self.pairs[:] # copy to separate object - pairs.sort(key=lambda s: s[1]) # sort by secondary date + datePairs = self.datePairs[:] # copy to separate object + datePairs.sort(key=lambda s: s[1]) # sort by secondary date + + # Spawn figure and axis + pairFig, pairAx = plt.subplots() - # Plot pairs in time - pairFig = plt.figure() - pairAx = pairFig.add_subplot(111) - for n,pair in enumerate(pairs): + # Loop through date pairs and plot them in time + for n, pair in enumerate(datePairs): # Color based on whether included or excluded if pair not in self.excludePairs: color = 'k' label = 'valid pair' linestyle = '-' else: - color = (0.6,0.6,0.6) + color = (0.6, 0.6, 0.6) label = 'excluded pair' linestyle = '--' + # Convert to datetime format - pairAx.plot([pair[0],pair[1]],[n,n],color=color,label=label,linestyle=linestyle) + pairAx.plot([pair[0],pair[1]], [n,n], + color=color, label=label, linestyle=linestyle) # Format x-axis pairAx.set_xticks(self.dateTicks) pairAx.set_xticklabels(self.dateTickLabels, rotation=90) # Legend - handles,labels=pairAx.get_legend_handles_labels() + handles, labels=pairAx.get_legend_handles_labels() uniqueLabels = dict(zip(labels,handles)) - pairAx.legend(uniqueLabels.values(),uniqueLabels.keys(), - bbox_to_anchor=(0.005,0.99),loc='upper left',borderaxespad=0.) + pairAx.legend(uniqueLabels.values(), uniqueLabels.keys(), + bbox_to_anchor=(0.005, 0.99), loc='upper left', borderaxespad=0.) # Other formatting pairAx.set_yticks([]) @@ -333,26 +448,22 @@ def plotPairs(self): ## Create triplet list - def createTriplets(self,minTime=None,maxTime=None,printTriplets=False): - ''' - Create a list of triplets given the date list and user-specified parameters. - First generate a list of all possible triplets based on the available dates. - Then validate that list across the list of existing pairs. + def createTriplets(self, minTime=None, maxTime=None, printTriplets=False): + '''Create a list of triplets given the date list and user-specified + parameters. - The stack object retains an ordered list of dates in both YYYYMMDD format and datetime - format, and pairList, based on the __format_dates__ function. + First generate a list of all possible triplets based on the + available dates. Then validate that list across the list of + existing pairs. + + The stack object retains an ordered list of dates in both + YYYYMMDD format and datetime format, and pairList, based on the + __format_dates__ function. ''' - log.debug('Creating list of all possible triplets') + log.debug('Creating list of triplets') - # Loop through dates to create all possible triplet combinations - self.triplets = [] - for i in range(self.nTripletEpochs-2): - for j in range(i+1,self.nTripletEpochs-1): - for k in range(j+1,self.nTripletEpochs): - epochI = self.tripletEpochs[i] # first date in sequence - epochJ = self.tripletEpochs[j] # second date in sequence - epochK = self.tripletEpochs[k] # third date in sequence - self.triplets.append([[epochJ,epochI],[epochK,epochJ],[epochK,epochI]]) + # Create a list of all possible triplets + self.__createAllTriplets__() # Remove triplets with pairs in "exclude pairs" list self.__checkExcludedTriplets__() @@ -368,102 +479,182 @@ def createTriplets(self,minTime=None,maxTime=None,printTriplets=False): # Finished sorting self.nTriplets = len(self.triplets) + log.debug('%s valid triplets identified', self.nTriplets) - # Print to text file - with open(os.path.join(self.workdir,'ValidTriplets.txt'), 'w') as tripletFile: - for triplet in self.triplets: - strPair = [self.__datePair2strPair__(pair) for pair in triplet] - tripletFile.write('{}\n'.format(strPair)) - tripletFile.close() + # Retrieve triplet reference dates + self.__retreiveTripletReferenceDates__() - # Report if requested + # Save triplets to text file + self.__saveTriplets__() + + # Print triplets to screen if requested if printTriplets == True: - # Print to screen - log.info('Existing triplets:') - for triplet in self.triplets: - log.info([self.__datePair2strPair__(pair) for pair in triplet]) - if self.verbose == True: - log.info('%s existing triplets found based on search criteria', self.nTriplets) + self.__printTriplets__() - # Reference dates - self.tripletDates = [[triplet[0][1],triplet[1][1],triplet[2][0]] for triplet in self.triplets] + def __createAllTriplets__(self): + '''Loop through dates to create all possible triplet combinations.''' + log.debug('Listing all possible triplets') + + # Create empty list + self.triplets = [] + + # Loop through first date in ordered list (date 1) + for i in range(self.nDates-2): + # Loop through second dates (start from first date after date 1) + for j in range(i+1, self.nDates-1): + # Loop through third dates (start from first date after date 2) + for k in range(j+1, self.nDates): + dateI = self.dates[i] # first date in sequence + dateJ = self.dates[j] # second date in sequence + dateK = self.dates[k] # third date in sequence + self.triplets.append([ + [dateJ, dateI], + [dateK, dateJ], + [dateK, dateI] + ]) - # Check triplets against excluded pairs def __checkExcludedTriplets__(self): + '''Check triplet list against excluded pairs list. Remove the + triplet if any of the pairs is listed in "exclude pairs". ''' - Check triplet list against excluded pairs list. Remove the triplet if any of the pairs - is listed in "exclude pairs". - ''' + log.debug('Checking triplets against excluded pairs') + + # Empty list of non-excluded triplets validTriplets = [] + + # Loop through all possible triplets for triplet in self.triplets: - # If no pairs are excluded, append to the valid triplets list - invalidTriplets = 0 # reset counter - for pair in triplet: - if pair in self.excludePairs: + # Invalid date pairs in triplet + invalidTriplets = 0 # reset counter + for datePair in triplet: + # Check against list of pairs to exclude + if datePair in self.excludePairs: invalidTriplets += 1 + + # If no pairs are excluded, append to the valid triplets list if invalidTriplets == 0: validTriplets.append(triplet) - self.triplets = validTriplets # update triplets list + + # Update triplets list + self.triplets = validTriplets # Check triplets against minTime - def __checkTripletsMinTime__(self,minTime): - ''' - Check that all pairs in a triplet are longer in duration than the minimum time interval - specified. + def __checkTripletsMinTime__(self, minTime): + '''Check that all pairs in a triplet are longer in duration + than the minimum time interval specified. ''' + log.debug('Checking triplet pairs are longer than %s days', minTime) + if minTime: validTriplets = [] for triplet in self.triplets: # Determine intervals between dates in days - intervals = [(pair[0]-pair[1]).days for pair in triplet] + intervals = [(datePair[0]-datePair[1]).days for \ + datePair in triplet] + + # Check against minimum allowable time interval if min(intervals) >= minTime: validTriplets.append(triplet) - self.triplets = validTriplets # update triplets list + + # Update triplets list + self.triplets = validTriplets # Check triplets against maxTime - def __checkTripletsMaxTime__(self,maxTime): - ''' - Check that all pairs in a triplet are shorter in duration than the maximum time interval - specified. + def __checkTripletsMaxTime__(self, maxTime): + '''Check that all pairs in a triplet are shorter in duration than + the maximum time interval specified. ''' + log.debug('Checking triplet pairs are shorter than %s days', maxTime) + if maxTime: validTriplets = [] for triplet in self.triplets: # Determine intervals between dates in days - intervals = [(pair[0]-pair[1]).days for pair in triplet] + intervals = [(datePair[0]-datePair[1]).days for \ + datePair in triplet] + + # Check against maximum allowable time interval if max(intervals) <= maxTime: validTriplets.append(triplet) - self.triplets = validTriplets # update triplets list + + # Update triplets list + self.triplets = validTriplets # Check triplets exist def __checkTripletsExist__(self): + '''Check list of all possible triplets against the list of pairs + that actually exist. ''' - Check list of all possible triplets against the list of pairs that actually exist. - ''' + log.debug('Checking triplets provided in data set.') + existingTriplets = [] for triplet in self.triplets: - existing = 0 # reset count of existing pairs - # Check that each pair of the triplet has a corresponding interferogram - for tripletPair in triplet: - if tripletPair in self.pairs: - existing += 1 # update if ifg exists + # Reset count of existing pairs + existing = 0 + + # Check that each pair of the triplet has a corresponding + # interferogram + for datePair in triplet: + if datePair in self.datePairs: + # Update if IFG exists + existing += 1 if existing == 3: existingTriplets.append(triplet) - self.triplets = existingTriplets # update triplet list + # Update triplet list + self.triplets = existingTriplets + + def __retreiveTripletReferenceDates__(self): + '''Create a list of reference dates with each triplet.''' + log.debug('Retreiving list of reference dates from valid triplet') + + # Reference dates + self.tripletRefDates = [] + + # Loop through triplets + for triplet in self.triplets: + # All dates in triplet + tripletDates = [] + [tripletDates.extend(datePair) for datePair in triplet] + + # Triplet unique dates + tripletDates = list(set(tripletDates)) + + # Sort earliest to latest + tripletDates.sort(key=lambda date: date.strftime("%Y%m%d")) + + # Append to list + self.tripletRefDates.append(tripletDates) + + def __saveTriplets__(self): + '''Save the list of valid triplets to a text file.''' + with open(os.path.join(self.workdir, 'ValidTriplets.txt'), 'w') \ + as tripletFile: + # Loop through valid triplets + for triplet in self.triplets: + strPair = [self.__datePair2strPair__(pair) for pair in triplet] + tripletFile.write('{}\n'.format(strPair)) + + def __printTriplets__(self): + '''Print the list of valid triplets.''' + log.info('Existing triplets:') + + # Loop through triplets + for triplet in self.triplets: + log.info([self.__datePair2strPair__(pair) for pair in triplet]) + + # Final statistic + log.info('%s existing triplets found based on search criteria', + self.nTriplets) - ## Plot triplets def plotTriplets(self): - ''' - Plot triplets. - ''' + '''Plot triplets.''' # Setup figure - tripletFig = plt.figure() - tripletAx = tripletFig.add_subplot(111) + tripletFig, tripletAx = plt.subplots() # Plot triplets for i in range(self.nTriplets): - tripletAx.plot(self.triplets[i],[i,i,i],'k',marker='o') + tripletAx.plot(self.triplets[i], [i,i,i], 'k', marker='o') # Format x-axis tripletAx.set_xticks(self.dateTicks) @@ -476,114 +667,26 @@ def plotTriplets(self): ## Compute misclosure - # Geo to map coordinates - def LoLa2XY(self,lon,lat): - ''' - Convert lon/lat coordinates to XY. - ''' - tnsf = self.IFGs.GetGeoTransform() - x = (lon - tnsf[0])/tnsf[1] - y = (lat - tnsf[3])/tnsf[5] - return x, y + def computeMisclosure(self, refXY=None, refLoLa=None): + '''Compute the misclosure of the phase triplets. - # Map to geo coordinates - def XY2LoLa(self,x,y): + A common reference point is required because the ifgs are not + coregistered. ''' - Convert X/Y to lon/lat. - ''' - tnsf = self.IFGs.GetGeoTransform() - lon = tnsf[0] + tnsf[1]*x - lat = tnsf[3] + tnsf[5]*y - return lon, lat + log.debug('Computing misclosure') - # Reference point formatting - def __referencePoint__(self,refXY,refLoLa): - ''' - Determine the reference point in XY coordinates. The reference point can be - automatically or manually selected by the user and is subtracted - from each interferogram. - The point can be given in pixels or lon/lat coordinates. If given in Lat/Lon, determine - the location in XY. - ''' - log.debug('Determining reference point...') - - if refLoLa.count(None) == 0: - # Determine the XY coordinates from the given lon/lat - self.refLon = refLoLa[0] - self.refLat = refLoLa[1] - x,y = self.LoLa2XY(refLoLa[0],refLoLa[1]) - self.refX = int(x) - self.refY = int(y) - log.debug('Reference point given as: X %s / Y %s; Lon %s / Lat %s', - self.refX, self.refY, self.refLon, self.refLat) + # Create background value mask + self.__createMask__() - elif refXY.count(None) == 0: - # Use the provided XY coordinates - self.refX = refXY[0] - self.refY = refXY[1] - self.refLon,self.refLat = self.XY2LoLa(refXY[0],refXY[1]) - log.debug('Reference point given as: X %s / Y %s; Lon %.4f / Lat %.4f', - self.refX, self.refY, self.refLon, self.refLat) - - else: - # Use a random reference point - self.__autoReferencePoint__() - - # Random reference point - def __autoReferencePoint__(self): - ''' - Use the coherence stack to automatically determine a suitable reference point. - ''' - # Load coherence data from cohStack.vrt - cohfile = os.path.join(self.imgdir,'cohStack.vrt') - cohDS = gdal.Open(cohfile, gdal.GA_ReadOnly) - cohMap = np.zeros((cohDS.RasterYSize,cohDS.RasterXSize)) - coh_min = 0.7 - - for n in range(1,cohDS.RasterCount+1): - cohMap += cohDS.GetRasterBand(n).ReadAsArray() - aveCoherence = cohMap/cohDS.RasterCount - cohMask = (aveCoherence >= coh_min) - - # Start with initial guess for reference point - self.refX = np.random.randint(cohDS.RasterXSize) - self.refY = np.random.randint(cohDS.RasterYSize) - - # Loop until suitable reference point is found - n = 0 - while cohMask[self.refY,self.refX] == False: - # Reselect reference points - self.refX = np.random.randint(cohDS.RasterXSize) - self.refY = np.random.randint(cohDS.RasterYSize) - n += 1 # update counter - - # Break loop after 10000 iterations - if n == 10000: - msg = f'No reference point with coherence >= {coh_min} found' - log.error(msg) - raise Exception(msg) - - # Convert to lon/lat - self.refLon,self.refLat = self.XY2LoLa(self.refX,self.refY) - - log.debug('Reference point chosen randomly as: X %s / Y %s; Lon %.4f / Lat %.4f.', - self.refX, self.refY, self.refLon, self.refLat) - - # Compute misclosure - def computeMisclosure(self,refXY=None,refLoLa=None): - ''' - Compute the misclosure of the phase triplets. - A common reference point is required because the ifgs are not coregistered. - ''' # Determine reference point - self.__referencePoint__(refXY,refLoLa) + self.__referencePoint__(refXY, refLoLa) # Misclosure placeholders self.netMscStack = [] self.absMscStack = [] # Compute phase triplets - log.debug('Calculating misclosure') + log.debug('Calculating triplet misclosure') for triplet in self.triplets: # Triplet date pairs @@ -592,22 +695,22 @@ def computeMisclosure(self,refXY=None,refLoLa=None): KIdates = triplet[2] # Triplet indices - add 1 because raster bands start at 1 - JIndx = self.pairs.index(JIdates)+1 - KJndx = self.pairs.index(KJdates)+1 - KIndx = self.pairs.index(KIdates)+1 + JIndx = self.datePairs.index(JIdates)+1 + KJndx = self.datePairs.index(KJdates)+1 + KIndx = self.datePairs.index(KIdates)+1 # Interferograms - JI = self.IFGs.GetRasterBand(JIndx).ReadAsArray() - KJ = self.IFGs.GetRasterBand(KJndx).ReadAsArray() - KI = self.IFGs.GetRasterBand(KIndx).ReadAsArray() + JI = self.unwStack.GetRasterBand(JIndx).ReadAsArray() + KJ = self.unwStack.GetRasterBand(KJndx).ReadAsArray() + KI = self.unwStack.GetRasterBand(KIndx).ReadAsArray() # Normalize to reference point - JI -= JI[self.refY,self.refX] - KJ -= KJ[self.refY,self.refX] - KI -= KI[self.refY,self.refX] + JI -= JI[self.refY, self.refX] + KJ -= KJ[self.refY, self.refX] + KI -= KI[self.refY, self.refX] # Compute (abs)misclosure - netMisclosure = JI+KJ-KI + netMisclosure = JI + KJ - KI absMisclosure = np.abs(netMisclosure) # Append to stack @@ -619,102 +722,172 @@ def computeMisclosure(self,refXY=None,refLoLa=None): self.absMscStack = np.array(self.absMscStack) # Cumulative misclosure - self.cumNetMisclosure = np.sum(self.netMscStack,axis=0) - self.cumAbsMisclosure = np.sum(self.absMscStack,axis=0) + self.cumNetMisclosure = np.sum(self.netMscStack, axis=0) + self.cumAbsMisclosure = np.sum(self.absMscStack, axis=0) + # Apply mask + self.cumNetMisclosure[self.mask==0] = 0 + self.cumAbsMisclosure[self.mask==0] = 0 - ## Plot and analyze misclosure - # Plotting miscellaneous functions - def __backgroundDetect__(self,img): - ''' - Detect the background value of an image. - ''' - edges=np.concatenate([img[:,0].flatten(), - img[0,:].flatten(), - img[-1,:].flatten(), - img[:,-1].flatten()]) - backgroundValue=mode(edges)[0][0] + def __createMask__(self): + '''Create a mask based on the nodata value.''' + log.debug('Creating no data mask') - return backgroundValue + # Retrieve first image from stack + img = self.unwStack.GetRasterBand(1).ReadAsArray() - def __imgClipValues__(self,img,percentiles): - ''' - Find values at which to clip the images (min/max) based on histogram percentiles. - ''' - clipValues={} - clipValues['min'],clipValues['max']=np.percentile(img.flatten(),percentiles) + # Mask no data values + self.mask = np.ones((self.M, self.N)) + self.mask[img==0] = 0 - return clipValues + def __referencePoint__(self, refXY, refLoLa): + '''Determine the reference point in XY coordinates. The reference + point can be automatically or manually selected by the user and + is subtracted from each interferogram. - def __plotCumNetMisclosure__(self): - ''' - Plot cumulative misclosure. + The point can be given in pixels or lon/lat coordinates. If + given in Lat/Lon, determine the location in XY, and vice-versa. ''' - cax = self.netMscAx.imshow(self.cumNetMsc,cmap='plasma', - vmin=self.cumNetMscClips['min'],vmax=self.cumNetMscClips['max'],zorder=1) - self.netMscAx.plot(self.refX,self.refY,'ks',zorder=2) - self.netMscAx.set_xticks([]); self.netMscAx.set_yticks([]) - self.netMscAx.set_title('Cumulative misclosure') + log.debug('Determining reference point...') - return cax + if refLoLa.count(None) == 0: + # Record reference lon/lat + self.refLon = refLoLa[0] + self.refLat = refLoLa[1] - def __plotCumAbsMisclosure__(self): - ''' - Plot cumulative absolute misclosure. - ''' - cax = self.absMscAx.imshow(self.cumAbsMsc,cmap='plasma', - vmin=self.cumAbsMscClips['min'],vmax=self.cumAbsMscClips['max'],zorder=1) - self.absMscAx.plot(self.refX,self.refY,'ks',zorder=2) - self.absMscAx.set_xticks([]); self.absMscAx.set_yticks([]) - self.absMscAx.set_title('Cumulative absolute misclosure') + # Determine the x/y coordinates from the given lon/lat + x, y = self.lola2xy(self.refLon, self.refLat) + self.refX = int(x) + self.refY = int(y) - return cax + log.debug('Reference point given as: X %s / Y %s; Lon %s / Lat %s', + self.refX, self.refY, self.refLon, self.refLat) - def __plotSeries__(self,ax,data,title): - ''' - Plot misclosure timeseries. - ''' - # Plot data - if self.plotTimeIntervals == False: - ax.plot([tripletDate[1] for tripletDate in self.tripletDates],data,'-k.') - else: - for n in range(self.nTriplets): - ax.plot([self.tripletDates[n][0],self.tripletDates[n][2]], - [data[n],data[n]],'k') - ax.plot(self.tripletDates[n][1],data[n],'ko') + elif refXY.count(None) == 0: + # Record reference x/y + self.refX = refXY[0] + self.refY = refXY[1] - # Formatting - ax.set_xticks(self.dateTicks) - ax.set_xticklabels([]) - ax.set_ylabel(title) + # Determine the lon/lat coordinates from the given x/y + self.refLon, self.refLat = self.xy2lola(self.refX, self.refY) + log.debug('Reference point given as: X %s / Y %s; Lon %.4f / Lat %.4f', + self.refX, self.refY, self.refLon, self.refLat) - # Plot misclosure - def plotCumMisclosure(self,queryXY=None,queryLoLa=None,pctmin=1,pctmax=99,plotTimeIntervals=False): + else: + # Use a random reference point + self.__autoReferencePoint__() + + def __autoReferencePoint__(self): + '''Use the coherence stack to automatically determine a suitable + reference point. ''' - Map-view plot of cumulative misclosure. + # Try to determine the reference point using coherence map + if self.__autoReferenceCoherence__() == False: + # Otherwise, resort to picking a random point + self.__autoReferenceRandom__() + + def __autoReferenceCoherence__(self): + '''Attempt to find cohStack.vrt file based on unwStack directory. + If found, determine a random high-coherence point. + If file not found, return False. ''' + log.debug('Attempting to find high-coherence reference point') + + # Check if coherence data have already been loaded + if not hasattr(self, 'meanCoh'): + # Check unwStack folder for cohStack + dirName = os.path.dirname(self.unwFile) + + try: + # Standard coherence file name + cohFile = os.path.join(dirName, 'cohStack.vrt') + + # Load average coherence map + self.__loadCohStack__(cohFile) + + except: + log.debug('Could not automatically find coherence stack') + return False + + # Randomly sample for high-coherence values + nTries = 10000 + + while nTries > 0: + # Pick random points + x = np.random.randint(0, self.N) + y = np.random.randint(0, self.M) + + # Check coherence at those points + if self.meanCoh[y,x] >= 0.7: + # Point is high-coherence, stop trying + break + + # Decrement counter + nTries -= 1 + + # Assign reference x/y + self.refX, self.refY = x, y + + # Convert to lon/lat + self.refLon, self.refLat = self.xy2lola(self.refX, self.refY) + + log.debug('Reference point chosen based on coherence as: X %s / Y %s; Lon %.4f / Lat %.4f.', + self.refX, self.refY, self.refLon, self.refLat) + + return True + + def __autoReferenceRandom__(self): + '''Choose random, non-masked point for reference.''' + log.debug('Choosing random reference point') + + # Number of tries + nTries = 100 + + while nTries > 0: + # Pick random points + x = np.random.randint(0, self.N) + y = np.random.randint(0, self.M) + + # Check if that point is masked + if self.mask[y,x] == 1: + # Point is not masked, stop trying + break + + # Decrement counter + nTries -= 1 + + # Assign reference x/y + self.refX, self.refY = x, y + + # Convert to lon/lat + self.refLon, self.refLat = self.xy2lola(self.refX, self.refY) + + log.debug('Reference point chosen randomly as: X %s / Y %s; Lon %.4f / Lat %.4f.', + self.refX, self.refY, self.refLon, self.refLat) + + + ## Plot misclosure + def plotMisclosure(self, queryXY=None, queryLoLa=None, + plotTimeIntervals=False): + '''Map-view plots of cumulative misclosure.''' log.debug('Begin misclosure analysis') # Parameters self.plotTimeIntervals = plotTimeIntervals # Set up interactive plots - self.netMscFig = plt.figure() - self.netMscAx = self.netMscFig.add_subplot(111) - - self.absMscFig = plt.figure() - self.absMscAx = self.absMscFig.add_subplot(111) + self.netMscFig, self.netMscAx = plt.subplots() + self.absMscFig, self.absMscAx = plt.subplots() - # Auto-detect background and clip values - self.cumNetMscBackground = self.__backgroundDetect__(self.cumNetMisclosure) - self.cumNetMsc = np.ma.array(self.cumNetMisclosure, - mask=(self.cumNetMisclosure==self.cumNetMscBackground)) - self.cumNetMscClips = self.__imgClipValues__(self.cumNetMsc,percentiles=[pctmin,pctmax]) + # Mask arrays to ignore no data values + self.cumNetMisclosure = np.ma.array(self.cumNetMisclosure, + mask=(self.cumNetMisclosure==0)) + self.cumAbsMisclosure = np.ma.array(self.cumAbsMisclosure, + mask=(self.cumAbsMisclosure==0)) - self.cumAbsMscBackground = self.__backgroundDetect__(self.cumAbsMisclosure) - self.cumAbsMsc = np.ma.array(self.cumAbsMisclosure, - mask=(self.cumAbsMisclosure==self.cumAbsMscBackground)) - self.cumAbsMscClips = self.__imgClipValues__(self.cumAbsMsc,percentiles=[pctmin,pctmax]) + # Auto-detect values for clipping color scale + self.cumNetMscClips = self.__imgClipValues__(self.cumNetMisclosure) + self.cumAbsMscClips = self.__imgClipValues__(self.cumAbsMisclosure) # Plot maps cax = self.__plotCumNetMisclosure__() @@ -726,113 +899,194 @@ def plotCumMisclosure(self,queryXY=None,queryLoLa=None,pctmin=1,pctmax=99,plotTi cbar.set_label('cum. abs. misclosure (radians)') # Plot timeseries points - self.mscSeriesFig = plt.figure('Misclosure',figsize=(8,8)) + self.mscSeriesFig = plt.figure('Misclosure', figsize=(8,8)) self.netMscSeriesAx = self.mscSeriesFig.add_subplot(411) self.cumNetMscSeriesAx = self.mscSeriesFig.add_subplot(412) self.absMscSeriesAx = self.mscSeriesFig.add_subplot(413) self.cumAbsMscSeriesAx = self.mscSeriesFig.add_subplot(414) # Pre-specified query points - self.__misclosureQuery__(queryXY,queryLoLa) + if (queryLoLa.count(None) == 0) or (queryXY.count(None) == 0): + self.__misclosureQuery__(queryXY, queryLoLa) # Link canvas to plots for interaction - self.netMscFig.canvas.mpl_connect('button_press_event', self.__misclosureAnalysis__) - self.absMscFig.canvas.mpl_connect('button_press_event', self.__misclosureAnalysis__) + self.netMscFig.canvas.mpl_connect('button_press_event', + self.__samplePixel__) + self.absMscFig.canvas.mpl_connect('button_press_event', + self.__samplePixel__) + + def __imgClipValues__(self, img): + '''Find values at which to clip the images (min/max) based on + histogram percentiles. + ''' + # Determine clip values based on percentiles + clipValues={} + clipValues['min'], clipValues['max'] = \ + np.percentile(img.compressed(), (2, 98)) + return clipValues - ## Misclosure analysis - def __misclosureAnalysis__(self,event): - ''' - Show the time history of each pixel based on interactive map. - ''' - px=event.xdata; py=event.ydata - px=int(round(px)); py=int(round(py)) + def __plotCumNetMisclosure__(self): + '''Plot cumulative misclosure map.''' + # Plot map + cax = self.netMscAx.imshow(self.cumNetMisclosure, cmap='plasma', + vmin=self.cumNetMscClips['min'], vmax=self.cumNetMscClips['max'], + zorder=1) - # Report position and cumulative values - log.info('px %s py %s', px, py) # report position - log.info('Cumulative misclosure: %s', self.cumNetMisclosure[py,px]) - log.info('Abs cumulative misclosure: %s', self.cumAbsMisclosure[py,px]) + # Plot reference point + self.netMscAx.plot(self.refX, self.refY, 'ks', zorder=2) - # Plot query points on maps - self.netMscAx.cla() - self.__plotCumNetMisclosure__() - self.netMscAx.plot(px,py,color='k',marker='o',markerfacecolor='w',zorder=3) + # Format axis + self.netMscAx.set_title('Cumulative misclosure') - self.absMscAx.cla() - self.__plotCumAbsMisclosure__() - self.absMscAx.plot(px,py,color='k',marker='o',markerfacecolor='w',zorder=3) + return cax - # Plot misclosure over time - log.info('Misclosure: %s', self.netMscStack[:,py,px]) - self.netMscSeriesAx.cla() # misclosure - self.__plotSeries__(self.netMscSeriesAx, self.netMscStack[:,py,px], 'misclosure') + def __plotCumAbsMisclosure__(self): + '''Plot cumulative absolute misclosure map.''' + # Plot map + cax = self.absMscAx.imshow(self.cumAbsMisclosure, cmap='plasma', + vmin=self.cumAbsMscClips['min'], vmax=self.cumAbsMscClips['max'], + zorder=1) - self.cumNetMscSeriesAx.cla() # cumulative misclosure - self.__plotSeries__(self.cumNetMscSeriesAx, np.cumsum(self.netMscStack[:,py,px]), 'cum. miscl.') + # Plot reference point + self.absMscAx.plot(self.refX, self.refY, 'ks', zorder=2) - self.absMscSeriesAx.cla() # absolute misclosure - self.__plotSeries__(self.absMscSeriesAx, self.absMscStack[:,py,px], 'abs. miscl') + # Format axis + self.absMscAx.set_title('Cumulative absolute misclosure') - self.cumAbsMscSeriesAx.cla() # cumulative absolute misclosure - self.__plotSeries__(self.cumAbsMscSeriesAx, np.cumsum(self.absMscStack[:,py,px]), - 'cum. abs. miscl.') + return cax - # Format x-axis - dates = pd.date_range(self.startDate-timedelta(days=30), - self.endDate+timedelta(days=30),freq='MS') - dateLabels = [date.strftime('%Y-%m') for date in dates] - self.cumAbsMscSeriesAx.set_xticklabels(self.dateTickLabels, rotation=90) + def __plotSeries__(self, ax, data, title): + '''Plot misclosure timeseries.''' + # Plot data + if self.plotTimeIntervals == False: + ax.plot([tripletRefDate[1] for tripletRefDate in self.tripletRefDates], + data, '-k.') + else: + for n in range(self.nTriplets): + ax.plot([self.tripletRefDates[n][0], + self.tripletRefDates[n][2]], + [data[n], data[n]],'k') + ax.plot(self.tripletRefDates[n][1], data[n], 'ko') - # Draw outcomes - self.netMscFig.canvas.draw() - self.absMscFig.canvas.draw() - self.mscSeriesFig.canvas.draw() + # Formatting + ax.set_xticks(self.dateTicks) + ax.set_xticklabels([]) + ax.set_ylabel(title) - ## Misclosure query - def __misclosureQuery__(self,queryXY=None,queryLoLa=None): - ''' - Show the time history of each pixel based on pre-specified selection. + ## Misclosure analysis + def __misclosureQuery__(self, queryXY=None, queryLoLa=None): + '''Show the time history of each pixel based on pre-specified + selection. ''' log.debug('Pre-specified query point...') # Convert bewteen lon/lat and image coordinates if queryLoLa.count(None) == 0: # Determine the XY coordinates from the given lon/lat - qLon,qLat = queryLoLa - qx,qy = self.LoLa2XY(queryLoLa[0],queryLoLa[1]) - qx = int(qx); qy = int(qy) # convert to integer values + qLon, qLat = queryLoLa + qx, qy = self.lola2xy(queryLoLa[0], queryLoLa[1]) elif queryXY.count(None) == 0: # Use the provided XY coordinates - qx,qy = queryXY - qLon,qLat = self.XY2LoLa(queryXY[0],queryXY[1]) + qx, qy = queryXY + qLon, qLat = self.xy2lola(queryXY[0], queryXY[1]) - log.debug('Query point: X %s / Y %s; Lon %.4f / Lat %.4f', qx, qy, qLon, qLat) + log.debug('Query point: X %s / Y %s; Lon %.4f / Lat %.4f', + qx, qy, qLon, qLat) # Plot query points on map - self.netMscAx.plot(qx,qy,color='k',marker='o',markerfacecolor='w',zorder=3) - self.absMscAx.plot(qx,qy,color='k',marker='o',markerfacecolor='w',zorder=3) + self.netMscAx.plot(qx, qy, + color='k', marker='o', markerfacecolor='w', zorder=3) + self.absMscAx.plot(qx, qy, + color='k', marker='o', markerfacecolor='w', zorder=3) # Plot misclosure over time - self.__plotSeries__(self.netMscSeriesAx, self.netMscStack[:,qy,qx], 'misclosure') - self.__plotSeries__(self.cumNetMscSeriesAx, np.cumsum(self.netMscStack[:,qy,qx]), 'cum. miscl.') - self.__plotSeries__(self.absMscSeriesAx, self.absMscStack[:,qy,qx], 'abs. miscl') - self.__plotSeries__(self.cumAbsMscSeriesAx, np.cumsum(self.absMscStack[:,qy,qx]), - 'cum. abs. miscl.') + self.__plotSeries__(self.netMscSeriesAx, + self.netMscStack[:,qy,qx], + 'misclosure') + self.__plotSeries__(self.cumNetMscSeriesAx, + np.cumsum(self.netMscStack[:,qy,qx]), + 'cum. miscl.') + self.__plotSeries__(self.absMscSeriesAx, + self.absMscStack[:,qy,qx], + 'abs. miscl') + self.__plotSeries__(self.cumAbsMscSeriesAx, + np.cumsum(self.absMscStack[:,qy,qx]), + 'cum. abs. miscl.') # Format x-axis - dates = pd.date_range(self.startDate-timedelta(days=30), - self.endDate+timedelta(days=30),freq='MS') - dateLabels = [date.strftime('%Y-%m') for date in dates] - self.cumAbsMscSeriesAx.set_xticklabels(self.dateTickLabels, rotation=90) + self.cumAbsMscSeriesAx.set_xticks(self.dateTicks) + self.cumAbsMscSeriesAx.set_xticklabels(self.dateTickLabels, + rotation=90) + + def __samplePixel__(self, event): + '''Show the time history of each pixel based on interactive map.''' + log.debug('Sampling point') + + # Retrieve values from map + px = event.xdata + py = event.ydata + px = int(round(px)) + py = int(round(py)) + + # Convert pixels to lat/lon + lon, lat = self.xy2lola(px, py) + + # Report position and cumulative values + log.info('px %s py %s', px, py) + log.info('lon %s lat %s', lon, lat) + log.info('Cumulative misclosure: %s', self.cumNetMisclosure[py,px]) + log.info('Abs cumulative misclosure: %s', self.cumAbsMisclosure[py,px]) + + # Plot query points on maps + self.netMscAx.cla() + self.__plotCumNetMisclosure__() + self.netMscAx.plot(px, py, + color='k', marker='o', markerfacecolor='w', zorder=3) + + self.absMscAx.cla() + self.__plotCumAbsMisclosure__() + self.absMscAx.plot(px, py, + color='k', marker='o', markerfacecolor='w', zorder=3) + + # Plot misclosure over time + log.info('Misclosure: %s', self.netMscStack[:,py,px]) + self.netMscSeriesAx.cla() # misclosure + self.__plotSeries__(self.netMscSeriesAx, + self.netMscStack[:,py,px], + 'misclosure') + + self.cumNetMscSeriesAx.cla() # cumulative misclosure + self.__plotSeries__(self.cumNetMscSeriesAx, + np.cumsum(self.netMscStack[:,py,px]), + 'cum. miscl.') + + self.absMscSeriesAx.cla() # absolute misclosure + self.__plotSeries__(self.absMscSeriesAx, + self.absMscStack[:,py,px], + 'abs. miscl') + + self.cumAbsMscSeriesAx.cla() # cumulative absolute misclosure + self.__plotSeries__(self.cumAbsMscSeriesAx, + np.cumsum(self.absMscStack[:,py,px]), + 'cum. abs. miscl.') + + # Format x-axis + self.cumAbsMscSeriesAx.set_xticks(self.dateTicks) + self.cumAbsMscSeriesAx.set_xticklabels(self.dateTickLabels, + rotation=90) + + # Draw outcomes + self.netMscFig.canvas.draw() + self.absMscFig.canvas.draw() + self.mscSeriesFig.canvas.draw() ## Plot triplet misclosure maps - def plotTripletMaps(self,pctmin=1,pctmax=99): - ''' - Plot the misclosure measurements for each triplet to figures. - ''' + def plotTripletMaps(self): + '''Plot the misclosure measurements for each triplet to figures.''' log.debug('Saving incremental misclosure maps to image files') # Parameters @@ -840,84 +1094,83 @@ def plotTripletMaps(self,pctmin=1,pctmax=99): maxSubplots = subplotDims[0]*subplotDims[1] # Output directory/subdirectory - self.figdir = os.path.join(self.workdir,'MisclosureFigs') + self.figdir = os.path.join(self.workdir, 'MisclosureFigs') try: os.mkdir(self.figdir) except: pass - figNb = 0 # start figure counter - plotNb = 1 # start subplot counter + figNb = 0 # start figure counter + plotNb = 1 # start subplot counter for i in range(self.nTriplets): # Plot number and subplot position if plotNb % maxSubplots == 1: # Spawn new figure - figNb += 1 # update figure counter - plotNb = 1 # reset subplot counter + figNb += 1 # update figure counter + plotNb = 1 # reset subplot counter Fig = plt.figure(figsize=(8,6)) - ax = Fig.add_subplot(subplotDims[0],subplotDims[1],plotNb) + ax = Fig.add_subplot(subplotDims[0], subplotDims[1], plotNb) # Format misclosure map - mscMapBackground = self.__backgroundDetect__(self.netMscStack[i,:,:]) - mscMap = np.ma.array(self.netMscStack[i,:,:],mask=(self.netMscStack[i,:,:]==mscMapBackground)) - mscMapClips = self.__imgClipValues__(mscMap,[pctmin,pctmax]) + mscMap = np.ma.array(self.netMscStack[i,:,:], mask=(self.mask==0)) + mscMapClips = self.__imgClipValues__(mscMap) # Plot misclosure - cax = ax.imshow(mscMap,cmap='plasma',vmin=mscMapClips['min'],vmax=mscMapClips['max']) + cax = ax.imshow(mscMap, cmap='plasma', + vmin=mscMapClips['min'], vmax=mscMapClips['max']) # Format axis - ax.set_xticks([]); ax.set_yticks([]) + ax.set_xticks([]) + ax.set_yticks([]) Fig.colorbar(cax, orientation='horizontal') - dates = [date.strftime('%Y%m%d') for date in self.tripletDates[i]] + dates = [date.strftime('%Y%m%d') for date in self.tripletRefDates[i]] ax.set_title('_'.join(dates)) if plotNb == maxSubplots: # Save old figure Fig.suptitle('Phase misclosure (radians)') - figname = 'MisclosureValues_fig{}.png'.format(figNb) - figpath = os.path.join(self.figdir,figname) - Fig.savefig(figpath,dpi=300) + figname = 'MisclosureValues_fig{:d}.png'.format(figNb) + figpath = os.path.join(self.figdir, figname) + Fig.savefig(figpath, dpi=300) plotNb += 1 # Save final figure Fig.suptitle('Phase misclosure (radians)') - figname = 'MisclosureValues_fig{}.png'.format(figNb) - figpath = os.path.join(self.figdir,figname) - Fig.savefig(figpath,dpi=300) + figname = 'MisclosureValues_fig{:d}.png'.format(figNb) + figpath = os.path.join(self.figdir, figname) + Fig.savefig(figpath, dpi=300) ## Save cumulative misclosure plots to geotiffs def saveCumMisclosure(self): - ''' - Save cumulative (/absolute) misclosure plots to georeferenced tiff files. Use metadata - from unwrapStack.vrt file. + '''Save cumulative (/absolute) misclosure plots to georeferenced + tiff files. Use metadata from unwrapStack.vrt file. ''' log.debug('Saving misclosure maps to geotiffs') # Fix background values - self.cumNetMisclosure[self.cumNetMisclosure==self.cumNetMscBackground] = 0 - self.cumAbsMisclosure[self.cumAbsMisclosure==self.cumAbsMscBackground] = 0 + self.cumNetMisclosure[self.mask==0] = 0 + self.cumAbsMisclosure[self.mask==0] = 0 # Save cumulative misclosure - cumNetMscSavename = os.path.join(self.figdir,'CumulativeMisclosure.tif') - self.__saveGeoTiff__(cumNetMscSavename,self.cumNetMisclosure) + cumNetMscSavename = os.path.join(self.figdir, 'CumulativeMisclosure.tif') + self.__saveGeoTiff__(cumNetMscSavename, self.cumNetMisclosure) # Save cumulative absolute misclosure - cumAbsMscSavename = os.path.join(self.figdir,'CumulativeAbsoluteMisclosure.tif') - self.__saveGeoTiff__(cumAbsMscSavename,self.cumAbsMisclosure) - - def __saveGeoTiff__(self,savename,img): - ''' - Template for saving geotiffs. - ''' - driver=gdal.GetDriverByName('GTiff') - DSout=driver.Create(savename,self.IFGs.RasterXSize,self.IFGs.RasterYSize,1, - gdal.GDT_Float32) + cumAbsMscSavename = os.path.join(self.figdir, 'CumulativeAbsoluteMisclosure.tif') + self.__saveGeoTiff__(cumAbsMscSavename, self.cumAbsMisclosure) + + def __saveGeoTiff__(self, savename, img): + '''Template for saving geotiffs.''' + driver = gdal.GetDriverByName('GTiff') + DSout = driver.Create(savename, + self.N, self.M, + 1, gdal.GDT_Float32) DSout.GetRasterBand(1).WriteArray(img) DSout.GetRasterBand(1).SetNoDataValue(0) - DSout.SetProjection(self.IFGs.GetProjection()) - DSout.SetGeoTransform(self.IFGs.GetGeoTransform()) + DSout.SetProjection(self.unwStack.GetProjection()) + DSout.SetGeoTransform(self.unwStack.GetGeoTransform()) DSout.FlushCache() @@ -925,15 +1178,17 @@ def __saveGeoTiff__(self,savename,img): ### MAIN CALL --- def main(inps=None): ## Gather arguments - inps=cmdLineParse() + inps = cmdLineParse() ## Load data based on data type - dataStack=stack(imgfile=inps.imgfile, - workdir=inps.workdir, - startDate=inps.startDate, endDate=inps.endDate, - excludePairs=inps.excludePairs, - verbose=inps.verbose) + dataStack = MisclosureStack( + unwFile=inps.unwFile, + cohFile=inps.cohFile, + workdir=inps.workdir, + startDate=inps.startDate, endDate=inps.endDate, + excludePairs=inps.excludePairs, + verbose=inps.verbose) # Plot pairs if requested if inps.plotPairs == True: @@ -941,7 +1196,7 @@ def main(inps=None): ## Create list of triplets - dataStack.createTriplets(minTime=inps.minTime,maxTime=inps.maxTime, + dataStack.createTriplets(minTime=inps.minTime, maxTime=inps.maxTime, printTriplets=inps.printTriplets) # Plot triplets if requested @@ -954,14 +1209,14 @@ def main(inps=None): refLoLa=[inps.refLon, inps.refLat]) # Plot and analyze data - dataStack.plotCumMisclosure(queryXY=[inps.queryX,inps.queryY], - queryLoLa=[inps.queryLon,inps.queryLat], - pctmin=inps.pctMinClip,pctmax=inps.pctMaxClip, + dataStack.plotMisclosure( + queryXY=[inps.queryX, inps.queryY], + queryLoLa=[inps.queryLon, inps.queryLat], plotTimeIntervals=inps.plotTimeIntervals) plt.show() # Save misclosure map for each triplet to figures - dataStack.plotTripletMaps(pctmin=inps.pctMinClip,pctmax=inps.pctMaxClip) + dataStack.plotTripletMaps() # Save misclosure maps to geotiffs dataStack.saveCumMisclosure() diff --git a/tools/ARIAtools/extractProduct.py b/tools/ARIAtools/extractProduct.py old mode 100755 new mode 100644 index c323cde2..25b9973b --- a/tools/ARIAtools/extractProduct.py +++ b/tools/ARIAtools/extractProduct.py @@ -12,20 +12,33 @@ from osgeo import gdal, osr import logging import requests - from ARIAtools.logger import logger + from ARIAtools.shapefile_util import open_shapefile, chunk_area from ARIAtools.mask_util import prep_mask from ARIAtools.unwrapStitching import product_stitch_overlap, product_stitch_2stage -gdal.UseExceptions() +import ARIAtools.unwrapStitching +print(ARIAtools.unwrapStitching.__file__) -# Suppress warnings +gdal.UseExceptions() +#Suppress warnings gdal.PushErrorHandler('CPLQuietErrorHandler') log = logging.getLogger(__name__) -_world_dem = "https://portal.opentopography.org/API/globaldem?demtype=SRTMGL1_E&west={}&south={}&east={}&north={}&outputFormat=GTiff" +## Set DEM path +_world_dem = "https://portal.opentopography.org/API/globaldem?demtype="\ + "SRTMGL1_E&west={}&south={}&east={}&north={}&outputFormat=GTiff" +dot_topo = os.path.expanduser('~/.topoapi') +if os.path.exists(dot_topo): + topapi = '&API_Key=' + with open(dot_topo) as f: + topapi = topapi + f.readlines()[0].split('\n')[0] + _world_dem = _world_dem + topapi +else: # your .topoapi does not exist + raise ValueError('Add your Open Topo API key to `~/.topoapi`.' + 'Refer to ARIAtools installation instructions.') def createParser(): """Extract specified product layers. The default will export all layers.""" @@ -108,7 +121,7 @@ def __call__(self, line, pix, h): class metadata_qualitycheck: """Metadata quality control function. - + Artifacts recognized based off of covariance of cross-profiles. Bug-fix varies based off of layer of interest. Verbose mode generates a series of quality control plots with @@ -134,7 +147,7 @@ def __init__(self, data_array, prod_key, outname, verbose=None): def __truncateArray__(self, data_array_band, Xmask, Ymask): # Mask columns/rows which are entirely made up of 0s - # First must crop all columns with no valid values + #first must crop all columns with no valid values nancols=np.all(data_array_band.mask == True, axis=0) data_array_band=data_array_band[:,~nancols] Xmask=Xmask[:,~nancols] @@ -156,10 +169,10 @@ def __getCovar__(self, prof_direc, profprefix=''): self.data_array_band, Xmask, Ymask = self.__truncateArray__( self.data_array_band, Xmask, Ymask) - # Append prefix for plot names + #append prefix for plot names prof_direc = profprefix + prof_direc - # Iterate through transpose of matrix if looking in azimuth + #iterate through transpose of matrix if looking in azimuth arrT='' if 'azimuth' in prof_direc: arrT='.T' @@ -169,36 +182,35 @@ def __getCovar__(self, prof_direc, profprefix=''): for i in enumerate(eval('self.data_array_band%s'%(arrT))): mid_line=i[1] xarr=np.array(range(len(mid_line))) - # Remove masked values from slice + #remove masked values from slice if mid_line.mask.size!=1: if True in mid_line.mask: xarr=xarr[~mid_line.mask] mid_line=mid_line[~mid_line.mask] - # Chunk array to better isolate artifacts + #chunk array to better isolate artifacts chunk_size= 4 for j in range(0, len(mid_line.tolist()), chunk_size): chunk = mid_line.tolist()[j:j+chunk_size] xarr_chunk = xarr[j:j+chunk_size] - # Make sure each iteration contains at least minimum number of elements + # make sure each iteration contains at least minimum number of elements if j==range(0, len(mid_line.tolist()), chunk_size)[-2] and \ len(mid_line.tolist()) % chunk_size != 0: chunk = mid_line.tolist()[j:] xarr_chunk = xarr[j:] - # Linear regression and get covariance + #linear regression and get covariance slope, bias, rsquared, p_value, std_err = linregress(xarr_chunk,chunk) rsquaredarr.append(abs(rsquared)**2) std_errarr.append(std_err) - # Terminate early if last iteration would have small chunk size + #terminate early if last iteration would have small chunk size if len(chunk)>chunk_size: break - # Exit loop/make plots in verbose mode if R^2 and standard error - # anomalous, or if on last iteration + #exit loop/make plots in verbose mode if R^2 and standard error anomalous, or if on last iteration if (min(rsquaredarr) < 0.9 and max(std_errarr) > 0.01) or \ (i[0]==(len(eval('self.data_array_band%s'%(arrT)))-1)): if self.verbose: - # Make quality-control plots + #Make quality-control plots import matplotlib.pyplot as plt ax0=plt.figure().add_subplot(111) ax0.scatter(xarr, mid_line, c='k', s=7) @@ -286,13 +298,13 @@ def __run__(self): self.data_array_band, Xmask, Ymask) # truncated grid covering the domain of the data - Xmask = Xmask[~self.data_array_band.mask] - Ymask = Ymask[~self.data_array_band.mask] + Xmask=Xmask[~self.data_array_band.mask] + Ymask=Ymask[~self.data_array_band.mask] self.data_array_band = self.data_array_band[~self.data_array_band.mask] XX = Xmask.flatten() YY = Ymask.flatten() A = np.c_[XX, YY, np.ones(len(XX))] - C, _, _, _ = lstsq(A, self.data_array_band.data.flatten()) + C,_,_,_ = lstsq(A, self.data_array_band.data.flatten()) # evaluate it on grid self.data_array_band = C[0]*X + C[1]*Y + C[2] #mask by nodata value @@ -305,7 +317,7 @@ def __run__(self): #update band self.data_array.GetRasterBand(i).WriteArray(self.data_array_band.filled()) # Pass warning and get R^2/standard error across range/azimuth (only do for first band) - if i == 1: + if i==1: # make sure appropriate unit is passed to print statement lyrunit = "\N{DEGREE SIGN}" if self.prod_key=='bPerpendicular' or self.prod_key=='bParallel': @@ -443,9 +455,9 @@ def merged_productbbox(metadata_dict, product_dict, workdir='./', bbox_file=None # If specified, check if user's bounding box meets minimum threshold area if bbox_file is not None: - user_bbox = open_shapefile(bbox_file, 0, 0) - overlap_area = shapefile_area(user_bbox) - if overlap_area < minimumOverlap: + user_bbox=open_shapefile(bbox_file, 0, 0) + overlap_area=shapefile_area(user_bbox) + if overlap_area 1: - # Initiate the residuals and design matrix - residualcycles = np.zeros((self.nfiles-1, 1)) - residualrange = np.zeros((self.nfiles-1, 1)) - A = np.zeros((self.nfiles-1, self.nfiles)) - - # The files are already sorted in the ARIAproduct class, will - # make consecutive overlaps between these sorted products + '''Function that will calculate the number of cycles each component needs to be shifted in order to minimize the two-pi modulu residual between a neighboring component. Outputs a fileMappingDict with as key a file number. Within fileMappingDict with a integer phase shift value for each unique connected component. + ''' + + # only need to comptue the minimize the phase offset if the number of files is larger than 2 + if self.nfiles>1: + + # initiate the residuals and design matrix + residualcycles = np.zeros((self.nfiles-1,1)) + residualrange = np.zeros((self.nfiles-1,1)) + A = np.zeros((self.nfiles-1,self.nfiles)) + + # the files are already sorted in the ARIAproduct class, will make consecutive overlaps between these sorted products for counter in range(self.nfiles-1): - # Getting the two neighboring frames + # getting the two neighboring frames bbox_frame1 = self.prodbbox[counter] bbox_frame2 = self.prodbbox[counter+1] - # Determine the intersection between the two frames + # determining the intersection between the two frames if not bbox_frame1.intersects(bbox_frame2): - log.error("Products do not overlap or were not " \ - "provided in a contigious sorted list.") + log.error("Products do not overlap or were not provided in a contigious sorted list.") raise Exception polyOverlap = bbox_frame1.intersection(bbox_frame2) - # Will save the geojson under a temp local filename - # Do this just to get the file outname - tmfile = tempfile.NamedTemporaryFile(mode='w+b', suffix='.json', - prefix='Overlap_', dir='.') + # will save the geojson under a temp local filename + tmfile = tempfile.NamedTemporaryFile(mode='w+b',suffix='.json', prefix='Overlap_', dir='.') outname = tmfile.name - - # Remove it as GDAL polygonize function cannot overwrite files + # will remove it as GDAL polygonize function cannot overwrite files tmfile.close() tmfile = None - - # Save the temp geojson + # saving the temp geojson save_shapefile(outname, polyOverlap, 'GeoJSON') - # Calculate the mean of the phase for each product in - # the overlap region alone. - # Will first attempt to mask out connected component 0, - # and default to complete overlap if this fails. - # Cropping the unwrapped phase and connected component - # to the overlap region alone, inhereting the no-data. - - # Connected component - out_data, connCompNoData1, geoTrans, proj = GDALread( - self.ccFile[counter], data_band=1, loadData=False) - out_data, connCompNoData2, geoTrans, proj = GDALread( - self.ccFile[counter+1], data_band=1, loadData=False) - - connCompFile1 = gdal.Warp('', self.ccFile[counter], - options=gdal.WarpOptions(format="MEM", - cutlineDSName=outname, - dstAlpha=True, - outputBounds=polyOverlap.bounds, - dstNodata=connCompNoData1)) - connCompFile2 = gdal.Warp('', self.ccFile[counter+1], - options=gdal.WarpOptions(format="MEM", - cutlineDSName=outname, - dstAlpha=True, - outputBounds=polyOverlap.bounds, - dstNodata=connCompNoData2)) - - # Reformat output bounds for GDAL translate - ulx, lry, lrx, uly = polyOverlap.bounds - projWin = (ulx, uly, lrx, lry) - - # Unwrapped phase - out_data, unwNoData1, geoTrans, proj = GDALread( - self.inpFile[counter], data_band=1, loadData=False) - out_data, unwNoData2, geoTrans, proj = GDALread( - self.inpFile[counter+1], data_band=1, loadData=False) - - unwFile1 = gdal.Translate('', self.inpFile[counter], - options=gdal.TranslateOptions(format="MEM", - projWin=projWin, - noData=unwNoData1)) - unwFile2 = gdal.Translate('', self.inpFile[counter+1], - options=gdal.TranslateOptions(format="MEM", - projWin=projWin, - noData=unwNoData2)) - - # Find the component with the largest overlap - connCompData1 = connCompFile1.GetRasterBand(1).ReadAsArray() - connCompData1[((connCompData1==connCompNoData1) - | (connCompData1==0))] = np.nan - connCompData2 = connCompFile2.GetRasterBand(1).ReadAsArray() - connCompData2[((connCompData2==connCompNoData2) - | (connCompData2==0))] = np.nan + # calculate the mean of the phase for each product in the overlap region alone + # will first attempt to mask out connected component 0, and default to complete overlap if this fails. + # Cropping the unwrapped phase and connected component to the overlap region alone, inhereting the no-data. + # connected component + out_data,connCompNoData1,geoTrans,proj = GDALread(self.ccFile[counter],data_band=1,loadData=False) + out_data,connCompNoData2,geoTrans,proj = GDALread(self.ccFile[counter+1],data_band=1,loadData=False) + connCompFile1 = gdal.Warp('', self.ccFile[counter], options=gdal.WarpOptions(format="MEM", cutlineDSName=outname, outputBounds=polyOverlap.bounds, dstNodata=connCompNoData1)) + connCompFile2 = gdal.Warp('', self.ccFile[counter+1], options=gdal.WarpOptions(format="MEM", cutlineDSName=outname, outputBounds=polyOverlap.bounds, dstNodata=connCompNoData2)) + + + # unwrapped phase + out_data,unwNoData1,geoTrans,proj = GDALread(self.inpFile[counter],data_band=1,loadData=False) + out_data,unwNoData2,geoTrans,proj = GDALread(self.inpFile[counter+1],data_band=1,loadData=False) + unwFile1 = gdal.Warp('', self.inpFile[counter], options=gdal.WarpOptions(format="MEM", cutlineDSName=outname, outputBounds=polyOverlap.bounds, dstNodata=unwNoData1)) + unwFile2 = gdal.Warp('', self.inpFile[counter+1], options=gdal.WarpOptions(format="MEM", cutlineDSName=outname, outputBounds=polyOverlap.bounds, dstNodata=unwNoData2)) + + + # finding the component with the largest overlap + connCompData1 =connCompFile1.GetRasterBand(1).ReadAsArray() + connCompData1[(connCompData1==connCompNoData1) | (connCompData1==0)]=np.nan + connCompData2 =connCompFile2.GetRasterBand(1).ReadAsArray() + connCompData2[(connCompData2==connCompNoData2) | (connCompData2==0)]=np.nan connCompData2_temp = (connCompData2*100) - connCompDiff = (connCompData2_temp.astype(np.int) - - connCompData1.astype(np.int)) - connCompDiff[(connCompDiff<0) | (connCompDiff>2000)] = 0 - temp_count = collections.Counter(connCompDiff.flatten()) + temp = connCompData2_temp.astype(np.int)-connCompData1.astype(np.int) + temp[(temp<0) | (temp>2000)]=0 + temp_count = collections.Counter(temp.flatten()) maxKey = 0 maxCount = 0 for key, keyCount in temp_count.items(): - if key != 0: - if keyCount > maxCount: - maxKey = key - maxCount = keyCount - print('Max key, count:', maxKey, maxCount) - - # If the max key count is 0, this means there is no good - # overlap region between products. + if key!=0: + if keyCount>maxCount: + maxKey =key + maxCount=keyCount + + # if the max key count is 0, this means there is no good overlap region between products. # In that scenario default to different stitching approach. - if maxKey != 0 and maxCount > 75: - # Masking the unwrapped phase and only use the - # largest overlapping connected component + if maxKey!=0 and maxCount>75: + # masking the unwrapped phase and only use the largest overlapping connected component unwData1 = unwFile1.GetRasterBand(1).ReadAsArray() - unwData2 = unwFile2.GetRasterBand(1).ReadAsArray() - - cutlineMask1 = connCompFile1.GetRasterBand(2).ReadAsArray() - cutlineMask2 = connCompFile2.GetRasterBand(2).ReadAsArray() - - unwData1[((unwData1==unwNoData1) - | (connCompDiff!=maxKey) - | (cutlineMask1==0))] = np.nan - unwData2[((unwData2==unwNoData2) - | (connCompDiff!=maxKey) - | (cutlineMask2==0))] = np.nan + unwData1[(unwData1==unwNoData1) | (temp!=maxKey)]=np.nan + unwData2 =unwFile2.GetRasterBand(1).ReadAsArray() + unwData2[(unwData2==unwNoData2) | (temp!=maxKey)]=np.nan # Calculation of the range correction - unwData1_wrapped = (unwData1 - - np.round(unwData1/(2*np.pi))*(2*np.pi)) - unwData2_wrapped = (unwData2 - - np.round(unwData2/(2*np.pi))*(2*np.pi)) - arr = unwData1_wrapped - unwData2_wrapped + unwData1_wrapped = unwData1-np.round(unwData1/(2*np.pi))*(2*np.pi) + unwData2_wrapped =unwData2-np.round(unwData2/(2*np.pi))*(2*np.pi) + arr =unwData1_wrapped-unwData2_wrapped - # Data is not fully decorrelated + # data is not fully decorrelated arr = arr - np.round(arr/(2*np.pi))*2*np.pi - range_temp = np.angle(np.nanmean(np.exp(1j*arr))) + range_temp = np.angle(np.nanmean(np.exp(1j*arr))) - # Calculation of the number of 2PI cycles accounting - # for range correction - corrected_range = unwData1 - (unwData2+range_temp) - cycles_temp = np.round((np.nanmean(corrected_range))/(2*np.pi)) - print(cycles_temp) + # calculation of the number of 2 pi cycles accounting for range correction + cycles_temp = np.round((np.nanmean(unwData1-(unwData2+range_temp)))/(2*np.pi)) else: - # Account for the case that no data was left, e.g. - # fully decorrelated - # In that scenario use all data and estimate from - # wrapped, histogram will be broader ... + # account for the case that no-data was left, e.g. fully decorrelated + # in that scenario use all data and estimate from wrapped, histogram will be broader... unwData1 = unwFile1.GetRasterBand(1).ReadAsArray() - unwData2 = unwFile2.GetRasterBand(1).ReadAsArray() - - cutlineMask1 = connCompFile1.GetRasterBand(2).ReadAsArray() - cutlineMask2 = connCompFile2.GetRasterBand(2).ReadAsArray() - - unwData1[((unwData1==unwNoData1) - | (cutlineMask1==0))] = np.nan - unwData2[((unwData2==unwNoData2) - | (cutlineMask2==0))] = np.nan - + unwData1[(unwData1==unwNoData1)] + unwData2 =unwFile2.GetRasterBand(1).ReadAsArray() + unwData2[(unwData2==unwNoData2)] # Calculation of the range correction - unwData1_wrapped = (unwData1 - - np.round(unwData1/(2*np.pi))*(2*np.pi)) - unwData2_wrapped = (unwData2 - - np.round(unwData2/(2*np.pi))*(2*np.pi)) - arr = unwData1_wrapped - unwData2_wrapped + unwData1_wrapped = unwData1-np.round(unwData1/(2*np.pi))*(2*np.pi) + unwData2_wrapped =unwData2-np.round(unwData2/(2*np.pi))*(2*np.pi) + arr =unwData1_wrapped-unwData2_wrapped arr = arr - np.round(arr/(2*np.pi))*2*np.pi - range_temp = np.angle(np.nanmean(np.exp(1j*arr))) + range_temp = np.angle(np.nanmean(np.exp(1j*arr))) - # Data is decorelated assume no 2-pi cycles + # data is decorelated assume no 2-pi cycles cycles_temp = 0 - # Closing the files + + # closing the files unwFile1 = None unwFile2 = None connCompFile1 = None connCompFile2 = None - # Remove the tempfile + # remove the tempfile shutil.os.remove(outname) - # Store the residual and populate the design matrix - residualcycles[counter] = cycles_temp - residualrange[counter] = range_temp - A[counter,counter] = 1 - A[counter,counter+1] = -1 - - # Invert the offsets with respect to the first product - cycles = np.round(np.linalg.lstsq(A[:,1:], - residualcycles,rcond=None)[0]) - rangesoffset = np.linalg.lstsq(A[:,1:], - residualrange,rcond=None)[0] + # store the residual and populate the design matrix + residualcycles[counter]=cycles_temp + residualrange[counter]=range_temp + A[counter,counter]=1 + A[counter,counter+1]=-1 + + + # invert the offsets with respect to the first product + cycles = np.round(np.linalg.lstsq(A[:,1:], residualcycles,rcond=None)[0]) + rangesoffset = np.linalg.lstsq(A[:,1:], residualrange,rcond=None)[0] #pdb.set_trace() - # Force first product to have 0 as offset + # force first product to have 0 as offset cycles = -1*np.concatenate((np.zeros((1,1)), cycles), axis=0) rangesoffset = -1*np.concatenate((np.zeros((1,1)), rangesoffset), axis=0) else: - # Nothing to be done, i.e. no phase cycles to be added + # nothing to be done, i.e. no phase cycles to be added cycles = np.zeros((1,1)) rangesoffset = np.zeros((1,1)) - # Build the mapping dictionary + # build the mapping dictionary fileMappingDict = {} - connCompOffset = 0 + connCompOffset =0 for fileCounter in range(self.nfiles): - # Get the number of connected components + + # get the number of connected components n_comp = 20 # The original connected components @@ -585,43 +501,42 @@ def __calculateCyclesOverlap__(self): # Generate the mapping of connectedComponents # Increment based on unique components of the merged product, 0 is grouped for all products connCompMapping = connComp - connCompMapping[1:] = connComp[1:]+connCompOffset + connCompMapping[1:]=connComp[1:]+connCompOffset # concatenate and generate a connComponent based mapping matrix # [original comp, new component, 2pi unw offset] - connCompMapping = np.concatenate( - (connComp, connCompMapping, cycleMapping, rangeOffsetMapping), - axis=1) + connCompMapping = np.concatenate((connComp, connCompMapping,cycleMapping,rangeOffsetMapping), axis=1) # Increment the count of total number of unique components - connCompOffset = connCompOffset + n_comp + connCompOffset = connCompOffset+n_comp - # Populate mapping dictionary for each product + # populate mapping dictionary for each product fileMappingDict_temp = {} fileMappingDict_temp['connCompMapping'] = connCompMapping - fileMappingDict_temp['connFile'] = self.ccFile[fileCounter] + fileMappingDict_temp['connFile'] = self.ccFile[fileCounter] fileMappingDict_temp['unwFile'] = self.inpFile[fileCounter] + # store it in the general mapping dictionary + fileMappingDict[fileCounter]=fileMappingDict_temp - # Store it in the general mapping dictionary - fileMappingDict[fileCounter] = fileMappingDict_temp - - # Pass the fileMapping back into self + # pass the fileMapping back into self self.fileMappingDict = fileMappingDict - class UnwrapComponents(Stitching): - """Stiching/unwrapping using 2-Stage Phase Unwrapping.""" + ''' + Stiching/unwrapping using 2-Stage Phase Unwrapping + ''' def __init__(self): - """Inheret properties from the parent class. - Parse the filenames and bbox as None as they need to be set by - the user, which will be caught when running the class. - """ + ''' + Inheret properties from the parent class + Parse the filenames and bbox as None as they need to be set by the user, which will be caught when running the class + ''' Stitching.__init__(self) def unwrapComponents(self): - ## Setting the method + + ## setting the method self.setStitchMethod("2stage") self.region=5 @@ -664,6 +579,7 @@ def unwrapComponents(self): ## Write out merged phase and connected component files self.__createImages__() + return def __populatePolyTable__(self): @@ -1196,16 +1112,13 @@ def GDALread(filename,data_band=1,loadData=True): def createConnComp_Int(inputs): - """Function to generate intermediate connected component files and - unwrapped VRT files that have with interger 2pi pixel shift applied. - - Will parse inputs in a single argument as it allows for parallel - processing. - - Return a list of files in a unqiue temp folder. - """ + ''' + Function to generate intermediate connected component files and unwrapped VRT files that have with interger 2pi pixel shift applied. + Will parse inputs in a single argument as it allows for parallel processing. + Return a list of files in a unqiue temp folder + ''' - # Parse the inputs to variables + # parsing the inputs to variables saveDir = inputs['saveDir'] saveNameID = inputs['saveNameID'] connFile = inputs['connFile'] @@ -1214,62 +1127,51 @@ def createConnComp_Int(inputs): ## Generating the intermediate files ## STEP 1: set-up the mapping functions - # Load the connected component - connData, connNoData, connGeoTrans, connProj = GDALread(connFile) + # loading the connected component + connData,connNoData,connGeoTrans,connProj = GDALread(connFile) - ## Define the mapping tables - # Set up the connected component unique ID mapping + ## Defining the mapping tables + # setting up the connected component unique ID mapping connIDMapping = connCompMapping[:,1] - - # Add the no-data to the mapping as well (such that we can handle a - # no-data region) - # I.e. num comp + 1 = Nodata connected component which gets mapped - # to no-data value again + # Will add the no-data to the mapping as well (such we can handle no-data region) + # i.e. num comp + 1 = Nodata connected component which gets mapped to no-data value again NoDataMapping = len(connIDMapping) connIDMapping = np.append(connIDMapping, [connNoData]) - # Set up the connected component integer 2PI shift mapping + # setting up the connected component integer 2PI shift mapping intMapping = connCompMapping[:,2] - - # Add the no-data to the mapping as well (such that we can handle a - # no-data region) - # I.e. max comp ID + 1 = Nodata connected component which gets - # mapped to 0 integer shift such that the no-data region remains - # unaffected + # Will add the no-data to the mapping as well (such we can handle no-data region) + # i.e. max comp ID + 1 = Nodata connected component which gets mapped to 0 integer shift such no-data region remains unaffected intMapping = np.append(intMapping, [0]) - # Update the connected component with the new no-data value used - # in the mapping - connData[connData==connNoData] = NoDataMapping + # update the connected component with the new no-data value used in the mapping + connData[connData==connNoData]=NoDataMapping ## STEP 2: apply the mapping functions - # Interger 2PI scaling mapping for unw phase - intShift = intMapping[connData.astype('int')] + # interger 2PI scaling mapping for unw phase - # Connected component mapping to unique ID + intShift = intMapping[connData.astype('int')] + # connected component mapping to unique ID connData = connIDMapping[connData.astype('int')] - ## STEP 3: write out the datasets - # Write out the unqiue ID connected component file - connDataName = os.path.abspath(os.path.join(saveDir, 'connComp', - saveNameID+'_connComp.tif')) - write_ambiguity(connData, connDataName, connProj, connGeoTrans, connNoData) + ## STEP 3: writing out the datasets + # writing out the unqiue ID connected component file + connDataName = os.path.abspath(os.path.join(saveDir,'connComp', saveNameID + '_connComp.tif')) + write_ambiguity(connData,connDataName,connProj,connGeoTrans,connNoData) + + # writing out the integer map as tiff file + intShiftName = os.path.abspath(os.path.join(saveDir,'unw',saveNameID + '_intShift.tif')) + write_ambiguity(intShift,intShiftName,connProj,connGeoTrans) - # Write out the integer map as tiff file - intShiftName = os.path.abspath(os.path.join(saveDir, 'unw', - saveNameID+'_intShift.tif')) - write_ambiguity(intShift, intShiftName, connProj, connGeoTrans) - # Write out the scalled vrt => 2PI * integer map + # writing out the scalled vrt => 2PI * integer map length = intShift.shape[0] width = intShift.shape[1] - scaleVRTName = os.path.abspath( - os.path.join(saveDir, 'unw', saveNameID+'_scale.vrt')) - build2PiScaleVRT(scaleVRTName, intShiftName, length=length, width=width) + scaleVRTName = os.path.abspath(os.path.join(saveDir,'unw',saveNameID + '_scale.vrt')) + build2PiScaleVRT(scaleVRTName,intShiftName,length=length,width=width) - # Offset the VRT for the range offset correction - unwRangeOffsetVRTName = os.path.abspath( - os.path.join(saveDir, 'unw', saveNameID + '_rangeOffset.vrt')) + # Offseting the vrt for the range offset correctiom + unwRangeOffsetVRTName = os.path.abspath(os.path.join(saveDir,'unw',saveNameID + '_rangeOffset.vrt')) buildScaleOffsetVRT(unwRangeOffsetVRTName,unwFile,connProj,connGeoTrans,File1_offset=connCompMapping[1,3],length=length,width=width) # writing out the corrected unw phase vrt => phase + 2PI * integer map @@ -1278,45 +1180,49 @@ def createConnComp_Int(inputs): return [connDataName, unwVRTName] -def write_ambiguity(data, outName, proj, geoTrans, noData=False): - """Write out an integer mapping in the Int16/Byte data range of values.""" +def write_ambiguity(data, outName,proj, geoTrans,noData=False): + ''' + Write out an integer mapping in the Int16/Byte data range of values + ''' + # GDAL precision support in tif Byte = gdal.GDT_Byte Int16 = gdal.GDT_Int16 - # Check if the path to the file needs to be created + # check if the path to the file needs to be created dirname = os.path.dirname(outName) if not os.path.isdir(dirname): os.makedirs(dirname) - # Get the GEOTIFF driver + # Getting the GEOTIFF driver driver = gdal.GetDriverByName('GTIFF') - # Leverage the compression option to ensure small file size + # leverage the compression option to ensure small file size dst_options = ['COMPRESS=LZW'] - # Create the dataset - ds = driver.Create(outName , data.shape[1], data.shape[0], 1, Int16, - dst_options) - # Set the proj and transformation + # create the dataset + ds = driver.Create(outName , data.shape[1], data.shape[0], 1, Int16, dst_options) + # setting the proj and transformation ds.SetGeoTransform(geoTrans) ds.SetProjection(proj) - # Populate the first band with data + # populate the first band with data bnd = ds.GetRasterBand(1) bnd.WriteArray(data) - # Set the no-data value + # setting the no-data value if noData is not None: bnd.SetNoDataValue(noData) bnd.FlushCache() - # Close the file + # close the file ds = None -def build2PiScaleVRT(output, File, width=False, length=False): - """Build a VRT file which scales a GDAL byte file with 2PI.""" +def build2PiScaleVRT(output,File,width=False,length=False): + ''' + Building a VRT file which scales a GDAL byte file with 2PI + ''' # DBTODO: The datatype should be loaded by default from the source raster to be applied. # should be ok for now, but could be an issue for large connected com - # The VRT template with 2PI scaling functionality + # the vrt template with 2-pi scaling functionality vrttmpl = ''' @@ -1329,28 +1235,28 @@ def build2PiScaleVRT(output, File, width=False, length=False): ''' - # The inputs needed to build the VRT - # Load the width and length from the GDAL file in case not specified + # the inputs needed to build the vrt + # load the width and length from the GDAL file in case not specified if not width or not length: ds = gdal.Open(File, gdal.GA_ReadOnly) width = ds.RasterXSize ysize = ds.RasterYSize ds = None - # Check if the path to the file needs to be created + # check if the path to the file needs to be created dirname = os.path.dirname(output) if not os.path.isdir(dirname): os.makedirs(dirname) - # Write out the VRT file + # write out the VRT file with open(output, 'w') as fid: fid.write(vrttmpl.format(width = width, length = length, File = File)) def buildScaleOffsetVRT(output,File1,proj,geoTrans,File1_offset=0, File1_scale = 1, width=False,length=False,description='Scalled and offsetted VRT'): - """Building a VRT file which sums two files together using pixel - functionality. - """ + ''' + Building a VRT file which sums two files together using pixel functionality + ''' # the vrt template with sum pixel functionality vrttmpl = ''' @@ -1369,29 +1275,30 @@ def buildScaleOffsetVRT(output,File1,proj,geoTrans,File1_offset=0, File1_scale = ''' # - # The inputs needed to build the VRT - # Load the width and length from the GDAL file in case not specified + + # the inputs needed to build the vrt + # load the width and length from the GDAL file in case not specified if not width or not length: ds = gdal.Open(File1, gdal.GA_ReadOnly) width = ds.RasterXSize ysize = ds.RasterYSize ds = None - # Check if the path to the file needs to be created + # check if the path to the file needs to be created dirname = os.path.dirname(output) if not os.path.isdir(dirname): os.makedirs(dirname) - # Write out the VRT file + # write out the VRT file with open('{0}'.format(output) , 'w') as fid: fid.write( vrttmpl.format(width=width,length=length,File1=File1,File1_offset=File1_offset, File1_scale = File1_scale, proj=proj,geoTrans=str(geoTrans)[1:-1],description=description)) def buildSumVRT(output,File1,File2,proj,geoTrans,length=False, width=False,description='Unwrapped Phase'): - """Building a VRT file which sums two files together using pixel - functionality. - """ + ''' + Building a VRT file which sums two files together using pixel functionality + ''' - # The VRT template with sum pixel functionality + # the vrt template with sum pixel functionality vrttmpl = ''' {proj} {geoTrans} @@ -1409,20 +1316,20 @@ def buildSumVRT(output,File1,File2,proj,geoTrans,length=False, width=False,descr ''' - # The inputs needed to build the vrt - # Load the width and length from the GDAL file in case not specified + # the inputs needed to build the vrt + # load the width and length from the GDAL file in case not specified if not width or not length: ds = gdal.Open(File1, gdal.GA_ReadOnly) width = ds.RasterXSize ysize = ds.RasterYSize ds = None - # Check if the path to the file needs to be created + # check if the path to the file needs to be created dirname = os.path.dirname(output) if not os.path.isdir(dirname): os.makedirs(dirname) - # Write out the VRT file + # write out the VRT file with open('{0}'.format(output) , 'w') as fid: fid.write( vrttmpl.format(width=width,length=length,File1=File1,File2=File2, proj=proj,geoTrans=str(geoTrans)[1:-1],description=description)) @@ -1525,12 +1432,9 @@ def gdalTest(file, verbose=False): -def product_stitch_overlap(unw_files, conn_files, prod_bbox_files, bbox_file, - prods_TOTbbox, outFileUnw = './unwMerged', - outFileConnComp = './connCompMerged', outputFormat='ENVI', - mask=None, verbose=False): +def product_stitch_overlap(unw_files, conn_files, prod_bbox_files, bbox_file, prods_TOTbbox, outFileUnw = './unwMerged', outFileConnComp = './connCompMerged', outputFormat='ENVI', mask=None, verbose=False): ''' - Stitching of products minimizing overlap betnween products + Stitching of products minimizing overlap betnween products ''' # report method to user