term frequency over time: how to plot +200 graphs in one plot with Python/pandas/matplotlib?

692 views Asked by At

I am conducting a textual content analysis of several web blogs, and now focusing on finding emerging trends. In order to do so for one blog, I coded a multi-step process:

  1. looping over all the posts, finding the top 5 keywords in each post
  2. adding them to a list, if they are not already in the list
  3. calculate the term frequency for all the terms in the list for every single post
  4. create a list of dictionaries, where for each post I save the date of the post, and the tf for every single word
  5. create a data frame from this list of dictionaries, and plot it

This works all fine, except that I get a 1000 graphs in one plot, while I only care for those who peak over a certain threshold. Meaning that they should have a different set of colours or be easily recognizable in some other way, and they should appear in the legend - the rest not. Any ideas how to do that?

Here is the code I use now, which produces an unreadable plot:

from pattern.db import Database
import pandas as pd
import matplotlib.pyplot as plt



def plot_trends(keywords_and_date_list):
d1 = pd.DataFrame(keywords_and_date_list)
d1.sort(inplace=True)

grouped = d1.groupby(pd.Grouper(freq='1M', key="date")).mean()

plt.style.use("ggplot")

fig = plt.figure(figsize=(25,6))

for i in d1.columns:
    if i == 'date':
        continue
    plt.plot(grouped.index, grouped[i], lw=2, label="monthly average " + i)        

plt.ylim(0,0.015)
plt.legend(prop={'size':7})
plt.title("Occurence of various words in blogs")
plt.xlabel("Post publication date")
plt.ylabel("Term Frequency")
plt.show()

Do any of you have any ideas of how to feasibly differentiate between the graphs that peak over, let's say 0.004, and assign them a different colour set and labels?

I played with a small data set in order to achieve this with panda's max function, but I don't get it to work.

import pandas as pd
import numpy as np
from pattern.db import date
import matplotlib.pyplot as plt

l = [dict(date=date('2015-01-02'), one=0.1, two=0.2)]
l.append(dict(date=date('2014-01-01'), one=0.2, two=0.5))
l.append(dict(date=date('2014-02-01'), one=0.5, two=0.6))
l.append(dict(date=date('2014-03-01'), one=0.1, two=0.7))

d1 = pd.DataFrame(l)

d2 = d1.set_index('date')

plt.style.use("ggplot")

fig = plt.figure(figsize=(10,6))

for i in d1.columns:
    if d1.max() >= 0.6:
        plt.plot(d1.index, d1[i], lw=2, label="monthly average " + i)
else:
    plt.plot(d1.index, d1[i], lw=2)

plt.ylim(0,1)
plt.legend(prop={'size':10})
plt.title("Occurence of various words in Naoki's blog")
plt.xlabel("Post publication date")
plt.ylabel("Term Frequency")
plt.show()

What I would like to see as a result is one graph with a label, and one without a label. I played with different syntaxes, but either I get two labelled graphs, or a value error, or an error that datetype and float is not comparable.

1

There are 1 answers

2
Andreus On BEST ANSWER

Your small data set script is largely correct, but with some minor errors.

  • You are missing the if i=='date': continue line. (The source of your 'not comparable' error).
  • In your post, your else line is mis-indented.
  • Possibly (only possibly) you need a call to plt.hold(True) to prevent the creation of new graphs.

Here is my modified version of your script. I swapped db.pattern.date for pd.to_datetime, added more random columns, and de-emphasized the low-valued data lines using alpha and linewidth

import pandas as pd
import numpy as np
# from pattern.db import date
import matplotlib.pyplot as plt

l = [dict(date=pd.to_datetime('2015-01-02'), one=0.1, two=0.2)]
l.append(dict(date=pd.to_datetime('2014-01-01'), one=0.2, two=0.5))
l.append(dict(date=pd.to_datetime('2014-02-01'), one=0.5, two=0.6))
l.append(dict(date=pd.to_datetime('2014-03-01'), one=0.1, two=0.7))


d1 = pd.DataFrame(l)
for i in range(100):
    d1['s%i'%i] = np.random.uniform(0,.5,4)
for i in range(10):
    d1['s%i'%i] = np.random.uniform(0,.7,4)

d2 = d1.set_index('date')

plt.style.use("ggplot")

fig = plt.figure(figsize=(10,6))

for i in d1.columns:
    if i=='date':
        continue
    if d1[i].max() >= 0.6:
        plt.plot(d1.index, d1[i], lw=2, label="monthly average " + i)
        plt.hold(True)
    else:
        plt.plot(d1.index, d1[i], lw=2, alpha=0.7, linewidth=0.3)
        plt.hold(True)

plt.ylim(0,1)
plt.legend(prop={'size':10})
plt.title("Occurence of various words in Naoki's blog")
plt.xlabel("Post publication date")
plt.ylabel("Term Frequency")
plt.show()