From 78340ec75ee72c3fcacf7f874ca710a3b7ae78ec Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Tue, 8 Apr 2025 13:24:32 +1200 Subject: [PATCH] Add field cleanup function --- src/scaffoldfitter/fitter.py | 30 +++++++++++++++++++++++------- tests/test_general.py | 1 - 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/scaffoldfitter/fitter.py b/src/scaffoldfitter/fitter.py index 71c18ad..6e5945e 100644 --- a/src/scaffoldfitter/fitter.py +++ b/src/scaffoldfitter/fitter.py @@ -7,7 +7,7 @@ from cmlibs.maths.vectorops import add, mult, sub from cmlibs.utils.zinc.field import ( assignFieldParameters, createFieldFiniteElementClone, getGroupList, findOrCreateFieldFiniteElement, - findOrCreateFieldStoredMeshLocation, getUniqueFieldName, orphanFieldByName, create_jacobian_determinant_field) + find_or_create_field_stored_mesh_location, getUniqueFieldName, orphanFieldByName, create_jacobian_determinant_field) from cmlibs.utils.zinc.finiteelement import ( evaluate_field_nodeset_range, findNodeWithName, get_scalar_field_minimum_in_mesh) from cmlibs.utils.zinc.group import match_fitting_group_names @@ -117,6 +117,15 @@ def __init__(self, zincModelFileName: str=None, zincDataFileName: str=None, regi fitterStep = FitterStepConfig() self.addFitterStep(fitterStep) + def cleanup(self): + self._fitterSteps = [] + self._clearFields() + self._rawDataRegion = None + self._fieldmodule = None + self._region = None + self._logger = None + self._context = None + def decodeSettingsJSON(self, s: str, decoder): """ Define Fitter from JSON serialisation output by encodeSettingsJSON. @@ -410,7 +419,7 @@ def _defineCommonDataFields(self): with ChangeManager(self._fieldmodule): mesh = self.getHighestDimensionMesh() datapoints = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) - self._dataHostLocationField = findOrCreateFieldStoredMeshLocation( + self._dataHostLocationField = find_or_create_field_stored_mesh_location( self._fieldmodule, mesh, "data_location_" + mesh.getName(), managed=False) self._dataHostCoordinatesField = self._fieldmodule.createFieldEmbedded( self._modelCoordinatesField, self._dataHostLocationField) @@ -973,6 +982,8 @@ def _calculateMarkerDataLocations(self): markerDataLocationGroupSize = self._markerDataLocationGroup.getSize() markerNodeGroupSize = self._markerNodeGroup.getSize() if self.getDiagnosticLevel() > 0: + print(str(markerDataLocationGroupSize) + " of " + str(markerDataGroupSize) + + " marker data points have model locations") if markerDataLocationGroupSize < markerDataGroupSize: print("Warning: Only " + str(markerDataLocationGroupSize) + " of " + str(markerDataGroupSize) + " marker data points have model locations") @@ -1259,6 +1270,8 @@ def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup, nodeIter = dataGroup.createNodeiterator() node = nodeIter.next() dataProportionCounter = 0.5 + pointsProjected = 0 + outlierPointsRemoved = 0 while node.isValid(): dataProportionCounter += dataProportion if dataProportionCounter >= 1.0: @@ -1277,6 +1290,9 @@ def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup, dataProjectionLengths.append((node.getIdentifier(), projectionLength)) if (outlierLength <= 0.0) or (projectionLength <= outlierLength): dataProjectionNodesetGroup.addNode(node) + pointsProjected += 1 + else: + outlierPointsRemoved += 1 node = nodeIter.next() if outlierLength < 0.0: relativeOutlierLength = (1.0 + outlierLength) * maximumProjectionLength @@ -1284,11 +1300,11 @@ def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup, if projectionLength > relativeOutlierLength: node = dataGroup.findNodeByIdentifier(nodeIdentifier) dataProjectionNodesetGroup.removeNode(node) - pointsProjected = dataProjectionNodesetGroup.getSize() - sizeBefore - if pointsProjected < dataGroup.getSize(): - if self.getDiagnosticLevel() > 0: - print("Warning: Only " + str(pointsProjected) + " of " + str(dataGroup.getSize()) + - " data points projected for group " + groupName) + pointsProjected -= 1 + outlierPointsRemoved += 1 + if self.getDiagnosticLevel() > 0: + print(str(pointsProjected) + " of " + str(dataGroup.getSize()) + " data points projected for group " + + groupName + "; " + str(outlierPointsRemoved) + " outliers removed") # add to active group self._activeDataNodesetGroup.addNodesConditional(self._dataProjectionNodeGroupFields[meshDimension - 1]) return diff --git a/tests/test_general.py b/tests/test_general.py index f01b544..66d72cf 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -109,7 +109,6 @@ def test_fit_projection_subgroup(self): zinc_data_file_name = os.path.join(here, "resources", "nerve_path_data.exf") if i == 0: # use fitter with model and data files - region = None fitter = Fitter(zinc_model_file_name, zinc_data_file_name) fitter.setDiagnosticLevel(1) fitter.load()