From 74cd3eb8e718b95d218fa69067d5600f02a0ceb7 Mon Sep 17 00:00:00 2001 From: Yuri Jean Fabris Date: Wed, 15 Mar 2023 16:48:26 -0300 Subject: [PATCH 1/3] feat: add claim bash argument and query abstract candidates on elastic search index --- .gitignore | 1 + multivers/data.py | 45 +++++++++++++++++++++++++++++++----------- multivers/predict.py | 12 +++++++---- script/predict.sh | 47 +++++--------------------------------------- 4 files changed, 48 insertions(+), 57 deletions(-) diff --git a/.gitignore b/.gitignore index b7af894..5b40e83 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__/ **/checkpoints*/* **/prediction/* **/scratch/* +.env diff --git a/multivers/data.py b/multivers/data.py index 6be9402..301fa64 100644 --- a/multivers/data.py +++ b/multivers/data.py @@ -7,9 +7,21 @@ from transformers import AutoTokenizer, BatchEncoding import torch import numpy as np +from elasticsearch import Elasticsearch +import scispacy +import spacy +import os import util +es_host = os.getenv("ES_HOST") +es_username = os.getenv("ES_USERNAME") +es_pswd = os.getenv("ES_PSWD") + +print(es_host, es_username, es_pswd) + +es = Elasticsearch(hosts=[es_host], basic_auth=[es_username, es_pswd], verify_certs=False) +nlp = spacy.load("en_core_sci_sm") def get_tokenizer(): "Need to add a few special tokens to the default longformer checkpoint." @@ -177,8 +189,7 @@ class MultiVerSReader: Class to handle SciFact with retrieved documents. """ def __init__(self, predict_args): - self.data_file = predict_args.input_file - self.corpus_file = predict_args.corpus_file + self.claim = predict_args.claim # Basically, I used two different sets of labels. This was dumb, but # doing this mapping fixes it. self.label_map = {"SUPPORT": "SUPPORTS", @@ -189,19 +200,31 @@ def get_data(self, tokenizer): Get the data for the relevant fold. """ res = [] - - corpus = util.load_jsonl(self.corpus_file) - corpus_dict = {x["doc_id"]: x for x in corpus} - claims = util.load_jsonl(self.data_file) + candidates = es.search(index='fractalflows', body={ + 'min_score': 20, + 'size': 10000, + 'query': { + "match": { + "abstract": { + "query": self.claim + } + } + } + }) + claims = [{"id": 1, "claim": self.claim}] + + print(candidates["hits"]["total"]) for claim in claims: - for doc_id in claim["doc_ids"]: - candidate_doc = corpus_dict[doc_id] + for hit in candidates["hits"]["hits"]: + candidate_doc = hit["_source"] + doc = nlp(candidate_doc["abstract"]) + abstract_sents = [sent.text for sent in doc.sents] to_tensorize = {"claim": claim["claim"], - "sentences": candidate_doc["abstract"], - "title": candidate_doc["title"]} + "sentences": abstract_sents, + "title": candidate_doc["title"][1]} entry = {"claim_id": claim["id"], - "abstract_id": candidate_doc["doc_id"], + "abstract_id": int(candidate_doc["pmc"][3:]), "to_tensorize": to_tensorize} res.append(entry) diff --git a/multivers/predict.py b/multivers/predict.py index 6c0746d..516d349 100644 --- a/multivers/predict.py +++ b/multivers/predict.py @@ -1,6 +1,7 @@ from tqdm import tqdm import argparse from pathlib import Path +import os from model import MultiVerSModel from data import get_dataloader @@ -10,9 +11,8 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint_path", type=str) - parser.add_argument("--input_file", type=str) - parser.add_argument("--corpus_file", type=str) parser.add_argument("--output_file", type=str) + parser.add_argument("--claim", type=str) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--device", default=0, type=int) parser.add_argument("--num_workers", default=4, type=int) @@ -37,7 +37,9 @@ def get_predictions(args): model.label_threshold = 0.0 # Since we're not running the training loop, gotta put model on GPU. - model.to(f"cuda:{args.device}") + if os.getenv("GPU") == "true": + model.to(f"cuda:{args.device}") + model.eval() model.freeze() @@ -63,7 +65,8 @@ def get_predictions(args): def format_predictions(args, predictions_all): # Need to get the claim ID's from the original file, since the data loader # won't have a record of claims for which no documents were retireved. - claims = util.load_jsonl(args.input_file) + # claims = util.load_jsonl(args.input_file) + claims = [{"id": 1, "claim": args.claim}] claim_ids = [x["id"] for x in claims] assert len(claim_ids) == len(set(claim_ids)) @@ -95,6 +98,7 @@ def format_predictions(args, predictions_all): def main(): args = get_args() + print(args.claim) outname = Path(args.output_file) predictions = get_predictions(args) diff --git a/script/predict.sh b/script/predict.sh index b870f7b..acc285c 100644 --- a/script/predict.sh +++ b/script/predict.sh @@ -1,47 +1,10 @@ # Make model predictions. -function predict_scifact() { - python multivers/predict.py \ - --checkpoint_path=checkpoints/scifact.ckpt \ - --input_file=data/scifact/claims_test_retrieved.jsonl \ - --corpus_file=data/scifact/corpus.jsonl \ - --output_file=prediction/scifact.jsonl -} - -function predict_healthver() { - python multivers/predict.py \ - --checkpoint_path=checkpoints/healthver.ckpt \ - --input_file=data/healthver/claims_test.jsonl \ - --corpus_file=data/healthver/corpus.jsonl \ - --output_file=prediction/healthver.jsonl -} - -function predict_covidfact() { - # NOTE: For covidfact, many of the claims are paper titles (or closely - # related). To avoid data leakage for this dataset, we evaluate using a - # version of the corpus with titles removed. - python multivers/predict.py \ - --checkpoint_path=checkpoints/covidfact.ckpt \ - --input_file=data/covidfact/claims_test.jsonl \ - --corpus_file=data/covidfact/corpus_without_titles.jsonl \ - --output_file=prediction/covidfact.jsonl -} - -######################################## - -model=$1 +claim=$1 mkdir -p prediction -if [[ $model == "scifact" ]] -then - predict_scifact -elif [[ $model == "covidfact" ]] -then - predict_covidfact -elif [[ $model == "healthver" ]] -then - predict_healthver -else - echo "Allowed options are: {scifact,covidfact,healthver}." -fi +python multivers/predict.py \ + --checkpoint_path=checkpoints/scifact.ckpt \ + --output_file=prediction/scifact.jsonl \ + --claim="$claim" From 6b1570ceca0dc6c27fcd2de18b329b8307665917 Mon Sep 17 00:00:00 2001 From: Yuri Jean Fabris Date: Wed, 15 Mar 2023 17:35:59 -0300 Subject: [PATCH 2/3] chore: freeze elasticsearch into requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 749007d..855ba54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ tqdm==4.49 transformers==4.2.2 gdown==4.5.4 openai==0.26.4 +elasticsearch==8.6.2 From df6d223bfba3e2adce0a514e1a4a35ee80dc74d0 Mon Sep 17 00:00:00 2001 From: Yuri Jean Fabris Date: Wed, 15 Mar 2023 17:58:12 -0300 Subject: [PATCH 3/3] chore: add ES_VERIFY_CERTS env var --- multivers/data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/multivers/data.py b/multivers/data.py index 301fa64..56d01f5 100644 --- a/multivers/data.py +++ b/multivers/data.py @@ -17,10 +17,11 @@ es_host = os.getenv("ES_HOST") es_username = os.getenv("ES_USERNAME") es_pswd = os.getenv("ES_PSWD") +es_verify_certs = False if os.getenv("ES_VERIFY_CERTS") == "false" else True -print(es_host, es_username, es_pswd) +print(es_host, es_username, es_pswd, es_verify_certs) -es = Elasticsearch(hosts=[es_host], basic_auth=[es_username, es_pswd], verify_certs=False) +es = Elasticsearch(hosts=[es_host], basic_auth=[es_username, es_pswd], verify_certs=es_verify_certs) nlp = spacy.load("en_core_sci_sm") def get_tokenizer():