Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 69 additions & 13 deletions src/scaffoldfitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

from cmlibs.maths.vectorops import add, mult, sub
from cmlibs.utils.zinc.field import (
assignFieldParameters, createFieldFiniteElementClone, getGroupList, findOrCreateFieldFiniteElement,
find_or_create_field_stored_mesh_location, getUniqueFieldName, orphanFieldByName, create_jacobian_determinant_field)
assignFieldParameters, createFieldFiniteElementClone, findOrCreateFieldFiniteElement,
find_or_create_field_stored_mesh_location, getUniqueFieldName, orphanFieldByName, create_jacobian_determinant_field,
get_group_list)
from cmlibs.utils.zinc.finiteelement import (
evaluate_field_nodeset_range, findNodeWithName, get_scalar_field_minimum_in_mesh)
from cmlibs.utils.zinc.group import match_fitting_group_names
from cmlibs.utils.zinc.general import ChangeManager
from cmlibs.utils.zinc.mesh import element_or_ancestor_is_in_mesh
from cmlibs.utils.zinc.region import copy_fitting_data
from cmlibs.utils.zinc.scene import SELECTION_GROUP_NAME
from cmlibs.zinc.context import Context
from cmlibs.zinc.element import Elementbasis, Elementfieldtemplate
from cmlibs.zinc.field import Field, FieldFindMeshLocation, FieldGroup
Expand All @@ -26,6 +28,9 @@
from scaffoldfitter.fitterstepfit import FitterStepFit


UNGROUPED_DATAPOINTS_GROUP_NAME = "Ungrouped_datapoints"


def _next_available_identifier(node_set, candidate):
node = node_set.findNodeByIdentifier(candidate)
while node.isValid():
Expand All @@ -34,6 +39,42 @@ def _next_available_identifier(node_set, candidate):
return candidate


def _is_internal_group(group):
if not group.isManaged():
return True

if group.getName() in [SELECTION_GROUP_NAME]:
return True

return False


def _create_ungrouped_datapoint_group(field_module, diagnostics_level):
"""
Create a group for ungrouped datapoints, if it does not already exist.
This is used to ensure that all datapoints are in a group, even if they are not
grouped by the data file.
"""
datapoints = field_module.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
if not datapoints.isValid():
return
ungrouped_group = field_module.findFieldByName(UNGROUPED_DATAPOINTS_GROUP_NAME)
if not ungrouped_group.isValid():
or_group = field_module.createFieldGroup()
for group in get_group_list(field_module):
if not _is_internal_group(group):
or_group = field_module.createFieldOr(or_group, group)

ungrouped_group = field_module.createFieldGroup()
ungrouped_group.setName(UNGROUPED_DATAPOINTS_GROUP_NAME)
ungrouped_group.setManaged(True)
ungrouped_nodeset_group = ungrouped_group.getOrCreateNodesetGroup(datapoints)
ungrouped_nodeset_group.addNodesConditional(field_module.createFieldNot(or_group))
if diagnostics_level > 0:
print(f"Created a group '{UNGROUPED_DATAPOINTS_GROUP_NAME}' for ungrouped datapoints,"
f" of size: {ungrouped_nodeset_group.getSize()}, from datapoints not already assigned into a group.")


class Fitter:

def __init__(self, zincModelFileName: str=None, zincDataFileName: str=None, region: Region=None):
Expand Down Expand Up @@ -468,6 +509,7 @@ def _loadData(self):
assert result == RESULT_OK, "Failed to load data file " + str(self._zincDataFileName)
data_fieldmodule = self._rawDataRegion.getFieldmodule()
with ChangeManager(data_fieldmodule):
_create_ungrouped_datapoint_group(data_fieldmodule, self.getDiagnosticLevel())
match_fitting_group_names(data_fieldmodule, self._fieldmodule,
log_diagnostics=self.getDiagnosticLevel() > 0)
copy_fitting_data(self._region, self._rawDataRegion)
Expand Down Expand Up @@ -616,6 +658,7 @@ def assignDataWeights(self, fitterStepFit: FitterStepFit):
# Future: divide by number of data points?
coordinatesCount = self._modelCoordinatesField.getNumberOfComponents()
with ChangeManager(self._fieldmodule):
fieldassignment = None
for groupName in self._dataProjectionGroupNames:
group = self._fieldmodule.findFieldByName(groupName).castGroup()
if not group.isValid():
Expand Down Expand Up @@ -676,7 +719,8 @@ def assignDataWeights(self, fitterStepFit: FitterStepFit):
result = fieldassignment.assign()
if result != RESULT_OK:
print('Incomplete assignment of marker data weight', result)
del fieldassignment
if fieldassignment is not None:
del fieldassignment

def assignDeformationPenalties(self, fitterStepFit: FitterStepFit):
"""
Expand All @@ -697,7 +741,7 @@ def assignDeformationPenalties(self, fitterStepFit: FitterStepFit):
curvatureComponents = coordinatesCount * meshDimension * meshDimension
groups = []
# add None for default group
for group in (getGroupList(self._fieldmodule) + [None]):
for group in (get_group_list(self._fieldmodule) + [None]):
if group:
meshGroup = group.getMeshGroup(mesh)
if (not meshGroup.isValid()) or (meshGroup.getSize() == 0):
Expand Down Expand Up @@ -1327,24 +1371,24 @@ def getGroupDataProjectionNodesetGroup(self, group: FieldGroup):
return dataNodesetGroup
return None

def getGroupDataProjectionMeshGroup(self, group: FieldGroup, fitterStep: FitterStep):
def getGroupDataProjectionMeshGroup(self, group: FieldGroup, fitter_step: FitterStep):
"""
Get mesh group for 2D if not 1D mesh containing elements for projecting data in group, if any.
If there is a subgroup set or inherited by the group, finds or builds on demand an equal or lower
dimensional intersection mesh group.
:group: Zinc annotation group for which to get projection mesh group.
:fitterStep: FitterStep to get config for, with optional subgroup to intersect with.
:fitter_step: FitterStep to get config for, with optional subgroup to intersect with.
:return: MeshGroup or None if none or empty, findHighestDimension.
The second argument is True if any of the mesh group doesn't have a child-parent[-grandparent]
map to highest dimension elements in the model or model fit group, and hence projections must be
map to the highest dimension elements in the model or model fit group, and hence projections must be
re-found on the top-level mesh group.
"""
groupName = group.getName()
activeFitterStepConfig = self.getActiveFitterStepConfig(fitterStep)
activeFitterStepConfig = self.getActiveFitterStepConfig(fitter_step)
subgroup = activeFitterStepConfig.getGroupProjectionSubgroup(groupName)[0]
groupProjectionData = self._groupProjectionData.get(groupName)
if groupProjectionData:
if (groupProjectionData[0] == subgroup):
if groupProjectionData[0] == subgroup:
return groupProjectionData[1], groupProjectionData[2]
returnMeshGroup = None
findHighestDimension = False
Expand All @@ -1364,7 +1408,7 @@ def getGroupDataProjectionMeshGroup(self, group: FieldGroup, fitterStep: FitterS
subgroupMeshGroup = subgroup.getMeshGroup(mesh)
if subgroupMeshGroup.isValid() and (subgroupMeshGroup.getSize() > 0):
intersectionGroup = self._fieldmodule.createFieldGroup()
intersectionGroup.setName(groupName + " " + subgroup.getName()) # for debugging
intersectionGroup.setName(groupName + " & " + subgroup.getName()) # for debugging
returnMeshGroup = intersectionGroup.createMeshGroup(mesh)
returnMeshGroup.addElementsConditional(self._fieldmodule.createFieldAnd(group, subgroup))
if returnMeshGroup.getSize() > 0:
Expand All @@ -1374,9 +1418,21 @@ def getGroupDataProjectionMeshGroup(self, group: FieldGroup, fitterStep: FitterS
else:
returnMeshGroup = meshGroup
break
elif groupName == UNGROUPED_DATAPOINTS_GROUP_NAME:
meshGroup = group.getOrCreateMeshGroup(mesh)
el_iterator = mesh.createElementiterator()
element = el_iterator.next()
while element.isValid():
# add all elements in mesh to ungrouped group
meshGroup.addElement(element)
element = el_iterator.next()
if meshGroup.isValid() and (meshGroup.getSize() > 0):
returnMeshGroup = meshGroup
break

if returnMeshGroup:
# do any elements in returnMeshGroup NOT have self or an ancestor in modelFitMesh?
# If so, need to find highest dimension mesh location by a second FindMeshLocation field when projecting
# If so, need to find the highest dimension mesh location by a second FindMeshLocation field when projecting
modelFitMesh = highestDimensionMesh
if self._modelFitGroup:
modelFitMesh = self._modelFitGroup.getMeshGroup(highestDimensionMesh)
Expand Down Expand Up @@ -1411,13 +1467,13 @@ def calculateDataProjections(self, fitterStep: FitterStep):

datapoints = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
fieldcache = self._fieldmodule.createFieldcache()
groups = getGroupList(self._fieldmodule)
groups = get_group_list(self._fieldmodule)
for group in groups:
if not group.isManaged():
continue # skip cmiss_selection, for example
groupName = group.getName()
dataGroup = self.getGroupDataProjectionNodesetGroup(group)
if not dataGroup:
if dataGroup is None:
continue
meshGroup, findHighestDimension = self.getGroupDataProjectionMeshGroup(group, fitterStep)
if not meshGroup:
Expand Down