[Feature] Add producer-consumer pipeline for uploading folder (#1671) · modelscope/modelscope@25de84c

1+

# Copyright (c) Alibaba, Inc. and its affiliates.

2+3+

import hashlib

4+

import os

5+

import tempfile

6+

from pathlib import Path

7+

from typing import Dict, List, Optional, Set, Tuple, Union

8+9+

import json

10+11+

from modelscope.utils.logger import get_logger

12+13+

logger = get_logger()

14+15+

UPLOAD_PROGRESS_FILE = '.ms_upload_progress'

16+17+18+

class UploadProgress:

19+

"""Tracks committed batch indices for upload_folder resume.

20+21+

Stored as JSON at {folder_path}/.ms_upload_progress. On resume,

22+

already-committed batches are skipped. Validates repo_id to prevent

23+

cross-repo confusion.

24+

"""

25+26+

def __init__(self, checkpoint_path: Union[str, Path], repo_id: str):

27+

"""Initialize checkpoint.

28+29+

Args:

30+

checkpoint_path: Path to the checkpoint file.

31+

repo_id: Repository ID for validation on resume.

32+

"""

33+

self._path = Path(checkpoint_path)

34+

self._repo_id = repo_id

35+

self._committed_batches: Set[int] = set()

36+

self._batch_fingerprints: Dict[int, str] = {}

37+

self._load()

38+39+

@staticmethod

40+

def compute_fingerprint(items: List[Tuple[str, str]], ) -> str:

41+

"""Compute a fingerprint from (file_path_in_repo, metadata) pairs.

42+43+

Used to detect when a batch's file set changes between runs,

44+

invalidating stale batch indices. The metadata element is

45+

typically 'mtime|size' but can be any string that changes

46+

when the file content changes. Called per-batch to produce

47+

individual batch fingerprints.

48+

"""

49+

parts = [f'{path}|{fhash}' for path, fhash in sorted(items)]

50+

return hashlib.sha256('||'.join(parts).encode()).hexdigest()

51+52+

def validate_batch_fingerprint(self, batch_idx: int,

53+

fingerprint: str) -> bool:

54+

"""Check if a committed batch's fingerprint still matches.

55+56+

Returns True if batch is committed and fingerprint matches (safe to skip).

57+

If committed but fingerprint mismatches, clears the batch's committed status.

58+

"""

59+

if batch_idx not in self._committed_batches:

60+

return False

61+

stored_fp = self._batch_fingerprints.get(batch_idx)

62+

if stored_fp is None:

63+

# Legacy checkpoint or first run — trust committed status

64+

self._batch_fingerprints[batch_idx] = fingerprint

65+

self._save()

66+

return True

67+

if stored_fp == fingerprint:

68+

return True

69+

# Fingerprint mismatch — invalidate this batch only

70+

self._committed_batches.discard(batch_idx)

71+

self._batch_fingerprints.pop(batch_idx, None)

72+

self._save()

73+

return False

74+75+

def is_batch_committed(self, batch_index: int) -> bool:

76+

"""Check if a batch has already been committed."""

77+

return batch_index in self._committed_batches

78+79+

def mark_batch_committed(self, batch_idx: int, fingerprint: str):

80+

"""Mark a batch as committed with its fingerprint and persist."""

81+

self._committed_batches.add(batch_idx)

82+

self._batch_fingerprints[batch_idx] = fingerprint

83+

self._save()

84+85+

def clear(self):

86+

"""Remove checkpoint file."""

87+

self._committed_batches.clear()

88+

self._batch_fingerprints.clear()

89+

try:

90+

if self._path.exists():

91+

self._path.unlink()

92+

logger.info(f'Upload checkpoint cleared: {self._path}')

93+

except Exception as e:

94+

logger.warning(f'Failed to remove checkpoint file: {e}')

95+96+

def _load(self):

97+

"""Load checkpoint from disk. Invalidates if repo_id mismatches."""

98+

if not self._path.exists():

99+

return

100+

try:

101+

with open(self._path, 'r') as f:

102+

data = json.load(f)

103+

# Validate repo_id to prevent cross-repo confusion

104+

if data.get('repo_id') != self._repo_id:

105+

logger.warning(

106+

f'Checkpoint repo_id mismatch '

107+

f'(cached: {data.get("repo_id")}, current: {self._repo_id}), '

108+

f'ignoring stale checkpoint.')

109+

return

110+

self._batch_fingerprints = {

111+

int(k): v

112+

for k, v in data.get('batch_fingerprints', {}).items()

113+

}

114+

self._committed_batches = set(data.get('committed_batches', []))

115+

if self._committed_batches:

116+

logger.info(

117+

f'Upload checkpoint loaded: {len(self._committed_batches)} '

118+

f'batch(es) already committed.')

119+

except Exception as e:

120+

logger.warning(f'Failed to load checkpoint, starting fresh: {e}')

121+

self._committed_batches = set()

122+123+

def _save(self):

124+

"""Atomic persist via temp file + rename."""

125+

try:

126+

self._path.parent.mkdir(parents=True, exist_ok=True)

127+

data = {

128+

'repo_id': self._repo_id,

129+

'batch_fingerprints':

130+

{str(k): v

131+

for k, v in self._batch_fingerprints.items()},

132+

'committed_batches': sorted(self._committed_batches),

133+

}

134+

fd, tmp_path = tempfile.mkstemp(

135+

dir=str(self._path.parent), prefix='.ms_upload_ckpt_tmp_')

136+

try:

137+

with os.fdopen(fd, 'w') as f:

138+

json.dump(data, f)

139+

os.replace(tmp_path, str(self._path))

140+

except BaseException:

141+

os.unlink(tmp_path)

142+

raise

143+

logger.info(

144+

f'Checkpoint saved: batches {sorted(self._committed_batches)} -> {self._path}'

145+

)

146+

except Exception as e:

147+

logger.warning(f'Failed to save checkpoint to {self._path}: {e}')