Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/scaffoldfitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,10 @@ def _loadData(self):
with ChangeManager(data_fieldmodule):
match_fitting_group_names(data_fieldmodule, self._fieldmodule,
log_diagnostics=self.getDiagnosticLevel() > 0)
copy_fitting_data(self._region, self._rawDataRegion)
try:
copy_fitting_data(self._region, self._rawDataRegion)
except:
self.print_log()
self._discoverDataCoordinatesField()
self._discoverMarkerGroup()

Expand Down Expand Up @@ -695,17 +698,21 @@ def assignDeformationPenalties(self, fitterStepFit: FitterStepFit):
coordinatesCount = self._modelCoordinatesField.getNumberOfComponents()
strainComponents = meshDimension * meshDimension
curvatureComponents = coordinatesCount * meshDimension * meshDimension
groups = []
groupInfoList = []
# add None for default group
for group in (getGroupList(self._fieldmodule) + [None]):
if group:
if not group.isManaged():
continue # unmanaged = internal, ignore
meshGroup = group.getMeshGroup(mesh)
if (not meshGroup.isValid()) or (meshGroup.getSize() == 0):
meshGroupSize = meshGroup.getSize()
if (not meshGroup.isValid()) or (meshGroupSize == 0):
continue
groupName = group.getName()
else:
meshGroup = None
groupName = None
meshGroupSize = mesh.getSize()
groupStrainPenalty, setLocally, inheritable = \
fitterStepFit.getGroupStrainPenalty(groupName, strainComponents)
groupStrainPenaltyNonZero = any((s > 0.0) for s in groupStrainPenalty)
Expand All @@ -714,8 +721,14 @@ def assignDeformationPenalties(self, fitterStepFit: FitterStepFit):
fitterStepFit.getGroupCurvaturePenalty(groupName, curvatureComponents)
groupCurvaturePenaltyNonZero = any((s > 0.0) for s in groupCurvaturePenalty)
groupCurvatureSet = setLocally or ((setLocally is False) and inheritable)
groups.append((group, groupName, meshGroup, groupStrainPenalty, groupStrainPenaltyNonZero, groupStrainSet,
groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet))
groupInfo = (group, groupName, meshGroup, meshGroupSize, groupStrainPenalty, groupStrainPenaltyNonZero,
groupStrainSet, groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet)
for index, tmpGroupInfo in enumerate(groupInfoList):
if meshGroupSize < tmpGroupInfo[3]:
groupInfoList.insert(index, groupInfo)
break
else:
groupInfoList.append(groupInfo)
with ChangeManager(self._fieldmodule):
self._deformActiveMeshGroup.removeAllElements()
self._strainActiveMeshGroup.removeAllElements()
Expand All @@ -729,8 +742,9 @@ def assignDeformationPenalties(self, fitterStepFit: FitterStepFit):
strainPenaltyNonZero = False
curvaturePenalty = None
curvaturePenaltyNonZero = False
for (group, groupName, meshGroup, groupStrainPenalty, groupStrainPenaltyNonZero, groupStrainSet,
groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet) in groups:
for (group, groupName, meshGroup, meshGroupSize, groupStrainPenalty, groupStrainPenaltyNonZero,
groupStrainSet, groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet) \
in groupInfoList:
if (not group) or meshGroup.containsElement(element):
if (not strainPenalty) and (groupStrainSet or (not group)):
strainPenalty = groupStrainPenalty
Expand Down
126 changes: 126 additions & 0 deletions tests/resources/group_test_line10.exf
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
EX Version: 3
Region: /
!#nodeset nodes
Define node template: node1
Shape. Dimension=0
#Fields=1
1) coordinates, coordinate, rectangular cartesian, real, #Components=3
x. #Values=2 (value,d/ds1)
y. #Values=2 (value,d/ds1)
z. #Values=2 (value,d/ds1)
Node template: node1
Node: 1
0.000000000000000e+00 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
Node: 2
1.000000000000000e+00 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
Node: 3
2.000000000000000e+00 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
Node: 4
3.000000000000000e+00 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
Node: 5
4.000000000000000e+00 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
Node: 6
5.000000000000000e+00 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
Node: 7
6.000000000000000e+00 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
Node: 8
7.000000000000000e+00 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
Node: 9
8.000000000000000e+00 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
Node: 10
9.000000000000000e+00 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
Node: 11
1.000000000000000e+01 1.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
0.000000000000000e+00 0.000000000000000e+00
!#mesh mesh1d, dimension=1, nodeset=nodes
Define element template: element1
Shape. Dimension=1, line
#Scale factor sets=0
#Nodes=2
#Fields=1
1) coordinates, coordinate, rectangular cartesian, real, #Components=3
x. c.Hermite, no modify, standard node based.
#Nodes=2
1. #Values=2
Value labels: value d/ds1
2. #Values=2
Value labels: value d/ds1
y. c.Hermite, no modify, standard node based.
#Nodes=2
1. #Values=2
Value labels: value d/ds1
2. #Values=2
Value labels: value d/ds1
z. c.Hermite, no modify, standard node based.
#Nodes=2
1. #Values=2
Value labels: value d/ds1
2. #Values=2
Value labels: value d/ds1
Element template: element1
Element: 1
Nodes:
1 2
Element: 2
Nodes:
2 3
Element: 3
Nodes:
3 4
Element: 4
Nodes:
4 5
Element: 5
Nodes:
5 6
Element: 6
Nodes:
6 7
Element: 7
Nodes:
7 8
Element: 8
Nodes:
8 9
Element: 9
Nodes:
9 10
Element: 10
Nodes:
10 11
Group name: all
!#nodeset nodes
Node group:
1..11
!#mesh mesh1d, dimension=1, nodeset=nodes
Element group:
1..10
Group name: marker
Group name: part
!#nodeset nodes
Node group:
6..9
!#mesh mesh1d, dimension=1, nodeset=nodes
Element group:
6..8
106 changes: 100 additions & 6 deletions tests/test_general.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import logging
import os
import sys
import unittest
from cmlibs.utils.zinc.field import createFieldMeshIntegral
from cmlibs.utils.zinc.field import (
create_field_mesh_integral, find_or_create_field_coordinates, find_or_create_field_group)
from cmlibs.utils.zinc.general import ChangeManager
from cmlibs.utils.zinc.region import copy_fitting_data, read_from_buffer, write_to_buffer
from cmlibs.zinc.context import Context
Expand All @@ -12,6 +9,12 @@
from scaffoldfitter.fitterstepalign import FitterStepAlign
from scaffoldfitter.fitterstepconfig import FitterStepConfig
from scaffoldfitter.fitterstepfit import FitterStepFit
import logging
import math
import os
import sys
import unittest


here = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -74,7 +77,7 @@ def test_fit_1d_outliers(self):
# check number of active data points and length of fitted model
activeDataNodeset = fitter.getActiveDataNodesetGroup()
self.assertEqual(activeDataNodeset.getSize(), expectedActiveDataSize)
lengthField = createFieldMeshIntegral(coordinates, fitter.getMesh(1), number_of_points=4)
lengthField = create_field_mesh_integral(coordinates, fitter.getMesh(1), number_of_points=4)
self.assertTrue(lengthField.isValid())
fieldcache = fieldmodule.createFieldcache()
result, length = lengthField.evaluateReal(fieldcache, 1)
Expand Down Expand Up @@ -302,5 +305,96 @@ def test_fit_projection_subgroup(self):
self.assertEqual(minElementIdentifier, 102)
self.assertAlmostEqual(minJacobian, 0.8073661446227143, delta=TOL)

def test_group_settings(self):
"""
Test that curvature settings on parts of the model are appropriately used.
:return:
"""
zinc_model_file_name = os.path.join(here, "resources", "group_test_line10.exf")

context = Context("Scaffoldfitter test group settings")
region = context.getDefaultRegion()
fitter = Fitter(region=region)
fitter.setDiagnosticLevel(2)

region.readFile(zinc_model_file_name)
fieldmodule = region.getFieldmodule()
mesh1d = fieldmodule.findMeshByDimension(1)
self.assertEqual(mesh1d.getSize(), 10)
coordinates = fieldmodule.findFieldByName("coordinates").castFiniteElement()
self.assertEqual(coordinates.getNumberOfComponents(), 3)
fitter.setModelCoordinatesField(coordinates)
all_group = fieldmodule.findFieldByName("all").castGroup()
part_group = fieldmodule.findFieldByName("part").castGroup()
zero_fibres = fieldmodule.createFieldConstant([0.0, 0.0, 0.0])
zero_fibres.setName("zero fibres")
zero_fibres.setManaged(True)
fitter.setFibreField(zero_fibres)
fitter.defineCommonMeshFields()

data_region = region.createChild("raw_data")
make_test_group_settings_data(data_region.getFieldmodule())
copy_fitting_data(region, data_region)
fitter.setDataCoordinatesField(coordinates)
fitter.defineDataProjectionFields()
fitter.initializeFit()

fit = FitterStepFit()
fitter.addFitterStep(fit)
fit.setGroupCurvaturePenalty(None, [0.01])
fit.setGroupCurvaturePenalty("part", [10.0])
fit.run()

# check correct curvature penalties are applied per element: the smaller "part" group penalties take precedence
fieldcache = fieldmodule.createFieldcache()
curvature_penalty = fieldmodule.findFieldByName("curvature_penalty")
self.assertTrue(curvature_penalty.isValid())
expected_curvature_penalties = [0.01, 0.01, 0.01, 0.01, 0.01, 10.0, 10.0, 10.0, 0.01, 0.01]
expected_coordinates = [
[0.5, 0.4780510817192079, 0.061844877202773374],
[1.5, -0.4798844161908709, 0.17041374538128035],
[2.5, 0.4698595597425159, 0.23730376789946722],
[3.5, -0.5124604803852486, 0.24629542592567483],
[4.5, 0.3546050329712843, 0.19562732940104102],
[5.5, -0.08683610995351383, 0.09241387186720304],
[6.5, 0.033064431519171385, -0.026357114228666476],
[7.5, -0.08741464996124625, -0.1395236301319323],
[8.5, 0.35829911353720506, -0.22552239670783675],
[9.5, -0.49872237642533174, -0.2500220113212905]]
for e in range(10):
element = mesh1d.findElementByIdentifier(e + 1)
fieldcache.setMeshLocation(element, [0.5])
result, cp = curvature_penalty.evaluateReal(fieldcache, 3)
self.assertEqual(result, RESULT_OK)
result, x = coordinates.evaluateReal(fieldcache, 3)
self.assertEqual(result, RESULT_OK)
for c in range(3):
self.assertAlmostEqual(cp[c], expected_curvature_penalties[e], delta=1.0E-8)
self.assertAlmostEqual(x[c], expected_coordinates[e][c], delta=1.0E-8)


def make_test_group_settings_data(fieldmodule):
"""
Make sinusoidal data centred on the x-axis from 0.0 to 10.0.
:param fieldmodule: Field module to create node points in, in group 'all'.
"""
with ChangeManager(fieldmodule):
coordinates = find_or_create_field_coordinates(fieldmodule)
nodes = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES)
all_group = find_or_create_field_group(fieldmodule, 'all')
all_nodes = all_group.createNodesetGroup(nodes)

nodetemplate = nodes.createNodetemplate()
nodetemplate.defineField(coordinates)
fieldcache = fieldmodule.createFieldcache()
for n in range(1001):
x = 0.01 * n
y = 0.5 * math.sin(x * math.pi)
z = 0.25 * math.sin(0.5 * x)
node = all_nodes.createNode(n + 1, nodetemplate)
fieldcache.setNode(node)
coordinates.assignReal(fieldcache, [x, y, z])


if __name__ == "__main__":
unittest.main()