matplotlib でローソク足チャートを描く

python の matplotlib を使い、ローソク足チャートを描くコードを書いた。

概要

下図のようなローソク足チャートを描画する関数を書いた。

関数の周辺

関数を定義

ローソク足チャートを描く関数を dfplot と名付けた。データフレーム df とチャートのタイトルに使う文字列 codeBrand を渡して動かす仕様にした。

def dfplot(df,codeBrand):
    #

関数に渡す df データ

関数に渡す df は株価データ。年月日 date、始値 o、高値 h、安値 l、終値 c、出来高 v の情報をまとめている。最初の25日分は移動平均線の描画に使うため、ローソク足としては表示しない。

>>>print(df)
          date        o        h        l        c             v
0    2010-11-17   9693.0   9817.0   9693.0   9812.0  1.541120e+09
1    2010-11-18   9820.0  10013.0   9798.0  10014.0  2.440270e+09
2    2010-11-19  10124.0  10130.0  10019.0  10022.0  2.147210e+09
3    2010-11-22  10133.0  10157.0  10091.0  10115.0  1.745160e+09
4    2010-11-24   9942.0  10064.0   9904.0  10030.0  2.121290e+09
..          ...      ...      ...      ...      ...           ...
320  2012-03-09   9911.0  10007.0   9853.0   9930.0  3.479760e+09
321  2012-03-12  10015.0  10021.0   9889.0   9890.0  2.260850e+09
322  2012-03-13   9921.0  10011.0   9888.0   9899.0  2.756410e+09
323  2012-03-14  10064.0  10115.0  10050.0  10051.0  2.341250e+09
324  2012-03-15  10115.0  10158.0  10077.0  10123.0  2.413540e+09

df の型は下記。(MySQLから出力したらこうなっていたから、こうした。)

>>>print(df.dtypes)
date     object
o       float64
h       float64
l       float64
c       float64
v       float64

インポート

関数を書く前に、いくつかインポートした。

#import sys
import numpy as np
import pandas as pd
#import mysql.connector
import matplotlib.pyplot as plt
from matplotlib.dates import date2num
from datetime import datetime,timedelta
import matplotlib.dates as mdates

# 日本語フォント用                                                                                                                                              
from matplotlib import rcParams
rcParams['font.sans-serif'] = ['Noto Sans CJK JP']

関数の内容と機能

関数に書いている内容と、機能を順に示す。

ベースの描画

チャートのベースを書く。

subplot2grid でキャンパスを分割し 、2つのチャートを作る。ax1 をローソク足チャートに、ax2 を出来高の棒グラフにする。ただし最初、ax1 は散布図で描く。

date の object 型は扱いにくいので、date2num 関数で数字にしておく。

def dfplot(df,codeBrand):
    #ベースとなるチャート作成                                                                                                                                   
    fig = plt.figure(figsize=(10,7))
    ax1 = plt.subplot2grid((4,1), (0,0),rowspan=3)
    ax2 = plt.subplot2grid((4,1), (3,0))
    ax1.set_title(codeBrand)
    ax1.scatter(date2num(df['date']), df['c'], color='k',s=1)
    vbar = ax2.bar(date2num(df['date']), df['v'])
    ax1.grid()
    ax2.grid()

表示範囲の微調整

高値、安値の水準を考慮して、データの y 軸の表示範囲を指定する。

最初の25日のデータは、移動平均線の作成のみに使うので、最後は表示しないのだが、表示範囲を考える際も最初の25日は使わない。

    #ax1 範囲指定(微調整含む)                                                                                                                                 
    ax1yDatamin = df['l'][25:].min()
    ax1yDatamax = df['h'][25:].max()
    ax1Datarange = ax1yDatamax - ax1yDatamin
    ax1ymin = ax1yDatamin - ax1Datarange/20.
    ax1ymax = ax1yDatamax + ax1Datarange/20.
    ax1range = ax1ymax - ax1ymin
    ax1.set_ylim(ax1ymin,ax1ymax)

y 軸の表示範囲が広がった。範囲の調整をしないと、後でローソク足が窮屈になってしまう。

移動平均線の描画

移動平均線を計算する関数 sma を書き、5日移動平均線と、25日移動平均線を計算して、ax1 にプロットする。

#移動平均                                                                                                                                                       
def sma(ohlcv, period):
  sma = ohlcv["c"].rolling(period).mean()
  vstack = np.vstack((date2num(df["date"].values), sma.values.T)).T  # x軸データを整数に                                                                        
  return vstack
   #ax1 移動平均線を描く                                                                                                                                     
    sma5 = sma(df, 5)
    sma25 = sma(df, 25)
    ax1.plot(sma5[:, 0], sma5[:, 1], color='g',linestyle='dashed')
    ax1.plot(sma25[:, 0], sma25[:, 1], color='b',linestyle='dashdot')

ローソク足の描画

ローソクは axvline の線で表現する。始値と終値を比較し、線の色を変更する。陽線は赤( col = ‘r’ )、陰線は黒( col = ‘k’ )にした。

また、データの日数 n の大小で、線の太さを調整した。

    #ax1 ローソク足チャート                                                                                                                                     
    #調整用にdfに載っている日数を取得                                                                                                                           
    n = len(df)

    if n < 80:
        lwAdd = 1 #linewidthに足して補正                                                                                                                        
    else:
        lwAdd = 0
    for i in range(len(df)):
        o = df['o'][i]
        h = df['h'][i]
        l = df['l'][i]
        c = df['c'][i]
        if c < o:
            col = 'k'
            ax1.axvline(x=date2num(df['date'])[i],ymin=(c-ax1ymin)/ax1range,ymax=(o-ax1ymin)/ax1range,linewidth=3+lwAdd,color=col)
            ax1.axvline(x=date2num(df['date'])[i],ymin=(l-ax1ymin)/ax1range,ymax=(h-ax1ymin)/ax1range,linewidth=1+lwAdd,color=col)
        else:
            col = 'r'
            ax1.axvline(x=date2num(df['date'])[i],ymin=(o-ax1ymin)/ax1range,ymax=(c-ax1ymin)/ax1range,linewidth=3+lwAdd,color=col)
            ax1.axvline(x=date2num(df['date'])[i],ymin=(l-ax1ymin)/ax1range,ymax=(h-ax1ymin)/ax1range,linewidth=1+lwAdd,color=col)

軸の調整

軸に関して、表記を整える。

25日移動平均線が切れるのを防ぐために、set_xlim を使い最初の25日を非表示にする。

また、set_ylabel を使い株価や出来高のラベルをつける。

    #ax1,ax2 x軸y軸の表記調整                                                                                                                                   
    ax1.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
    ax2.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
    months  = mdates.MonthLocator()
    days    = mdates.DayLocator()
    daysFmt = mdates.DateFormatter('%Y-%m-$d')
    ax1.set_xlim(df["date"].iloc[25], df["date"].iloc[len(df)-1]+timedelta(days=1)) #x軸の範囲                                                                  
    ax2.set_xlim(df["date"].iloc[25], df["date"].iloc[len(df)-1]+timedelta(days=1)) #x軸の範囲                                                                  
    fig.autofmt_xdate() #x軸のオートフォーマット                                                                                                                
    ax1.xaxis.set_major_locator(months)
    ax2.xaxis.set_major_locator(months)
    plt.xticks(rotation=0,ha="left",size='small')
    ax1.set_ylabel("株価")
    ax2.set_ylabel("出来高")

イベントの注記

株価に影響を与えるイベントを表示する。

ネガティブなイベントは右上か左下、ポジティブなイベントは左上か右下に書くと良いと考え、if 文で分岐させた。このように注記をどの座標に書くかの指定に手間をかけたのだが、まだ「そこじゃない」感が出るかもしれない。

   #イベント加筆                                                                                                                                               
    event_data= [
        ("2008-09-16","リーマン・ブラザーズ破綻翌日","N"),
        ("2008-10-28","日経平均株価の安値がバブル後最安","P"),
        ("2009-03-10","日経平均株価の終値がバブル後最安","P"),
        ("2011-03-11","東日本大震災","N"),
        ("2011-08-02","米国債務上限引き上げ法の可決","P"),
        ("2011-08-08","S&Pの米国債格下げの翌営業日","N"),
        ("2012-02-14","日本銀行のバレンタイン緩和","P")
        ]
    for date, label, pos  in event_data:
        eventTFdf = df[date2num(df["date"])==date2num([date])]
        if not(eventTFdf.empty):
            h = df["h"][eventTFdf.index[0]]
            l = df["l"][eventTFdf.index[0]]
            if h + ax1range/4. < ax1yDatamax:
                point_y = h
                text_y  = h + ax1range/3.
                if pos == "N":
                    HValign = ["left","bottom"]
                else:
                    HValign = ["right","bottom"]
            else:
                point_y = l
                text_y  = l - ax1range/3.
                if pos == "N":
                    HValign = ["right","top"]
                else:
                    HValign = ["left","top"]
            ax1.annotate(label, xy=(date2num([date]), point_y),
                         xytext=(date2num([date]),text_y),
                         arrowprops=dict(arrowstyle='->',facecolor='black'),
                         horizontalalignment=HValign[0], verticalalignment=HValign[1])
        else:
            print(label,date,"は表示範囲外のためコメントされません")

以上で完成。

関数全体

def dfplot(df,codeBrand):
    #ベースとなるチャート作成                                                                                                                                   
    fig = plt.figure(figsize=(10,7))
    ax1 = plt.subplot2grid((4,1), (0,0),rowspan=3)
    ax2 = plt.subplot2grid((4,1), (3,0))
    ax1.set_title(codeBrand)
    ax1.scatter(date2num(df['date']), df['c'], color='k',s=1)
    vbar = ax2.bar(date2num(df['date']), df['v'])
    ax1.grid()
    ax2.grid()

    #ax1 範囲指定(微調整含む)                                                                                                                                 
    ax1yDatamin = df['l'][25:].min()
    ax1yDatamax = df['h'][25:].max()
    ax1Datarange = ax1yDatamax - ax1yDatamin
    ax1ymin = ax1yDatamin - ax1Datarange/20.
    ax1ymax = ax1yDatamax + ax1Datarange/20.
    ax1range = ax1ymax - ax1ymin
    ax1.set_ylim(ax1ymin,ax1ymax)

    #ax1 移動平均線から描く                                                                                                                                     
    sma5 = sma(df, 5)
    sma25 = sma(df, 25)
    ax1.plot(sma5[:, 0], sma5[:, 1], color='g',linestyle='dashed')
    ax1.plot(sma25[:, 0], sma25[:, 1], color='b',linestyle='dashdot')

    #ax1 ローソク足チャート                                                                                                                                     
    #調整用にdfに載っている日数を取得                                                                                                                           
    n = len(df)

    if n < 80:
        lwAdd = 1 #linewidthに足して補正                                                                                                                        
    else:
        lwAdd = 0
    for i in range(len(df)):
        o = df['o'][i]
        h = df['h'][i]
        l = df['l'][i]
        c = df['c'][i]
        if c < o:
            col = 'k'
            ax1.axvline(x=date2num(df['date'])[i],ymin=(c-ax1ymin)/ax1range,ymax=(o-ax1ymin)/ax1range,linewidth=3+lwAdd,color=col)
            ax1.axvline(x=date2num(df['date'])[i],ymin=(l-ax1ymin)/ax1range,ymax=(h-ax1ymin)/ax1range,linewidth=1+lwAdd,color=col)
        else:
            col = 'r'
            ax1.axvline(x=date2num(df['date'])[i],ymin=(o-ax1ymin)/ax1range,ymax=(c-ax1ymin)/ax1range,linewidth=3+lwAdd,color=col)
            ax1.axvline(x=date2num(df['date'])[i],ymin=(l-ax1ymin)/ax1range,ymax=(h-ax1ymin)/ax1range,linewidth=1+lwAdd,color=col)


    #ax1,ax2 x軸y軸の表記調整                                                                                                                                   
    ax1.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
    ax2.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
    months  = mdates.MonthLocator()
    days    = mdates.DayLocator()
    daysFmt = mdates.DateFormatter('%Y-%m-$d')
    ax1.set_xlim(df["date"].iloc[25], df["date"].iloc[len(df)-1]+timedelta(days=1)) #x軸の範囲                                                                  
    ax2.set_xlim(df["date"].iloc[25], df["date"].iloc[len(df)-1]+timedelta(days=1)) #x軸の範囲                                                                  
    fig.autofmt_xdate() #x軸のオートフォーマット                                                                                                                
    ax1.xaxis.set_major_locator(months)
    ax2.xaxis.set_major_locator(months)
    plt.xticks(rotation=0,ha="left",size='small')
    ax1.set_ylabel("株価")
    ax2.set_ylabel("出来高")

    #イベント加筆                                                                                                                                               
    event_data= [
        ("2008-09-16","リーマン・ブラザーズ破綻翌日","N"),
        ("2008-10-28","日経平均株価の安値がバブル後最安","P"),
        ("2009-03-10","日経平均株価の終値がバブル後最安","P"),
        ("2011-03-11","東日本大震災","N"),
        ("2011-08-02","米国債務上限引き上げ法の可決","P"),
        ("2011-08-08","S&Pの米国債格下げの翌営業日","N"),
        ("2012-02-14","日本銀行のバレンタイン緩和","P")
        ]
    for date, label, pos  in event_data:
        eventTFdf = df[date2num(df["date"])==date2num([date])]
        if not(eventTFdf.empty):
            h = df["h"][eventTFdf.index[0]]
            l = df["l"][eventTFdf.index[0]]
            if h + ax1range/4. < ax1yDatamax:
                point_y = h
                text_y  = h + ax1range/3.
                if pos == "N":
                    HValign = ["left","bottom"]
                else:
                    HValign = ["right","bottom"]
            else:
                point_y = l
                text_y  = l - ax1range/3.
                if pos == "N":
                    HValign = ["right","top"]
                else:
                    HValign = ["left","top"]
            ax1.annotate(label, xy=(date2num([date]), point_y),
                         xytext=(date2num([date]),text_y),
                         arrowprops=dict(arrowstyle='->',facecolor='black'),
                         horizontalalignment=HValign[0], verticalalignment=HValign[1])
        else:
            print(label,date,"は表示範囲外のためコメントされません")

    plt.savefig("保存フォルダ名/" + str(code) + "-" + str(df["date"].iloc[len(df)-1]) + "-" + str(n-25) +"d.png")
    plt.show()

コメント

タイトルとURLをコピーしました