SERAC training still don't work

Hi, the old issue is closed so I will submit a new one.
After use new counterfact.py, the training process of SERAC still get low result. I conduct it on Cheng98/llama-160m and anton-l/gpt-j-tiny-random for llama2 and gpt-j experiments. In training log, both get low result in acc_val(for llama-160, it is 0.2778, for gpt-j, it is 0.0).
With the mentality of giving it a try, I use trained model to conduct model editing, both get low result, similar to the old issue.
Here is my training code:

#!/usr/bin/env python
# coding: utf-8
from easyeditor import EditTrainer, SERACTrainingHparams, ZsreDataset, CounterFactDataset
from argparse import ArgumentParser

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--hparams', type=str, required=True)
    parser.add_argument('--train_dataset', type=str, required=True)
    parser.add_argument('--eval_dataset', type=str, required=True)
    args = parser.parse_args()

    training_hparams = SERACTrainingHparams.from_hparams(args.hparams)
    if 'zsre' in args.train_dataset:
        train_ds = ZsreDataset(args.train_dataset, config=training_hparams)
        eval_ds = ZsreDataset(args.eval_dataset, config=training_hparams)
    else:
        train_ds = CounterFactDataset(args.train_dataset, config=training_hparams)
        eval_ds = CounterFactDataset(args.eval_dataset, config=training_hparams)
    trainer = EditTrainer(
        config=training_hparams,
        train_set=train_ds,
        val_set=eval_ds
    )
    trainer.run()

The train data is counterfact-train.json, and here is my editing code:

from easyeditor import BaseEditor, KnowEditDataset
from easyeditor import SERACHparams
import argparse
import os
import json


def save_serac_model(SERAC, save_dir):
    SERAC.replacement.save_pretrained(os.path.join(save_dir, "replace_model"))
    SERAC.classifier.save_pretrained(os.path.join(save_dir, "classifier_model"))
    json.dump(SERAC.cache_inputs, open(os.path.join(save_dir, "cache_inputs.json"), 'w', encoding='utf8'))
    json.dump(SERAC.cache_labels, open(os.path.join(save_dir, "cache_labels.json"), 'w', encoding='utf8'))



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--editing_method', required=True, type=str)
    parser.add_argument('--hparams_dir', required=True, type=str)
    parser.add_argument('--data_dir', required=True, type=str)
    parser.add_argument('--ds_size', default=None, type=int)
    parser.add_argument('--metrics_save_dir', default='./output', type=str)
    parser.add_argument('--model_save_dir', default='./serac_model', type=str)
    parser.add_argument('--datatype', default=None, type=str)

    args = parser.parse_args()

    datas = KnowEditDataset(args.data_dir, size=args.ds_size)
    if args.datatype == 'counterfact' or args.datatype == 'recent' or args.datatype == 'zsre':
        prompts = [data['prompt'] for data in datas]
        subjects = [data['subject'] for data in datas]
        target_new = [data['target_new'] for data in datas]

        portability_r = [data['portability_r'] for data in datas]
        portability_s = [data['portability_s'] for data in datas]
        portability_l = [data['portability_l'] for data in datas]

        ground_truth = [data['ground_truth'] for data in datas]
        rephrase_prompts = [data['rephrase_prompt'] for data in datas]  # new para

        portability_reasoning_prompts = []
        portability_reasoning_ans = []
        portability_Logical_Generalization_prompts = []
        portability_Logical_Generalization_ans = []
        portability_Subject_Aliasing_prompts = []
        portability_Subject_Aliasing_ans = []

        portability_data = [portability_r, portability_s, portability_l]
        portability_prompts = [portability_reasoning_prompts, portability_Subject_Aliasing_prompts,
                               portability_Logical_Generalization_prompts]
        portability_answers = [portability_reasoning_ans, portability_Subject_Aliasing_ans,
                               portability_Logical_Generalization_ans]
        for data, portable_prompts, portable_answers in zip(portability_data, portability_prompts, portability_answers):
            for item in data:
                if item is None:
                    portable_prompts.append(None)
                    portable_answers.append(None)
                else:
                    temp_prompts = []
                    temp_answers = []
                    for pr in item:
                        prompt = pr["prompt"]
                        an = pr["ground_truth"]
                        while isinstance(an, list):
                            an = an[0]
                        if an.strip() == "":
                            continue
                        temp_prompts.append(prompt)
                        temp_answers.append(an)
                    portable_prompts.append(temp_prompts)
                    portable_answers.append(temp_answers)
        assert len(prompts) == len(portability_reasoning_prompts) == len(
            portability_Logical_Generalization_prompts) == len(portability_Subject_Aliasing_prompts)

        locality_rs = [data['locality_rs'] for data in datas]
        locality_f = [data['locality_f'] for data in datas]
        locality_Relation_Specificity_prompts = []
        locality_Relation_Specificity_ans = []
        locality_Forgetfulness_prompts = []
        locality_Forgetfulness_ans = []

        locality_data = [locality_rs, locality_f]
        locality_prompts = [locality_Relation_Specificity_prompts, locality_Forgetfulness_prompts]
        locality_answers = [locality_Relation_Specificity_ans, locality_Forgetfulness_ans]
        for data, local_prompts, local_answers in zip(locality_data, locality_prompts, locality_answers):
            for item in data:
                if item is None:
                    local_prompts.append(None)
                    local_answers.append(None)
                else:
                    temp_prompts = []
                    temp_answers = []
                    for pr in item:
                        prompt = pr["prompt"]
                        an = pr["ground_truth"]
                        while isinstance(an, list):
                            an = an[0]
                        if an.strip() == "":
                            continue
                        temp_prompts.append(prompt)
                        temp_answers.append(an)
                    local_prompts.append(temp_prompts)
                    local_answers.append(temp_answers)
        assert len(prompts) == len(locality_Relation_Specificity_prompts) == len(locality_Forgetfulness_prompts)
        locality_inputs = {}
        portability_inputs = {}

        locality_inputs = {
            'Relation_Specificity': {
                'prompt': locality_Relation_Specificity_prompts,
                'ground_truth': locality_Relation_Specificity_ans
            },
            'Forgetfulness': {
                'prompt': locality_Forgetfulness_prompts,
                'ground_truth': locality_Forgetfulness_ans
            }
        }
        portability_inputs = {
            'Subject_Aliasing': {
                'prompt': portability_Subject_Aliasing_prompts,
                'ground_truth': portability_Subject_Aliasing_ans
            },
            'reasoning': {
                'prompt': portability_reasoning_prompts,
                'ground_truth': portability_reasoning_ans
            },
            'Logical_Generalization': {
                'prompt': portability_Logical_Generalization_prompts,
                'ground_truth': portability_Logical_Generalization_ans
            }
        }
    else:
        raise NotImplementedError

    hparams = SERACHparams.from_hparams(args.hparams_dir)
    editor = BaseEditor.from_hparams(hparams)

    metrics, edited_model, _= editor.edit(
        prompts=prompts,
        target_new=target_new,
        ground_truth=ground_truth,
        rephrase_prompts=rephrase_prompts,  # new para
        subject=subjects,
        locality_inputs=locality_inputs,
        portability_inputs=portability_inputs,
        copy=True,
        return_orig_weights=True,
        keep_original_weight=True,
    )
    if not os.path.exists(args.metrics_save_dir):
        os.makedirs(args.metrics_save_dir)
    json.dump(metrics, open(os.path.join(args.metrics_save_dir,
                                         f'{args.editing_method}_{args.datatype}_{hparams.model_name.split("/")[-1]}_results.json'),
                            'w'), indent=4)

I have only successfully conduct experiment on zsre dataset with llama2, which replacement model is provide in this link. So maybe it is still the problem of training?