TFT Model
· 阅读需 6 分钟
🔥中文简明 TFT 使用指南】,覆盖从数据准备 → 模型定义 → 训练 → 预测 → 可视化的完整流程,适合用于回归预测任务。
🧠 Temporal Fusion Transformer 中文简明使用指南(pytorch-forecasting)
✅ 1. 安装依赖
pip install pytorch-forecasting pytorch-lightning==1.6.0
确保你的环境中有 PyTorch(>=1.8),建议使用 Python 3.9 或 3.10。
✅ 2. 数据格式要求
数据是一个 DataFrame,最少要包含以下列:
group_id | time_idx | target | known_features... | unknown_features... |
---|---|---|---|---|
A | 0 | 1.2 | ... | ... |
A | 1 | 1.4 | ... | ... |
B | 0 | 3.3 | ... | ... |
核心字段说明:
字段名 | 含义 |
---|---|
group_id | 序列组 ID,例如股票代码、商品编号 |
time_idx | 时间索引(必须为整数) |
target | 目标值(你要预测的变量) |
known_reals | 未来可知的数值特征 |
unknown_reals | 未来未知的特征(如销量、价格) |
✅ 3. 构造 TimeSeriesDataSet
from pytorch_forecasting import TimeSeriesDataSet, GroupNormalizer
training_cutoff = data["time_idx"].max() - max_prediction_length
training = TimeSeriesDataSet(
data[data.time_idx <= training_cutoff],
time_idx="time_idx",
target="target",
group_ids=["group_id"],
max_encoder_length=30, # 使用过去30步作为输入
max_prediction_length=7, # 预测未来7步
time_varying_known_reals=["time_idx", "known_feature1", "known_feature2"],
time_varying_unknown_reals=["target", "unknown_feature1"],
target_normalizer=GroupNormalizer(groups=["group_id"]), # 回归任务一定要这个!
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)
然后构造验证集:
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
✅ 4. 构造 DataLoader
from torch.utils.data import DataLoader
batch_size = 64
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
✅ 5. 定义 TFT 模型
from pytorch_forecasting import TemporalFusionTransformer
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.001,
hidden_size=16,
attention_head_size=1,
dropout=0.1,
loss=RMSE(), # 回归任务用 RMSE, MAE, QuantileLoss 等
log_interval=10,
reduce_on_plateau_patience=4,
)
✅ 6. 模型训练
from pytorch_lightning import Trainer
trainer = Trainer(
max_epochs=20,
gradient_clip_val=0.1,
enable_model_summary=True,
accelerator="auto",
)
trainer.fit(
tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
✅ 7. 模型预测与评估
# 加载最优模型
best_model = TemporalFusionTransformer.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
# 预测
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
predictions = best_model.predict(val_dataloader)
✅ 8. 可视化结果
best_model.plot_prediction(x=next(iter(val_dataloader))[0], show_future_observed=True)
✅ 9. 进阶参数说明(可选)
参数名 | 含义 |
---|---|
max_encoder_length | 输入序列长度 |
max_prediction_length | 输出预测步数 |
GroupNormalizer | 针对每组数据单独归一化 |
hidden_size | LSTM 隐藏层大小 |
attention_head_size | 注意力头数 |
learning_rate | 学习率 |
loss | 损失函数(RMSE/MAE/QuantileLoss) |
✅ 示例预测图
import matplotlib.pyplot as plt
raw_predictions, x = best_model.predict(val_dataloader, mode="raw", return_x=True)
# 画出第一组的预测情况
best_model.plot_prediction(x, raw_predictions, idx=0)
plt.show()
一个完整可运行的 TFT 回归预测示例代码:
✅ 数据字段说明(你提供的):
字段名 | 含义 | 类型 |
---|---|---|
_time | 时间戳 | 时间或索引 |
OPEN | 开盘价 | 连续变量 |
HIGH | 最高价 | 连续变量 |
LOW | 最低价 | 连续变量 |
CLOSE | 收盘价(目标值) | 连续变量 |
VOL | 成交量 | 连续变量 |
✅ 📄 完整可运行代码示例(tft_close_predict.py
)
import pandas as pd
import numpy as np
import torch
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, GroupNormalizer, RMSE
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import matplotlib.pyplot as plt
# 1. 读取数据(假设数据存储在 CSV 文件中)
df = pd.read_csv("your_data.csv")
# 2. 数据预处理
df["_time"] = pd.to_datetime(df["_time"])
df = df.sort_values("_time").reset_index(drop=True)
df["time_idx"] = np.arange(len(df))
df["group"] = "asset_1" # 单资产预测任务,构造一个 group_id 列
# 3. 目标值预测为“下一天的收盘价”
df["target"] = df["CLOSE"].shift(-1)
df = df.dropna().reset_index(drop=True)
# 4. 设置参数
max_encoder_length = 30
max_prediction_length = 1
# 5. 构建训练集
training_cutoff = df["time_idx"].max() - max_prediction_length
training = TimeSeriesDataSet(
df[df.time_idx <= training_cutoff],
time_idx="time_idx",
target="target",
group_ids=["group"],
max_encoder_length=max_encoder_length,
max_prediction_length=max_prediction_length,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["OPEN", "HIGH", "LOW", "CLOSE", "VOL", "target"],
target_normalizer=GroupNormalizer(groups=["group"]),
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)
# 6. 构建验证集
validation = TimeSeriesDataSet.from_dataset(training, df, predict=True, stop_randomization=True)
# 7. 构建 dataloader
batch_size = 64
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
# 8. 构建模型
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.001,
hidden_size=16,
attention_head_size=1,
dropout=0.1,
loss=RMSE(),
log_interval=10,
reduce_on_plateau_patience=4,
)
# 9. 训练模型
trainer = pl.Trainer(
max_epochs=20,
accelerator="auto",
gradient_clip_val=0.1,
)
trainer.fit(tft, train_dataloader, val_dataloader)
# 10. 预测与可视化
best_model = TemporalFusionTransformer.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
raw_predictions, x = best_model.predict(val_dataloader, mode="raw", return_x=True)
best_model.plot_prediction(x, raw_predictions, idx=0)
plt.show()
✅ 数据文件格式示例(your_data.csv
)
_time,OPEN,HIGH,LOW,CLOSE,VOL
2023-01-01,100,105,98,102,1000
2023-01-02,102,107,101,106,1200
2023-01-03,106,110,104,108,900
...
🔍 可选优化建议
- 若数据中有多个标的,可以用
symbol
或code
当作group_id
。 - 如果你有未来已知信息(如节假日、周几、宏观经济指标),可以加入
time_varying_known_reals
。 - 如果你想预测未来多天收盘价,改
max_prediction_length=1
为你想要的天数即可。