10

I have a pandas dataframe with three columns and a datetime index

date        px_last  200dma     50dma           
2014-12-24  2081.88 1953.16760  2019.2726
2014-12-26  2088.77 1954.37975  2023.7982
2014-12-29  2090.57 1955.62695  2028.3544
2014-12-30  2080.35 1956.73455  2032.2262
2014-12-31  2058.90 1957.66780  2035.3240

I would like to make a time series plot of the 'px_last' column that is colored green if on the given day the 50dma is above the 200dma value and colored red if the 50dma value is below the 200dma value. I have seen this example, but can't seem to make it work for my case http://matplotlib.org/examples/pylab_examples/multicolored_line.html

1
  • Anyone have a way to extend this to 3 or more colors? Commented Jan 11, 2016 at 19:42

2 Answers 2

10

Here is an example to do it without matplotlib.collections.LineCollection. The idea is to first identify the cross-over point and then using a plot function via groupby.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# simulate data
# =============================
np.random.seed(1234)
df = pd.DataFrame({'px_last': 100 + np.random.randn(1000).cumsum()}, index=pd.date_range('2010-01-01', periods=1000, freq='B'))
df['50dma'] = pd.rolling_mean(df['px_last'], window=50)
df['200dma'] = pd.rolling_mean(df['px_last'], window=200)
df['label'] = np.where(df['50dma'] > df['200dma'], 1, -1)


# plot
# =============================
df = df.dropna(axis=0, how='any')

fig, ax = plt.subplots()

def plot_func(group):
    global ax
    color = 'r' if (group['label'] < 0).all() else 'g'
    lw = 2.0
    ax.plot(group.index, group.px_last, c=color, linewidth=lw)

df.groupby((df['label'].shift() * df['label'] < 0).cumsum()).apply(plot_func)

# add ma lines
ax.plot(df.index, df['50dma'], 'k--', label='MA-50')
ax.plot(df.index, df['200dma'], 'b--', label='MA-200')
ax.legend(loc='best')

enter image description here

Sign up to request clarification or add additional context in comments.

6 Comments

I was getting a StopIteration Error because, plot_func wasn't returning anything so I just added the line to the function return group besides that, that is perfect. Thanks!
@Mishiko it's quite strange. the code runs properly on my PC. Can you upload your sample data (via dropbox sharelink or google driver) if that's not proprietary?
I was getting the same error with your random data, but yes I will upload a link to some sample data of mine later today.
@Mishiko Make sure you have the most recent version of pandas and matplotlib. :-)
Is there anyway to extend this to 3 or more colors?
|
10

Building on @Jianxun Li's answer, here's a version that's more easily extendible to 3+ colors:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# Simulate data
np.random.seed(1234)
df = pd.DataFrame(
    {'px_last': 100 + np.random.randn(1000).cumsum()},
    index=pd.date_range('2010-01-01', periods=1000, freq='B'),
)
df['50dma'] = df['px_last'].rolling(window=50, center=False).mean()
df['200dma'] = df['px_last'].rolling(window=200, center=False).mean()

## Apply labels
df['label'] = 'out of bounds'
df.loc[abs(df['50dma'] - df['200dma']) >= 7, 'label'] = '|50dma - 200dma| >= 7'
df.loc[abs(df['50dma'] - df['200dma']) < 7, 'label'] = '|50dma - 200dma| < 7'
df.loc[abs(df['50dma'] - df['200dma']) < 5, 'label'] = '|50dma - 200dma| < 5'
df.loc[abs(df['50dma'] - df['200dma']) < 3, 'label'] = '|50dma - 200dma| < 3'
df = df[df['label'] != 'out of bounds']

## Convert labels to colors
label2color = {
    '|50dma - 200dma| < 3': 'green',
    '|50dma - 200dma| < 5': 'yellow',
    '|50dma - 200dma| < 7': 'orange',
    '|50dma - 200dma| >= 7': 'red',
}
df['color'] = df['label'].apply(lambda label: label2color[label])

# Create plot
fig, ax = plt.subplots()

def gen_repeating(s):
    """Generator: groups repeated elements in an iterable
    E.g.
        'abbccc' -> [('a', 0, 0), ('b', 1, 2), ('c', 3, 5)]
    """
    i = 0
    while i < len(s):
        j = i
        while j < len(s) and s[j] == s[i]:
            j += 1
        yield (s[i], i, j-1)
        i = j

## Add px_last lines
for color, start, end in gen_repeating(df['color']):
    if start > 0: # make sure lines connect
        start -= 1
    idx = df.index[start:end+1]
    df.loc[idx, 'px_last'].plot(ax=ax, color=color, label='')

## Add 50dma and 200dma lines
df['50dma'].plot(ax=ax, color='k', ls='--', label='MA$_{50}$')
df['200dma'].plot(ax=ax, color='b', ls='--', label='MA$_{200}$')

## Get artists and labels for legend and chose which ones to display
handles, labels = ax.get_legend_handles_labels()

## Create custom artists
g_line = plt.Line2D((0,1),(0,0), color='green')
y_line = plt.Line2D((0,1),(0,0), color='yellow')
o_line = plt.Line2D((0,1),(0,0), color='orange')
r_line = plt.Line2D((0,1),(0,0), color='red')

## Create legend from custom artist/label lists
ax.legend(
    handles + [g_line, y_line, o_line, r_line],
    labels + [
        '|MA$_{50} - $MA$_{200}| < 3$',
        '|MA$_{50} - $MA$_{200}| < 5$',
        '|MA$_{50} - $MA$_{200}| < 7$',
        '|MA$_{50} - $MA$_{200}| \geq 7$',
    ],
    loc='best',
)

# Display plot
plt.show()

I've also added a fancy-ish legend.

multicolor-line

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.