diff --git a/CMakeLists.txt b/CMakeLists.txt index c45fca3b72..f0a657fe94 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -386,8 +386,15 @@ if (BOUT_GENERATE_FIELDOPS) if (NOT ClangFormat_FOUND) message(FATAL_ERROR "clang-format not found, but you have requested to generate code!") endif() + if (BOUT_ENABLE_RAJA) + set(GEN_LOOP_EXEC "raja") + elseif (BOUT_ENABLE_OPENMP) + set(GEN_LOOP_EXEC "openmp") + else() + set(GEN_LOOP_EXEC "serial") + endif() add_custom_command( OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/src/field/generated_fieldops.cxx - COMMAND ${Python3_EXECUTABLE} gen_fieldops.py --filename generated_fieldops.cxx.tmp + COMMAND ${Python3_EXECUTABLE} gen_fieldops.py --loop-exec ${GEN_LOOP_EXEC} --filename generated_fieldops.cxx.tmp COMMAND ${ClangFormat_BIN} generated_fieldops.cxx.tmp -i COMMAND ${CMAKE_COMMAND} -E rename generated_fieldops.cxx.tmp generated_fieldops.cxx DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/src/field/gen_fieldops.jinja ${CMAKE_CURRENT_SOURCE_DIR}/src/field/gen_fieldops.py diff --git a/include/bout/coordinates_accessor.hxx b/include/bout/coordinates_accessor.hxx index 532351d57a..2376ab5039 100644 --- a/include/bout/coordinates_accessor.hxx +++ b/include/bout/coordinates_accessor.hxx @@ -31,7 +31,7 @@ /// -> If Coordinates data is changed, the cache should be cleared /// by calling CoordinatesAccessor::clear() struct CoordinatesAccessor { - CoordinatesAccessor() = delete; + CoordinatesAccessor() {} /// Constructor from Coordinates /// Copies data from coords, doesn't modify it diff --git a/include/bout/field_accessor.hxx b/include/bout/field_accessor.hxx index 69b58da979..71d0537d9e 100644 --- a/include/bout/field_accessor.hxx +++ b/include/bout/field_accessor.hxx @@ -57,10 +57,17 @@ struct FieldAccessor { /// Constructor from Field3D /// /// @param[in] f The field to access. Must already be allocated - explicit FieldAccessor(FieldType& f) : coords(f.getCoordinates()) { + explicit FieldAccessor(FieldType& f) { ASSERT0(f.getLocation() == location); ASSERT0(f.isAllocated()); + if (auto* Coords = f.getCoordinates()) { + coords = CoordinatesAccessor{Coords}; + } + else { + coords = CoordinatesAccessor{}; + } + data = BoutRealArray{&f(0, 0, 0)}; // Field size @@ -81,15 +88,19 @@ struct FieldAccessor { ddt = BoutRealArray{&(f.timeDeriv()->operator()(0, 0, 0))}; } + explicit FieldAccessor(const FieldType& f) : FieldAccessor(const_cast(f)) {} + /// Provide shorthand for access to field data. /// Does not convert between 3D and 2D indices, /// so fa[i] is equivalent to fa.data[i]. /// BOUT_HOST_DEVICE inline const BoutReal& operator[](int ind) const { return data[ind]; } + BOUT_HOST_DEVICE inline BoutReal& operator[](int ind) { return data[ind]; } BOUT_HOST_DEVICE inline const BoutReal& operator[](const Ind3D& ind) const { return data[ind.ind]; } + BOUT_HOST_DEVICE inline BoutReal& operator[](const Ind3D& ind) { return data[ind.ind]; } // Pointers to the field data arrays // These are wrapped in BoutRealArray types so they can be indexed with Ind3D or int @@ -115,6 +126,9 @@ struct FieldAccessor { template using Field2DAccessor = FieldAccessor; +template +using Field3DAccessor = FieldAccessor; + /// Syntactic sugar for time derivative of a field /// /// Usage: @@ -130,4 +144,28 @@ BOUT_HOST_DEVICE inline BoutRealArray& ddt(const FieldAccessor(fa.ddt); } +struct FieldPerpAccessor { + FieldPerpAccessor() = delete; + + int nx, nz; + int yindex; + BoutReal* data; + + explicit FieldPerpAccessor(const FieldPerp& f) { + ASSERT0(f.isAllocated()); + + data = BoutRealArray{const_cast(&f(0, 0, 0))}; + + // Field size + nx = f.getNx(); + nz = f.getNz(); + + yindex = f.getIndex(); + } + + BOUT_HOST_DEVICE int getIndex() const { return yindex; } + BOUT_HOST_DEVICE inline const BoutReal& operator[](int ind) const { return data[ind]; } + BOUT_HOST_DEVICE inline BoutReal& operator[](int ind) { return data[ind]; } +}; + #endif diff --git a/include/bout/mesh.hxx b/include/bout/mesh.hxx index a1c88a2634..b6553d06ec 100644 --- a/include/bout/mesh.hxx +++ b/include/bout/mesh.hxx @@ -762,6 +762,11 @@ public: return {(indPerp.ind - jz) * LocalNy + LocalNz * jy + jz, LocalNy, LocalNz}; } + BOUT_HOST_DEVICE int flatIndPerpto3D(const int& flatIndPerp, const int nz, int jy = 0) const { + int jz = flatIndPerp % nz; + return (flatIndPerp - jz) * LocalNy + LocalNz * jy + jz; + } + /// Converts an Ind3D to an Ind2D representing a 2D index using a lookup -- to be used with care Ind2D map3Dto2D(const Ind3D& ind3D) { return {indexLookup3Dto2D[ind3D.ind], LocalNy, 1}; diff --git a/include/bout/rajalib.hxx b/include/bout/rajalib.hxx index b9f6913459..92eae68858 100644 --- a/include/bout/rajalib.hxx +++ b/include/bout/rajalib.hxx @@ -23,6 +23,15 @@ #include "RAJA/RAJA.hpp" // using RAJA lib +#if BOUT_HAS_CUDA +// TODO: Make configurable +const int CUDA_BLOCK_SIZE = 256; +using EXEC_POL = RAJA::cuda_exec; +//using EXEC_POL = RAJA::loop_exec; +#else // not BOUT_USE_CUDA +using EXEC_POL = RAJA::loop_exec; +#endif // end BOUT_USE_CUDA + /// Wrapper around RAJA::forall /// Enables computations to be done on CPU or GPU (CUDA). /// @@ -81,7 +90,7 @@ struct RajaForAll { // Note: must be a local variable const int* _ob_i_ind_raw = &_ob_i_ind[0]; RAJA::forall(RAJA::RangeSegment(0, _ob_i_ind.size()), - [=] RAJA_DEVICE(int id) { + [=] RAJA_DEVICE(int id) mutable { // Look up index and call user function f(_ob_i_ind_raw[id]); }); @@ -127,7 +136,7 @@ private: /// to create variables which shadow the class members. /// #define BOUT_FOR_RAJA(index, region, ...) \ - RajaForAll(region) << [ =, ##__VA_ARGS__ ] RAJA_DEVICE(int index) +RajaForAll(region) << [ =, ##__VA_ARGS__ ] RAJA_DEVICE(int index) mutable #else // BOUT_HAS_RAJA diff --git a/include/bout/single_index_ops.hxx b/include/bout/single_index_ops.hxx index 60bd78bc36..c29d1a471f 100644 --- a/include/bout/single_index_ops.hxx +++ b/include/bout/single_index_ops.hxx @@ -7,17 +7,6 @@ #include "field_accessor.hxx" -#if BOUT_HAS_RAJA -//-- RAJA CUDA settings--------------------------------------------------------start -#if BOUT_HAS_CUDA -const int CUDA_BLOCK_SIZE = 256; // TODO: Make configurable -using EXEC_POL = RAJA::cuda_exec; -#else // not BOUT_USE_CUDA -using EXEC_POL = RAJA::loop_exec; -#endif // end BOUT_USE_CUDA -////-----------CUDA settings------------------------------------------------------end -#endif // end BOUT_HAS_RAJA - // Ind3D: i.zp(): BOUT_HOST_DEVICE inline int i_zp(const int id, const int nz) { int jz = id % nz; diff --git a/src/field/gen_fieldops.jinja b/src/field/gen_fieldops.jinja index ecd4e628cc..60f9cbbd7e 100644 --- a/src/field/gen_fieldops.jinja +++ b/src/field/gen_fieldops.jinja @@ -8,6 +8,26 @@ checkData({{lhs.name}}); checkData({{rhs.name}}); + {% if (region_loop == "BOUT_FOR_RAJA") %} + {% if out.field_type == "FieldPerp" %} + auto {{out.name}}_acc = FieldPerpAccessor{ {{out.name}} }; + {% else %} + auto {{out.name}}_acc = FieldAccessor({{out.name}}); + {% endif %} + {% if lhs.field_type == "FieldPerp" %} + auto {{lhs.name}}_acc = FieldPerpAccessor{ {{lhs.name}} }; + {% elif lhs.field_type == "BoutReal" %} + {% else %} + auto {{lhs.name}}_acc = FieldAccessor({{lhs.name}}); + {% endif %} + {% if rhs.field_type == "FieldPerp" %} + auto {{rhs.name}}_acc = FieldPerpAccessor{ {{rhs.name}} }; + {% elif rhs.field_type == "BoutReal" %} + {% else %} + auto {{rhs.name}}_acc = FieldAccessor({{rhs.name}}); + {% endif %} + {% endif %} + {% if out == "Field3D" %} {% if lhs == rhs == "Field3D" %} {{out.name}}.setRegion({{lhs.name}}.getMesh()->getCommonRegion({{lhs.name}}.getRegionID(), @@ -20,45 +40,98 @@ {% endif %} {% if (out == "Field3D") and ((lhs == "Field2D") or (rhs =="Field2D")) %} + {% if (region_loop == "BOUT_FOR_RAJA") %} + int mesh_nz = {{lhs.name if lhs.field_type != "BoutReal" else rhs.name}}_acc.mesh_nz; + {% else %} Mesh *localmesh = {{lhs.name if lhs.field_type != "BoutReal" else rhs.name}}.getMesh(); + {% endif %} {% if (lhs == "Field2D") %} {{region_loop}}({{index_var}}, {{lhs.name}}.getRegion({{region_name}})) { {% else %} {{region_loop}}({{index_var}}, {{rhs.name}}.getRegion({{region_name}})) { {% endif %} - const auto {{mixed_base_ind}} = localmesh->ind2Dto3D({{index_var}}); - {% if (operator == "/") and (rhs == "Field2D") %} - const auto tmp = 1.0 / {{rhs.mixed_index}}; - for (int {{jz_var}} = 0; {{jz_var}} < localmesh->LocalNz; ++{{jz_var}}){ - {{out.mixed_index}} = {{lhs.mixed_index}} * tmp; + {% if (region_loop == "BOUT_FOR_RAJA") %} + const auto {{mixed_base_ind}} = {{index_var}} * mesh_nz; + {% else %} + const auto {{mixed_base_ind}} = localmesh->ind2Dto3D({{index_var}}); + {% endif %} + {% if (operator == "/") and (rhs == "Field2D") %} + {% if (region_loop == "BOUT_FOR_RAJA") %} + const auto tmp = 1.0 / {{rhs.mixed_index_acc}}; {% else %} - for (int {{jz_var}} = 0; {{jz_var}} < localmesh->LocalNz; ++{{jz_var}}){ - {{out.mixed_index}} = {{lhs.mixed_index}} {{operator}} {{rhs.mixed_index}}; + const auto tmp = 1.0 / {{rhs.mixed_index}}; {% endif %} + {% if (region_loop == "BOUT_FOR_RAJA") %} + for (int {{jz_var}} = 0; {{jz_var}} < mesh_nz; ++{{jz_var}}){ + {% else %} + for (int {{jz_var}} = 0; {{jz_var}} < localmesh->LocalNz; ++{{jz_var}}){ + {% endif %} + {% if (region_loop == "BOUT_FOR_RAJA") %} + {{out.mixed_index_acc}} = {{lhs.mixed_index_acc}} * tmp; + {% else %} + {{out.mixed_index}} = {{lhs.mixed_index}} * tmp; + {% endif %} + {% else %} + {% if (region_loop == "BOUT_FOR_RAJA") %} + for (int {{jz_var}} = 0; {{jz_var}} < mesh_nz; ++{{jz_var}}){ + {% else %} + for (int {{jz_var}} = 0; {{jz_var}} < localmesh->LocalNz; ++{{jz_var}}){ + {% endif %} + {% if (region_loop == "BOUT_FOR_RAJA") %} + {{out.mixed_index_acc}} = {{lhs.mixed_index_acc}} {{operator}} {{rhs.mixed_index_acc}}; + {% else %} + {{out.mixed_index}} = {{lhs.mixed_index}} {{operator}} {{rhs.mixed_index}}; + {% endif %} + {% endif %} } - } + }{% if (region_loop == "BOUT_FOR_RAJA") %};{% endif %} {% elif out == "FieldPerp" and (lhs == "Field2D" or lhs == "Field3D" or rhs == "Field2D" or rhs == "Field3D")%} Mesh *localmesh = {{lhs.name if lhs.field_type != "BoutReal" else rhs.name}}.getMesh(); {{region_loop}}({{index_var}}, {{out.name}}.getRegion({{region_name}})) { - int yind = {{lhs.name if lhs == "FieldPerp" else rhs.name}}.getIndex(); - const auto {{mixed_base_ind}} = localmesh->indPerpto3D({{index_var}}, yind); + {% if (region_loop == "BOUT_FOR_RAJA") %} + int yind = {{lhs.name if lhs == "FieldPerp" else rhs.name}}_acc.getIndex(); + {% else %} + int yind = {{lhs.name if lhs == "FieldPerp" else rhs.name}}.getIndex(); + {% endif %} + {% if (region_loop == "BOUT_FOR_RAJA") %} + ; // DONE2 + const auto {{mixed_base_ind}} = localmesh->flatIndPerpto3D({{index_var}}, result_acc.nz, yind); + {% else %} + const auto {{mixed_base_ind}} = localmesh->indPerpto3D({{index_var}}, yind); + {% endif %} {% if lhs != "FieldPerp" %} - {{out.index}} = {{lhs.base_index}} {{operator}} {{rhs.index}}; + {% if (region_loop == "BOUT_FOR_RAJA") %} + {{out.index_acc}} = {{lhs.base_index_acc}} {{operator}} {{rhs.index_acc}}; + {% else %} + {{out.index}} = {{lhs.base_index}} {{operator}} {{rhs.index}}; + {% endif %} {% else %} - {{out.index}} = {{lhs.index}} {{operator}} {{rhs.base_index}}; + {% if (region_loop == "BOUT_FOR_RAJA") %} + {{out.index_acc}} = {{lhs.index_acc}} {{operator}} {{rhs.base_index_acc}}; + {% else %} + {{out.index}} = {{lhs.index}} {{operator}} {{rhs.base_index}}; + {% endif %} {% endif %} - } + }{% if (region_loop == "BOUT_FOR_RAJA") %};{% endif %} {% elif (operator == "/") and (rhs == "BoutReal") %} const auto tmp = 1.0 / {{rhs.index}}; {{region_loop}}({{index_var}}, {{out.name}}.getValidRegionWithDefault({{region_name}})) { - {{out.index}} = {{lhs.index}} * tmp; - } + {% if (region_loop == "BOUT_FOR_RAJA") %} + {{out.index_acc}} = {{lhs.index_acc}} * tmp; + {% else %} + {{out.index}} = {{lhs.index}} * tmp; + {% endif %} + }{% if (region_loop == "BOUT_FOR_RAJA") %};{% endif %} {% else %} {{region_loop}}({{index_var}}, {{out.name}}.getValidRegionWithDefault({{region_name}})) { - {{out.index}} = {{lhs.index}} {{operator}} {{rhs.index}}; - } + {% if (region_loop == "BOUT_FOR_RAJA") %} + {{out.index_acc}} = {{lhs.index_acc}} {{operator}} {{rhs.index_acc}}; + {% else %} + {{out.index}} = {{lhs.index}} {{operator}} {{rhs.index}}; + {% endif %} + }{% if (region_loop == "BOUT_FOR_RAJA") %};{% endif %} {% endif %} checkData({{out.name}}); @@ -84,49 +157,102 @@ checkData(*this); checkData({{rhs.name}}); + {% if (region_loop == "BOUT_FOR_RAJA") %} + {% if lhs.field_type == "FieldPerp" %} + auto this_acc = FieldPerpAccessor{(*this)}; + {% else %} + auto this_acc = FieldAccessor(*this); + {% endif %} + {% if rhs.field_type == "FieldPerp" %} + auto {{rhs.name}}_acc = FieldPerpAccessor{ {{rhs.name}} }; + {% elif rhs.field_type == "BoutReal" %} + {% else %} + auto {{rhs.name}}_acc = FieldAccessor({{rhs.name}}); + {% endif %} + {% endif %} + {% if lhs == rhs == "Field3D" %} regionID = fieldmesh->getCommonRegion(regionID, {{rhs.name}}.regionID); {% endif %} - {% if (lhs == "Field3D") and (rhs =="Field2D") %} + {% if (region_loop == "BOUT_FOR_RAJA") %} + int mesh_nz = fieldmesh->LocalNz; + {% endif %} {{region_loop}}({{index_var}}, {{rhs.name}}.getRegion({{region_name}})) { - const auto {{mixed_base_ind}} = fieldmesh->ind2Dto3D({{index_var}}); - {% if (operator == "/") and (rhs == "Field2D") %} - const auto tmp = 1.0 / {{rhs.mixed_index}}; - for (int {{jz_var}} = 0; {{jz_var}} < fieldmesh->LocalNz; ++{{jz_var}}){ - (*this)[{{mixed_base_ind}} + {{jz_var}}] *= tmp; + {% if (region_loop == "BOUT_FOR_RAJA") %} + const auto {{mixed_base_ind}} = {{index_var}} * mesh_nz; + {% else %} + const auto {{mixed_base_ind}} = fieldmesh->ind2Dto3D({{index_var}}); + {% endif %} + {% if (operator == "/") and (rhs == "Field2D") %} + {% if (region_loop == "BOUT_FOR_RAJA") %} + const auto tmp = 1.0 / {{rhs.mixed_index_acc}}; + for (int {{jz_var}} = 0; {{jz_var}} < mesh_nz; ++{{jz_var}}){ + this_acc[{{mixed_base_ind}} + {{jz_var}}] *= tmp; {% else %} - for (int {{jz_var}} = 0; {{jz_var}} < fieldmesh->LocalNz; ++{{jz_var}}){ - (*this)[{{mixed_base_ind}} + {{jz_var}}] {{operator}}= {{rhs.index}}; + const auto tmp = 1.0 / {{rhs.mixed_index}}; + for (int {{jz_var}} = 0; {{jz_var}} < fieldmesh->LocalNz; ++{{jz_var}}){ + (*this)[{{mixed_base_ind}} + {{jz_var}}] *= tmp; {% endif %} + {% else %} + {% if (region_loop == "BOUT_FOR_RAJA") %} + for (int {{jz_var}} = 0; {{jz_var}} < mesh_nz; ++{{jz_var}}){ + this_acc[{{mixed_base_ind}} + {{jz_var}}] {{operator}}= {{rhs.index_acc}}; + {% else %} + for (int {{jz_var}} = 0; {{jz_var}} < fieldmesh->LocalNz; ++{{jz_var}}){ + (*this)[{{mixed_base_ind}} + {{jz_var}}] {{operator}}= {{rhs.index}}; + {% endif %} + {% endif %} } - } + }{% if (region_loop == "BOUT_FOR_RAJA") %};{% endif %} {% elif lhs == "FieldPerp" and (rhs == "Field3D" or rhs == "Field2D")%} Mesh *localmesh = this->getMesh(); + {% if (region_loop == "BOUT_FOR_RAJA") %} + int yind = this->getIndex(); + {% endif %} {{region_loop}}({{index_var}}, this->getRegion({{region_name}})) { - int yind = this->getIndex(); - const auto {{mixed_base_ind}} = localmesh->indPerpto3D({{index_var}}, yind); - (*this)[{{index_var}}] {{operator}}= {{rhs.base_index}}; - } + {% if (region_loop == "BOUT_FOR_RAJA") %} + const auto {{mixed_base_ind}} = localmesh->flatIndPerpto3D({{index_var}}, yind); + this_acc[{{index_var}}] {{operator}}= {{rhs.base_index_acc}}; + {% else %} + int yind = this->getIndex(); + const auto {{mixed_base_ind}} = localmesh->indPerpto3D({{index_var}}, yind); + (*this)[{{index_var}}] {{operator}}= {{rhs.base_index}}; + {% endif %} + }{% if (region_loop == "BOUT_FOR_RAJA") %};{% endif %} {% elif rhs == "FieldPerp" and (lhs == "Field3D" or lhs == "Field2D")%} Mesh *localmesh = this->getMesh(); {{region_loop}}({{index_var}}, {{rhs.name}}.getRegion({{region_name}})) { - int yind = {{rhs.name}}.getIndex(); - const auto {{mixed_base_ind}} = localmesh->indPerpto3D({{index_var}}, yind); - (*this)[{{base_ind_var}}] {{operator}}= {{rhs.index}}; - } + {% if (region_loop == "BOUT_FOR_RAJA") %} + int yind = {{rhs.name}}.getIndex(); + const auto {{mixed_base_ind}} = localmesh->indPerpto3D({{index_var}}, yind); + this_acc[{{base_ind_var}}] {{operator}}= {{rhs.index}}; + {% else %} + int yind = {{rhs.name}}.getIndex(); + const auto {{mixed_base_ind}} = localmesh->indPerpto3D({{index_var}}, yind); + (*this)[{{base_ind_var}}] {{operator}}= {{rhs.index}}; + {% endif %} + }{% if (region_loop == "BOUT_FOR_RAJA") %};{% endif %} {% elif (operator == "/") and (lhs == "Field3D" or lhs == "Field2D") and (rhs =="BoutReal") %} const auto tmp = 1.0 / {{rhs.index}}; {{region_loop}}({{index_var}}, this->getRegion({{region_name}})) { + {% if (region_loop == "BOUT_FOR_RAJA") %} + this_acc[{{index_var}}] *= tmp; + {% else %} (*this)[{{index_var}}] *= tmp; - } + {% endif %} + }{% if (region_loop == "BOUT_FOR_RAJA") %};{% endif %} {% else %} {{region_loop}}({{index_var}}, this->getRegion({{region_name}})) { - (*this)[{{index_var}}] {{operator}}= {{rhs.index}}; - } + {% if (region_loop == "BOUT_FOR_RAJA") %} + this_acc[{{index_var}}] {{operator}}= {{rhs.index_acc}}; + {% else %} + (*this)[{{index_var}}] {{operator}}= {{rhs.index}}; + {% endif %} + }{% if (region_loop == "BOUT_FOR_RAJA") %};{% endif %} {% endif %} checkData(*this); diff --git a/src/field/gen_fieldops.py b/src/field/gen_fieldops.py index 29631ff7aa..c1ed1fed8e 100755 --- a/src/field/gen_fieldops.py +++ b/src/field/gen_fieldops.py @@ -132,6 +132,17 @@ def index(self): else: return "{self.name}[{self.index_var}]".format(self=self) + @property + def index_acc(self): + """Returns "_acc[{index_var}]" for an accessor-based index, except if + field_type is BoutReal, in which case just returns "" + + """ + if self.field_type == "BoutReal": + return "{self.name}".format(self=self) + else: + return "{self.name}_acc[{self.index_var}]".format(self=self) + @property def mixed_index(self): """Returns "[{index_var} + {jz_var}]" if field_type is Field3D, @@ -147,6 +158,21 @@ def mixed_index(self): else: # Field2D return "{self.name}[{self.index_var}]".format(self=self) + @property + def mixed_index_acc(self): + """Returns "_acc[{index_var} + {jz_var}]" for an accessor if field_type + is Field3D, self.index if Field2D or just returns "" for BoutReal + + """ + if self.field_type == "BoutReal": + return "{self.name}_acc".format(self=self) + elif self.field_type == "Field3D": + return "{self.name}_acc[{self.mixed_base_ind_var} + {self.jz_var}]".format( + self=self + ) + else: # Field2D + return "{self.name}_acc[{self.index_var}]".format(self=self) + @property def base_index(self): """Returns "[{mixed_base_ind_var}]" if field_type is Field3D, Field2D or FieldPerp @@ -158,6 +184,17 @@ def base_index(self): else: return "{self.name}[{self.mixed_base_ind_var}]".format(self=self) + @property + def base_index_acc(self): + """Returns "_acc[{mixed_base_ind_var}]" for an accessor if field_type is + Field3D, Field2D or FieldPerp or just returns "" for BoutReal + + """ + if self.field_type == "BoutReal": + return "{self.name}".format(self=self) + else: + return "{self.name}_acc[{self.mixed_base_ind_var}]".format(self=self) + def __eq__(self, other): try: return self.field_type == other.field_type @@ -198,11 +235,11 @@ def returnType(f1, f2): ) # By default use OpenMP enabled loops but allow to disable parser.add_argument( - "--no-openmp", - action="store_false", - default=False, - dest="noOpenMP", - help="Don't use OpenMP compatible loops", + "--loop-exec", + default="openmp", + dest="loop_exec", + choices=["serial", "openmp", "raja"], + help="Choose the loop execution method. Default is OpenMP", ) args = parser.parse_args() @@ -213,10 +250,16 @@ def returnType(f1, f2): mixed_base_ind_var = "base_ind" region_name = '"RGN_ALL"' - if args.noOpenMP: + if args.loop_exec == "openmp": + region_loop = "BOUT_FOR" + elif args.loop_exec == "raja": + region_loop = "BOUT_FOR_RAJA" + header += "#include \n" + header += "#include \n" + elif args.loop_exec == "serial": region_loop = "BOUT_FOR_SERIAL" else: - region_loop = "BOUT_FOR" + raise ValueError("Unknown loop execution method") # Declare what fields we currently support: # Field perp is currently missing diff --git a/src/mesh/coordinates.cxx b/src/mesh/coordinates.cxx index 34c524d1e7..8123720144 100644 --- a/src/mesh/coordinates.cxx +++ b/src/mesh/coordinates.cxx @@ -23,6 +23,8 @@ #include "parallel/fci.hxx" #include "parallel/shiftedmetricinterp.hxx" +#include "bout/coordinates_accessor.hxx" + // use anonymous namespace so this utility function is not available outside this file namespace { template @@ -1203,10 +1205,11 @@ int Coordinates::geometry(bool recalculate_staggered, localmesh->recalculateStaggeredCoordinates(); } - // Invalidate and recalculate cached variables + // Invalidate and recalculate cached variables and any accessor zlength_cache.reset(); Grad2_par2_DDY_invSgCache.clear(); invSgCache.reset(); + CoordinatesAccessor::clear(this); return 0; } diff --git a/src/mesh/coordinates_accessor.cxx b/src/mesh/coordinates_accessor.cxx index aff546c2b0..196234d999 100644 --- a/src/mesh/coordinates_accessor.cxx +++ b/src/mesh/coordinates_accessor.cxx @@ -40,8 +40,9 @@ CoordinatesAccessor::CoordinatesAccessor(const Coordinates* coords) { // Copy data from Coordinates variable into data array // Uses the symbol to look up the corresponding Offset -#define COPY_STRIPE1(symbol) \ - data[stripe_size * ind.ind + static_cast(Offset::symbol)] = coords->symbol[ind]; +#define COPY_STRIPE1(symbol) \ + if (coords->symbol.isAllocated()) \ + data[stripe_size * ind.ind + static_cast(Offset::symbol)] = coords->symbol[ind]; // Implement copy for each argument #define COPY_STRIPE(...) \ @@ -54,10 +55,15 @@ CoordinatesAccessor::CoordinatesAccessor(const Coordinates* coords) { COPY_STRIPE(d1_dx, d1_dy, d1_dz); COPY_STRIPE(J); - data[stripe_size * ind.ind + static_cast(Offset::B)] = coords->Bxy[ind]; - data[stripe_size * ind.ind + static_cast(Offset::Byup)] = coords->Bxy.yup()[ind]; - data[stripe_size * ind.ind + static_cast(Offset::Bydown)] = - coords->Bxy.ydown()[ind]; + if (coords->Bxy.isAllocated()) { + data[stripe_size * ind.ind + static_cast(Offset::B)] = coords->Bxy[ind]; + if (coords->Bxy.yup().isAllocated()) + data[stripe_size * ind.ind + static_cast(Offset::Byup)] = + coords->Bxy.yup()[ind]; + if (coords->Bxy.ydown().isAllocated()) + data[stripe_size * ind.ind + static_cast(Offset::Bydown)] = + coords->Bxy.ydown()[ind]; + } COPY_STRIPE(G1, G3); COPY_STRIPE(g11, g12, g13, g22, g23, g33); diff --git a/tests/integrated/test_suite b/tests/integrated/test_suite index 307a8d84b3..77ad7882c4 100755 --- a/tests/integrated/test_suite +++ b/tests/integrated/test_suite @@ -188,7 +188,7 @@ class Test(threading.Thread): self.output += "\n(It is likely that a timeout occured)" else: # ❌ Failed - print("\u274C", end="") # No newline + print("\u274c", end="") # No newline print(" %7.3f s" % (time.time() - self.local.start_time), flush=True) def _cost(self):