Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ Learn more about Skala in our [ArXiv paper](https://arxiv.org/abs/2506.14665).

## What's in here

This repository contains two main components:
This repository contains three main components:

1. The Python package `microsoft-skala`, which is also distributed [on PyPI](https://pypi.org/project/microsoft-skala/) and contains a Pytorch implementation of the Skala model, its hookups to quantum chemistry packages [PySCF](https://pyscf.org/) and [ASE](https://ase-lib.org/), and an independent client library for the Skala model served [in Azure AI Foundry](https://ai.azure.com/catalog/models/Skala).
2. A development version of the CPU/GPU C++ library for XC functionals [GauXC](https://github.com/wavefunction91/GauXC) with an add-on supporting Pytorch-based functionals like Skala. GauXC is part of the stack that serves Skala in Azure AI Foundry and can be used to integrate Skala into other third-party DFT codes.
3. An example of using Skala in C++ CPU applications through LibTorch, see [`examples/cpp/cpp_integration`](examples/cpp/cpp_integration).

All information below relates to the Python package, the development version of GauXC including its license and other information can be found in [`third_party/gauxc`](https://github.com/microsoft/skala/tree/main/third_party/gauxc).

Expand Down Expand Up @@ -59,4 +60,3 @@ This project may contain trademarks or logos for projects, products, or services
Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.

3 changes: 3 additions & 0 deletions examples/cpp/cpp_integration/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.pt
*.fun
/_build
22 changes: 22 additions & 0 deletions examples/cpp/cpp_integration/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(
"skala_cpp_integration"
VERSION "0.0.1"
LANGUAGES CXX
)

find_package(nlohmann_json REQUIRED)
find_package(Torch REQUIRED)

add_executable(
${PROJECT_NAME}
./main.cpp
)
target_link_libraries(
${PROJECT_NAME}
"${TORCH_LIBRARIES}"
)
set_property(
TARGET ${PROJECT_NAME}
PROPERTY CXX_STANDARD 17
)
52 changes: 52 additions & 0 deletions examples/cpp/cpp_integration/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Integrating Skala in C++ code

This example demonstrates how to use the Skala machine learning functional in C++ CPU applications using LibTorch.

## Setup environment

Set up the conda environment using the provided environment file:

```bash
cd examples/cpp/cpp_integration
conda env create -n skala_cpp_integration -f environment.yml
conda activate skala_cpp_integration
```

## Build library

The example can be built using CMake.
The provided environment is configured for CMake to find the required dependencies.

```bash
cmake -S . -B _build -G Ninja
cmake --build _build
```

For any changes to the code, rebuild using the last command.

## Run example

Download the Skala model, as well as a reference LDA functional from HuggingFace
using the provided download script:

```bash
./download_model.py
```

Prepare the molecular features for a test molecule (H2) using the provided script:

```bash
python ./prepare_inputs.py --output-dir H2
```

Finally, run $E_\text{xc}$ and (partial) $V_\text{xc}$ computations with the C++ example:

```bash
./_build/skala_cpp_integration skala-1.0.fun H2
```

**Note:** You are expected to add D3 dispersion correction (using b3lyp settings) to the final energy of Skala.

## Performance tuning

[This guide](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html) from Intel provides useful tips on how to tune performance of PyTorch models on CPU.
63 changes: 63 additions & 0 deletions examples/cpp/cpp_integration/download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3

"""
This script downloads the Skala model, as well as a reference LDA functional from HuggingFace.

The LDA functional computes -3 / 4 * (3 / math.pi) ** (1 / 3) * density.abs() ** (4 / 3),
and can be used to verify that the C++ integration is working correctly.
"""

import shutil

from huggingface_hub import hf_hub_download

from skala.functional.load import TracedFunctional

GRID_SIZE = "grid_size"
NUM_ATOMS = "num_atoms"

feature_shapes = {
"density": [2, GRID_SIZE],
"kin": [2, GRID_SIZE],
"grad": [2, 3, GRID_SIZE],
"grid_coords": [GRID_SIZE, 3],
"grid_weights": [GRID_SIZE],
"coarse_0_atomic_coords": [NUM_ATOMS, 3],
}

feature_labels = {
"density": "Electron density on grid, two spin channels",
"kin": "Kinetic energy density on grid, two spin channels",
"grad": "Gradient of electron density on grid, two spin channels",
"grid_coords": "Coordinates of grid points",
"grid_weights": "Weights of grid points",
"coarse_0_atomic_coords": "Atomic coordinates",
}


def main() -> None:
huggingface_repo_id = "microsoft/skala"
for filename in ("skala-1.0.fun", "baselines/ldax.fun"):
output_path = filename.split("/")[-1]
download_model(huggingface_repo_id, filename, output_path)


def download_model(huggingface_repo_id: str, filename: str, output_path: str) -> None:
path = hf_hub_download(repo_id=huggingface_repo_id, filename=filename)
shutil.copyfile(path, output_path)

print(f"Downloaded the {filename} functional to {output_path}")

fun = TracedFunctional.load(output_path)

print("\nExpected inputs:")
for feature in fun.features:
print(
f"- {feature} {feature_shapes[feature]} in float64 ({feature_labels[feature]})"
)

print(f"\nExpected D3 dispersion settings: {fun.expected_d3_settings}\n")


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions examples/cpp/cpp_integration/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: skala-cpp-integration
channels:
- conda-forge
dependencies:
- cmake
- cxx-compiler
- ninja
- nlohmann_json
- libtorch
- pytorch # Only required to produce example features.
- pip:
- microsoft-skala
163 changes: 163 additions & 0 deletions examples/cpp/cpp_integration/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#include <torch/script.h>
#include <torch/csrc/autograd/autograd.h>
#include <nlohmann/json.hpp>

#include <iostream>
#include <memory>

using json = nlohmann::json;
using IValueList = std::vector<c10::IValue>;
using IValueMap = std::unordered_map<std::string, c10::IValue>;
using FeatureDict = c10::Dict<std::string, at::Tensor>;

at::Tensor
load_feature(const std::string &filename, torch::DeviceType device)
{
std::ifstream input(filename, std::ios::binary);
if (!input.is_open())
{
throw std::runtime_error("Failed to open feature file: " + filename);
}
std::vector<char> bytes(
(std::istreambuf_iterator<char>(input)),
(std::istreambuf_iterator<char>()));

input.close();
return torch::jit::pickle_load(bytes).toTensor().to(device);
}

FeatureDict
load_features(const std::string &prefix, const std::vector<std::string> &keys, torch::DeviceType device)
{
FeatureDict featmap;
for (const auto &key : keys)
{
featmap.insert(key, load_feature(prefix + "/" + key + ".pt", device));
}
return featmap;
}

std::tuple<torch::jit::Method, std::vector<std::string>>
load_model(const std::string &filename, torch::DeviceType device)
{
torch::jit::script::Module mod;
torch::jit::ExtraFilesMap extra_files{{"features", ""}, {"protocol_version", ""}};
std::vector<std::string> keys;

try
{
// Deserialize the ScriptModule from a file using torch::jit::load().
mod = torch::jit::load(filename, device, extra_files);
}
catch (const c10::Error &e)
{
throw std::runtime_error("Error loading the model from " + filename + ": " + e.what());
}

auto version = json::parse(extra_files.at("protocol_version")).get<int>();
if (version != 2)
{
throw std::runtime_error("Unsupported protocol version " + std::to_string(version));
}

auto features = json::parse(extra_files.at("features"));
// check if features is array
if (!features.is_array())
{
throw std::runtime_error("features is not an array");
}
for (const auto &feature : features)
{
if (!feature.is_string())
{
throw std::runtime_error("feature is not a string");
}
keys.push_back(feature.get<std::string>());
}

return std::make_tuple(mod.get_method("get_exc_density"), keys);
}

at::Tensor
get_exc(const torch::jit::Method &exc_func, const FeatureDict &features)
{
IValueList args;
IValueMap kwargs;
kwargs["mol"] = features;
return exc_func(args, kwargs).toTensor();
}

std::tuple<at::Tensor, c10::Dict<std::string, at::Tensor>>
get_exc_and_grad(const torch::jit::Method &exc_func, const FeatureDict &features)
{
// Create a mutable copy only for the tensors that need gradients
FeatureDict features_with_grad;
std::vector<at::Tensor> input_tensors;
std::vector<std::string> tensor_keys;

for (const auto &kv : features)
{
auto tensor_with_grad = kv.value().clone().requires_grad_(true);
features_with_grad.insert(kv.key(), tensor_with_grad);
input_tensors.push_back(tensor_with_grad);
tensor_keys.push_back(kv.key());
}

IValueList args;
IValueMap kwargs;
kwargs["mol"] = features_with_grad;

auto exc_on_grid = exc_func(args, kwargs).toTensor();
auto exc = (exc_on_grid * features_with_grad.at("grid_weights")).sum();

auto gradients = torch::autograd::grad(
{exc}, // outputs
input_tensors, // inputs
/*grad_outputs=*/{}, // grad_outputs (defaults to ones)
/*retain_graph=*/false, // retain_graph, necessary for higher-order grads
/*create_graph=*/false, // create_graph, necessary for higher-order grads
/*allow_unused=*/true // allow_unused
);

c10::Dict<std::string, at::Tensor> grad;
for (size_t i = 0; i < tensor_keys.size(); ++i)
{
grad.insert(tensor_keys[i], gradients[i]);
}

return std::make_tuple(exc_on_grid, grad);
}

int main(int argc, const char *argv[])
{
if (argc != 3)
{
std::cerr << "usage: skala_cpp_integration <path-to-fun-file> <feature-file-directory>\n";
return -1;
}

const torch::DeviceType device = torch::kCPU;

const auto [exc_func, feature_keys] = load_model(std::string(argv[1]), device);
const auto features = load_features(std::string(argv[2]), feature_keys, device);

std::cout << "Compute Exc..." << std::endl;

const auto exc_on_grid = get_exc(exc_func, features);
const auto exc = (exc_on_grid * features.at("grid_weights")).sum();

std::cout << "Exc = " << exc.item() << std::endl;

std::cout << "Compute Exc and dExc/dfeat..." << std::endl;

const auto [exc_on_grid2, grad] = get_exc_and_grad(exc_func, features);
const auto exc2 = (exc_on_grid2 * features.at("grid_weights")).sum();

std::cout << "Exc = " << exc2.item() << std::endl;
for (const auto &kv : grad)
{
std::cout << "|dExc/d(" << kv.key() << ")| = " << kv.value().norm().item() << std::endl;
}

return 0;
}
Loading
Loading