diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 18b2061e1..013ba7805 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,3 +40,6 @@ jobs: run: DATA_REPO_GIT="" python -m pytest --cov=FIAT/ test/FIAT - name: Test FInAT run: DATA_REPO_GIT="" python -m pytest --cov=finat/ --cov=gem/ test/finat + - name: Test FInAT with FUSE + run: | + FIREDRAKE_USE_FUSE=1 DATA_REPO_GIT="" python -m pytest test/finat diff --git a/FIAT/discontinuous_pc.py b/FIAT/discontinuous_pc.py index dbb566906..8973a58f6 100644 --- a/FIAT/discontinuous_pc.py +++ b/FIAT/discontinuous_pc.py @@ -7,29 +7,17 @@ # Modified by David A. Ham (david.ham@imperial.ac.uk), 2018 from FIAT import finite_element, polynomial_set, dual_set, functional -from FIAT.reference_element import (Point, - DefaultLine, - UFCInterval, - UFCQuadrilateral, - UFCHexahedron, - UFCTriangle, - UFCTetrahedron, - make_affine_mapping, - flatten_reference_cube) +from FIAT.reference_element import (make_affine_mapping, + flatten_reference_cube, + cell_to_simplex) from FIAT.P0 import P0Dual import numpy as np -hypercube_simplex_map = {Point(): Point(), - DefaultLine(): DefaultLine(), - UFCInterval(): UFCInterval(), - UFCQuadrilateral(): UFCTriangle(), - UFCHexahedron(): UFCTetrahedron()} - class DPC0(finite_element.CiarletElement): def __init__(self, ref_el): flat_el = flatten_reference_cube(ref_el) - poly_set = polynomial_set.ONPolynomialSet(hypercube_simplex_map[flat_el], 0) + poly_set = polynomial_set.ONPolynomialSet(cell_to_simplex(flat_el), 0) dual = P0Dual(ref_el) # Implement entity_permutations when we handle that for HigherOrderDPC. # Currently, orientation_tuples in P0Dual(ref_el).entity_permutations @@ -58,7 +46,7 @@ def __init__(self, ref_el, flat_el, degree): # Change coordinates here. # Vertices of the simplex corresponding to the reference element. - v_simplex = hypercube_simplex_map[flat_el].get_vertices() + v_simplex = cell_to_simplex(flat_el).get_vertices() # Vertices of the reference element. v_hypercube = flat_el.get_vertices() # For the mapping, first two vertices are unchanged in all dimensions. @@ -74,12 +62,12 @@ def __init__(self, ref_el, flat_el, degree): # make nodes by getting points # need to do this dimension-by-dimension, facet-by-facet - top = hypercube_simplex_map[flat_el].get_topology() + top = cell_to_simplex(flat_el).get_topology() cur = 0 for dim in sorted(top): for entity in sorted(top[dim]): - pts_cur = hypercube_simplex_map[flat_el].make_points(dim, entity, degree) + pts_cur = cell_to_simplex(flat_el).make_points(dim, entity, degree) pts_cur = [tuple(np.matmul(A, np.array(x)) + b) for x in pts_cur] nodes_cur = [functional.PointEvaluation(flat_el, x) for x in pts_cur] @@ -102,7 +90,7 @@ class HigherOrderDPC(finite_element.CiarletElement): def __init__(self, ref_el, degree): flat_el = flatten_reference_cube(ref_el) - poly_set = polynomial_set.ONPolynomialSet(hypercube_simplex_map[flat_el], degree) + poly_set = polynomial_set.ONPolynomialSet(cell_to_simplex(flat_el), degree) dual = DPCDualSet(ref_el, flat_el, degree) formdegree = flat_el.get_spatial_dimension() # n-form super().__init__(poly_set=poly_set, diff --git a/FIAT/lagrange.py b/FIAT/lagrange.py index eead0cfa6..dab845446 100644 --- a/FIAT/lagrange.py +++ b/FIAT/lagrange.py @@ -40,7 +40,7 @@ def __init__(self, ref_el, degree, point_variant="equispaced", sort_entities=Fal entities = [(dim, entity) for dim in sorted(top) for entity in sorted(top[dim])] if sort_entities: # sort the entities by support vertex ids - support = [top[dim][entity] for dim, entity in entities] + support = [sorted(top[dim][entity]) for dim, entity in entities] entities = [entity for verts, entity in sorted(zip(support, entities))] # make nodes by getting points diff --git a/FIAT/nedelec.py b/FIAT/nedelec.py index 192507116..9b338f80c 100644 --- a/FIAT/nedelec.py +++ b/FIAT/nedelec.py @@ -22,7 +22,7 @@ def NedelecSpace2D(ref_el, degree): raise Exception("NedelecSpace2D requires 2d reference element") k = degree - 1 - vec_Pkp1 = polynomial_set.ONPolynomialSet(ref_el, k + 1, (sd,)) + vec_Pkp1 = polynomial_set.ONPolynomialSet(ref_el, k + 1, (sd,), scale="orthonormal") dimPkp1 = expansions.polynomial_dimension(ref_el, k + 1) dimPk = expansions.polynomial_dimension(ref_el, k) @@ -32,7 +32,7 @@ def NedelecSpace2D(ref_el, degree): for i in range(sd)))) vec_Pk_from_Pkp1 = vec_Pkp1.take(vec_Pk_indices) - Pkp1 = polynomial_set.ONPolynomialSet(ref_el, k + 1) + Pkp1 = polynomial_set.ONPolynomialSet(ref_el, k + 1, scale="orthonormal") PkH = Pkp1.take(list(range(dimPkm1, dimPk))) Q = create_quadrature(ref_el, 2 * (k + 1)) diff --git a/FIAT/quadrature.py b/FIAT/quadrature.py index 51e4f45e2..a0837e25b 100644 --- a/FIAT/quadrature.py +++ b/FIAT/quadrature.py @@ -196,10 +196,12 @@ class CollapsedQuadratureTetrahedronRule(CollapsedQuadratureSimplexRule): class FacetQuadratureRule(QuadratureRule): """A quadrature rule on a facet mapped from a reference quadrature rule. """ - def __init__(self, ref_el, entity_dim, entity_id, Q_ref): + def __init__(self, ref_el, entity_dim, entity_id, Q_ref, facet_orientation=None): # Construct the facet of interest facet = ref_el.construct_subelement(entity_dim) facet_topology = ref_el.get_topology()[entity_dim][entity_id] + if facet_orientation: + facet_topology = tuple(facet_orientation.permute(list(facet_topology))) facet.vertices = ref_el.get_vertices_of_subcomplex(facet_topology) # Map reference points and weights on the appropriate facet diff --git a/FIAT/raviart_thomas.py b/FIAT/raviart_thomas.py index 80ac13d9a..e0ce5de4a 100644 --- a/FIAT/raviart_thomas.py +++ b/FIAT/raviart_thomas.py @@ -20,7 +20,7 @@ def RTSpace(ref_el, degree): sd = ref_el.get_spatial_dimension() k = degree - 1 - vec_Pkp1 = polynomial_set.ONPolynomialSet(ref_el, k + 1, (sd,)) + vec_Pkp1 = polynomial_set.ONPolynomialSet(ref_el, k + 1, (sd,), scale="orthonormal") dimPkp1 = expansions.polynomial_dimension(ref_el, k + 1) dimPk = expansions.polynomial_dimension(ref_el, k) @@ -30,7 +30,7 @@ def RTSpace(ref_el, degree): for i in range(sd)))) vec_Pk_from_Pkp1 = vec_Pkp1.take(vec_Pk_indices) - Pkp1 = polynomial_set.ONPolynomialSet(ref_el, k + 1) + Pkp1 = polynomial_set.ONPolynomialSet(ref_el, k + 1, scale="orthonormal") PkH = Pkp1.take(list(range(dimPkm1, dimPk))) Q = create_quadrature(ref_el, 2 * (k + 1)) diff --git a/FIAT/reference_element.py b/FIAT/reference_element.py index 87794b506..321b3cad1 100644 --- a/FIAT/reference_element.py +++ b/FIAT/reference_element.py @@ -133,7 +133,7 @@ class Cell: """Abstract class for a reference cell. Provides accessors for geometry (vertex coordinates) as well as topology (orderings of vertices that make up edges, faces, etc.""" - def __init__(self, shape, vertices, topology): + def __init__(self, shape, vertices, topology, sub_entities=None): """The constructor takes a shape code, the physical vertices expressed as a list of tuples of numbers, and the topology of a cell. @@ -145,24 +145,42 @@ def __init__(self, shape, vertices, topology): self.vertices = vertices self.topology = topology - # Given the topology, work out for each entity in the cell, - # which other entities it contains. - self.sub_entities = {} - for dim, entities in topology.items(): - self.sub_entities[dim] = {} - - for e, v in entities.items(): - vertices = frozenset(v) - sub_entities = [] - - for dim_, entities_ in topology.items(): - for e_, vertices_ in entities_.items(): - if vertices.issuperset(vertices_): - sub_entities.append((dim_, e_)) - - # Sort for the sake of determinism and by UFC conventions - self.sub_entities[dim][e] = sorted(sub_entities) - + if sub_entities: + self.sub_entities = sub_entities + else: + # If sub entity list not provided + # Given the topology, work out for each entity in the cell, + # which other entities it contains. + self.sub_entities = {} + self.sub_entities_old = {} + for dim, entities in topology.items(): + self.sub_entities[dim] = {} + self.sub_entities_old[dim] = {} + + for e, v in entities.items(): + vertices = frozenset(v) + sub_entities = [] + sub_entities_old = [] + for dim_, entities_ in topology.items(): + for e_, vertices_ in entities_.items(): + if vertices.issuperset(vertices_): + sub_entities_old.append((dim_, e_)) + + # in order to maintain ordering, extract subentities from vertex numbering + entities_of_dim_ = list(entities_.values()) + + from itertools import permutations + # generate all possible sub entities + sub_list = permutations(v, len(entities_of_dim_[0])) + for s in sub_list: + # add the sub entities in the same order as in topology + for i, val in entities_.items(): + if set(s) == set(val) and (dim_, i) not in sub_entities: + sub_entities.append((dim_, i)) + + self.sub_entities[dim][e] = list(sub_entities) + self.sub_entities_old[dim][e] = list(sub_entities_old) + self.sub_entities = self.sub_entities_old # Build super-entity dictionary by inverting the sub-entity dictionary self.super_entities = {dim: {entity: [] for entity in topology[dim]} for dim in topology} for dim0 in topology: @@ -183,7 +201,6 @@ def __init__(self, shape, vertices, topology): neighbors = children if dim1 < dim0 else parents d01_entities = tuple(e for d, e in neighbors if d == dim1) self.connectivity[(dim0, dim1)].append(d01_entities) - # Dictionary with derived cells self._split_cache = {} @@ -387,14 +404,14 @@ class SimplicialComplex(Cell): This consists of list of vertex locations and a topology map defining facets. """ - def __init__(self, shape, vertices, topology): + def __init__(self, shape, vertices, topology, sub_ents=None): # Make sure that every facet has the right number of vertices to be # a simplex. for dim in topology: for entity in topology[dim]: assert len(topology[dim][entity]) == dim + 1 - super().__init__(shape, vertices, topology) + super().__init__(shape, vertices, topology, sub_ents) def compute_normal(self, facet_i, cell=None): """Returns the unit normal vector to facet i of codimension 1.""" @@ -528,7 +545,8 @@ def make_points(self, dim, entity_id, order, variant=None, interior=1): facet of dimension dim. Order indicates how many points to include in each direction.""" if dim == 0: - return (self.get_vertices()[entity_id], ) + return (self.get_vertices()[self.get_topology()[dim][entity_id][0]],) + # return (self.get_vertices()[entity_id], ) elif 0 < dim <= self.get_spatial_dimension(): entity_verts = \ self.get_vertices_of_subcomplex( @@ -1144,7 +1162,7 @@ def compute_normal(self, i): class TensorProductCell(Cell): """A cell that is the product of FIAT cells.""" - def __init__(self, *cells): + def __init__(self, *cells, sub_entities=None): # Vertices vertices = tuple(tuple(chain(*coords)) for coords in product(*[cell.get_vertices() @@ -1167,7 +1185,7 @@ def __init__(self, *cells): topology[dim] = dict(enumerate(topology[dim][key] for key in sorted(topology[dim]))) - super().__init__(TENSORPRODUCT, vertices, topology) + super().__init__(TENSORPRODUCT, vertices, topology, sub_entities) self.cells = tuple(cells) def __repr__(self): @@ -1317,10 +1335,14 @@ def compare(self, op, other): This is done dimension by dimension.""" if hasattr(other, "product"): other = other.product - if isinstance(other, type(self)): + if isinstance(other, TensorProductCell): return all(op(a, b) for a, b in zip(self.cells, other.cells)) else: - return op(self, other) + if op == operator.gt or op == operator.lt or operator.ne: + return not op(other, self) + if op == operator.ge or op == operator.le: + return not op(other, self) or operator.eq(self, other) + raise ValueError("Unknown operator in cell comparison") def __gt__(self, other): return self.compare(operator.gt, other) @@ -1410,7 +1432,7 @@ def is_macrocell(self): class Hypercube(Cell): """Abstract class for a reference hypercube""" - def __init__(self, dimension, product): + def __init__(self, dimension, product, sub_entities=None): self.dimension = dimension self.shape = hypercube_shapes[dimension] @@ -1418,7 +1440,7 @@ def __init__(self, dimension, product): verts = product.get_vertices() topology = flatten_entities(pt) - super().__init__(self.shape, verts, topology) + super().__init__(self.shape, verts, topology, sub_entities) self.product = product self.unflattening_map = compute_unflattening_map(pt) @@ -1845,3 +1867,10 @@ def max_complex(complexes): return max_cell else: raise ValueError("Cannot find the maximal complex") + + +def cell_to_simplex(cell): + if cell.is_simplex(): + return cell + else: + return ufc_simplex(cell.get_dimension()) diff --git a/finat/__init__.py b/finat/__init__.py index e531ea695..b9d49cdb3 100644 --- a/finat/__init__.py +++ b/finat/__init__.py @@ -8,7 +8,8 @@ Lagrange, Real, Serendipity, # noqa: F401 TrimmedSerendipityCurl, TrimmedSerendipityDiv, # noqa: F401 TrimmedSerendipityEdge, TrimmedSerendipityFace, # noqa: F401 - Nedelec, NedelecSecondKind, RaviartThomas, Regge) # noqa: F401 + Nedelec, NedelecSecondKind, RaviartThomas, Regge, # noqa: F401 + FuseElement) # noqa: F401 from .spectral import (GaussLobattoLegendre, GaussLegendre, KongMulderVeldhuizen, # noqa: F401 Legendre, IntegratedLegendre, # noqa: F401 FDMLagrange, FDMQuadrature, FDMDiscontinuousLagrange, # noqa: F401 diff --git a/finat/cube.py b/finat/cube.py index b1517e03f..7c85947ce 100644 --- a/finat/cube.py +++ b/finat/cube.py @@ -1,12 +1,11 @@ from __future__ import absolute_import, division, print_function -from FIAT.reference_element import (UFCHexahedron, UFCQuadrilateral, - compute_unflattening_map, flatten_entities, +from FIAT.reference_element import (compute_unflattening_map, flatten_entities, flatten_permutations) from FIAT.tensor_product import FlattenedDimensions as FIAT_FlattenedDimensions from gem.utils import cached_property - from finat.finiteelementbase import FiniteElementBase +from finat.element_factory import as_fiat_cell class FlattenedDimensions(FiniteElementBase): @@ -23,9 +22,9 @@ def __init__(self, element): def cell(self): dim = self.product.cell.get_spatial_dimension() if dim == 2: - return UFCQuadrilateral() + return as_fiat_cell("quadrilateral") elif dim == 3: - return UFCHexahedron() + return as_fiat_cell("hexahedron") else: raise NotImplementedError("Cannot guess cell for spatial dimension %s" % dim) diff --git a/finat/element_factory.py b/finat/element_factory.py index b688ba09e..3ca2ed2aa 100644 --- a/finat/element_factory.py +++ b/finat/element_factory.py @@ -27,6 +27,7 @@ import ufl from FIAT import ufc_cell +from FIAT.reference_element import TensorProductCell __all__ = ("as_fiat_cell", "create_base_element", "create_element", "supported_elements") @@ -110,9 +111,18 @@ def as_fiat_cell(cell): """Convert a ufl cell to a FIAT cell. :arg cell: the :class:`ufl.Cell` to convert.""" + if isinstance(cell, str): + cell = finat.ufl.as_cell(cell) if not isinstance(cell, ufl.AbstractCell): raise ValueError("Expecting a UFL Cell") - return ufc_cell(cell) + if isinstance(cell, ufl.TensorProductCell) and any([hasattr(c, "to_fiat") for c in cell._cells]): + if not all([hasattr(c, "to_fiat") for c in cell._cells]): + raise NotImplementedError("FUSE defined cells cannot be tensor producted with FIAT defined cells") + return TensorProductCell(*[c.to_fiat() for c in cell._cells]) + try: + return cell.to_fiat() + except AttributeError: + return ufc_cell(cell) @singledispatch @@ -320,8 +330,17 @@ def convert_restrictedelement(element, **kwargs): return finat.RestrictedElement(finat_elem, element.restriction_domain()), deps -hexahedron_tpc = ufl.TensorProductCell(ufl.interval, ufl.interval, ufl.interval) -quadrilateral_tpc = ufl.TensorProductCell(ufl.interval, ufl.interval) +@convert.register(finat.ufl.FuseElement) +def convert_fuse_element(element, **kwargs): + if element.triple.flat: + new_elem = element.triple.unflatten() + finat_elem, deps = _create_element(new_elem.to_ufl(), **kwargs) + return finat.FlattenedDimensions(finat_elem), deps + return finat.fiat_elements.FuseElement(element.triple), set() + + +hexahedron_tpc = ufl.TensorProductCell(finat.ufl.as_cell("interval"), finat.ufl.as_cell("interval"), finat.ufl.as_cell("interval")) +quadrilateral_tpc = ufl.TensorProductCell(finat.ufl.as_cell("interval"), finat.ufl.as_cell("interval")) _cache = weakref.WeakKeyDictionary() diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index e78d41bc0..a9312fe57 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -467,3 +467,8 @@ def __init__(self, cell, degree, **kwargs): class NedelecSecondKind(VectorFiatElement): def __init__(self, cell, degree, **kwargs): super().__init__(FIAT.NedelecSecondKind(cell, degree, **kwargs)) + + +class FuseElement(FiatElement): + def __init__(self, triple): + super(FuseElement, self).__init__(triple.to_fiat()) diff --git a/finat/ufl/__init__.py b/finat/ufl/__init__.py index 21a7d13d1..1cb7f3dca 100644 --- a/finat/ufl/__init__.py +++ b/finat/ufl/__init__.py @@ -14,8 +14,9 @@ from finat.ufl.brokenelement import BrokenElement # noqa: F401 from finat.ufl.enrichedelement import EnrichedElement, NodalEnrichedElement # noqa: F401 from finat.ufl.finiteelement import FiniteElement # noqa: F401 -from finat.ufl.finiteelementbase import FiniteElementBase # noqa: F401 +from finat.ufl.finiteelementbase import FiniteElementBase, as_cell # noqa: F401 from finat.ufl.hdivcurl import HCurlElement, HDivElement, WithMapping, HDiv, HCurl # noqa: F401 from finat.ufl.mixedelement import MixedElement, TensorElement, VectorElement # noqa: F401 from finat.ufl.restrictedelement import RestrictedElement # noqa: F401 from finat.ufl.tensorproductelement import TensorProductElement # noqa: F401 +from finat.ufl.fuseelement import FuseElement # noqa: F401 diff --git a/finat/ufl/finiteelement.py b/finat/ufl/finiteelement.py index f27942269..c68bf4355 100644 --- a/finat/ufl/finiteelement.py +++ b/finat/ufl/finiteelement.py @@ -11,9 +11,9 @@ # Modified by Massimiliano Leoni, 2016 # Modified by Matthew Scroggs, 2023 -from ufl.cell import TensorProductCell, as_cell +from ufl.cell import TensorProductCell from finat.ufl.elementlist import canonical_element_description, simplices -from finat.ufl.finiteelementbase import FiniteElementBase +from finat.ufl.finiteelementbase import FiniteElementBase, as_cell from ufl.utils.formatting import istr diff --git a/finat/ufl/finiteelementbase.py b/finat/ufl/finiteelementbase.py index 3a84f908e..1b698b056 100644 --- a/finat/ufl/finiteelementbase.py +++ b/finat/ufl/finiteelementbase.py @@ -15,7 +15,7 @@ from hashlib import md5 from ufl import pullback -from ufl.cell import AbstractCell, as_cell +from ufl.cell import AbstractCell, as_cell as as_cell_ufl from ufl.finiteelement import AbstractFiniteElement from ufl.utils.sequences import product @@ -266,3 +266,18 @@ def pullback(self): return pullback.physical_pullback raise ValueError(f"Unsupported mapping: {self.mapping()}") + + +def as_cell(cell: AbstractCell | str | tuple[AbstractCell, ...]) -> AbstractCell: + import os + try: + import fuse + except ModuleNotFoundError: + fuse = None + if isinstance(cell, str) and bool(os.getenv("FIREDRAKE_USE_FUSE", 0)): + if fuse: + return fuse.constructCellComplex(cell) + else: + raise ModuleNotFoundError("FIREDRAKE_USE_FUSE is active but fuse is not installed") + else: + return as_cell_ufl(cell) diff --git a/finat/ufl/fuseelement.py b/finat/ufl/fuseelement.py new file mode 100644 index 000000000..ca79d8494 --- /dev/null +++ b/finat/ufl/fuseelement.py @@ -0,0 +1,44 @@ +"""Element.""" +# -*- coding: utf-8 -*- +# Copyright (C) 2025 India Marsden +# +# SPDX-License-Identifier: LGPL-3.0-or-later + +from finat.ufl.finiteelementbase import FiniteElementBase + + +class FuseElement(FiniteElementBase): + """ + A finite element defined using FUSE. + + TODO: Need to deal with cases where value shape and reference value shape are different + """ + + def __init__(self, triple, cell=None): + self.triple = triple + if not cell: + cell = self.triple.cell.to_ufl() + + degree = self.triple.degree() + self.sobolev_space = self.triple.spaces[1] + super(FuseElement, self).__init__("IT", cell, degree, None, triple.get_value_shape()) + + def __repr__(self): + return repr(self.triple) + + def __str__(self): + return "" % (self.triple.spaces[0], self.triple.cell) + + def mapping(self): + if str(self.sobolev_space) == "HCurl": + return "covariant Piola" + elif str(self.sobolev_space) == "HDiv": + return "contravariant Piola" + else: + return "identity" + + def sobolev_space(self): + return self.triple.spaces[1] + + def reconstruct(self, family=None, cell=None, degree=None, quad_scheme=None, variant=None): + return FuseElement(self.triple, cell=cell) diff --git a/finat/ufl/mixedelement.py b/finat/ufl/mixedelement.py index 20c253535..24fdd7e0a 100644 --- a/finat/ufl/mixedelement.py +++ b/finat/ufl/mixedelement.py @@ -16,7 +16,7 @@ from ufl.cell import CellSequence, as_cell from ufl.domain import MeshSequence from finat.ufl.finiteelement import FiniteElement -from finat.ufl.finiteelementbase import FiniteElementBase +from finat.ufl.finiteelementbase import FiniteElementBase, as_cell from ufl.permutation import compute_indices from ufl.pullback import MixedPullback, SymmetricPullback from ufl.utils.indexflattening import flatten_multiindex, shape_to_strides, unflatten_index diff --git a/finat/ufl/tensorproductelement.py b/finat/ufl/tensorproductelement.py index a580c8c76..8d33d5725 100644 --- a/finat/ufl/tensorproductelement.py +++ b/finat/ufl/tensorproductelement.py @@ -13,8 +13,8 @@ from itertools import chain -from ufl.cell import TensorProductCell, as_cell -from finat.ufl.finiteelementbase import FiniteElementBase +from ufl.cell import TensorProductCell +from finat.ufl.finiteelementbase import FiniteElementBase, as_cell from ufl.sobolevspace import DirectionalSobolevSpace diff --git a/pyproject.toml b/pyproject.toml index e0212994c..9214223b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "scipy", "symengine", "sympy", + "fuse-element @ git+https://github.com/firedrakeproject/fuse.git", ] requires-python = ">=3.10" authors = [ diff --git a/test/FIAT/regression/scripts/getreferencerepo b/test/FIAT/regression/scripts/getreferencerepo index a68e1aa0b..0dc1bfaea 100755 --- a/test/FIAT/regression/scripts/getreferencerepo +++ b/test/FIAT/regression/scripts/getreferencerepo @@ -39,7 +39,7 @@ if [ ! -d "$DATA_DIR" ]; then else pushd $DATA_DIR echo "Found existing reference data repository, pulling new data" - git checkout main + git checkout master if [ $? -ne 0 ]; then echo "Failed to checkout main, check state of reference data directory." exit 1