eval_command.py 1.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
from gogole import query
from gogole.query import vectorial_query
from gogole.parser import CACMParser
from gogole.parser import QRelsParser

def run(collection, args):
    
    # Runs the CACM Parser on the queries file with the same structure
    cacm_parser = CACMParser("data/query.text")
    nrequests = int(args.nrequests[0])

    qrels_parser = QRelsParser()
    relevant_docs_by_query = qrels_parser.parse_all(nrequests)

    for weight_type in vectorial_query.WEIGHTING_TYPES:
        precision_sum = 0
        recall_sum = 0
        nb_queries = 0

        # Here a document is a query wrapped in a CACMDocument
        for document in cacm_parser.find_documents(limit=nrequests):
            q = document.abstract

            query_cls = query.QUERY_MAP[query.QUERY_TYPE_VECTORIAL]
            query_browser = query_cls(collection, weight_type)

            all_results , t = query_browser.timed_search(q)
            n_results = [res for idx, res in enumerate(query_browser.find_n_first_elements(all_results, n=10), start=1)]
            
            # If there is nothing for this query id, drop it
            if document.document_id not in relevant_docs_by_query:
                continue
            
            relevant_docs = relevant_docs_by_query[document.document_id]

            intersection_docs = [res for res in n_results if res in relevant_docs]

            if len(n_results) != 0:
                precision = len(intersection_docs) / len(n_results)
            else:
                precision = 0

            recall = len(intersection_docs) / len(relevant_docs)
            
            precision_sum += precision
            recall_sum += recall
            nb_queries += 1

        precision = precision_sum / nb_queries
        recall = recall_sum / nb_queries
        print("for weight {weight}: precision: {precision}, rappel: {recall}".format(weight=weight_type, precision=precision, recall=recall))