SERAC training still don't work
Hi, the old issue is closed so I will submit a new one.
After use new counterfact.py, the training process of SERAC still get low result. I conduct it on Cheng98/llama-160m and anton-l/gpt-j-tiny-random for llama2 and gpt-j experiments. In training log, both get low result in acc_val(for llama-160, it is 0.2778, for gpt-j, it is 0.0).
With the mentality of giving it a try, I use trained model to conduct model editing, both get low result, similar to the old issue.
Here is my training code:
#!/usr/bin/env python
# coding: utf-8
from easyeditor import EditTrainer, SERACTrainingHparams, ZsreDataset, CounterFactDataset
from argparse import ArgumentParser
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--hparams', type=str, required=True)
parser.add_argument('--train_dataset', type=str, required=True)
parser.add_argument('--eval_dataset', type=str, required=True)
args = parser.parse_args()
training_hparams = SERACTrainingHparams.from_hparams(args.hparams)
if 'zsre' in args.train_dataset:
train_ds = ZsreDataset(args.train_dataset, config=training_hparams)
eval_ds = ZsreDataset(args.eval_dataset, config=training_hparams)
else:
train_ds = CounterFactDataset(args.train_dataset, config=training_hparams)
eval_ds = CounterFactDataset(args.eval_dataset, config=training_hparams)
trainer = EditTrainer(
config=training_hparams,
train_set=train_ds,
val_set=eval_ds
)
trainer.run()
The train data is counterfact-train.json, and here is my editing code:
from easyeditor import BaseEditor, KnowEditDataset
from easyeditor import SERACHparams
import argparse
import os
import json
def save_serac_model(SERAC, save_dir):
SERAC.replacement.save_pretrained(os.path.join(save_dir, "replace_model"))
SERAC.classifier.save_pretrained(os.path.join(save_dir, "classifier_model"))
json.dump(SERAC.cache_inputs, open(os.path.join(save_dir, "cache_inputs.json"), 'w', encoding='utf8'))
json.dump(SERAC.cache_labels, open(os.path.join(save_dir, "cache_labels.json"), 'w', encoding='utf8'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--editing_method', required=True, type=str)
parser.add_argument('--hparams_dir', required=True, type=str)
parser.add_argument('--data_dir', required=True, type=str)
parser.add_argument('--ds_size', default=None, type=int)
parser.add_argument('--metrics_save_dir', default='./output', type=str)
parser.add_argument('--model_save_dir', default='./serac_model', type=str)
parser.add_argument('--datatype', default=None, type=str)
args = parser.parse_args()
datas = KnowEditDataset(args.data_dir, size=args.ds_size)
if args.datatype == 'counterfact' or args.datatype == 'recent' or args.datatype == 'zsre':
prompts = [data['prompt'] for data in datas]
subjects = [data['subject'] for data in datas]
target_new = [data['target_new'] for data in datas]
portability_r = [data['portability_r'] for data in datas]
portability_s = [data['portability_s'] for data in datas]
portability_l = [data['portability_l'] for data in datas]
ground_truth = [data['ground_truth'] for data in datas]
rephrase_prompts = [data['rephrase_prompt'] for data in datas] # new para
portability_reasoning_prompts = []
portability_reasoning_ans = []
portability_Logical_Generalization_prompts = []
portability_Logical_Generalization_ans = []
portability_Subject_Aliasing_prompts = []
portability_Subject_Aliasing_ans = []
portability_data = [portability_r, portability_s, portability_l]
portability_prompts = [portability_reasoning_prompts, portability_Subject_Aliasing_prompts,
portability_Logical_Generalization_prompts]
portability_answers = [portability_reasoning_ans, portability_Subject_Aliasing_ans,
portability_Logical_Generalization_ans]
for data, portable_prompts, portable_answers in zip(portability_data, portability_prompts, portability_answers):
for item in data:
if item is None:
portable_prompts.append(None)
portable_answers.append(None)
else:
temp_prompts = []
temp_answers = []
for pr in item:
prompt = pr["prompt"]
an = pr["ground_truth"]
while isinstance(an, list):
an = an[0]
if an.strip() == "":
continue
temp_prompts.append(prompt)
temp_answers.append(an)
portable_prompts.append(temp_prompts)
portable_answers.append(temp_answers)
assert len(prompts) == len(portability_reasoning_prompts) == len(
portability_Logical_Generalization_prompts) == len(portability_Subject_Aliasing_prompts)
locality_rs = [data['locality_rs'] for data in datas]
locality_f = [data['locality_f'] for data in datas]
locality_Relation_Specificity_prompts = []
locality_Relation_Specificity_ans = []
locality_Forgetfulness_prompts = []
locality_Forgetfulness_ans = []
locality_data = [locality_rs, locality_f]
locality_prompts = [locality_Relation_Specificity_prompts, locality_Forgetfulness_prompts]
locality_answers = [locality_Relation_Specificity_ans, locality_Forgetfulness_ans]
for data, local_prompts, local_answers in zip(locality_data, locality_prompts, locality_answers):
for item in data:
if item is None:
local_prompts.append(None)
local_answers.append(None)
else:
temp_prompts = []
temp_answers = []
for pr in item:
prompt = pr["prompt"]
an = pr["ground_truth"]
while isinstance(an, list):
an = an[0]
if an.strip() == "":
continue
temp_prompts.append(prompt)
temp_answers.append(an)
local_prompts.append(temp_prompts)
local_answers.append(temp_answers)
assert len(prompts) == len(locality_Relation_Specificity_prompts) == len(locality_Forgetfulness_prompts)
locality_inputs = {}
portability_inputs = {}
locality_inputs = {
'Relation_Specificity': {
'prompt': locality_Relation_Specificity_prompts,
'ground_truth': locality_Relation_Specificity_ans
},
'Forgetfulness': {
'prompt': locality_Forgetfulness_prompts,
'ground_truth': locality_Forgetfulness_ans
}
}
portability_inputs = {
'Subject_Aliasing': {
'prompt': portability_Subject_Aliasing_prompts,
'ground_truth': portability_Subject_Aliasing_ans
},
'reasoning': {
'prompt': portability_reasoning_prompts,
'ground_truth': portability_reasoning_ans
},
'Logical_Generalization': {
'prompt': portability_Logical_Generalization_prompts,
'ground_truth': portability_Logical_Generalization_ans
}
}
else:
raise NotImplementedError
hparams = SERACHparams.from_hparams(args.hparams_dir)
editor = BaseEditor.from_hparams(hparams)
metrics, edited_model, _= editor.edit(
prompts=prompts,
target_new=target_new,
ground_truth=ground_truth,
rephrase_prompts=rephrase_prompts, # new para
subject=subjects,
locality_inputs=locality_inputs,
portability_inputs=portability_inputs,
copy=True,
return_orig_weights=True,
keep_original_weight=True,
)
if not os.path.exists(args.metrics_save_dir):
os.makedirs(args.metrics_save_dir)
json.dump(metrics, open(os.path.join(args.metrics_save_dir,
f'{args.editing_method}_{args.datatype}_{hparams.model_name.split("/")[-1]}_results.json'),
'w'), indent=4)
I have only successfully conduct experiment on zsre dataset with llama2, which replacement model is provide in this link. So maybe it is still the problem of training?