diff --git a/src/scaffoldfitter/fitter.py b/src/scaffoldfitter/fitter.py index 4641275..ab771e6 100644 --- a/src/scaffoldfitter/fitter.py +++ b/src/scaffoldfitter/fitter.py @@ -6,14 +6,16 @@ from cmlibs.maths.vectorops import add, mult, sub from cmlibs.utils.zinc.field import ( - assignFieldParameters, createFieldFiniteElementClone, getGroupList, findOrCreateFieldFiniteElement, - find_or_create_field_stored_mesh_location, getUniqueFieldName, orphanFieldByName, create_jacobian_determinant_field) + assignFieldParameters, createFieldFiniteElementClone, findOrCreateFieldFiniteElement, + find_or_create_field_stored_mesh_location, getUniqueFieldName, orphanFieldByName, create_jacobian_determinant_field, + get_group_list) from cmlibs.utils.zinc.finiteelement import ( evaluate_field_nodeset_range, findNodeWithName, get_scalar_field_minimum_in_mesh) from cmlibs.utils.zinc.group import match_fitting_group_names from cmlibs.utils.zinc.general import ChangeManager from cmlibs.utils.zinc.mesh import element_or_ancestor_is_in_mesh from cmlibs.utils.zinc.region import copy_fitting_data +from cmlibs.utils.zinc.scene import SELECTION_GROUP_NAME from cmlibs.zinc.context import Context from cmlibs.zinc.element import Elementbasis, Elementfieldtemplate from cmlibs.zinc.field import Field, FieldFindMeshLocation, FieldGroup @@ -26,6 +28,9 @@ from scaffoldfitter.fitterstepfit import FitterStepFit +UNGROUPED_DATAPOINTS_GROUP_NAME = "Ungrouped_datapoints" + + def _next_available_identifier(node_set, candidate): node = node_set.findNodeByIdentifier(candidate) while node.isValid(): @@ -34,6 +39,42 @@ def _next_available_identifier(node_set, candidate): return candidate +def _is_internal_group(group): + if not group.isManaged(): + return True + + if group.getName() in [SELECTION_GROUP_NAME]: + return True + + return False + + +def _create_ungrouped_datapoint_group(field_module, diagnostics_level): + """ + Create a group for ungrouped datapoints, if it does not already exist. + This is used to ensure that all datapoints are in a group, even if they are not + grouped by the data file. + """ + datapoints = field_module.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + if not datapoints.isValid(): + return + ungrouped_group = field_module.findFieldByName(UNGROUPED_DATAPOINTS_GROUP_NAME) + if not ungrouped_group.isValid(): + or_group = field_module.createFieldGroup() + for group in get_group_list(field_module): + if not _is_internal_group(group): + or_group = field_module.createFieldOr(or_group, group) + + ungrouped_group = field_module.createFieldGroup() + ungrouped_group.setName(UNGROUPED_DATAPOINTS_GROUP_NAME) + ungrouped_group.setManaged(True) + ungrouped_nodeset_group = ungrouped_group.getOrCreateNodesetGroup(datapoints) + ungrouped_nodeset_group.addNodesConditional(field_module.createFieldNot(or_group)) + if diagnostics_level > 0: + print(f"Created a group '{UNGROUPED_DATAPOINTS_GROUP_NAME}' for ungrouped datapoints," + f" of size: {ungrouped_nodeset_group.getSize()}, from datapoints not already assigned into a group.") + + class Fitter: def __init__(self, zincModelFileName: str=None, zincDataFileName: str=None, region: Region=None): @@ -468,6 +509,7 @@ def _loadData(self): assert result == RESULT_OK, "Failed to load data file " + str(self._zincDataFileName) data_fieldmodule = self._rawDataRegion.getFieldmodule() with ChangeManager(data_fieldmodule): + _create_ungrouped_datapoint_group(data_fieldmodule, self.getDiagnosticLevel()) match_fitting_group_names(data_fieldmodule, self._fieldmodule, log_diagnostics=self.getDiagnosticLevel() > 0) copy_fitting_data(self._region, self._rawDataRegion) @@ -616,6 +658,7 @@ def assignDataWeights(self, fitterStepFit: FitterStepFit): # Future: divide by number of data points? coordinatesCount = self._modelCoordinatesField.getNumberOfComponents() with ChangeManager(self._fieldmodule): + fieldassignment = None for groupName in self._dataProjectionGroupNames: group = self._fieldmodule.findFieldByName(groupName).castGroup() if not group.isValid(): @@ -676,7 +719,8 @@ def assignDataWeights(self, fitterStepFit: FitterStepFit): result = fieldassignment.assign() if result != RESULT_OK: print('Incomplete assignment of marker data weight', result) - del fieldassignment + if fieldassignment is not None: + del fieldassignment def assignDeformationPenalties(self, fitterStepFit: FitterStepFit): """ @@ -697,7 +741,7 @@ def assignDeformationPenalties(self, fitterStepFit: FitterStepFit): curvatureComponents = coordinatesCount * meshDimension * meshDimension groups = [] # add None for default group - for group in (getGroupList(self._fieldmodule) + [None]): + for group in (get_group_list(self._fieldmodule) + [None]): if group: meshGroup = group.getMeshGroup(mesh) if (not meshGroup.isValid()) or (meshGroup.getSize() == 0): @@ -1327,24 +1371,24 @@ def getGroupDataProjectionNodesetGroup(self, group: FieldGroup): return dataNodesetGroup return None - def getGroupDataProjectionMeshGroup(self, group: FieldGroup, fitterStep: FitterStep): + def getGroupDataProjectionMeshGroup(self, group: FieldGroup, fitter_step: FitterStep): """ Get mesh group for 2D if not 1D mesh containing elements for projecting data in group, if any. If there is a subgroup set or inherited by the group, finds or builds on demand an equal or lower dimensional intersection mesh group. :group: Zinc annotation group for which to get projection mesh group. - :fitterStep: FitterStep to get config for, with optional subgroup to intersect with. + :fitter_step: FitterStep to get config for, with optional subgroup to intersect with. :return: MeshGroup or None if none or empty, findHighestDimension. The second argument is True if any of the mesh group doesn't have a child-parent[-grandparent] - map to highest dimension elements in the model or model fit group, and hence projections must be + map to the highest dimension elements in the model or model fit group, and hence projections must be re-found on the top-level mesh group. """ groupName = group.getName() - activeFitterStepConfig = self.getActiveFitterStepConfig(fitterStep) + activeFitterStepConfig = self.getActiveFitterStepConfig(fitter_step) subgroup = activeFitterStepConfig.getGroupProjectionSubgroup(groupName)[0] groupProjectionData = self._groupProjectionData.get(groupName) if groupProjectionData: - if (groupProjectionData[0] == subgroup): + if groupProjectionData[0] == subgroup: return groupProjectionData[1], groupProjectionData[2] returnMeshGroup = None findHighestDimension = False @@ -1364,7 +1408,7 @@ def getGroupDataProjectionMeshGroup(self, group: FieldGroup, fitterStep: FitterS subgroupMeshGroup = subgroup.getMeshGroup(mesh) if subgroupMeshGroup.isValid() and (subgroupMeshGroup.getSize() > 0): intersectionGroup = self._fieldmodule.createFieldGroup() - intersectionGroup.setName(groupName + " " + subgroup.getName()) # for debugging + intersectionGroup.setName(groupName + " & " + subgroup.getName()) # for debugging returnMeshGroup = intersectionGroup.createMeshGroup(mesh) returnMeshGroup.addElementsConditional(self._fieldmodule.createFieldAnd(group, subgroup)) if returnMeshGroup.getSize() > 0: @@ -1374,9 +1418,21 @@ def getGroupDataProjectionMeshGroup(self, group: FieldGroup, fitterStep: FitterS else: returnMeshGroup = meshGroup break + elif groupName == UNGROUPED_DATAPOINTS_GROUP_NAME: + meshGroup = group.getOrCreateMeshGroup(mesh) + el_iterator = mesh.createElementiterator() + element = el_iterator.next() + while element.isValid(): + # add all elements in mesh to ungrouped group + meshGroup.addElement(element) + element = el_iterator.next() + if meshGroup.isValid() and (meshGroup.getSize() > 0): + returnMeshGroup = meshGroup + break + if returnMeshGroup: # do any elements in returnMeshGroup NOT have self or an ancestor in modelFitMesh? - # If so, need to find highest dimension mesh location by a second FindMeshLocation field when projecting + # If so, need to find the highest dimension mesh location by a second FindMeshLocation field when projecting modelFitMesh = highestDimensionMesh if self._modelFitGroup: modelFitMesh = self._modelFitGroup.getMeshGroup(highestDimensionMesh) @@ -1411,13 +1467,13 @@ def calculateDataProjections(self, fitterStep: FitterStep): datapoints = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) fieldcache = self._fieldmodule.createFieldcache() - groups = getGroupList(self._fieldmodule) + groups = get_group_list(self._fieldmodule) for group in groups: if not group.isManaged(): continue # skip cmiss_selection, for example groupName = group.getName() dataGroup = self.getGroupDataProjectionNodesetGroup(group) - if not dataGroup: + if dataGroup is None: continue meshGroup, findHighestDimension = self.getGroupDataProjectionMeshGroup(group, fitterStep) if not meshGroup: