Skip to content

sb-2700/ml4dynamics

 
 

Repository files navigation

Mitigating distribution shift in machine learning-augmented hybrid simulation

Codebase for mitigating distribution shift in MLHS using tangent-space reegularized algorithm based on the paper [1] by Jiaxi Zhao and Qianxiao Li.

Installing & Getting Started

Install the package and set the environment

git clone git@github.com:jiaxi98/ml4dynamics.git
mkdir venv
mkdir venv/ml4dynamics
python -m venv venv/ml4dynamics
source venv/ml4dynamics/bin/activate
cd ml4dynamics
pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -e .

For discussion related to compatibility of jax and torch, please refer to: link1, link2, link3

Generate the datasets

Both filter and correction SGS stresses are calculated, modified the corresponding config file ks.yaml to switch between different boundary conditions: periodic and Dirichlet-Neumann BC.

python ml4dynamics/dataset_utils/generate_ks.py

Train the model

One can train on 1 or 10 trajectories. To verify the performance, one can verify on one trajectory to visualize and also on 10 trajectory to see the statistics, check ml4dynamics.utils.utils.eval_a_posteriori for more details.

python ml4dynamics/trainers/train_jax.py -c ks

Train a global/local model

This codebase supports both the training of the global (mesh to mesh) and local (pt to pt) SGS model. To switch between these two modes, please modify the config file accordingly. Below are two examples for global/local training.

  • global
train:
  # input: ["u", "u_x", "u_xx", "x"]
  # input: 4
  input: global
  • local
train:
  input: ["u", "u_x", "u_xx", "x"]
  # input: 4
  # input: global

Notice for local training, two choices to specify the input features are provided. The first choice $(u, u_x, u_{xx})$ directly specify the input features with order of the local model while the second choice specify the size of the stencil for the input features, e.g. $3$ refers to $(u_{i-1}, u_i, u_{i+1})$ and $5$ refers to $(u_{i-1}, u_{i-1}, u_i, u_{i+1}, u_{i+2})$ with order.

Citation

If you find this codebase useful for your research, please consider citing:

@article{zhao2025mitigating,
  title={Mitigating Distribution Shift in Machine Learning--Augmented Hybrid Simulation},
  author={Zhao, Jiaxi and Li, Qianxiao},
  journal={SIAM Journal on Scientific Computing},
  volume={47},
  number={2},
  pages={C475--C500},
  year={2025},
  publisher={SIAM}
}

@inproceedings{zhao2025generative,
  title={Generative subgrid-scale modeling},
  author={Zhao, Jiaxi and Arisaka, Sohei and Li, Qianxiao},
  booktitle={ICLR 2025 Workshop on Machine Learning Multiscale Processes},
  year={2025}
}

About

Jiaxi Zhao ML repo

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.5%
  • Other 0.5%