跳到主要内容

TFT Model

· 阅读需 6 分钟
Jiujiuwhoami
Digital nomads

🔥中文简明 TFT 使用指南】,覆盖从数据准备 → 模型定义 → 训练 → 预测 → 可视化的完整流程,适合用于回归预测任务。


🧠 Temporal Fusion Transformer 中文简明使用指南(pytorch-forecasting)


✅ 1. 安装依赖

pip install pytorch-forecasting pytorch-lightning==1.6.0

确保你的环境中有 PyTorch(>=1.8),建议使用 Python 3.9 或 3.10。


✅ 2. 数据格式要求

数据是一个 DataFrame,最少要包含以下列:

group_idtime_idxtargetknown_features...unknown_features...
A01.2......
A11.4......
B03.3......

核心字段说明:

字段名含义
group_id序列组 ID,例如股票代码、商品编号
time_idx时间索引(必须为整数)
target目标值(你要预测的变量)
known_reals未来可知的数值特征
unknown_reals未来未知的特征(如销量、价格)

✅ 3. 构造 TimeSeriesDataSet

from pytorch_forecasting import TimeSeriesDataSet, GroupNormalizer

training_cutoff = data["time_idx"].max() - max_prediction_length

training = TimeSeriesDataSet(
data[data.time_idx <= training_cutoff],
time_idx="time_idx",
target="target",
group_ids=["group_id"],
max_encoder_length=30, # 使用过去30步作为输入
max_prediction_length=7, # 预测未来7步
time_varying_known_reals=["time_idx", "known_feature1", "known_feature2"],
time_varying_unknown_reals=["target", "unknown_feature1"],
target_normalizer=GroupNormalizer(groups=["group_id"]), # 回归任务一定要这个!
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)

然后构造验证集:

validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)

✅ 4. 构造 DataLoader

from torch.utils.data import DataLoader

batch_size = 64

train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)

✅ 5. 定义 TFT 模型

from pytorch_forecasting import TemporalFusionTransformer

tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.001,
hidden_size=16,
attention_head_size=1,
dropout=0.1,
loss=RMSE(), # 回归任务用 RMSE, MAE, QuantileLoss 等
log_interval=10,
reduce_on_plateau_patience=4,
)

✅ 6. 模型训练

from pytorch_lightning import Trainer

trainer = Trainer(
max_epochs=20,
gradient_clip_val=0.1,
enable_model_summary=True,
accelerator="auto",
)

trainer.fit(
tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)

✅ 7. 模型预测与评估

# 加载最优模型
best_model = TemporalFusionTransformer.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

# 预测
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
predictions = best_model.predict(val_dataloader)

✅ 8. 可视化结果

best_model.plot_prediction(x=next(iter(val_dataloader))[0], show_future_observed=True)

✅ 9. 进阶参数说明(可选)

参数名含义
max_encoder_length输入序列长度
max_prediction_length输出预测步数
GroupNormalizer针对每组数据单独归一化
hidden_sizeLSTM 隐藏层大小
attention_head_size注意力头数
learning_rate学习率
loss损失函数(RMSE/MAE/QuantileLoss)

✅ 示例预测图

import matplotlib.pyplot as plt
raw_predictions, x = best_model.predict(val_dataloader, mode="raw", return_x=True)

# 画出第一组的预测情况
best_model.plot_prediction(x, raw_predictions, idx=0)
plt.show()


一个完整可运行的 TFT 回归预测示例代码:


✅ 数据字段说明(你提供的):

字段名含义类型
_time时间戳时间或索引
OPEN开盘价连续变量
HIGH最高价连续变量
LOW最低价连续变量
CLOSE收盘价(目标值)连续变量
VOL成交量连续变量

✅ 📄 完整可运行代码示例(tft_close_predict.py

import pandas as pd
import numpy as np
import torch
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, GroupNormalizer, RMSE
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import matplotlib.pyplot as plt

# 1. 读取数据(假设数据存储在 CSV 文件中)
df = pd.read_csv("your_data.csv")

# 2. 数据预处理
df["_time"] = pd.to_datetime(df["_time"])
df = df.sort_values("_time").reset_index(drop=True)
df["time_idx"] = np.arange(len(df))
df["group"] = "asset_1" # 单资产预测任务,构造一个 group_id 列

# 3. 目标值预测为“下一天的收盘价”
df["target"] = df["CLOSE"].shift(-1)
df = df.dropna().reset_index(drop=True)

# 4. 设置参数
max_encoder_length = 30
max_prediction_length = 1

# 5. 构建训练集
training_cutoff = df["time_idx"].max() - max_prediction_length
training = TimeSeriesDataSet(
df[df.time_idx <= training_cutoff],
time_idx="time_idx",
target="target",
group_ids=["group"],
max_encoder_length=max_encoder_length,
max_prediction_length=max_prediction_length,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["OPEN", "HIGH", "LOW", "CLOSE", "VOL", "target"],
target_normalizer=GroupNormalizer(groups=["group"]),
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)

# 6. 构建验证集
validation = TimeSeriesDataSet.from_dataset(training, df, predict=True, stop_randomization=True)

# 7. 构建 dataloader
batch_size = 64
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)

# 8. 构建模型
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.001,
hidden_size=16,
attention_head_size=1,
dropout=0.1,
loss=RMSE(),
log_interval=10,
reduce_on_plateau_patience=4,
)

# 9. 训练模型
trainer = pl.Trainer(
max_epochs=20,
accelerator="auto",
gradient_clip_val=0.1,
)

trainer.fit(tft, train_dataloader, val_dataloader)

# 10. 预测与可视化
best_model = TemporalFusionTransformer.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

raw_predictions, x = best_model.predict(val_dataloader, mode="raw", return_x=True)
best_model.plot_prediction(x, raw_predictions, idx=0)
plt.show()

✅ 数据文件格式示例(your_data.csv

_time,OPEN,HIGH,LOW,CLOSE,VOL
2023-01-01,100,105,98,102,1000
2023-01-02,102,107,101,106,1200
2023-01-03,106,110,104,108,900
...

🔍 可选优化建议

  • 若数据中有多个标的,可以用 symbolcode 当作 group_id
  • 如果你有未来已知信息(如节假日、周几、宏观经济指标),可以加入 time_varying_known_reals
  • 如果你想预测未来多天收盘价,改 max_prediction_length=1 为你想要的天数即可。