[Feat]: Add general key-value storage interface (KVStorageManager) & yuanrong storage backend by dpj135 · Pull Request #96 · TransferQueue/TransferQueue
@dpj135 Thank you for adding the import! The StorageClientFactory import is now properly included on line 31, which resolves the issue. The KVStorageManager class can now be instantiated without any NameError. ✅
[transfer_queue/storage/managers/factory.py:26-40]
def register(cls, manager_type: str, manager_cls: type[TransferQueueStorageManager]):
if not issubclass(manager_cls, TransferQueueStorageManager):
raise TypeError(
f"manager_cls {getattr(manager_cls, 'name', repr(manager_cls))} must be "
f"a subclass of TransferQueueStorageManager"
)
cls._registry[manager_type] = manager_cls
End snippet
[transfer_queue/storage/managers/factory.py:35-40]
def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager:
if manager_type not in cls._registry:
raise ValueError(
f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}"
)
return cls._registrymanager_type
End snippet
[transfer_queue/storage/clients/factory.py:35-40]
def register(cls, client_type: str):
"""
Decorator to register a concrete client class with the factory.
Args:
client_type (str): The name used to identify the client
Returns:
Callable: The decorator function that returns the original class
"""
def decorator(client_class: TransferQueueStorageClient) -> TransferQueueStorageClient:
cls._registry[client_type] = client_class
return client_class
return decorator
End snippet
[transfer_queue/storage/clients/factory.py:59-71]
def create(cls, client_type: str, config: dict) -> TransferQueueStorageClient:
"""
Create and return an instance of the storage client by name.
Args:
client_type (str): The registered name of the client
Returns:
StorageClientFactory: An instance of the requested client
Raises:
ValueError: If no client is registered with the given name
"""
if client_type not in cls._registry:
raise ValueError(f"Unknown StorageClient: {client_type}")
return cls._registryclient_type
End snippet
[transfer_queue/utils/zmq_utils.py:31-65]
class ZMQRequestType(ExplicitEnum):
# HANDSHAKE
HANDSHAKE = "HANDSHAKE" # TransferQueueStorageUnit -> TransferQueueController
HANDSHAKE_ACK = "HANDSHAKE_ACK" # TransferQueueController -> TransferQueueStorageUnit
# DATA_OPERATION
GET_DATA = "GET"
PUT_DATA = "PUT"
GET_DATA_RESPONSE = "GET_DATA_RESPONSE"
PUT_DATA_RESPONSE = "PUT_DATA_RESPONSE"
CLEAR_DATA = "CLEAR_DATA"
CLEAR_DATA_RESPONSE = "CLEAR_DATA_RESPONSE"
PUT_GET_OPERATION_ERROR = "PUT_GET_OPERATION_ERROR"
PUT_GET_ERROR = "PUT_GET_ERROR"
PUT_ERROR = "PUT_ERROR"
GET_ERROR = "GET_ERROR"
CLEAR_DATA_ERROR = "CLEAR_DATA_ERROR"
# META_OPERATION
GET_META = "GET_META"
GET_META_RESPONSE = "GET_META_RESPONSE"
GET_CLEAR_META = "GET_CLEAR_META"
GET_CLEAR_META_RESPONSE = "GET_CLEAR_META_RESPONSE"
CLEAR_META = "CLEAR_META"
CLEAR_META_RESPONSE = "CLEAR_META_RESPONSE"
# CHECK_CONSUMPTION
CHECK_CONSUMPTION = "CHECK_CONSUMPTION"
CONSUMPTION_RESPONSE = "CONSUMPTION_RESPONSE"
# NOTIFY_DATA_UPDATE
NOTIFY_DATA_UPDATE = "NOTIFY_DATA_UPDATE"
NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK"
NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR"
End snippet
[transfer_queue/utils/zmq_utils.py:68-87]
class ZMQServerInfo:
def init(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, str]):
self.role = role
self.id = id
self.ip = ip
self.ports = ports
def to_addr(self, port_name: str) -> str:
return f"tcp://{self.ip}:{self.ports[port_name]}"
def to_dict(self):
return {
"role": self.role,
"id": self.id,
"ip": self.ip,
"ports": self.ports,
}
def __str__(self) -> str:
return f"ZMQSocketInfo(role={self.role}, id={self.id}, ip={self.ip}, ports={self.ports})"
End snippet
[transfer_queue/utils/zmq_utils.py:100-114]
def create(
cls,
request_type: ZMQRequestType,
sender_id: str,
body: dict[str, Any],
receiver_id: Optional[str] = None,
) -> "ZMQMessage":
return cls(
request_type=request_type,
sender_id=sender_id,
receiver_id=receiver_id,
body=body,
request_id=str(uuid4().hex[:8]),
timestamp=time.time(),
)
End snippet
[transfer_queue/utils/zmq_utils.py:116-118]
def serialize(self) -> bytes:
"""Using pickle to serialize ZMQMessage objects"""
return pickle.dumps(self)
End snippet
[transfer_queue/utils/zmq_utils.py:121-131]
def deserialize(cls, data: bytes | list[bytes]):
"""Using pickle to deserialize ZMQMessage objects"""
if isinstance(data, list):
# Process multiple byte streams by deserializing each in sequence
result = []
for d in data:
result.append(pickle.loads(d))
return result
else:
# Single byte stream case
return pickle.loads(data)
End snippet
[transfer_queue/metadata.py:136-151]
class BatchMeta:
"""Records the metadata of a batch of data samples."""
samples: list[SampleMeta]
extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
def __post_init__(self):
"""Initialize all computed properties during initialization"""
# Basic properties
object.__setattr__(self, "_size", len(self.samples))
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
# Pre-compute all list properties for better performance
if self.samples:
for idx, sample in enumerate(self.samples):
object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly
object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
# assume all samples have the same fields.
object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
else:
object.__setattr__(self, "_global_indexes", [])
object.__setattr__(self, "_field_names", [])
End snippet
[transfer_queue/metadata.py:167-169]
def global_indexes(self) -> list[int]:
"""Get all global indexes in this batch"""
return getattr(self, "_global_indexes", [])
End snippet
[transfer_queue/metadata.py:74-76]
def get_field_by_name(self, name: str) -> Optional[FieldMeta]:
"""Get FieldMeta by field name"""
return self.fields.get(name)
End snippet
[transfer_queue/storage/clients/yuanrong_client.py:41-63]
def put(self, keys: list[str], values: list[Tensor]):
"""
Store tensors to remote storage.
Args:
keys (list): List of string keys
values (list): List of torch.Tensor on NPU
"""
if not isinstance(keys, list) or not isinstance(values, list):
raise ValueError("keys and values must be lists")
if len(keys) != len(values):
raise ValueError("Number of keys must match number of values")
# TODO: Support the situation when the number of keys is greater than 10000
if len(keys) > 10000:
raise NotImplementedError('We will support the number of keys greater than 10000 int the future')
for value in values:
if not isinstance(value, torch.Tensor):
raise ValueError(f"Expected torch.Tensor, got {type(value)}")
if value.device.type != 'npu':
raise ValueError(f"Tensor is on {value.device}, not on NPU")
self._ds_client.dev_mset(keys, values)
End snippet
[transfer_queue/storage/clients/yuanrong_client.py:65-90]
def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Tensor]:
"""
Retrieve tensors from remote storage.
Args:
keys (list): List of keys to fetch
shapes (list): Expected shapes of returned tensors
dtypes (list): Expected dtypes of returned tensors
Returns:
list: List of retrieved NPU tensors
"""
if shapes is None:
raise ValueError('Yuanrong storage client needs Expected shapes of returned tensors')
if dtypes is None:
raise ValueError('Yuanrong storage client needs Expected dtypes of returned tensors')
if len(dtypes) != len(shapes):
raise ValueError('Length of dtypes must equal length of shapes')
values: list[Tensor] = self._create_empty_tensorlist(shapes=shapes, dtypes=dtypes)
# TODO: Support the situation when the number of keys is greater than 10000
if len(keys) > 10000:
raise NotImplementedError('We will support the number of keys greater than 10000 int the future')
# Timeout set to 2000ms
self._ds_client.dev_mget(keys, values, 2000)
return values
End snippet
[transfer_queue/storage/clients/yuanrong_client.py:92-98]
def clear(self, keys: list[str]):
"""
Delete entries from storage by keys.
Args:
keys (list): List of keys to delete
"""
self._ds_client.dev_delete(keys)
End snippet
[transfer_queue/storage/metadata.py:136-151]
class BatchMeta:
"""Records the metadata of a batch of data samples."""
samples: list[SampleMeta]
extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
def __post_init__(self):
"""Initialize all computed properties during initialization"""
# Basic properties
object.__setattr__(self, "_size", len(self.samples))
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
# Pre-compute all list properties for better performance
if self.samples:
for idx, sample in enumerate(self.samples):
object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly
object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
# assume all samples have the same fields.
object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
else:
object.__setattr__(self, "_global_indexes", [])
object.__setattr__(self, "_field_names", [])
End snippet
[transfer_queue/storage/managers/base.py:— end —]
Note: The above snippets are provided as-is from the referenced files to help understand interactions with the file transfer_queue/storage/managers/base.py and its related components.