Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
ARG PYTORCH_CUDA_VERSION=2.0.1-cuda11.7-cudnn8
FROM pytorch/pytorch:${PYTORCH_CUDA_VERSION}-runtime
ARG PYTORCH_CUDA_VERSION=2.9.0-cuda12.8-cudnn9
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed to bump CUDA version or else flash-attn would have a compile error. decided to bump PyTorch version to the latest compatible version too

# Need devel version for flash-attn to compile
FROM pytorch/pytorch:${PYTORCH_CUDA_VERSION}-devel

ENV DEBIAN_FRONTEND=noninteractive

Expand All @@ -17,7 +18,7 @@ RUN apt-get update -q \
&& rm -rf /var/lib/apt/lists/*

# Install uv
COPY --from=ghcr.io/astral-sh/uv:0.8.17-python3.10-alpine /usr/local/bin/uv /usr/local/bin/uv
COPY --from=ghcr.io/astral-sh/uv:0.9.22-python3.10-alpine /usr/local/bin/uv /usr/local/bin/uv

# Don't run as root, instead run as non-root user that can still use
# passwordless sudo for admin privileges
Expand All @@ -32,24 +33,30 @@ RUN addgroup --gid ${GROUPID} ${USERNAME} \
&& chown -R ${USERNAME}:${USERNAME} "/workspace"
USER ${USERNAME}

# uv will install into VIRTUAL_ENV. We add VIRTUAL_ENV to the path so that users
# of image can use `python` rather than `uv run python`.
ENV VIRTUAL_ENV=/workspace/SafeTuneBed/.venv
ENV PATH="$VIRTUAL_ENV/bin:/home/dev/.local/bin:$PATH"

# Copy package installation files
COPY --chown=${USERNAME}:${USERNAME} pyproject.toml uv.lock /workspace/SafeTuneBed/

# Create venv and install dependencies
# Run Pyright so its Node.js package gets installed
RUN cd /workspace/SafeTuneBed \
&& mkdir -p src/safetunebed \
&& touch src/safetunebed/__init__.py \
&& uv sync --extra dev --frozen \
&& rm -rf "${HOME}/.cache" \
&& basedpyright . \
&& cd .. \
&& rm -rf SafeTuneBed
# Install dependencies into a venv outside the repo directory. The repo will be
# cloned at runtime (see k8s/devbox.yaml), but the venv with deps is pre-baked.
# UV_PROJECT_ENVIRONMENT tells uv where to create the venv instead of .venv.
#
# This env var affects all uv projects in the container, so installing a
# different uv project will clobber the venv. If that becomes a problem, then
# make UV_PROJECT_ENVIRONMENT local to the uv commands in the Dockerfile and
# then create a symlink from /home/dev/.venv into /workspace/SafeTuneBed/.venv
# after cloning in the k8s/ YAMLs.
ENV UV_PROJECT_ENVIRONMENT=/home/dev/.venv
ENV PATH="/home/dev/.venv/bin:/home/dev/.local/bin:$PATH"

# Create venv and install dependencies (but not the project itself).
# flash_attn is memory intensive to compile, MAX_JOBS decreases the memory usage
# though compilation will be slower. If building locally you may need to go into
# "Settings > Resources" in Docker Desktop and bump the memory limit.
# (https://github.com/Dao-AILab/flash-attention#installation-and-features)
# Run Pyright so its Node.js package gets installed.
RUN --mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
MAX_JOBS=1 uv sync --all-groups --locked --no-install-project \
&& uv run basedpyright --version

WORKDIR /workspace

# Default command to run -- may be changed at runtime
CMD ["/bin/bash"]
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION ?= 0.0.1
VERSION ?= 0.0.2
FULL_IMAGE_NAME = ghcr.io/tomtseng/tamperbench:$(VERSION)

.PHONY: docker-build devbox cpu large
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"bitsandbytes>=0.46.1; platform_system == 'Linux'",
"datasets>=3.3.0",
"evaluate>=0.4.0",
"flash-attn>=2.8.3; platform_system == 'Linux'",
"fsspec>=2023.12.2",
"harmtune @ git+https://github.com/AlignmentResearch/harmtune.git",
"immutabledict>=4.0.0",
Expand Down Expand Up @@ -58,6 +59,11 @@ name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

# Needed for installing flash-attn:
# https://github.com/astral-sh/uv/issues/6437#issuecomment-3167274955
[tool.uv.extra-build-dependencies]
flash-attn = [{ requirement = "torch", match-runtime = true }]

[project.optional-dependencies]
dev = [
"basedpyright>=1.28.1",
Expand Down
Loading
Loading