|
|
|
|
|
""" |
|
|
Sync Strawberry Picker Models to HuggingFace Repository |
|
|
|
|
|
This script automates the process of syncing trained models from strawberryPicker |
|
|
to the HuggingFace strawberryPicker repository. |
|
|
|
|
|
Usage: |
|
|
python sync_to_huggingface.py [--dry-run] |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import shutil |
|
|
import argparse |
|
|
import subprocess |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import hashlib |
|
|
|
|
|
def calculate_file_hash(file_path): |
|
|
"""Calculate SHA256 hash of a file.""" |
|
|
sha256_hash = hashlib.sha256() |
|
|
with open(file_path, "rb") as f: |
|
|
for byte_block in iter(lambda: f.read(4096), b""): |
|
|
sha256_hash.update(byte_block) |
|
|
return sha256_hash.hexdigest() |
|
|
|
|
|
def find_updated_models(repo_path): |
|
|
"""Find models that have been updated.""" |
|
|
updated_models = [] |
|
|
|
|
|
|
|
|
detection_pt = Path(repo_path) / "detection" / "best.pt" |
|
|
if detection_pt.exists(): |
|
|
detection_hash = calculate_file_hash(detection_pt) |
|
|
updated_models.append({ |
|
|
'component': 'detection', |
|
|
'path': detection_pt, |
|
|
'hash': detection_hash |
|
|
}) |
|
|
|
|
|
|
|
|
classification_pth = Path(repo_path) / "classification" / "best_enhanced_classifier.pth" |
|
|
if classification_pth.exists(): |
|
|
classification_hash = calculate_file_hash(classification_pth) |
|
|
updated_models.append({ |
|
|
'component': 'classification', |
|
|
'path': classification_pth, |
|
|
'hash': classification_hash |
|
|
}) |
|
|
|
|
|
return updated_models |
|
|
|
|
|
def export_detection_to_onnx(model_path, output_dir): |
|
|
"""Export detection model to ONNX format.""" |
|
|
try: |
|
|
cmd = [ |
|
|
"yolo", "export", |
|
|
f"model={model_path}", |
|
|
f"dir={output_dir}", |
|
|
"format=onnx", |
|
|
"opset=12" |
|
|
] |
|
|
|
|
|
print(f"Exporting detection model to ONNX...") |
|
|
result = subprocess.run(cmd, capture_output=True, text=True) |
|
|
|
|
|
if result.returncode == 0: |
|
|
print("Successfully exported detection model to ONNX") |
|
|
return True |
|
|
else: |
|
|
print(f"Export failed: {result.stderr}") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during ONNX export: {e}") |
|
|
return False |
|
|
|
|
|
def update_model_metadata(repo_path, models_info): |
|
|
"""Update metadata files with sync information.""" |
|
|
metadata = { |
|
|
"last_sync": datetime.now().isoformat(), |
|
|
"models": {} |
|
|
} |
|
|
|
|
|
for model in models_info: |
|
|
metadata["models"][model['component']] = { |
|
|
"hash": model['hash'], |
|
|
"path": str(model['path']), |
|
|
"last_updated": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
metadata_file = Path(repo_path) / "sync_metadata.json" |
|
|
with open(metadata_file, 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
print(f"Updated metadata file: {metadata_file}") |
|
|
return True |
|
|
|
|
|
def sync_repository(repo_path, dry_run=False): |
|
|
"""Main sync function for strawberryPicker repository.""" |
|
|
print(f"Syncing strawberryPicker repository at {repo_path}") |
|
|
|
|
|
if dry_run: |
|
|
print("DRY RUN MODE - No changes will be made") |
|
|
|
|
|
|
|
|
updated_models = find_updated_models(repo_path) |
|
|
|
|
|
if not updated_models: |
|
|
print("No models found to sync") |
|
|
return True |
|
|
|
|
|
print(f"Found {len(updated_models)} model components:") |
|
|
for model in updated_models: |
|
|
print(f" - {model['component']} (hash: {model['hash'][:8]}...)") |
|
|
|
|
|
if dry_run: |
|
|
return True |
|
|
|
|
|
|
|
|
detection_model = next((m for m in updated_models if m['component'] == 'detection'), None) |
|
|
if detection_model: |
|
|
detection_dir = Path(repo_path) / "detection" |
|
|
export_detection_to_onnx(detection_model['path'], detection_dir) |
|
|
|
|
|
|
|
|
update_model_metadata(repo_path, updated_models) |
|
|
|
|
|
print("\nSync completed successfully!") |
|
|
print("Remember to:") |
|
|
print("1. Review and commit changes to git") |
|
|
print("2. Push to HuggingFace: git push origin main") |
|
|
print("3. Update READMEs with any new performance metrics") |
|
|
|
|
|
return True |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Sync strawberryPicker models to HuggingFace") |
|
|
parser.add_argument("--repo-path", |
|
|
default="/home/user/machine-learning/HuggingfaceModels/strawberryPicker", |
|
|
help="Path to strawberryPicker repository") |
|
|
parser.add_argument("--dry-run", action="store_true", |
|
|
help="Show what would be done without making changes") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not os.path.exists(args.repo_path): |
|
|
print(f"Error: Repository path {args.repo_path} does not exist") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
success = sync_repository(args.repo_path, args.dry_run) |
|
|
|
|
|
if not success: |
|
|
sys.exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |