import numpy as np
import pytest

from moteur_de_recherche_web.representation_vectorielle_tfidf import RepresentationVectorielle
from moteur_de_recherche_web.visiteur_modele_booleen_etendu import visiter
from moteur_de_recherche_web.analyseur import Et, Ou, Non

# @pytest.fixture
# def index():
#     return {
#         "doc1": ["mot1", "mot2", "mot3"],
#         "doc2": ["mot2", "mot3", "mot4", "mot5"],
#         "doc3": ["mot3", "mot4", "mot5"],
#         "doc4": ["mot1", "mot3", "mot4","mot5"],
#     }

@pytest.fixture
def representation_vectorielle():
    return RepresentationVectorielle(
        ["mot1", "mot2", "mot3", "mot4", "mot5"],
        ["doc1", "doc2", "doc3", "doc4"],
        np.array([[np.log(4./2), np.log(4./2), np.log(4./4), 0, 0],
                 [0, np.log(4./2), np.log(4./4), np.log(4./3), np.log(4./3)],
                 [0, 0, np.log(4./4), np.log(4./3), np.log(4./3)],
                 [np.log(4./2), 0, np.log(4./4), np.log(4./3), np.log(4./3)]])/np.log(4./2)
    )

def test_visiter_mot(representation_vectorielle):
    assert visiter("mot1", "doc1", representation_vectorielle) == np.log(4./2) / np.log(4./2)
    assert visiter("mot1", "doc2", representation_vectorielle) == 0
    assert visiter("mot1", "doc3", representation_vectorielle) == 0
    assert visiter("mot1", "doc4", representation_vectorielle) == np.log(4./2) / np.log(4./2)
    assert visiter("mot2", "doc1", representation_vectorielle) == np.log(4./2) / np.log(4./2)
    assert visiter("mot2", "doc2", representation_vectorielle) == np.log(4./2) / np.log(4./2)
    assert visiter("mot2", "doc3", representation_vectorielle) == 0
    assert visiter("mot2", "doc4", representation_vectorielle) == 0
    assert visiter("mot5", "doc1", representation_vectorielle) == 0
    assert visiter("mot5", "doc2", representation_vectorielle) == np.log(4./3) / np.log(4./2)
    assert visiter("mot5", "doc3", representation_vectorielle) == np.log(4./3) / np.log(4./2)
    assert visiter("mot5", "doc4", representation_vectorielle) == np.log(4./3) / np.log(4./2)

def test_visiter_et(representation_vectorielle):
    assert visiter(Et("mot1", "mot2"), "doc1", representation_vectorielle) == 1 - np.sqrt(((1 - np.log(4./2) / np.log(4./2))**2 \
                                                                                           + (1 - np.log(4./2) / np.log(4./2))**2) / 2)
    assert visiter(Et("mot1", "mot2"), "doc2", representation_vectorielle) == 1 - np.sqrt(((1 - 0. / np.log(4./2))**2 \
                                                                                           + (1 - np.log(4./2) / np.log(4./2))**2) / 2)
    
def test_visiter_ou(representation_vectorielle):
    assert visiter(Ou("mot1", "mot2"), "doc1", representation_vectorielle) == np.sqrt(((np.log(4./2) / np.log(4./2))**2 \
                                                                                       + (np.log(4./2) / np.log(4./2))**2) / 2)    
    assert visiter(Ou("mot1", "mot2"), "doc2", representation_vectorielle) == np.sqrt((0.0 \
                                                                                       + (np.log(4./2) / np.log(4./2))**2) / 2)  

def test_visiter_non(representation_vectorielle):
    assert visiter(Non("mot1"), "doc1", representation_vectorielle) == 1 - np.log(4./2) / np.log(4./2)
    assert visiter(Non("mot1"), "doc2", representation_vectorielle) == 1 - 0.0
    assert visiter(Non("mot1"), "doc3", representation_vectorielle) == 1 - 0.0
    assert visiter(Non("mot1"), "doc4", representation_vectorielle) == 1 - np.log(4./2) / np.log(4./2)
    assert visiter(Non("mot2"), "doc1", representation_vectorielle) == 1 - np.log(4./2) / np.log(4./2)
    assert visiter(Non("mot2"), "doc2", representation_vectorielle) == 1 - np.log(4./2) / np.log(4./2)
    assert visiter(Non("mot2"), "doc3", representation_vectorielle) == 1 - 0.0
    assert visiter(Non("mot2"), "doc4", representation_vectorielle) == 1 - 0.0
    assert visiter(Non("mot5"), "doc1", representation_vectorielle) == 1 - 0.0
    assert visiter(Non("mot5"), "doc2", representation_vectorielle) == 1 - np.log(4./3) / np.log(4./2)
    assert visiter(Non("mot5"), "doc3", representation_vectorielle) == 1 - np.log(4./3) / np.log(4./2)
    assert visiter(Non("mot5"), "doc4", representation_vectorielle) == 1 - np.log(4./3) / np.log(4./2)
