在這一篇文章中,我們介紹兩個基本的方式來改進DTW的速度和準確性:warping constraint 和 z-normalization。
以及在金融數據上pattern的辨識與預測的簡單應用。
- Warping constrain
首先,讓我們先定義 warping constraint: 所謂的warping constraint就是我們的DTW路徑允許偏離對角線的程度。如下圖所示,warping constraint w = r/n (或是可以直接定義成 w = r)
Sakoe-Chiba Band
Source: Abdullah Mueen, Eamonn J. Keogh: Extracting Optimal Performance from Dynamic Time Warping. KDD 2016: 2129-2130
這個 w 有兩個作用:1. 加速整個演算法,2. 增加準確性
- 加速:
由上圖所示,我們應該就可以很容易看出為什麼 w 可以替我們的DTW加速:因為我們所被允許經過的格子變少了(圖中的灰色通道),由n2 減少到 w*n2。
def dtw_constraint(template, unknown, window = 10, alignment_curve = False, gen_plot = False):
template = np.array(template)
unknown = np.array(unknown)
# local constraint
window = max( window, abs(len(template)-len(unknown)) )
# dtw matrix initialization
dtw = np.ndarray(shape = (len(template), len(unknown)))
dtw[:,:] = float('inf')
for row in range(dtw.shape[0]):
# set the constraint
for col in range(max(0, row-window), min(len(unknown), row+window)):
dist = np.sqrt( np.sum( (template[row] - unknown[col])**2 ) )
if row == 0 and col == 0:
dtw[row, col] = dist
elif row == 0:
dtw[row, col] = dist + dtw[row, col - 1]
elif col == 0:
dtw[row, col] = dist + dtw[row-1, col]
else:
dtw[row, col] = dist + min(dtw[row-1, col], dtw[row-1, col-1], dtw[row, col-1])
idx = np.argsort( [ dtw[row-1, col], dtw[row-1, col-1], dtw[row, col-1] ] )[0]
# trace for alignment curve
if alignment_curve:
row = 0
col = 0
alignment = [ [row, col] ]
while row != dtw.shape[0] - 1 or col != dtw.shape[1] - 1:
if row == dtw.shape[0] - 1:
col += 1
elif col == dtw.shape[1] - 1:
row += 1
else:
idx = np.argsort( [ dtw[row+1, col], dtw[row+1, col+1], dtw[row, col+1] ] )[0]
if idx == 0:
row += 1
elif idx == 1:
row += 1
col += 1
else:
col += 1
alignment.append([row, col])
alignment = np.array(map(np.array, alignment))
if gen_plot:
fig = plt.figure(figsize = (7,7))
plt.imshow( dtw )
plt.xlim(0, dtw.shape[1])
plt.ylim(0, dtw.shape[0])
plt.title("Constrained Dynamic Time Warping Matrix Heat Map")
if alignment_curve:
plt.plot( alignment[:,1], alignment[:,0], linewidth = 3, color = 'white', label = 'alignment curve')
plt.legend(loc = 'best')
plt.show()
if alignment_curve:
return dtw, alignment
else:
return dtw
- 準確性:
為了解釋 w 怎麼幫助提升準確性,我們舉一個簡單的例子。
讓我們先隨便的生成兩類不同的序列。第一種是一條長度100的序列,但是在0~29之間隨機生成一個高度為1的峰值。第二種一樣是一條長度100的序列,但是在70~99之間生成一個高度為1的峰值。
也就是這「兩類」不同的序列,是依照他們的峰值位在哪個區域所決定。
# generate two sets of sequences
np.random.seed(100)
a = {}
b = {}
for i in range(5):
a["seq_%s" %i] = np.zeros(100)
a["seq_%s" %i][np.random.choice(30,1)] = 1
b["seq_%s" %i] = np.zeros(100)
b["seq_%s" %i][np.random.choice(np.arange(70,100),1)] = 1
# you can plot them if you want to
if False:
plt.plot(np.arange(100), a["seq_%s" %i], color = 'b')
plt.plot(np.arange(100), b["seq_%s" %i], color = 'r')
plt.show()
現在,讓我們再次隨機生成一個第一類的序列(query)。DTW能夠告訴我們它是屬於第一類還是第二類嗎?
query = np.zeros(100)
query[np.random.choice(30,1)] = 1
plt.plot(np.arange(100), query)
我們現在先試試看利用上一篇文章所介紹的,最基本的DTW算法試試。
(小提醒:DTW是用來衡量兩段序列相似程度的算法,所以最後算出的數值越小,代表兩個序列越相像」)
basic_dis = []
for i in range(5):
basic_dis.append( dtw_basic(arr1 = a['seq_%s' %i], arr2 = query, alignment_curve = False, gen_plot = False)[-1,-1] )
for i in range(5):
basic_dis.append( dtw_basic(arr1 = b['seq_%s' %i], arr2 = query, alignment_curve = False, gen_plot = False)[-1,-1] )
basic_dis
output: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
可以看到,最原始的DTW無法分辨query是屬於第一類還是第二類。對它來說,都是一樣的。
再來,我們增加 warping constraint來看看。
constraint_dis = []
for i in range(5):
constraint_dis.append( dtw_constraint(template = a['seq_%s' %i], unknown = query, window = 20,
alignment_curve = False, gen_plot = False)[-1,-1] )
for i in range(5):
constraint_dis.append( dtw_constraint(template = b['seq_%s' %i], unknown = query, window = 20,
alignment_curve = False, gen_plot = False)[-1,-1] )
constraint_dis
output: [0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 2.0]
可以看到,這一次改良版的DTW可以分辨出query是屬於第一類的序列。(因為前五個output都為0)
讓我們現在來看看到底發生什麼事。
在第一個例子中,DTW告訴我們query和所有序列都是同一類,其實這正是我們可以預期的結果。因為其實所有的序列,其實都只不過是一個長度100、在0~99之間隨機出現一個高度為1峰值的序列。
但是,如果一但我們引入了warping constraint,我們保證了DTW在搜尋峰值的時候,不會跑得太遠,所以可以將query正確的分到第一類。
(註:在這邊我們可以說「正確的分到第一類」,是因為類別是由我們自己定義的。在實際的情況上,不一定這麼簡單。)
讓我們看看,到底DTW怎麼把兩段不同的序列相連的。
def dtw_plot(query, template, alignment_curve):
"""
Function to plot the connections between two sequences.
"""
fig = plt.figure(figsize = (14,7))
# shift the query sequence up
plt.plot(np.arange(len(query)), query + 2, lw = 2, label = 'query')
plt.plot(np.arange(len(template)), template, lw = 2, label = 'template')
plt.legend(loc = 'best')
for x1, x2 in alignment_curve:
#print query[x1]+2, a['seq_%s' %0][x2]
plt.plot([x1,x2], [template[x1], query[x2]+2], 'r')
我們先來看看原始的DTW怎麼把query和第一類的序列相連。注意:這邊我們把query向上平移了2個單位,以利識別。
c, d = dtw_basic(arr1 = query, arr2 = a['seq_0'], alignment_curve = True, gen_plot = False)
dtw_plot(query, a['seq_0'], alignment_curve = d)
再來看看兩個序列分處於不同類別時候的情況。
c, d = dtw_basic(arr1 = query, arr2 = b['seq_0'], alignment_curve = True, gen_plot = True)
dtw_plot(query, b['seq_0'], alignment_curve = d)
再來,讓我們引入warping constraint。
當兩個序列為同一個類別時:
c, d = dtw_constraint(template = a['seq_0'], unknown = query, window = 20,
alignment_curve = True, gen_plot = True)
dtw_plot(query, a['seq_0'], alignment_curve = d)
序列為不同類別時:
當我們引入warping constraint之後,由DTW對應的路徑是否合理許多了?
- Z-normalization
正如同上面例子所看到的,我們可以單存的平移一個序列而不會改變它的形狀。但是這個「平移」,對於DTW的計算,卻會有很大的影響,因為各個相對應的資料點間的距離都被平移了。
為了處理這個問題,我們引入z-normalization:
$\bar{x}_i$ 和 $\sigma_i$ 分別是序列i的平均值和標準差。
def z_normalize(array):
return 1.0*(array - np.mean(array))/np.std(array)
先讓我們看看normalization是如何影響我們的結果:
# shift the a sequence by 2 and compute the DTW
dtw_constraint(template = a['seq_0'], unknown = a['seq_0']+2, window = 20,
alignment_curve = False, gen_plot = False)[-1,-1]
output: 200.0
# shift the a sequence by 2 and compute the DTW after normalization
dtw_constraint(template = z_normalize(a['seq_0']),
unknown = z_normalize(a['seq_0']+2), window = 20,
alignment_curve = False, gen_plot = False)[-1,-1]
output: 2.1650736758971334e-13
上面第一個例子,我們將兩個一模一樣的序列拿來比較,但其中一個向上平移了2個單位。
這個簡單的小動作,讓DTW告訴你這兩個序列是非常不相像的!
但是經過normalization後,我們簡單地將平移的效應移除,得到了第二個結果,也就是兩個序列一模一樣。
這個簡單的小動作,讓DTW告訴你這兩個序列是非常不相像的!
但是經過normalization後,我們簡單地將平移的效應移除,得到了第二個結果,也就是兩個序列一模一樣。
- DTW in action
再來,終於到了我們可以將DTW應用在金融數據的時候了!在這裡,我們使用美股的代表性ETF:SPY舉例。
spy = pd.read_csv(os.getcwd() + "/SPY.csv", parse_dates = True, index_col = 0)
spy.tail()
為了便於比較,我們將我們的query和每一段的template切成長度60的序列。
# assign our query
query = spy.iloc[-60:]['Adj Close']
# we slices a 60 days sequence every 20 steps
window = 60
step = 20
n = int( (len(spy) - window)/step )
# create a dataframe to store results
dtw_df = pd.DataFrame(columns = ['start_date', 'end_date', 'dtw_basic', 'dtw_constraint'], index = [0])
for i in range(n):
# template sequence
template = spy.iloc[i*step:i*step + window]['Adj Close']
# apply z-normalization
a = z_normalize(array = query)
b = z_normalize(array = template)
dtw_df.loc[i, 'start_date'] = spy.index[i*step]
dtw_df.loc[i, 'end_date'] = spy.index[i*step+window]
dtw_df.loc[i, 'dtw_basic'] = dtw_basic(arr1 = b, arr2 = a,
alignment_curve = False, gen_plot = False)[-1,-1]
dtw_df.loc[i, 'dtw_constraint'] = dtw_constraint(template = b, unknown = a, window = 5,
alignment_curve = False, gen_plot = False)[-1,-1]
dtw_df.head()
我們將所有的計算結果由小到大排列出來(數值越小越相似),其中最相似的是:
# find out the period with shortest dtw measurements
print dtw_df.sort_values(by = 'dtw_basic').iloc[0]
print dtw_df.sort_values(by = 'dtw_constraint').iloc[0]
和目前(2017-08-28)最相似的一段價格走勢為2005-05-31 到2005-08-24 之間的走勢。
如果我們相信價格走勢在某種程度上是會不停重複出現,那麼我們就可以參照這段歷史作出我們的預測啦!
我們將這兩段走勢的起點都設為1,畫在一起比較如下圖:
start_idx = spy.index.get_loc( dtw_df.sort_values(by = 'dtw_basic').iloc[0]['start_date'] )
end_idx = spy.index.get_loc( dtw_df.sort_values(by = 'dtw_basic').iloc[0]['end_date'] )
# we make our forecast for 20 days forward
# notice that we align our starting points to unity
forecast = spy.iloc[start_idx: end_idx + 20]['Adj Close']/spy.iloc[start_idx]['Adj Close']
query = spy.iloc[-60:]['Adj Close']/spy.iloc[-60]['Adj Close']
fig = plt.figure(figsize = (14,7))
plt.plot(np.arange(len(forecast)), forecast, "-r", label = 'similar trends in history')
plt.plot(np.arange(len(query)), query, "-b", label = 'recent history')
plt.legend(loc = 'best')
plt.show()
文章中的 code 也可以在 github上找到:
https://github.com/S-H-Ho/study_notes/blob/master/dynamic_time_warping_2.ipynb
https://github.com/S-H-Ho/study_notes/blob/master/dynamic_time_warping_2.ipynb
下次,我們來討論如何在需要對比大量序列的時候,增快我們比對的速度。
Evernote helps you remember everything and get organized effortlessly. Download Evernote. |
你好,看完你的文章感覺相當有意思,不知道是否能交流些交易的想法?我的email是: ndc24075@gmail.com
回覆刪除您好 看您做台指機效不錯 請問能代操分潤嗎?
回覆刪除我的line mh_lee
回覆刪除希望能談談 謝謝
本杰明·李先生服务如何给我贷款!!!
回覆刪除大家好,我是来自瑞士苏黎世的Lea Paige Matteo,想用这种媒介感谢本杰明先生履行诺言,给我贷款,我陷入了财政困境,需要再融资和支付账单,以及创业。我试图从私人和企业组织的各种贷款公司寻求贷款,但从未成功,大多数银行拒绝了我的信贷要求。但正如上帝所愿,一位名叫丽莎·赖斯的朋友把我介绍给这个资助服务,并经历了从公司获得贷款的正当程序, 在5个工作日内,我最大的惊喜,就像我的朋友Lisa,我也获得了216,000.00美元的贷款,所以我建议每个人谁希望贷款,"如果你必须联系任何公司,如获得贷款在线低利率1.9%的利率和更好的还款计划/时间表,请联系这个融资服务。此外,他不知道我这样做,但由于我的喜悦,我很高兴,希望让人们知道更多关于这个伟大的公司谁真正提供贷款,这是我的祈祷,上帝应该祝福他们更多,因为他们把微笑的人脸上。您可以通过 [Lfdsloans@outlook.com] 或 Whatsapp 文本 +1-989 394 3740 通过电子邮件与他们联系。