Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/scaffoldfitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cmlibs.maths.vectorops import add, mult, sub
from cmlibs.utils.zinc.field import (
assignFieldParameters, createFieldFiniteElementClone, getGroupList, findOrCreateFieldFiniteElement,
findOrCreateFieldStoredMeshLocation, getUniqueFieldName, orphanFieldByName, create_jacobian_determinant_field)
find_or_create_field_stored_mesh_location, getUniqueFieldName, orphanFieldByName, create_jacobian_determinant_field)
from cmlibs.utils.zinc.finiteelement import (
evaluate_field_nodeset_range, findNodeWithName, get_scalar_field_minimum_in_mesh)
from cmlibs.utils.zinc.group import match_fitting_group_names
Expand Down Expand Up @@ -117,6 +117,15 @@ def __init__(self, zincModelFileName: str=None, zincDataFileName: str=None, regi
fitterStep = FitterStepConfig()
self.addFitterStep(fitterStep)

def cleanup(self):
self._fitterSteps = []
self._clearFields()
self._rawDataRegion = None
self._fieldmodule = None
self._region = None
self._logger = None
self._context = None

def decodeSettingsJSON(self, s: str, decoder):
"""
Define Fitter from JSON serialisation output by encodeSettingsJSON.
Expand Down Expand Up @@ -410,7 +419,7 @@ def _defineCommonDataFields(self):
with ChangeManager(self._fieldmodule):
mesh = self.getHighestDimensionMesh()
datapoints = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
self._dataHostLocationField = findOrCreateFieldStoredMeshLocation(
self._dataHostLocationField = find_or_create_field_stored_mesh_location(
self._fieldmodule, mesh, "data_location_" + mesh.getName(), managed=False)
self._dataHostCoordinatesField = self._fieldmodule.createFieldEmbedded(
self._modelCoordinatesField, self._dataHostLocationField)
Expand Down Expand Up @@ -973,6 +982,8 @@ def _calculateMarkerDataLocations(self):
markerDataLocationGroupSize = self._markerDataLocationGroup.getSize()
markerNodeGroupSize = self._markerNodeGroup.getSize()
if self.getDiagnosticLevel() > 0:
print(str(markerDataLocationGroupSize) + " of " + str(markerDataGroupSize) +
" marker data points have model locations")
if markerDataLocationGroupSize < markerDataGroupSize:
print("Warning: Only " + str(markerDataLocationGroupSize) +
" of " + str(markerDataGroupSize) + " marker data points have model locations")
Expand Down Expand Up @@ -1259,6 +1270,8 @@ def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup,
nodeIter = dataGroup.createNodeiterator()
node = nodeIter.next()
dataProportionCounter = 0.5
pointsProjected = 0
outlierPointsRemoved = 0
while node.isValid():
dataProportionCounter += dataProportion
if dataProportionCounter >= 1.0:
Expand All @@ -1277,18 +1290,21 @@ def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup,
dataProjectionLengths.append((node.getIdentifier(), projectionLength))
if (outlierLength <= 0.0) or (projectionLength <= outlierLength):
dataProjectionNodesetGroup.addNode(node)
pointsProjected += 1
else:
outlierPointsRemoved += 1
node = nodeIter.next()
if outlierLength < 0.0:
relativeOutlierLength = (1.0 + outlierLength) * maximumProjectionLength
for nodeIdentifier, projectionLength in dataProjectionLengths:
if projectionLength > relativeOutlierLength:
node = dataGroup.findNodeByIdentifier(nodeIdentifier)
dataProjectionNodesetGroup.removeNode(node)
pointsProjected = dataProjectionNodesetGroup.getSize() - sizeBefore
if pointsProjected < dataGroup.getSize():
if self.getDiagnosticLevel() > 0:
print("Warning: Only " + str(pointsProjected) + " of " + str(dataGroup.getSize()) +
" data points projected for group " + groupName)
pointsProjected -= 1
outlierPointsRemoved += 1
if self.getDiagnosticLevel() > 0:
print(str(pointsProjected) + " of " + str(dataGroup.getSize()) + " data points projected for group " +
groupName + "; " + str(outlierPointsRemoved) + " outliers removed")
# add to active group
self._activeDataNodesetGroup.addNodesConditional(self._dataProjectionNodeGroupFields[meshDimension - 1])
return
Expand Down
1 change: 0 additions & 1 deletion tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def test_fit_projection_subgroup(self):
zinc_data_file_name = os.path.join(here, "resources", "nerve_path_data.exf")
if i == 0:
# use fitter with model and data files
region = None
fitter = Fitter(zinc_model_file_name, zinc_data_file_name)
fitter.setDiagnosticLevel(1)
fitter.load()
Expand Down