import pytest
import numpy as np
from moteur_de_recherche_web.representation_vectorielle_tfidf import _tf, _idf, _tfidf, \
    index_en_representation_vectorielle_tfidf, document_en_vecteur, DocumentNAppartenantPasAuCorpusErreur, \
    similarite_cosinus

@pytest.fixture
def index():
    return {
        "doc1": ["mot1", "mot2", "mot3"],
        "doc2": ["mot2", "mot3", "mot4", "mot5"],
        "doc3": ["mot3", "mot4", "mot5"],
        "doc4": ["mot1", "mot3", "mot4","mot5"],
    }

def test_tf(index):
    tf = _tf(index, ["doc1", "doc2", "doc3", "doc4"], ["mot1", "mot2", "mot3", "mot4", "mot5"])
    assert tf.shape == (4, 5)
    assert np.array_equal(tf, np.array([[1, 1, 1, 0, 0],
                                         [0, 1, 1, 1, 1],
                                         [0, 0, 1, 1, 1],
                                         [1, 0, 1, 1, 1]]))

def test_idf(index):
    idf = _idf(index, ["mot1", "mot2", "mot3", "mot4", "mot5"])
    assert idf.shape == (5,)
    assert np.array_equal(idf, np.array([np.log(4./2),
                                        np.log(4./2),
                                        np.log(4./4),
                                        np.log(4./3),
                                        np.log(4./3)]))
    
def test_tfidf(index):
    tfidf = _tfidf(index, ["doc1", "doc2", "doc3", "doc4"], ["mot1", "mot2", "mot3", "mot4", "mot5"])
    assert tfidf.shape == (4, 5)
    assert np.array_equal(tfidf, np.array([[np.log(4./2), np.log(4./2), np.log(4./4), 0, 0],
                                          [0, np.log(4./2), np.log(4./4), np.log(4./3), np.log(4./3)],
                                          [0, 0, np.log(4./4), np.log(4./3), np.log(4./3)],
                                          [np.log(4./2), 0, np.log(4./4), np.log(4./3), np.log(4./3)]])/np.log(4./2))
    
def test_index_en_representation_vectorielle_tfidf(index):
    representation_vectorielle = index_en_representation_vectorielle_tfidf(index)
    assert set(representation_vectorielle.mots) == {"mot1", "mot2", "mot3", "mot4", "mot5"}
    assert set(representation_vectorielle.documents) == {"doc1", "doc2", "doc3", "doc4"}
    assert representation_vectorielle.tfidf.shape == (4, 5)
    assert np.array_equal(representation_vectorielle.tfidf, np.array([[np.log(4./2), np.log(4./2), np.log(4./4), 0, 0],
                                                                     [0, np.log(4./2), np.log(4./4), np.log(4./3), np.log(4./3)],
                                                                     [0, 0, np.log(4./4), np.log(4./3), np.log(4./3)],
                                                                     [np.log(4./2), 0, np.log(4./4), np.log(4./3), np.log(4./3)]])/np.log(4./2))
    
def test_document_en_vecteur(index):
    representation_vectorielle = index_en_representation_vectorielle_tfidf(index)
    vecteur = document_en_vecteur("doc1", representation_vectorielle)
    assert vecteur.shape == (5,)
    assert np.array_equal(vecteur, np.array([np.log(4./2), np.log(4./2), np.log(4./4), 0, 0])/np.log(4./2))
                          
def test_document_en_vecteur_erreur(index):
    representation_vectorielle = index_en_representation_vectorielle_tfidf(index)
    with pytest.raises(DocumentNAppartenantPasAuCorpusErreur):
        document_en_vecteur("doc5", representation_vectorielle)

def test_similarite_cosinus(index):
    representation_vectorielle = index_en_representation_vectorielle_tfidf(index)
    v_doc1 = document_en_vecteur("doc1", representation_vectorielle)
    v_doc2 = document_en_vecteur("doc2", representation_vectorielle)
    similarite = np.dot(v_doc1, v_doc2) / (np.linalg.norm(v_doc1) * np.linalg.norm(v_doc2))
    assert similarite_cosinus("doc1", "doc2", representation_vectorielle) == similarite
