-
Notifications
You must be signed in to change notification settings - Fork 77
Migrate 2 r2 #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Migrate 2 r2 #26
Changes from all commits
542b261
773a26c
f4f0ea6
822ec06
c9c3b09
16f71d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| repos: | ||
| - repo: https://github.com/pre-commit/pre-commit-hooks | ||
| rev: v4.6.0 | ||
| hooks: | ||
| - id: check-yaml | ||
| - id: end-of-file-fixer | ||
| - id: trailing-whitespace | ||
| - repo: https://github.com/astral-sh/ruff-pre-commit | ||
| rev: v0.6.4 | ||
| hooks: | ||
| - id: ruff | ||
| args: [--fix] | ||
| - id: ruff-format | ||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,15 @@ | ||
| import json | ||
| import os | ||
| import time | ||
|
|
||
| import requests | ||
| import yaml | ||
| from loguru import logger | ||
| from huggingface_hub import HfApi | ||
|
|
||
| from demo import LoraTrainingArguments, train_lora | ||
| from utils.constants import model2base_model, model2size | ||
| from utils.flock_api import get_task, submit_task | ||
| from utils.flock_api import get_task, submit_task, get_address | ||
| from utils.gpu_utils import get_gpu_type | ||
|
|
||
| HF_USERNAME = os.environ["HF_USERNAME"] | ||
| from utils.cloudflare_utils import CloudStorage | ||
|
|
||
| if __name__ == "__main__": | ||
| task_id = os.environ["TASK_ID"] | ||
|
|
@@ -59,33 +56,26 @@ | |
| gpu_type = get_gpu_type() | ||
|
|
||
| try: | ||
| logger.info("Start to push the lora weight to the hub...") | ||
| api = HfApi(token=os.environ["HF_TOKEN"]) | ||
| repo_name = f"{HF_USERNAME}/task-{task_id}-{model_id.replace('/', '-')}" | ||
| # check whether the repo exists | ||
| try: | ||
| api.create_repo( | ||
| repo_name, | ||
| exist_ok=False, | ||
| repo_type="model", | ||
| ) | ||
| except Exception as e: | ||
| logger.info( | ||
| f"Repo {repo_name} already exists. Will commit the new version." | ||
| ) | ||
| logger.info("Start to push the lora weight to the cloudflare R2...") | ||
|
|
||
| commit_message = api.upload_folder( | ||
| folder_path="outputs", | ||
| repo_id=repo_name, | ||
| repo_type="model", | ||
| upload_data = get_address(task_id) | ||
| cf_storage = CloudStorage( | ||
| access_key=upload_data["data"]["access_key"], | ||
| secret_key=upload_data["data"]["secret_key"], | ||
| endpoint_url=upload_data["data"]["endpoint_url"], | ||
| session_token=upload_data["data"]["session_token"], | ||
| bucket=upload_data["data"]["bucket"], | ||
| ) | ||
| cf_storage.initialize() | ||
| cf_storage.upload_folder( | ||
| local_folder="outputs", cloud_path=upload_data["data"]["folder_name"] | ||
| ) | ||
| # get commit hash | ||
| commit_hash = commit_message.oid | ||
| logger.info(f"Commit hash: {commit_hash}") | ||
| logger.info(f"Repo name: {repo_name}") | ||
| # submit | ||
| submit_task( | ||
| task_id, repo_name, model2base_model[model_id], gpu_type, commit_hash | ||
| task_id, | ||
| model2base_model[model_id], | ||
| gpu_type, | ||
| upload_data["data"]["bucket"], | ||
| upload_data["data"]["folder_name"], | ||
|
Comment on lines
+61
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Enhance error handling for upload process While basic error handling exists, it could be more specific to help with debugging and recovery. upload_data = get_address(task_id)
+ if not upload_data or "data" not in upload_data:
+ raise ValueError("Invalid upload data received from API")
+
cf_storage = CloudStorage(
access_key=upload_data["data"]["access_key"],
secret_key=upload_data["data"]["secret_key"],
endpoint_url=upload_data["data"]["endpoint_url"],
session_token=upload_data["data"]["session_token"],
bucket=upload_data["data"]["bucket"],
)
cf_storage.initialize()
+ except ValueError as e:
+ logger.error(f"API Error: {e}")
+ continue
+ except ConnectionError as e:
+ logger.error(f"Network Error during upload: {e}")
+ continue
except Exception as e:
- logger.error(f"Error: {e}")
+ logger.exception("Unexpected error during upload process")
logger.info("Proceed to the next model...")
|
||
| ) | ||
| logger.info("Task submitted successfully") | ||
| except Exception as e: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -5,3 +5,4 @@ loguru | |||||
| trl>=0.9.3,<=0.9.6 | ||||||
| bitsandbytes | ||||||
| pyyaml | ||||||
| boto3 | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add version constraints for boto3. Other critical dependencies have version constraints to ensure compatibility. Since boto3 is essential for the R2 storage functionality, it should also have version constraints to prevent potential breaking changes in future updates. Apply this diff to add version constraints: -boto3
+boto3>=1.26.0,<2.0.0📝 Committable suggestion
Suggested change
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| import os | ||
| import boto3 | ||
| import botocore | ||
| import threading | ||
| from tqdm import tqdm | ||
| from loguru import logger | ||
|
|
||
|
|
||
| class ProgressPercentage: | ||
| def __init__(self, filename): | ||
| self.filename = filename | ||
| self.size = float(os.path.getsize(filename)) | ||
| self._seen_so_far = 0 | ||
| self._lock = threading.Lock() | ||
|
|
||
| def __call__(self, bytes_amount): | ||
| with self._lock: | ||
| self._seen_so_far += bytes_amount | ||
| percentage = (self._seen_so_far / self.size) * 100 | ||
| logger.info( | ||
| "\r%s %s / %s (%.2f%%)" | ||
| % (self.filename, self._seen_so_far, self.size, percentage), | ||
| end="", | ||
| ) | ||
|
|
||
|
|
||
| class CloudStorage: | ||
| def __init__( | ||
| self, | ||
| access_key=None, | ||
| secret_key=None, | ||
| endpoint_url=None, | ||
| bucket=None, | ||
| session_token=None, | ||
| ): | ||
| self.access_key = access_key | ||
| self.secret_key = secret_key | ||
| self.endpoint_url = endpoint_url | ||
| self.bucket = bucket | ||
| self.client = None | ||
| self.session_token = session_token | ||
|
|
||
| def initialize(self): | ||
| if ( | ||
| self.access_key is None | ||
| or self.secret_key is None | ||
| or self.endpoint_url is None | ||
| ): | ||
| logger.error( | ||
| "Please provide access_key, secret_key, session_token and endpoint_url" | ||
| ) | ||
| raise | ||
| self.client = boto3.client( | ||
| "s3", | ||
| endpoint_url=self.endpoint_url, | ||
| aws_access_key_id=self.access_key, | ||
| aws_secret_access_key=self.secret_key, | ||
| aws_session_token=self.session_token, | ||
| ) | ||
| return self | ||
|
|
||
| def upload_folder(self, local_folder, cloud_path, bucket=None): | ||
| if bucket is None and self.bucket is None: | ||
| logger.error("Please provide bucket name") | ||
| return | ||
| if bucket is None: | ||
| bucket = self.bucket | ||
| stream = tqdm(os.walk(local_folder)) | ||
| for root, dirs, files in stream: | ||
| for file in files: | ||
| localFilePath = os.path.join(root, file) | ||
| relativePath = os.path.relpath(localFilePath, local_folder) | ||
| cloudPath = os.path.join(cloud_path, relativePath) | ||
| cloudPath = cloudPath.replace("\\", "/") | ||
| try: | ||
| self.client.upload_file( | ||
| localFilePath, | ||
| bucket, | ||
| cloudPath, | ||
| Callback=ProgressPercentage(localFilePath), | ||
| ) | ||
| except botocore.exceptions.ClientError as e: | ||
| logger.error(e) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Update pre-commit hooks to latest stable versions and consider additional security checks
The current configuration is using outdated versions. Here's what needs to be updated:
pre-commit-hooksfrom v4.6.0 to v5.0.0rufffrom v0.6.4 to v0.8.4Additionally, consider adding these security-focused hooks from pre-commit-hooks:
check-added-large-filescheck-case-conflictcheck-merge-conflictdetect-private-keyHere's the suggested updated configuration:
🔗 Analysis chain
Consider additional hooks and verify versions.
The current configuration provides basic linting and formatting. Consider adding these useful hooks:
check-added-large-files: Prevents large files from being committedcheck-case-conflict: Checks for files with names that would conflict on case-insensitive filesystemscheck-merge-conflict: Checks for files containing merge conflict stringsdetect-private-key: Detects private keysLet's verify the latest stable versions of the hooks:
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
Length of output: 147