import os, sys
import logging
import numpy as np
import pandas as pd
import multiprocessing
import gmpy2
from gmpy2 import mpfr
from functools import partial
from gensim import models, matutils
from gensim.corpora import MmCorpus, Dictionary
from scipy.stats import entropy
from scipy.spatial.distance import pdist, squareform
from scipy.linalg import svdvals
# This method straight from pyLDAvis->gensim.py
def extract_data(topic_model, corpus, dictionary, doc_topic_dists=None):
if not matutils.ismatrix(corpus):
corpus_csc = matutils.corpus2csc(corpus, num_terms=len(dictionary))
else:
corpus_csc = corpus
# Need corpus to be a streaming gensim list corpus for len and inference functions below:
corpus = matutils.Sparse2Corpus(corpus_csc)
# TODO: add the hyperparam to smooth it out? no beta in online LDA impl.. hmm..
# for now, I'll just make sure we don't ever get zeros...
fnames_argsort = np.asarray(list(dictionary.token2id.values()), dtype=np.int_)
doc_lengths = corpus_csc.sum(axis=0).A.ravel()
assert doc_lengths.shape[0] == len(corpus), 'Document lengths and corpus have different sizes {} != {}'.format(doc_lengths.shape[0], len(corpus))
if hasattr(topic_model, 'lda_alpha'):
num_topics = len(topic_model.lda_alpha)
else:
num_topics = topic_model.num_topics
if doc_topic_dists is None:
# If its an HDP model.
if hasattr(topic_model, 'lda_beta'):
gamma = topic_model.inference(corpus)
else:
gamma, _ = topic_model.inference(corpus)
doc_topic_dists = gamma / gamma.sum(axis=1)[:, None]
else:
if isinstance(doc_topic_dists, list):
doc_topic_dists = matutils.corpus2dense(doc_topic_dists, num_topics).T
elif issparse(doc_topic_dists):
doc_topic_dists = doc_topic_dists.T.todense()
doc_topic_dists = doc_topic_dists / doc_topic_dists.sum(axis=1)
assert doc_topic_dists.shape[1] == num_topics, 'Document topics and number of topics do not match {} != {}'.format(doc_topic_dists.shape[1], num_topics)
# get the topic-term distribution straight from gensim without
# iterating over tuples
if hasattr(topic_model, 'lda_beta'):
topic = topic_model.lda_beta
else:
topic = topic_model.state.get_lambda()
topic = topic / topic.sum(axis=1)[:, None]
topic_term_dists = topic[:, fnames_argsort]
assert topic_term_dists.shape[0] == doc_topic_dists.shape[1]
log_likelihood = topic_model.bound(corpus, gamma=gamma)
return {'topic_term_dists': topic_term_dists, 'doc_topic_dists': doc_topic_dists, 'doc_lengths': doc_lengths, 'log_likelihood': log_likelihood}
def griffiths_2004(log_likelihood):
return
def cao_juan_2009(topic_term_dists, num_topics):
cos_pdists = squareform(pdist(topic_term_dists, metric='cosine'))
return np.sum(cos_pdists) / (num_topics*(num_topics - 1)/2)
def arun_2010(topic_term_dists, doc_topic_dists, doc_lengths, num_topics):
P = svdvals(topic_term_dists)
Q = np.matmul(doc_lengths, doc_topic_dists) / np.linalg.norm(doc_lengths)
return entropy(P, Q) # ??Kullback Leibler??
def deveaud_2014(topic_term_dists, num_topics):
jsd_pdists = squareform(pdist(topic_term_dists, metric=jensen_shannon))
return np.sum(jsd_pdists) / (num_topics*(num_topics - 1))
def jensen_shannon(P, Q):
M = 0.5 * (P + Q)
return 0.5 * (entropy(P, M) + entropy(Q, M))
def create_models(num_topics, **kwargs):
return(gensim.models.LdaMulticore(corpus=kwargs['corpus'], id2word=kwargs['dictionary'], num_topics=num_topics))
def main(topic_model, corpus, dictionary, num_topics=range(10, 50, 25)):
topic_model = models.LdaModel.load(topic_model)
dictionary = Dictionary.load(dictionary)
corpus = MmCorpus(corpus)
result = extract_data(topic_model, corpus, dictionary)
print('{}: {}'.format('Num Topics', topic_model.num_topics))
print('{}: {}'.format('Griffiths2004', griffiths_2004(result['log_likelihood'])))
print('{}: {}'.format('CaoJuan2009', cao_juan_2009(result['topic_term_dists'], topic_model.num_topics)))
print('{}: {}'.format('Arun2010', arun_2010(result['topic_term_dists'], result['doc_topic_dists'], result['doc_lengths'], topic_model.num_topics)))
print('{}: {}'.format('Deveaud2014', deveaud_2014(result['topic_term_dists'], topic_model.num_topics)))
if __name__ == '__main__':
sys.exit(main(sys.argv[1], sys.argv[2], sys.argv[3]))