2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

アテンション機構でトレンドを捉えるProphetモデルを提案してみた

Posted at

はじめに

こんにちは、事業会社で働いているデータサイエンティストです。

今回の記事では、時系列予測モデルでよく利用されるMetaのProphetをベイズアテンション(注意機構)で拡張して、トレンドの形と季節成分の要素を自動で学習モデルを紹介します。

時系列予測において、トレンドの抽出は常に最も難しい課題のひとつです。たとえば、Prophetのように、トレンドの形状をあらかじめ関数として仮定するアプローチもありますが、実際のトレンドが本当にその形に従うとは限りません。

一方で、ガウス過程などを用いる柔軟なモデルもありますが、平均構造を適切にモデル化しないと、トレンド成分が早い段階でゼロに回帰してしまい、長期予測にはあまり向きません。

今回紹介するモデルでは、ガウス過程のようにトレンドをデータドリブンな形で自動的に学習します。さらに、アテンションに似た構造を導入しているため、学習期間を超えてもトレンドがゼロに回帰しにくく、長期予測にも対応できるという利点があります。

モデル定式化

まず、時系列$y$は、このように生成されると仮定します:

$$
y_{t} \sim Normal(\mu + f_{trend, t} + f_{year, t} + f_{month, t} + f_{week, t}, \sigma)
$$

要するに、時系列の変動は、切片($\mu$)、長期トレンド($f_{trend}$)、年次周期性($f_{year}$)、月次周期性($f_{month}$)、週次周期性($f_{week}$)に分解できると仮定します。では、まず長期トレンドのモデリング方法を説明します。

長期トレンド

全体の構造から説明すると、こうなります:

$$
f_{trend, t} = softmax(アテンションスコア(t)) '\cdot \beta
$$

ここで、$アテンションスコア$も$\beta$も無限次元のベクトルです。続いて、アテンションスコアの要素$d$の指数変換はこのように計算されます:

$$
exp(アテンションスコア_{d}(t)) = \omega_{d} \cdot \left( \frac{1}{\sqrt{2\pi} \sigma_{d}} \exp\left( -\frac{1}{2} \left( \frac{t - \mu_{d}}{\sigma_{d}} \right)^2 \right) \right)
$$

ここで、$\mu_{d}$は要素$d$が担当する時間の中心、$\sigma_{d}$はその要素が管轄する時間の広がり(分散)を表します。したがって、例えばある時点$t$が要素1の中心$\mu_1$と要素2の中心$\mu_2$の間にあり、$\mu_1$に近い場合、その時点のトレンド成分は$\beta_1$の影響を強く受けつつも、$\beta_2$の影響も少し受け、結果として$\beta_1$と$\beta_2$の加重平均として表現されます。

この構造は、ChatGPTの中核にもなっているアテンション機構論文リンク)と非常に似ています。具体的には:

  • 時点$t$が「クエリ(query)」として機能し、
  • 各要素$d$の中心$\mu_{d}$と範囲$\sigma_{d}$が「キー(key)」として、
  • トレンドパラメータ$\beta_{d}$が「バリュー(value)」として振る舞います。

そして、
$$
アテンションスコア_d(t) \quad \leftrightarrow \quad \text{クエリとキーの類似度スコア}
$$
に相当し、
$$
softmax(\text{アテンションスコア}(t))
$$
で各要素の重要度(注意重み)を計算し、これを値($\beta$)の重みとして加重平均することで、最終的なトレンド$f_{\text{trend}, t}$を得ています。

このように、トレンド成分は時間ごとに異なるトレンド「プロトタイプ」へ注意を向け、重み付きの平均を取ることで柔軟に変化を表現しています。Prophetのようにトレンドの形を固定せず、また単純なガウス過程でもない、アテンション的かつベイズ的なトレンド表現が可能となっています。

ここではさらに、要素の数を研究者の恣意的な意思決定ではなく、データドリブンな形で判断するため、$\omega_{d}$がすぐゼロに収束するよう、棒折り過程で構築します。

具体的には、要素の重要度を決定するためのハイパーパラメータ $\gamma$ を、以下のようにガンマ分布からサンプリングします:

$$
\gamma \sim Gamma(1, 1)
$$

続いて、棒折り過程に基づき、各要素 $d$ に対応する重み $\omega_d$ を構築します:

$$
\delta_d \sim Beta(1, \gamma)
$$

$$
\omega_d = \delta_d \prod\limits_{l=1}^{d - 1} (1 - \delta_l)
$$

この処理は $d = 1$ から $d = \infty$ まで理論的に繰り返されます。ただ、$\omega_{d}$がすぐゼロに収束するため、実際は有限個の要素しか利用されません。

周期性

周期性に関しては、Prophetの論文を参考に

$$
s(t) = \sum_{n = 1}^{\infty} \phi_{n} \cdot (a_{n} cos(\frac{2\pi n t}{P}) + b_{n}sin(\frac{2 \pi n t}{P}))
$$

のフーリエ級数構造で表します。Prophetの論文では分析者が最大で考慮するフーリエ級数の要素量を設定していますが、本モデルでは長期トレンドと同じように、棒折り過程でデータドリブンな形で推定します。また、本モデルでは年次周期性、月次周期性、週次周期性を設定していますが、それぞれの周期性に異なる$\phi$、$a$、$b$を設定します。

Stanでの実装

Stanでの実装コードはこちらです:

attention_prophet.stan
functions {
  vector stick_breaking(vector breaks){
    int length = size(breaks) + 1;
    vector[length] result;
    
    result[1] = breaks[1];
    real summed = result[1];
    for (d in 2:(length - 1)) {
      result[d] = (1 - summed) * breaks[d];
      summed += result[d];
    }
    result[length] = 1 - summed;
    
    return result;
  }
}
data {
  int time_type;
  int approx_type;
  
  array[time_type] real time_seq;
  
  int N;
  vector[N] y;
}
parameters {
  real<lower=0> week_alpha;
  vector<lower=0, upper=1>[approx_type - 1] week_breaks;
  vector[approx_type] week_a;
  vector[approx_type] week_b;
  
  real<lower=0> month_alpha;
  vector<lower=0, upper=1>[approx_type - 1] month_breaks;
  vector[approx_type] month_a;
  vector[approx_type] month_b;
  
  real<lower=0> year_alpha;
  vector<lower=0, upper=1>[approx_type - 1] year_breaks;
  vector[approx_type] year_a;
  vector[approx_type] year_b;
  
  real<lower=0> trend_alpha;
  vector<lower=0, upper=1>[approx_type - 1] trend_breaks;
  vector[approx_type] trend_midpoint;
  vector<lower=0>[approx_type] trend_spread;
  vector[approx_type] trend_beta;
  
  real intercept;
  real<lower=0> sigma;
}
transformed parameters {
  simplex[approx_type] week_dimension;
  simplex[approx_type] month_dimension;
  simplex[approx_type] year_dimension;
  simplex[approx_type] trend_dimension;
  
  week_dimension = stick_breaking(week_breaks);
  month_dimension = stick_breaking(month_breaks);
  year_dimension = stick_breaking(year_breaks);
  trend_dimension = stick_breaking(trend_breaks);
  
  vector[time_type] f_week = rep_vector(0.0, time_type);
  vector[time_type] f_month = rep_vector(0.0, time_type);
  vector[time_type] f_year = rep_vector(0.0, time_type);
  vector[time_type] f_trend;
  
  for (i in 1:time_type){
    vector[approx_type] case_when;
    for (j in 1:approx_type){
      f_week[i] += week_dimension[j] * (week_a[j] * cos((2 * pi() * j * i)/7) + week_b[j] * sin((2 * pi() * j * i)/7));
      f_month[i] += month_dimension[j] * (month_a[j] * cos((2 * pi() * j * i)/30.4375) + month_b[j] * sin((2 * pi() * j * i)/30.4375));
      f_year[i] += year_dimension[j] * (year_a[j] * cos((2 * pi() * j * i)/365.25) + year_b[j] * sin((2 * pi() * j * i)/365.25));
      case_when[j] = log(trend_dimension[j]) + normal_lpdf(time_seq[i] | trend_midpoint[j], trend_spread[j]);
    }
    f_trend[i] = softmax(case_when) '* trend_beta;
  }
}
model {
  week_alpha ~ gamma(1, 1);
  week_breaks ~ beta(1, week_alpha);
  week_a ~ normal(0, 1);
  week_b ~ normal(0, 1);
  
  month_alpha ~ gamma(1, 1);
  month_breaks ~ beta(1, month_alpha);
  month_a ~ normal(0, 1);
  month_b ~ normal(0, 1);
  
  year_alpha ~ gamma(1, 1);
  year_breaks ~ beta(1, year_alpha);
  year_a ~ normal(0, 1);
  year_b ~ normal(0, 1);
  
  trend_alpha ~ gamma(1, 1);
  trend_breaks ~ beta(1, trend_alpha);
  trend_midpoint ~ normal(0, 10);
  trend_spread ~ inv_gamma(5, 5);
  trend_beta ~ normal(0, 1);
  
  intercept ~ normal(0, 5);
  sigma ~ inv_gamma(5, 5);
  
  y ~ normal(intercept + f_week[1:N] + f_month[1:N] + f_year[1:N] + f_trend[1:N], sigma);
}
generated quantities {
  array[time_type] real predict;
  predict = normal_rng(intercept + f_week + f_month + f_year + f_trend, sigma);
}

モデル学習

本記事では、WikipediaにおけるPeyton Manningの日次ログページビューの日時データを利用します。データはこちらからダウンロードできます。

ダウンロードしたデータを読み込み:

df <- readr::read_csv("example_wp_log_peyton_manning.csv")

モデルをコンパイルして:

m_init <- cmdstanr::cmdstan_model("attention_prophet.stan")

最後にモデルを変分推論で学習します。ここでは精度確認のため、最後の90日間のデータをモデルに渡さずに、検証データとして手元に残します:

> m_estimate <- m_init$variational(
     seed = 12345,
     data = list(
         time_type = nrow(df),
         approx_type = 20,
         
         time_seq = ((1:nrow(df)) - mean(1:nrow(df)))/sd(1:nrow(df)),
         
         N = length(df$y[1:(length(df$y) - 90)]),
         y = df$y[1:(length(df$y) - 90)]
     )
 )
------------------------------------------------------------ 
EXPERIMENTAL ALGORITHM: 
  This procedure has not been thoroughly tested and may be unstable 
  or buggy. The interface is subject to change. 
------------------------------------------------------------ 
Gradient evaluation took 0.023673 seconds 
1000 transitions using 10 leapfrog steps per transition would take 236.73 seconds. 
Adjust your expectations accordingly! 
Begin eta adaptation. 
Iteration:   1 / 250 [  0%]  (Adaptation) 
Iteration:  50 / 250 [ 20%]  (Adaptation) 
Iteration: 100 / 250 [ 40%]  (Adaptation) 
Iteration: 150 / 250 [ 60%]  (Adaptation) 
Iteration: 200 / 250 [ 80%]  (Adaptation) 
Success! Found best value [eta = 1] earlier than expected. 
Begin stochastic gradient ascent. 
  iter             ELBO   delta_ELBO_mean   delta_ELBO_med   notes  
   100        -3223.266             1.000            1.000 
   200        -2747.425             0.587            1.000 
   300        -2689.026             0.398            0.173 
   400        -2663.638             0.301            0.173 
   500        -2706.251             0.244            0.022 
   600        -2663.923             0.206            0.022 
   700        -2659.598             0.177            0.016 
   800        -2702.841             0.157            0.016 
   900        -2730.739             0.140            0.016 
  1000        -2657.634             0.129            0.016 
  1100        -2643.428             0.030            0.016 
  1200        -2671.969             0.013            0.016 
  1300        -2648.325             0.012            0.011 
  1400        -2657.023             0.012            0.011 
  1500        -2642.324             0.011            0.010 
  1600        -2645.681             0.009            0.009   MEAN ELBO CONVERGED   MEDIAN ELBO CONVERGED 
Drawing a sample of size 1000 from the approximate posterior...  
COMPLETED. 
Finished in  43.7 seconds.

43秒で終わりました、早いですね!

最後に、推定結果をデータフレイムに保存します:

m_summary <- m_estimate$summary()

学習結果可視化

ここでは早速、推定されたトレンド、年次周期性、月次周期性、週次周期性を確認しましょう:

g_trend <- m_summary |>
  dplyr::filter(stringr::str_detect(variable, "^f_trend\\[")) |>
  dplyr::bind_cols(date = df$ds) |>
  ggplot2::ggplot() + 
  ggplot2::geom_line(ggplot2::aes(x = date, y = mean)) + 
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = q5, ymax = q95), fill = ggplot2::alpha("blue", 0.3)) +
  ggplot2::labs(
    title = "長期トレンド",
    x = "", y = ""
  ) +
  ggplot2::scale_x_date(
    date_breaks = "1 year",
    date_labels = "%Y"
  ) +
  ggplot2::theme_gray(base_family = "HiraKakuPro-W3")

g_year <- m_summary |>
  dplyr::filter(stringr::str_detect(variable, "^f_year\\[")) |>
  dplyr::bind_cols(date = df$ds) |>
  ggplot2::ggplot() + 
  ggplot2::geom_line(ggplot2::aes(x = date, y = mean)) + 
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = q5, ymax = q95), fill = ggplot2::alpha("blue", 0.3)) + 
  ggplot2::labs(
    title = "年次周期性",
    x = "", y = ""
  ) +
  ggplot2::scale_x_date(
    date_breaks = "1 year",
    date_labels = "%Y"
  ) +
  ggplot2::theme_gray(base_family = "HiraKakuPro-W3") + 
  ggplot2::theme(
    axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)
  )

g_month <- m_summary |>
  dplyr::filter(stringr::str_detect(variable, "^f_month\\[")) |>
  dplyr::bind_cols(date = df$ds) |>
  dplyr::filter(dplyr::between(date, as.Date("2015-01-01"), as.Date("2015-05-31"))) |>
  ggplot2::ggplot() + 
  ggplot2::geom_line(ggplot2::aes(x = date, y = mean)) + 
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = q5, ymax = q95), fill = ggplot2::alpha("blue", 0.3)) + 
  ggplot2::labs(
    title = "月次周期性",
    x = "", y = ""
  ) +
  ggplot2::theme_gray(base_family = "HiraKakuPro-W3") + 
  ggplot2::theme(
    axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)
  )


g_week <- m_summary |>
  dplyr::filter(stringr::str_detect(variable, "^f_week\\[")) |>
  dplyr::bind_cols(date = df$ds) |>
  dplyr::filter(dplyr::between(date, as.Date("2015-01-01"), as.Date("2015-01-31"))) |>
  ggplot2::ggplot() + 
  ggplot2::geom_line(ggplot2::aes(x = date, y = mean)) + 
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = q5, ymax = q95), fill = ggplot2::alpha("blue", 0.3)) + 
  ggplot2::labs(
    title = "週次周期性",
    x = "", y = ""
  ) +
  ggplot2::scale_x_date(
    breaks = seq(as.Date("2015-01-01"), as.Date("2015-01-31"), by = "1 day"),
    labels = scales::date_format("%A")
  ) +
  ggplot2::theme_gray(base_family = "HiraKakuPro-W3") + 
  ggplot2::theme(
    axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)
  )

gridExtra::grid.arrange(g_trend, g_year, g_month, g_week, nrow = 2)

components.png

長期トレンドは、最初は上昇傾向にありましたが、2012年6月頃にピークを迎え、その後は下降トレンドに転じました。

次に周期性のスケールに注目してください。年次成分はおおよそ-1から1の範囲で変動しているのに対し、週次成分は-0.1から0.1、月次成分はさらに小さく-0.05から0.05の範囲で推移しています。したがって、周期成分の影響の強さは、年次 > 週次 > 月次 の順であることがわかります。

Peyton Manningを検索する人々は、おそらく仕事中のビジネスパーソンのように月末の最終営業日などを気にするわけではないため、月次の周期性があまり現れていないと考えられます。

年次の傾向を見ると、検索量は年の前半に少なく、後半に向けて増加する傾向が見られます。また週次の傾向としては、週末に検索量が減少する傾向があることも明らかになりました。

最後に、予測結果と実際のデータを比較します:

m_summary |>
  dplyr::filter(stringr::str_detect(variable, "^predict\\[")) |>
  dplyr::bind_cols(date = df$ds, 
                   answer = df$y,
                   status = c(rep("train", length(df$y[1:(length(df$y) - 90)])), rep("test", 90))
                   ) |>
  ggplot2::ggplot() + 
  ggplot2::geom_point(ggplot2::aes(x = date, y = answer, color = status)) +
  ggplot2::geom_line(ggplot2::aes(x = date, y = mean)) + 
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = q5, ymax = q95), fill = ggplot2::alpha("blue", 0.3)) + 
  ggplot2:: scale_color_manual(
    values = c(
      "train" = ggplot2::alpha("blue", 0.5),
      "test" = ggplot2::alpha("red", 0.5)
    )
  ) + 
  ggplot2::labs(
    title = "予測",
    x = "", y = "", color = ""
  ) + 
  ggplot2::theme_gray(base_family = "HiraKakuPro-W3") 

prediction.png

青い線は予測結果の事後分布における平均値を示しており、青い帯はその5%分位点と95%分位点の範囲、つまり信用区間(90%)を表しています。また、青い点は学習データ、赤い点は検証データを表しています。視覚的に見ると、このモデルは学習期間中だけでなく検証期間においても、実際のデータとよくフィットしていることが確認できます。

実際に全体の平均絶対パーセント誤差を計算しましょう;

> m_summary |>
     dplyr::filter(stringr::str_detect(variable, "^predict\\[")) |>
     dplyr::bind_cols(date = df$ds, 
                      answer = df$y,
                      status = c(rep("train", length(df$y[1:(length(df$y) - 90)])), rep("test", 90))
     ) |>
     dplyr::mutate(
         ape = abs((answer - mean)/answer)
     ) |>
     dplyr::pull(ape) |>
     mean()
[1] 0.04945335

4%です。なかなか精度高いですね。90日間の検証データだけで計算しても:

> m_summary |>
     dplyr::filter(stringr::str_detect(variable, "^predict\\[")) |>
     dplyr::bind_cols(date = df$ds, 
                      answer = df$y,
                      status = c(rep("train", length(df$y[1:(length(df$y) - 90)])), rep("test", 90))
     ) |>
     dplyr::mutate(
         ape = abs((answer - mean)/answer)
     ) |>
     dplyr::filter(status == "test") |>
     dplyr::pull(ape) |>
     mean()
[1] 0.06456061

6%なので、ビジネスの実務にも耐えられそうな精度です。

結論

いかがでしたか?

このように、棒折り過程と言語モデルにおけるアテンション機構を活用することで、解釈可能で平均に回帰しないトレンドを推定し、精度の高いモデルを推定することができます。

最後に、私たちと一緒に働きたい方はぜひ下記のリンクもご確認ください:

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?