diff --git a/src/scaffoldfitter/fitter.py b/src/scaffoldfitter/fitter.py index 4641275..9f7d312 100644 --- a/src/scaffoldfitter/fitter.py +++ b/src/scaffoldfitter/fitter.py @@ -470,7 +470,10 @@ def _loadData(self): with ChangeManager(data_fieldmodule): match_fitting_group_names(data_fieldmodule, self._fieldmodule, log_diagnostics=self.getDiagnosticLevel() > 0) - copy_fitting_data(self._region, self._rawDataRegion) + try: + copy_fitting_data(self._region, self._rawDataRegion) + except: + self.print_log() self._discoverDataCoordinatesField() self._discoverMarkerGroup() @@ -695,17 +698,21 @@ def assignDeformationPenalties(self, fitterStepFit: FitterStepFit): coordinatesCount = self._modelCoordinatesField.getNumberOfComponents() strainComponents = meshDimension * meshDimension curvatureComponents = coordinatesCount * meshDimension * meshDimension - groups = [] + groupInfoList = [] # add None for default group for group in (getGroupList(self._fieldmodule) + [None]): if group: + if not group.isManaged(): + continue # unmanaged = internal, ignore meshGroup = group.getMeshGroup(mesh) - if (not meshGroup.isValid()) or (meshGroup.getSize() == 0): + meshGroupSize = meshGroup.getSize() + if (not meshGroup.isValid()) or (meshGroupSize == 0): continue groupName = group.getName() else: meshGroup = None groupName = None + meshGroupSize = mesh.getSize() groupStrainPenalty, setLocally, inheritable = \ fitterStepFit.getGroupStrainPenalty(groupName, strainComponents) groupStrainPenaltyNonZero = any((s > 0.0) for s in groupStrainPenalty) @@ -714,8 +721,14 @@ def assignDeformationPenalties(self, fitterStepFit: FitterStepFit): fitterStepFit.getGroupCurvaturePenalty(groupName, curvatureComponents) groupCurvaturePenaltyNonZero = any((s > 0.0) for s in groupCurvaturePenalty) groupCurvatureSet = setLocally or ((setLocally is False) and inheritable) - groups.append((group, groupName, meshGroup, groupStrainPenalty, groupStrainPenaltyNonZero, groupStrainSet, - groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet)) + groupInfo = (group, groupName, meshGroup, meshGroupSize, groupStrainPenalty, groupStrainPenaltyNonZero, + groupStrainSet, groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet) + for index, tmpGroupInfo in enumerate(groupInfoList): + if meshGroupSize < tmpGroupInfo[3]: + groupInfoList.insert(index, groupInfo) + break + else: + groupInfoList.append(groupInfo) with ChangeManager(self._fieldmodule): self._deformActiveMeshGroup.removeAllElements() self._strainActiveMeshGroup.removeAllElements() @@ -729,8 +742,9 @@ def assignDeformationPenalties(self, fitterStepFit: FitterStepFit): strainPenaltyNonZero = False curvaturePenalty = None curvaturePenaltyNonZero = False - for (group, groupName, meshGroup, groupStrainPenalty, groupStrainPenaltyNonZero, groupStrainSet, - groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet) in groups: + for (group, groupName, meshGroup, meshGroupSize, groupStrainPenalty, groupStrainPenaltyNonZero, + groupStrainSet, groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet) \ + in groupInfoList: if (not group) or meshGroup.containsElement(element): if (not strainPenalty) and (groupStrainSet or (not group)): strainPenalty = groupStrainPenalty diff --git a/tests/resources/group_test_line10.exf b/tests/resources/group_test_line10.exf new file mode 100644 index 0000000..5238dfc --- /dev/null +++ b/tests/resources/group_test_line10.exf @@ -0,0 +1,126 @@ +EX Version: 3 +Region: / +!#nodeset nodes +Define node template: node1 +Shape. Dimension=0 +#Fields=1 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. #Values=2 (value,d/ds1) + y. #Values=2 (value,d/ds1) + z. #Values=2 (value,d/ds1) +Node template: node1 +Node: 1 + 0.000000000000000e+00 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +Node: 2 + 1.000000000000000e+00 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +Node: 3 + 2.000000000000000e+00 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +Node: 4 + 3.000000000000000e+00 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +Node: 5 + 4.000000000000000e+00 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +Node: 6 + 5.000000000000000e+00 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +Node: 7 + 6.000000000000000e+00 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +Node: 8 + 7.000000000000000e+00 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +Node: 9 + 8.000000000000000e+00 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +Node: 10 + 9.000000000000000e+00 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +Node: 11 + 1.000000000000000e+01 1.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 + 0.000000000000000e+00 0.000000000000000e+00 +!#mesh mesh1d, dimension=1, nodeset=nodes +Define element template: element1 +Shape. Dimension=1, line +#Scale factor sets=0 +#Nodes=2 +#Fields=1 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. c.Hermite, no modify, standard node based. + #Nodes=2 + 1. #Values=2 + Value labels: value d/ds1 + 2. #Values=2 + Value labels: value d/ds1 + y. c.Hermite, no modify, standard node based. + #Nodes=2 + 1. #Values=2 + Value labels: value d/ds1 + 2. #Values=2 + Value labels: value d/ds1 + z. c.Hermite, no modify, standard node based. + #Nodes=2 + 1. #Values=2 + Value labels: value d/ds1 + 2. #Values=2 + Value labels: value d/ds1 +Element template: element1 +Element: 1 + Nodes: + 1 2 +Element: 2 + Nodes: + 2 3 +Element: 3 + Nodes: + 3 4 +Element: 4 + Nodes: + 4 5 +Element: 5 + Nodes: + 5 6 +Element: 6 + Nodes: + 6 7 +Element: 7 + Nodes: + 7 8 +Element: 8 + Nodes: + 8 9 +Element: 9 + Nodes: + 9 10 +Element: 10 + Nodes: + 10 11 +Group name: all +!#nodeset nodes +Node group: +1..11 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +1..10 +Group name: marker +Group name: part +!#nodeset nodes +Node group: +6..9 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +6..8 diff --git a/tests/test_general.py b/tests/test_general.py index 66d72cf..ca36c7a 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -1,8 +1,5 @@ -import logging -import os -import sys -import unittest -from cmlibs.utils.zinc.field import createFieldMeshIntegral +from cmlibs.utils.zinc.field import ( + create_field_mesh_integral, find_or_create_field_coordinates, find_or_create_field_group) from cmlibs.utils.zinc.general import ChangeManager from cmlibs.utils.zinc.region import copy_fitting_data, read_from_buffer, write_to_buffer from cmlibs.zinc.context import Context @@ -12,6 +9,12 @@ from scaffoldfitter.fitterstepalign import FitterStepAlign from scaffoldfitter.fitterstepconfig import FitterStepConfig from scaffoldfitter.fitterstepfit import FitterStepFit +import logging +import math +import os +import sys +import unittest + here = os.path.abspath(os.path.dirname(__file__)) @@ -74,7 +77,7 @@ def test_fit_1d_outliers(self): # check number of active data points and length of fitted model activeDataNodeset = fitter.getActiveDataNodesetGroup() self.assertEqual(activeDataNodeset.getSize(), expectedActiveDataSize) - lengthField = createFieldMeshIntegral(coordinates, fitter.getMesh(1), number_of_points=4) + lengthField = create_field_mesh_integral(coordinates, fitter.getMesh(1), number_of_points=4) self.assertTrue(lengthField.isValid()) fieldcache = fieldmodule.createFieldcache() result, length = lengthField.evaluateReal(fieldcache, 1) @@ -302,5 +305,96 @@ def test_fit_projection_subgroup(self): self.assertEqual(minElementIdentifier, 102) self.assertAlmostEqual(minJacobian, 0.8073661446227143, delta=TOL) + def test_group_settings(self): + """ + Test that curvature settings on parts of the model are appropriately used. + :return: + """ + zinc_model_file_name = os.path.join(here, "resources", "group_test_line10.exf") + + context = Context("Scaffoldfitter test group settings") + region = context.getDefaultRegion() + fitter = Fitter(region=region) + fitter.setDiagnosticLevel(2) + + region.readFile(zinc_model_file_name) + fieldmodule = region.getFieldmodule() + mesh1d = fieldmodule.findMeshByDimension(1) + self.assertEqual(mesh1d.getSize(), 10) + coordinates = fieldmodule.findFieldByName("coordinates").castFiniteElement() + self.assertEqual(coordinates.getNumberOfComponents(), 3) + fitter.setModelCoordinatesField(coordinates) + all_group = fieldmodule.findFieldByName("all").castGroup() + part_group = fieldmodule.findFieldByName("part").castGroup() + zero_fibres = fieldmodule.createFieldConstant([0.0, 0.0, 0.0]) + zero_fibres.setName("zero fibres") + zero_fibres.setManaged(True) + fitter.setFibreField(zero_fibres) + fitter.defineCommonMeshFields() + + data_region = region.createChild("raw_data") + make_test_group_settings_data(data_region.getFieldmodule()) + copy_fitting_data(region, data_region) + fitter.setDataCoordinatesField(coordinates) + fitter.defineDataProjectionFields() + fitter.initializeFit() + + fit = FitterStepFit() + fitter.addFitterStep(fit) + fit.setGroupCurvaturePenalty(None, [0.01]) + fit.setGroupCurvaturePenalty("part", [10.0]) + fit.run() + + # check correct curvature penalties are applied per element: the smaller "part" group penalties take precedence + fieldcache = fieldmodule.createFieldcache() + curvature_penalty = fieldmodule.findFieldByName("curvature_penalty") + self.assertTrue(curvature_penalty.isValid()) + expected_curvature_penalties = [0.01, 0.01, 0.01, 0.01, 0.01, 10.0, 10.0, 10.0, 0.01, 0.01] + expected_coordinates = [ + [0.5, 0.4780510817192079, 0.061844877202773374], + [1.5, -0.4798844161908709, 0.17041374538128035], + [2.5, 0.4698595597425159, 0.23730376789946722], + [3.5, -0.5124604803852486, 0.24629542592567483], + [4.5, 0.3546050329712843, 0.19562732940104102], + [5.5, -0.08683610995351383, 0.09241387186720304], + [6.5, 0.033064431519171385, -0.026357114228666476], + [7.5, -0.08741464996124625, -0.1395236301319323], + [8.5, 0.35829911353720506, -0.22552239670783675], + [9.5, -0.49872237642533174, -0.2500220113212905]] + for e in range(10): + element = mesh1d.findElementByIdentifier(e + 1) + fieldcache.setMeshLocation(element, [0.5]) + result, cp = curvature_penalty.evaluateReal(fieldcache, 3) + self.assertEqual(result, RESULT_OK) + result, x = coordinates.evaluateReal(fieldcache, 3) + self.assertEqual(result, RESULT_OK) + for c in range(3): + self.assertAlmostEqual(cp[c], expected_curvature_penalties[e], delta=1.0E-8) + self.assertAlmostEqual(x[c], expected_coordinates[e][c], delta=1.0E-8) + + +def make_test_group_settings_data(fieldmodule): + """ + Make sinusoidal data centred on the x-axis from 0.0 to 10.0. + :param fieldmodule: Field module to create node points in, in group 'all'. + """ + with ChangeManager(fieldmodule): + coordinates = find_or_create_field_coordinates(fieldmodule) + nodes = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + all_group = find_or_create_field_group(fieldmodule, 'all') + all_nodes = all_group.createNodesetGroup(nodes) + + nodetemplate = nodes.createNodetemplate() + nodetemplate.defineField(coordinates) + fieldcache = fieldmodule.createFieldcache() + for n in range(1001): + x = 0.01 * n + y = 0.5 * math.sin(x * math.pi) + z = 0.25 * math.sin(0.5 * x) + node = all_nodes.createNode(n + 1, nodetemplate) + fieldcache.setNode(node) + coordinates.assignReal(fieldcache, [x, y, z]) + + if __name__ == "__main__": unittest.main()