diff --git a/metatomic-torch/src/outputs.cpp b/metatomic-torch/src/outputs.cpp index 86715665..2586cc0f 100644 --- a/metatomic-torch/src/outputs.cpp +++ b/metatomic-torch/src/outputs.cpp @@ -46,6 +46,23 @@ static std::string join_names(const std::vector& names) { return oss.str(); } +static std::string create_list(const int32_t size) { + std::ostringstream oss; + oss << "["; + if (size > 3) { + oss << "[0], ..., [n]"; + } else { + for (int32_t i = 0; i < size; i++) { + oss << "[" << i << "]"; + if (i + 1 < size) { + oss << ", "; + } + } + } + oss << "]"; + return oss.str(); +} + /// Ensure the TensorMap has a single block with the expected key static void validate_single_block(const std::string& name, const TensorMap& value) { auto expected_label = LabelsHolder::create({"_"}, {{0}}); @@ -132,11 +149,65 @@ static void validate_atomic_samples( } } -/// Ensure the block has no components -static void validate_no_components(const std::string& name, const TensorBlock& block) { - if (block->components().size() != 0) { +static void validate_components(const std::string& name, const std::vector& components, const std::vector& expected_components) { + if (components.size() != expected_components.size()) { + if (expected_components.size() == 0) { + C10_THROW_ERROR(ValueError, + "invalid components for " + name + " output: `components` should be empty" + ); + } + C10_THROW_ERROR(ValueError, + "invalid components for '" + name + "' output: " + "expected" + std::to_string(expected_components.size()) + "component(s)" + ); + } + for (size_t i = 0; i < expected_components.size(); i++){ + if (*components[i] != *expected_components[i]) { + auto label_values = expected_components[i]->values(); + std::string expected_labels = "Labels('" + join_names(expected_components[i]->names()) + "', " + create_list(label_values.size(-1)) + ")`"; + C10_THROW_ERROR(ValueError, + "invalid components for '" + name + "' output: " + "expected `" + expected_labels + "`" + ); + } + } +} + +static void validate_properties(const std::string& name, const TensorBlock& block, const Labels& expected_properties) { + if (*block->properties() != *expected_properties) { + auto label_values = expected_properties->values(); + std::string expected_labels = "Labels('" + join_names(expected_properties->names()) + "', " + create_list(label_values.size(-1)) + ")`"; + C10_THROW_ERROR(ValueError, + "invalid properties for '" + name + "' output: " + "expected `" + expected_labels + "`" + ); + } +} + +static void validate_gradient( + const std::string& name, + const std::string& parameter, + const TensorBlock& gradient, + const std::vector& expected_samples_names, + const std::vector& expected_components +) { + + if (gradient->samples()->names() != expected_samples_names) { + C10_THROW_ERROR(ValueError, + "invalid samples for '" + name + "' output '" + parameter + "' gradients: " + "expected the names to be " + join_names(expected_samples_names) + ", got " + + join_names(gradient->samples()->names()) + ); + } + + validate_components(name + " '" + parameter + "' gradients", gradient->components(), expected_components); +} + +static void validate_no_gradients(const std::string& name, const TensorBlock& block) { + if (block->gradients_list().size() > 0) { C10_THROW_ERROR(ValueError, - "invalid components for " + name + " output: `components` should be empty" + "invalid gradients for '" + name + "' output: " + "expected no gradients, found " + join_names(block->gradients_list()) ); } } @@ -158,17 +229,15 @@ static void check_energy_like( auto energy_block = TensorMapHolder::block_by_id(value, 0); auto tensor_options = torch::TensorOptions().device(value->device()); // Ensure that the block has no components - validate_no_components(name, energy_block); + validate_components(name, energy_block->components(), {}); // The only difference between energy & energy_ensemble is in the properties Labels expected_properties; - std::string expected_properties_str; if (name == "energy" || name == "energy_uncertainty") { expected_properties = torch::make_intrusive( "energy", torch::tensor({{0}}, tensor_options) ); - expected_properties_str = "`Labels(\"energy\", [[0]])`"; } else { assert(name == "energy_ensemble"); const auto n_ensemble_members = energy_block->values().size(-1); @@ -176,14 +245,8 @@ static void check_energy_like( "energy", torch::arange(n_ensemble_members, tensor_options).reshape({-1, 1}) ); - expected_properties_str = "`Labels(\"energy\", [[0], ..., [n]])`"; - } - - if (*energy_block->properties() != *expected_properties) { - C10_THROW_ERROR(ValueError, - "invalid properties for '" + name + " ' output: expected " + expected_properties_str - ); } + validate_properties(name, energy_block, expected_properties); auto gradients = TensorBlockHolder::gradients(energy_block); for (const auto& [parameter, gradient]: gradients) { @@ -193,65 +256,23 @@ static void check_energy_like( auto xyz = torch::tensor({{0}, {1}, {2}}, tensor_options); // strain gradient checks if (parameter == "strain") { - if (gradient->samples()->names() != std::vector{"sample"}) { - C10_THROW_ERROR(ValueError, - "invalid samples for '" + name + "' output 'strain' gradients: " - "expected the names to be ['sample'], got " + - join_names(gradient->samples()->names()) - ); - } - - auto components = gradient->components(); - if (components.size() != 2) { - C10_THROW_ERROR(ValueError, - "invalid components for '" + name + "' output 'strain' " - "gradients: expected two components" - ); - } - - if (*components[0] != *torch::make_intrusive("xyz_1", xyz)) { - C10_THROW_ERROR(ValueError, - "invalid components for '" + name + "' output 'strain' " - "gradients: expected Labels('xyz_1', [[0], [1], [2]]) for " - "the first component" - ); - } + const std::vector expected_samples_names{"sample"}; + std::vector expected_components{ + torch::make_intrusive("xyz_1", xyz), + torch::make_intrusive("xyz_2", xyz) + }; - if (*components[1] != *torch::make_intrusive("xyz_2", xyz)) { - C10_THROW_ERROR( - ValueError, - "invalid components for '" + name + "' output 'strain' " - "gradients: expected Labels('xyz_2', [[0], [1], [2]]) for " - "the first component" - ); - } + validate_gradient(name, parameter, gradient, expected_samples_names, expected_components); } // positions gradient checks if (parameter == "positions") { - if (gradient->samples()->names() != std::vector{"sample", "system", "atom"}) { - C10_THROW_ERROR(ValueError, - "invalid samples for '" + name + "' output 'positions' " - "gradients: expected the names to be ['sample', 'system', 'atom'], " - "got " + join_names(gradient->samples()->names()) - ); - } - - auto components = gradient->components(); - if (components.size() != 1) { - C10_THROW_ERROR(ValueError, - "invalid components for '" + name + "' output 'positions' " - "gradients: expected one component" - ); - } + const std::vector expected_samples_names{"sample", "system", "atom"}; + std::vector expected_components{ + torch::make_intrusive("xyz", xyz) + }; - if (*components[0] != *torch::make_intrusive("xyz", xyz)) { - C10_THROW_ERROR(ValueError, - "invalid components for '" + name + "' output 'positions' " - "gradients: expected Labels('xyz', [[0], [1], [2]]) for the " - "first component" - ); - } + validate_gradient(name, parameter, gradient, expected_samples_names, expected_components); } } } @@ -272,16 +293,10 @@ static void check_features( auto features_block = TensorMapHolder::block_by_id(value, 0); // Check that the block has no components - validate_no_components("features", features_block); + validate_components("features", features_block->components(), {}); // Should not have any explicit gradients - // all gradient calculations are done using autograd - if (features_block->gradients_list().size() > 0) { - C10_THROW_ERROR(ValueError, - "invalid gradients for 'features' output: " - "expected no gradients, found " + join_names(features_block->gradients_list()) - ); - } + validate_no_gradients("features", features_block); } /// Check output metadata for non-conservative forces. @@ -296,37 +311,20 @@ static void check_non_conservative_forces( // Check samples values from systems & selected_atoms validate_atomic_samples("non_conservative_forces", value, systems, request, selected_atoms); - + auto forces_block = TensorMapHolder::block_by_id(value, 0); - - // Check that the block has correct "Cartesian-form" components - auto components = forces_block->components(); - if (components.size() != 1) { - C10_THROW_ERROR(ValueError, - "invalid components for 'non_conservative_forces' output: " - "expected one component" - ); - } auto tensor_options = torch::TensorOptions().device(value->device()); - auto expected_component = torch::make_intrusive( - "xyz", - torch::tensor({{0}, {1}, {2}}, tensor_options) - ); - - if (*components[0] != *expected_component) { - C10_THROW_ERROR(ValueError, - "invalid components for 'non_conservative_forces' output: " - "expected `Labels('xyz', [[0], [1], [2]])`" - ); - } - + std::vector expected_components{ + torch::make_intrusive( + "xyz", + torch::tensor({{0}, {1}, {2}}, tensor_options) + ) + }; + + validate_components("non_conservative_forces", forces_block->components(), expected_components); + // Should not have any gradients - if (forces_block->gradients_list().size() > 0) { - C10_THROW_ERROR(ValueError, - "invalid gradients for 'non_conservative_forces' output: " - "expected no gradients, found " + join_names(forces_block->gradients_list()) - ); - } + validate_no_gradients("non_conservative_forces", forces_block); } /// Check output metadata for the non-conservative stress. @@ -342,40 +340,17 @@ static void check_non_conservative_stress( validate_atomic_samples("non_conservative_stress", value, systems, request, torch::nullopt); auto stress_block = TensorMapHolder::block_by_id(value, 0); - auto components = stress_block->components(); - auto tensor_options = torch::TensorOptions().device(value->device()); auto xyz = torch::tensor({{0}, {1}, {2}}, tensor_options); - - // Check that the block has correct "Cartesian-form" components - if (components.size() != 2) { - C10_THROW_ERROR(ValueError, - "invalid components for 'non_conservative_stress' output: " - "expected two components, got " + std::to_string(stress_block->components().size()) - ); - } - - if (*components[0] != *torch::make_intrusive("xyz_1", xyz)) { - C10_THROW_ERROR(ValueError, - "invalid components for 'non_conservative_stress' output: " - "expected `Labels('xyz_1', [[0], [1], [2]])`" - ); - } - - if (*components[1] != *torch::make_intrusive("xyz_2", xyz)) { - C10_THROW_ERROR(ValueError, - "invalid components for 'non_conservative_stress' output: " - "expected `Labels('xyz_1', [[0], [1], [2]])`" - ); - } - + std::vector expected_components{ + torch::make_intrusive("xyz_1", xyz), + torch::make_intrusive("xyz_2", xyz) + }; + + validate_components("non_conservative_stress", stress_block->components(), expected_components); + // Should not have any gradients - if (stress_block->gradients_list().size() > 0) { - C10_THROW_ERROR(ValueError, - "invalid gradients for 'non_conservative_stress' output: " - "expected no gradients, found " + join_names(stress_block->gradients_list()) - ); - } + validate_no_gradients("non_conservative_stress", stress_block); } /// Check output metadata for positions. @@ -389,50 +364,26 @@ static void check_positions( // Check samples values from systems validate_atomic_samples("positions", value, systems, request, torch::nullopt); - - auto positions_block = TensorMapHolder::block_by_id(value, 0); - auto components = positions_block->components(); - - // Check that the block has correct "Cartesian-form" components - if (components.size() != 1) { - C10_THROW_ERROR(ValueError, - "invalid components for 'positions' output: expected one " - "component, got " + std::to_string(positions_block->components().size()) - ); - } - + auto tensor_options = torch::TensorOptions().device(value->device()); - auto expected_component = torch::make_intrusive( - "xyz", - torch::tensor({{0}, {1}, {2}}, tensor_options) - ); - - if (*components[0] != *expected_component) { - C10_THROW_ERROR(ValueError, - "invalid components for 'positions' output: " - "expected `Labels('xyz', [[0], [1], [2]])`" - ); - } + auto positions_block = TensorMapHolder::block_by_id(value, 0); + std::vector expected_components{ + torch::make_intrusive( + "xyz", + torch::tensor({{0}, {1}, {2}}, tensor_options) + ) + }; + + validate_components("positions", positions_block->components(), expected_components); auto expected_properties = torch::make_intrusive( "positions", torch::tensor({{0}}, tensor_options) ); - - if (*positions_block->properties() != *expected_properties) { - C10_THROW_ERROR(ValueError, - "invalid properties for 'positions' output: " - " expected `Labels('positions', [[0]])`" - ); - } + validate_properties("positions", positions_block, expected_properties); // Should not have any gradients - if (positions_block->gradients_list().size() > 0) { - C10_THROW_ERROR(ValueError, - "invalid gradients for 'positions' output: expected no " - "gradients, found " + join_names(positions_block->gradients_list()) - ); - } + validate_no_gradients("positions", positions_block); } /// Check output metadata for momenta. @@ -447,49 +398,25 @@ static void check_momenta( // Check samples values from systems validate_atomic_samples("momenta", value, systems, request, torch::nullopt); - auto momenta_block = TensorMapHolder::block_by_id(value, 0); - auto components = momenta_block->components(); - - // Check that the block has correct "Cartesian-form" components - if (components.size() != 1) { - C10_THROW_ERROR(ValueError, - "invalid components for 'momenta' output: expected one component, " - "got " + std::to_string(momenta_block->components().size()) - ); - } - + auto tensor_options = torch::TensorOptions().device(value->device()); - auto expected_component = torch::make_intrusive( - "xyz", - torch::tensor({{0}, {1}, {2}}, tensor_options) - ); - - if (*components[0] != *expected_component) { - C10_THROW_ERROR(ValueError, - "invalid components for 'momenta' output: " - "expected `Labels('xyz', [[0], [1], [2]])`" - ); - } + auto momenta_block = TensorMapHolder::block_by_id(value, 0); + std::vector expected_component { + torch::make_intrusive( + "xyz", + torch::tensor({{0}, {1}, {2}}, tensor_options) + ) + }; + validate_components("momenta", momenta_block->components(), expected_component); auto expected_properties = torch::make_intrusive( "momenta", torch::tensor({{0}}, tensor_options) ); - - if (*momenta_block->properties() != *expected_properties) { - C10_THROW_ERROR(ValueError, - "invalid properties for 'momenta' output: expected " - "`Labels('momenta', [[0]])`" - ); - } + validate_properties("momenta", momenta_block, expected_properties); // Should not have any gradients - if (momenta_block->gradients_list().size() > 0) { - C10_THROW_ERROR(ValueError, - "invalid gradients for 'momenta' output: expected no " - "gradients, found " + join_names(momenta_block->gradients_list()) - ); - } + validate_no_gradients("momenta", momenta_block); } void metatomic_torch::check_outputs(