GitHub - umzi2/MoSR
def detect(state): # Get values from state n_block = get_seq_len(state, "gblocks") - 6 in_ch = state["gblocks.0.weight"].shape[1] dim = state["gblocks.0.weight"].shape[0] # Calculate expansion ratio and convolution ratio expansion_ratio = (state["gblocks.1.fc1.weight"].shape[0] / state["gblocks.1.fc1.weight"].shape[1]) / 2 conv_ratio = state["gblocks.1.conv.weight"].shape[0] / dim kernel_size = state["gblocks.1.conv.weight"].shape[2] # Determine upsampler type and calculate upscale if "upsampler.init_pos" in state: upsampler = "dys" out_ch = state["upsampler.end_conv.weight"].shape[0] upscale = math.isqrt(state["upsampler.offset.weight"].shape[0] // 8) elif "upsampler.in_to_k.weight" in state: upsampler = 'gps' out_ch = in_ch upscale = math.isqrt(state['upsampler.in_to_k.weight'].shape[0] // 8 // out_ch) else: upsampler = "ps" out_ch = in_ch upscale = math.isqrt(state["upsampler.0.weight"].shape[0] // out_ch) # Print results print(f""" in_ch: {in_ch} out_ch: {out_ch} dim: {dim} n_block: {n_block} upsampler: {upsampler} upscale: {upscale} kernel_size: {kernel_size} expansion_ratio: {expansion_ratio} conv_ratio: {conv_ratio}""") signature = [ 'gblocks.0.weight', 'gblocks.0.bias', 'gblocks.1.norm.weight', 'gblocks.1.norm.bias', 'gblocks.1.fc1.weight', 'gblocks.1.fc1.bias', 'gblocks.1.conv.weight', 'gblocks.1.conv.bias', 'gblocks.1.fc2.weight', 'gblocks.1.fc2.bias', ]