Skip to content

Visualbert VQA model inference lower accuracy in validation around 40% by huggingface framework #45

@guanhdrmq

Description

@guanhdrmq
class VQADataset(torch.utils.data.Dataset):
"""VQA (v2) dataset."""
def __init__(self, questions, annotations, tokenizer, image_preprocess, frcnn, frcnn_cfg):
self.questions = questions
self.annotations = annotations
self.tokenizer = tokenizer
self.image_preprocess = image_preprocess
self.frcnn = frcnn
self.frcnn_cfg = frcnn_cfg

def __len__(self):
return len(self.annotations)

def __getitem__(self, idx):
 # answer
annotation = self.annotations[idx]
#  question
questions = self.questions[idx]
image_path = id_to_filename[annotation["image_id"]]
image_path = image_path.replace("./multimodal_data/vqa2/val2014/.", "", 1)
text = questions['question']

inputs = self.tokenizer(
     text,
     padding="max_length",
     max_length=25,
     truncation=True,
     return_token_type_ids=True,
     return_attention_mask=True,
     add_special_tokens=True,
     return_tensors="pt")


images, sizes, scales_yx = self.image_preprocess(image_path)
output_dict = self.frcnn(
                     images,
                     sizes,
                     scales_yx=scales_yx,
                     padding="max_detections",
                     max_detections=self.frcnn_cfg.max_detections,
                     return_tensors="pt")

# Very important that the boxes are normalized
feature = output_dict.get("roi_features")
normalized_boxes = output_dict.get("normalized_boxes")

inputs.update(
    {
     "visual_embeds": feature,
     "visual_attention_mask": torch.ones(feature.shape[:-1], dtype=torch.float),
     # "visual_token_type_ids": torch.ones(feature.shape[:-1], dtype=torch.long),
     "output_attentions": False
     }
)

# remove batch dimension
for k, v in inputs.items():
     if isinstance(v, torch.Tensor):
        inputs[k] = v.squeeze()

# add labels
labels = annotation['labels']
# print("label candidate:", labels)
scores = annotation["scores"]

targets = torch.zeros(len(config.id2label), dtype=torch.float)
for label, score in zip(labels, scores):
    # print(f"Setting target at index {label} to {score}")
    targets[label] = score
inputs["labels"] = targets
inputs["text"] = text

print(text)
return inputs

from visualbert.processing_image import Preprocess
from visualbert.visualizing_image import SingleImageViz
from visualbert.modeling_frcnn import GeneralizedRCNN
from visualbert.utils import Config

frcnn_cfg = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=frcnn_cfg)
image_preprocess = Preprocess(frcnn_cfg)

from transformers import VisualBertForQuestionAnswering, AutoTokenizer, BertTokenizerFast
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

model = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa",
num_labels=len(config.id2label),
id2label=config.id2label,
label2id=config.label2id,
output_hidden_states=True)

model.to(device)
model.eval()

dataset = VQADataset(questions=questions[:100],
annotations=annotations[:100],
tokenizer=tokenizer,
image_preprocess=image_preprocess,
frcnn=frcnn,
frcnn_cfg=frcnn_cfg)

test_dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
correct = 0.0
total = 0

for batch in tqdm(test_dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
logits = outputs.logits # [batch_size, 3129]
_, pre = torch.max(logits, 1)
_, target = torch.max(batch["labels"], 1)
print("prediction:", pre)
print("target:", target)
print("Predicted answer:", model.config.id2label[pre.item()])
print("Target answer:", model.config.id2label[target.item()])
correct += (pre == target).sum()
total = total + 1
print(total)

final_acc = correct / float(len(test_dataloader.dataset))
print('Accuracy of test: %f %%' % (100 * float(final_acc)))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions