[Feature] Add producer-consumer pipeline for uploading folder (#1671) · modelscope/modelscope@25de84c
1+# Copyright (c) Alibaba, Inc. and its affiliates.
2+3+import hashlib
4+import os
5+import tempfile
6+from pathlib import Path
7+from typing import Dict, List, Optional, Set, Tuple, Union
8+9+import json
10+11+from modelscope.utils.logger import get_logger
12+13+logger = get_logger()
14+15+UPLOAD_PROGRESS_FILE = '.ms_upload_progress'
16+17+18+class UploadProgress:
19+"""Tracks committed batch indices for upload_folder resume.
20+21+ Stored as JSON at {folder_path}/.ms_upload_progress. On resume,
22+ already-committed batches are skipped. Validates repo_id to prevent
23+ cross-repo confusion.
24+ """
25+26+def __init__(self, checkpoint_path: Union[str, Path], repo_id: str):
27+"""Initialize checkpoint.
28+29+ Args:
30+ checkpoint_path: Path to the checkpoint file.
31+ repo_id: Repository ID for validation on resume.
32+ """
33+self._path = Path(checkpoint_path)
34+self._repo_id = repo_id
35+self._committed_batches: Set[int] = set()
36+self._batch_fingerprints: Dict[int, str] = {}
37+self._load()
38+39+@staticmethod
40+def compute_fingerprint(items: List[Tuple[str, str]], ) -> str:
41+"""Compute a fingerprint from (file_path_in_repo, metadata) pairs.
42+43+ Used to detect when a batch's file set changes between runs,
44+ invalidating stale batch indices. The metadata element is
45+ typically 'mtime|size' but can be any string that changes
46+ when the file content changes. Called per-batch to produce
47+ individual batch fingerprints.
48+ """
49+parts = [f'{path}|{fhash}' for path, fhash in sorted(items)]
50+return hashlib.sha256('||'.join(parts).encode()).hexdigest()
51+52+def validate_batch_fingerprint(self, batch_idx: int,
53+fingerprint: str) -> bool:
54+"""Check if a committed batch's fingerprint still matches.
55+56+ Returns True if batch is committed and fingerprint matches (safe to skip).
57+ If committed but fingerprint mismatches, clears the batch's committed status.
58+ """
59+if batch_idx not in self._committed_batches:
60+return False
61+stored_fp = self._batch_fingerprints.get(batch_idx)
62+if stored_fp is None:
63+# Legacy checkpoint or first run — trust committed status
64+self._batch_fingerprints[batch_idx] = fingerprint
65+self._save()
66+return True
67+if stored_fp == fingerprint:
68+return True
69+# Fingerprint mismatch — invalidate this batch only
70+self._committed_batches.discard(batch_idx)
71+self._batch_fingerprints.pop(batch_idx, None)
72+self._save()
73+return False
74+75+def is_batch_committed(self, batch_index: int) -> bool:
76+"""Check if a batch has already been committed."""
77+return batch_index in self._committed_batches
78+79+def mark_batch_committed(self, batch_idx: int, fingerprint: str):
80+"""Mark a batch as committed with its fingerprint and persist."""
81+self._committed_batches.add(batch_idx)
82+self._batch_fingerprints[batch_idx] = fingerprint
83+self._save()
84+85+def clear(self):
86+"""Remove checkpoint file."""
87+self._committed_batches.clear()
88+self._batch_fingerprints.clear()
89+try:
90+if self._path.exists():
91+self._path.unlink()
92+logger.info(f'Upload checkpoint cleared: {self._path}')
93+except Exception as e:
94+logger.warning(f'Failed to remove checkpoint file: {e}')
95+96+def _load(self):
97+"""Load checkpoint from disk. Invalidates if repo_id mismatches."""
98+if not self._path.exists():
99+return
100+try:
101+with open(self._path, 'r') as f:
102+data = json.load(f)
103+# Validate repo_id to prevent cross-repo confusion
104+if data.get('repo_id') != self._repo_id:
105+logger.warning(
106+f'Checkpoint repo_id mismatch '
107+f'(cached: {data.get("repo_id")}, current: {self._repo_id}), '
108+f'ignoring stale checkpoint.')
109+return
110+self._batch_fingerprints = {
111+int(k): v
112+for k, v in data.get('batch_fingerprints', {}).items()
113+ }
114+self._committed_batches = set(data.get('committed_batches', []))
115+if self._committed_batches:
116+logger.info(
117+f'Upload checkpoint loaded: {len(self._committed_batches)} '
118+f'batch(es) already committed.')
119+except Exception as e:
120+logger.warning(f'Failed to load checkpoint, starting fresh: {e}')
121+self._committed_batches = set()
122+123+def _save(self):
124+"""Atomic persist via temp file + rename."""
125+try:
126+self._path.parent.mkdir(parents=True, exist_ok=True)
127+data = {
128+'repo_id': self._repo_id,
129+'batch_fingerprints':
130+ {str(k): v
131+for k, v in self._batch_fingerprints.items()},
132+'committed_batches': sorted(self._committed_batches),
133+ }
134+fd, tmp_path = tempfile.mkstemp(
135+dir=str(self._path.parent), prefix='.ms_upload_ckpt_tmp_')
136+try:
137+with os.fdopen(fd, 'w') as f:
138+json.dump(data, f)
139+os.replace(tmp_path, str(self._path))
140+except BaseException:
141+os.unlink(tmp_path)
142+raise
143+logger.info(
144+f'Checkpoint saved: batches {sorted(self._committed_batches)} -> {self._path}'
145+ )
146+except Exception as e:
147+logger.warning(f'Failed to save checkpoint to {self._path}: {e}')