Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
Comment on lines +1 to +13
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Update pre-commit hooks to latest stable versions and consider additional security checks

The current configuration is using outdated versions. Here's what needs to be updated:

  • Update pre-commit-hooks from v4.6.0 to v5.0.0
  • Update ruff from v0.6.4 to v0.8.4

Additionally, consider adding these security-focused hooks from pre-commit-hooks:

  • check-added-large-files
  • check-case-conflict
  • check-merge-conflict
  • detect-private-key

Here's the suggested updated configuration:

repos:
  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v5.0.0
    hooks:
      - id: check-yaml
      - id: end-of-file-fixer
      - id: trailing-whitespace
      - id: check-added-large-files
      - id: check-case-conflict
      - id: check-merge-conflict
      - id: detect-private-key
  - repo: https://github.com/astral-sh/ruff-pre-commit
    rev: v0.8.4
    hooks:
      - id: ruff
        args: [--fix]
      - id: ruff-format
🔗 Analysis chain

Consider additional hooks and verify versions.

The current configuration provides basic linting and formatting. Consider adding these useful hooks:

  • check-added-large-files: Prevents large files from being committed
  • check-case-conflict: Checks for files with names that would conflict on case-insensitive filesystems
  • check-merge-conflict: Checks for files containing merge conflict strings
  • detect-private-key: Detects private keys

Let's verify the latest stable versions of the hooks:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check latest versions of pre-commit hooks and ruff

# Check latest pre-commit-hooks version
gh api repos/pre-commit/pre-commit-hooks/releases/latest --jq .tag_name

# Check latest ruff version
gh api repos/astral-sh/ruff/releases/latest --jq .tag_name

Length of output: 147

2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ RUN pip3 install --no-cache-dir -r requirements.txt
# Copy the rest of the application code
COPY . .

CMD ["python", "full_automation.py"]
CMD ["python", "full_automation.py"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@ curl --location 'https://fed-ledger-prod.flock.io/api/v1/tasks/submit-result' \
- `hg_repo_id`: The Hugging Face repository where the model is stored, typically in the format `username/repository-name`.
- `base_model`: The base model used for training. A list of all supported models can be found [here](https://github.com/FLock-io/llm-loss-validator/blob/main/src/core/constant.py).
- `gpu_type`: The type of GPU used for training the model.
- `revision`: The commit hash from the Hugging Face repository. This uniquely identifies the version of the model that was trained and submitted, allowing for precise tracking of changes and updates.
- `revision`: The commit hash from the Hugging Face repository. This uniquely identifies the version of the model that was trained and submitted, allowing for precise tracking of changes and updates.
1 change: 0 additions & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from trl import SFTTrainer, SFTConfig

from dataset import SFTDataCollator, SFTDataset
from merge import merge_lora_to_base_model
from utils.constants import model2template


Expand Down
2 changes: 1 addition & 1 deletion demo_data.jsonl

Large diffs are not rendered by default.

48 changes: 19 additions & 29 deletions full_automation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import json
import os
import time

import requests
import yaml
from loguru import logger
from huggingface_hub import HfApi

from demo import LoraTrainingArguments, train_lora
from utils.constants import model2base_model, model2size
from utils.flock_api import get_task, submit_task
from utils.flock_api import get_task, submit_task, get_address
from utils.gpu_utils import get_gpu_type

HF_USERNAME = os.environ["HF_USERNAME"]
from utils.cloudflare_utils import CloudStorage

if __name__ == "__main__":
task_id = os.environ["TASK_ID"]
Expand Down Expand Up @@ -59,33 +56,26 @@
gpu_type = get_gpu_type()

try:
logger.info("Start to push the lora weight to the hub...")
api = HfApi(token=os.environ["HF_TOKEN"])
repo_name = f"{HF_USERNAME}/task-{task_id}-{model_id.replace('/', '-')}"
# check whether the repo exists
try:
api.create_repo(
repo_name,
exist_ok=False,
repo_type="model",
)
except Exception as e:
logger.info(
f"Repo {repo_name} already exists. Will commit the new version."
)
logger.info("Start to push the lora weight to the cloudflare R2...")

commit_message = api.upload_folder(
folder_path="outputs",
repo_id=repo_name,
repo_type="model",
upload_data = get_address(task_id)
cf_storage = CloudStorage(
access_key=upload_data["data"]["access_key"],
secret_key=upload_data["data"]["secret_key"],
endpoint_url=upload_data["data"]["endpoint_url"],
session_token=upload_data["data"]["session_token"],
bucket=upload_data["data"]["bucket"],
)
cf_storage.initialize()
cf_storage.upload_folder(
local_folder="outputs", cloud_path=upload_data["data"]["folder_name"]
)
# get commit hash
commit_hash = commit_message.oid
logger.info(f"Commit hash: {commit_hash}")
logger.info(f"Repo name: {repo_name}")
# submit
submit_task(
task_id, repo_name, model2base_model[model_id], gpu_type, commit_hash
task_id,
model2base_model[model_id],
gpu_type,
upload_data["data"]["bucket"],
upload_data["data"]["folder_name"],
Comment on lines +61 to +78
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance error handling for upload process

While basic error handling exists, it could be more specific to help with debugging and recovery.

             upload_data = get_address(task_id)
+            if not upload_data or "data" not in upload_data:
+                raise ValueError("Invalid upload data received from API")
+
             cf_storage = CloudStorage(
                 access_key=upload_data["data"]["access_key"],
                 secret_key=upload_data["data"]["secret_key"],
                 endpoint_url=upload_data["data"]["endpoint_url"],
                 session_token=upload_data["data"]["session_token"],
                 bucket=upload_data["data"]["bucket"],
             )
             cf_storage.initialize()
+        except ValueError as e:
+            logger.error(f"API Error: {e}")
+            continue
+        except ConnectionError as e:
+            logger.error(f"Network Error during upload: {e}")
+            continue
         except Exception as e:
-            logger.error(f"Error: {e}")
+            logger.exception("Unexpected error during upload process")
             logger.info("Proceed to the next model...")

Committable suggestion skipped: line range outside the PR's diff.

)
logger.info("Task submitted successfully")
except Exception as e:
Expand Down
1 change: 0 additions & 1 deletion merge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ loguru
trl>=0.9.3,<=0.9.6
bitsandbytes
pyyaml
boto3
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add version constraints for boto3.

Other critical dependencies have version constraints to ensure compatibility. Since boto3 is essential for the R2 storage functionality, it should also have version constraints to prevent potential breaking changes in future updates.

Apply this diff to add version constraints:

-boto3
+boto3>=1.26.0,<2.0.0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
boto3
boto3>=1.26.0,<2.0.0

83 changes: 83 additions & 0 deletions utils/cloudflare_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import boto3
import botocore
import threading
from tqdm import tqdm
from loguru import logger


class ProgressPercentage:
def __init__(self, filename):
self.filename = filename
self.size = float(os.path.getsize(filename))
self._seen_so_far = 0
self._lock = threading.Lock()

def __call__(self, bytes_amount):
with self._lock:
self._seen_so_far += bytes_amount
percentage = (self._seen_so_far / self.size) * 100
logger.info(
"\r%s %s / %s (%.2f%%)"
% (self.filename, self._seen_so_far, self.size, percentage),
end="",
)


class CloudStorage:
def __init__(
self,
access_key=None,
secret_key=None,
endpoint_url=None,
bucket=None,
session_token=None,
):
self.access_key = access_key
self.secret_key = secret_key
self.endpoint_url = endpoint_url
self.bucket = bucket
self.client = None
self.session_token = session_token

def initialize(self):
if (
self.access_key is None
or self.secret_key is None
or self.endpoint_url is None
):
logger.error(
"Please provide access_key, secret_key, session_token and endpoint_url"
)
raise
self.client = boto3.client(
"s3",
endpoint_url=self.endpoint_url,
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key,
aws_session_token=self.session_token,
)
return self

def upload_folder(self, local_folder, cloud_path, bucket=None):
if bucket is None and self.bucket is None:
logger.error("Please provide bucket name")
return
if bucket is None:
bucket = self.bucket
stream = tqdm(os.walk(local_folder))
for root, dirs, files in stream:
for file in files:
localFilePath = os.path.join(root, file)
relativePath = os.path.relpath(localFilePath, local_folder)
cloudPath = os.path.join(cloud_path, relativePath)
cloudPath = cloudPath.replace("\\", "/")
try:
self.client.upload_file(
localFilePath,
bucket,
cloudPath,
Callback=ProgressPercentage(localFilePath),
)
except botocore.exceptions.ClientError as e:
logger.error(e)
27 changes: 24 additions & 3 deletions utils/flock_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ def get_task(task_id: int):


def submit_task(
task_id: int, hg_repo_id: str, base_model: str, gpu_type: str, revision: str
task_id: int, base_model: str, gpu_type: str, bucket: str, folder_name: str
):
payload = json.dumps(
{
"task_id": task_id,
"data": {
"hg_repo_id": hg_repo_id,
"base_model": base_model,
"gpu_type": gpu_type,
"revision": revision,
"bucket": bucket,
"folder_name": folder_name,
},
}
)
Expand All @@ -41,3 +41,24 @@ def submit_task(
if response.status_code != 200:
raise Exception(f"Failed to submit task: {response.text}")
return response.json()


def get_address(task_id: int):
payload = json.dumps(
{
"task_id": task_id,
}
)
headers = {
"flock-api-key": FLOCK_API_KEY,
"Content-Type": "application/json",
}
response = requests.request(
"POST",
f"{FED_LEDGER_BASE_URL}/tasks/get_storage_credentials",
headers=headers,
data=payload,
)
if response.status_code != 200:
raise Exception(f"Failed to submit task: {response.text}")
return response.json()