Skip to content

setup for PyTorch 2.9.1 and Blackwell GPUs #4

@21tesla

Description

@21tesla

edit or use this environment.yml

name: ppiflow
channels:
  - pytorch
  - nvidia
  - conda-forge
  - bioconda
  - defaults
dependencies:
  # --- Core System & Scientific Stack ---
  - python=3.10
  - pip
  - wheel
  - setuptools
  - rdkit
  - pymol-open-source
  - openmm
  - pdbfixer
  - biopython
  - numpy
  - pandas
  - scipy
  - scikit-learn
  - matplotlib
  - seaborn
  - jupyter
  - ipython

  # --- PIP Dependencies ---
  - pip:
      # 1. Force the NVIDIA/PyTorch cu128 Index for sm_120 support
      - --index-url https://download.pytorch.org/whl/cu128

      # 2. PyTorch Stack (
      - torch==2.9.1
      - torchvision==0.24.1
      - torchaudio==2.9.1

      # 3. Deep Learning 
      - torch-geometric
      - lightning
      - torchmetrics

      # 4. Utilities
      - fair-esm
      - biotite
      - mdanalysis
      - wandb
      - hydra-core
      - omegaconf
      - tqdm
      - einops
      - py3dmol
      - tensorboard
      - dm-tree
      - GPUtil
      - mdtraj
      - tmtools
      - freesasa
      - optree
      - deepspeed

Create Conda Environment

conda env create -f environment.yml
conda activate ppiflow

install cuda 12.8 compiler if you have a 13.0 compiler elsewhere

conda install -c nvidia cuda-toolkit=12.8

configure cuda 12.8 compiler

export CUDA_HOME=$CONDA_PREFIX
export PATH=$CONDA_PREFIX/bin:$PATH
CUDA_HEADER_PATH=$(find $CONDA_PREFIX -name "cuda_runtime_api.h" -type f | head -n 1 | xargs dirname)
export CFLAGS="-I$CUDA_HEADER_PATH $CFLAGS"
export CXXFLAGS="-I$CUDA_HEADER_PATH $CXXFLAGS"
export CPATH=$CUDA_HEADER_PATH:$CPATH
export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH

edit or replace models/layer_norm/torch_ext_compile.py

# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from torch.utils.cpp_extension import load


def compile(name, sources, extra_include_paths, build_directory):
    # Force PyTorch to ignore manual flags.
    # os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;8.0" 
    
    return load(
        name=name,
        sources=sources,
        extra_include_paths=extra_include_paths,
        extra_cflags=[
            "-O3",
            "-DVERSION_GE_1_1",
            "-DVERSION_GE_1_3",
            "-DVERSION_GE_1_5",
        ],
        extra_cuda_cflags=[
            "-O3",
            "--use_fast_math",
            "-DVERSION_GE_1_1",
            "-DVERSION_GE_1_3",
            "-DVERSION_GE_1_5",
            "-std=c++17",
            "-maxrregcount=50",
            "-U__CUDA_NO_HALF_OPERATORS__",
            "-U__CUDA_NO_HALF_CONVERSIONS__",
            "--expt-relaxed-constexpr",
            "--expt-extended-lambda",
            
            # Legacy Architectures
            "-gencode", "arch=compute_70,code=sm_70",
            "-gencode", "arch=compute_80,code=sm_80",
            "-gencode", "arch=compute_86,code=sm_86",
            
            # Newer Architectures (Hopper)
            "-gencode", "arch=compute_90,code=sm_90",
            
            # --- BLACKWELL SUPPORT (sm_120) ---
            "-gencode", "arch=compute_120,code=sm_120",
            "-gencode", "arch=compute_120,code=compute_120", 
        ],
        verbose=True,
        build_directory=build_directory,
    )

remove security checks by including this block (in sample_binder.py and others)

import torch
import omegaconf
import omegaconf.base
import omegaconf.listconfig
import omegaconf.dictconfig
import omegaconf.nodes
import typing
import collections
import pathlib
import numpy as np

# Security Whitelist
torch.serialization.add_safe_globals([
   # Omegaconf internals
   omegaconf.dictconfig.DictConfig,
   omegaconf.listconfig.ListConfig,
   omegaconf.base.ContainerMetadata,
   omegaconf.base.Metadata,
   omegaconf.base.Node,
   omegaconf.nodes.AnyNode,
   omegaconf.nodes.IntegerNode,
   omegaconf.nodes.FloatNode,
   omegaconf.nodes.BooleanNode,
   omegaconf.nodes.StringNode,
   omegaconf.nodes.EnumNode,
   # Typing & Built-ins
   typing.Any, typing.List, typing.Dict, typing.Union, typing.Tuple,
   dict, list, set, tuple, str, int, float, bool, type, slice, complex,
   # Collections & Paths
   collections.defaultdict,
   collections.OrderedDict,
   pathlib.PosixPath,
   pathlib.Path,
   # Numpy
   np.ndarray, np.dtype, np.float64, np.float32, np.int64, np.int32
])

build torch related software

pip install torch-scatter torch-sparse torch-cluster torch-spline-conv --no-build-isolation

run a sample binder script

python sample_binder.py \
    --input_pdb /path/to/target.pdb \
    --target_chain B \
    --binder_chain A \
    --config /path/to/configs/inference_binder.yaml \
    --specified_hotspots "B119,B141,B200" \
    --samples_min_length 75 \
    --samples_max_length 76 \
    --samples_per_target 5 \
    --model_weights /path/to/model/binder.ckpt \
    --output_dir /path/to/output/binder_test \
    --name IL7Ra

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions