# coding=utf-8
# ======================================
# File: qt_operator.py
# Author: Jackie PENG
# Contact: jackie.pengzhao@gmail.com
# Created: 2020-02-21
# Desc:
# Operator Class definition.
# ======================================
import logging
import os
import numpy as np
import pandas as pd
from typing import Generator, Optional, Union, Any, Iterable, Mapping
import qteasy
from qteasy.strategy import BaseStrategy, RuleIterator
from qteasy.group import Group
from qteasy.parameter import Parameter
from qteasy.datatypes import DataType
from qteasy.history import (
check_and_prepare_trade_prices,
check_and_prepare_benchmark_data,
check_and_prepare_backtest_data,
)
from qteasy.utilfuncs import (
TIME_FREQ_STRINGS,
AVAILABLE_OP_TYPES,
str_to_list,
rolling_window,
SlideView,
sanitize_filename,
)
from qteasy.built_in import (
available_built_in_strategies,
BUILT_IN_STRATEGIES,
get_built_in_strategy,
)
SIGNAL_TYPE_ID = {'pt': 0, 'ps': 1, 'vs': 2}
# 模拟实盘(live_trade)当前已验证可运行的资产类型代码(与 DataType 一致)
_LIVE_TRADE_SUPPORTED_ASSET_TYPES = frozenset({'E', 'FD'})
[文档]class Operator:
"""qteasy 中的核心「交易员」对象,用于承载策略组并在统一时间表上生成交易信号。
Operator 负责管理一个或多个策略组,准备并缓存运行所需的历史数据,按照运行频率
与时机调用各策略生成信号,并按组内/组间合并规则将信号合成为最终的交易指令,
可用于回测、优化和实盘运行。关于 Operator 与 Strategy / Group 的关系、run_freq /
run_timing 单一来源等架构细节,详见文档「Operator / Strategy / Group 架构」章节。
"""
def __init__(self,
strategies: Union[str, BaseStrategy, type, list[Union[str, BaseStrategy, type]]] = None,
*,
name: str = None,
signal_type: str = 'pt',
op_type: str = 'batch',
group_merge_type: str = 'None',
run_freq: str = 'd',
run_timing: str = 'close',
) -> None:
"""生成一个 Operator 对象,并初始化其策略组、运行频率与信号模式等关键属性。
Parameters
----------
strategies : str or BaseStrategy or type or list of (str or BaseStrategy or type), optional
用于生成交易信号的交易策略清单,可以是内置策略 ID 字符串、自定义策略实例
或策略类本身;列表中可混合使用多种形式。为空时会创建一个暂不包含策略的
Operator。
name : str, optional
Operator 对象名称,用于区分不同交易员实例,亦会体现在日志与 report 中。
signal_type : {'pt', 'ps', 'vs'} or None, default 'pt'
组内策略的交易信号模式:
- 'pt':position target,信号表示目标持仓权重;
- 'ps':proportional signal,信号表示相对总资产的买卖比例;
- 'vs':value/volume signal,信号表示固定金额或数量的买卖指令。
详细含义与执行差异见文档中 PT/PS/VS 相关章节。
op_type : {'batch', 'stepwise'} or None, default 'batch', deprecated
运行类型;当前版本主要用于兼容旧代码,新实现以 Group 运行时间表为准。
group_merge_type : {'None', 'and', 'or'}, default 'None'
多个策略组在同一时间点生成信号时的合并方式:
- 'None':各组信号独立执行;
- 'and':各组信号按加总方式合并;
- 'or':各组信号按相乘方式合并。
run_freq : str, default 'd'
新增策略时默认使用的运行频率字符串,需为合法时间频率之一(如 'd'、'30min' 等)。
run_timing : str, default 'close'
新增策略时默认使用的运行时机(如 'open'、'close'),决定信号在每个频率内的触发时点。
Returns
-------
Operator
新构建的 Operator 对象。
Examples
--------
>>> import qteasy as qt
>>> op = qt.Operator('dma, macd')
>>> op.strategy_ids
['dma', 'macd']
>>> op2 = qt.Operator(['dma', 'macd'], signal_type='pt', run_freq='d', run_timing='close')
>>> isinstance(op2, qt.Operator)
True
"""
# 全局 signal 行号(与 op_signal_index 对齐),用于 tracing 与 process data 对齐
self._trace_signal_index = 0
# 当前正在生成/处理的全局 signal 行号,用于 process data API
self._current_signal_index: int = 0
self.debug = False # debug模式下,Operator对象自动被认为是ready的
self.name = name
# 如果对象的种类未在参数中给出,则直接指定最简单的策略种类
if isinstance(strategies, str):
stg = str_to_list(strategies)
elif isinstance(strategies, (BaseStrategy, type)):
stg = [strategies]
elif isinstance(strategies, list):
stg = strategies
else:
stg = []
if signal_type is None:
signal_type = 'pt'
if op_type is None:
op_type = 'batch'
if group_merge_type:
if not isinstance(group_merge_type, str):
raise TypeError(f'group_merge_type should be a string, got {type(group_merge_type)} instead.')
group_merge_type = group_merge_type.lower()
if group_merge_type not in ['none', 'and', 'or']:
raise ValueError(f'Invalid group_merge_type ({group_merge_type})')
# check if run_freq is valid
if not isinstance(run_freq, str):
raise TypeError(f'run_freq should be a string, got {type(run_freq)} instead')
if '-' in run_freq:
main_freq = run_freq.split('-')[0]
else:
main_freq = run_freq
if main_freq not in TIME_FREQ_STRINGS:
raise ValueError(f'run_freq should be one of {TIME_FREQ_STRINGS}, got {run_freq} instead')
# 初始化Operator对象的"工作数据"或"运行数据",以下属性由Operator自动设置,不允许用户手动设置:
# Operator对象的工作变量
self._op_type = ''
self._next_stg_index = 0 # int——递增的策略index,确保不会出现重复的index
# Operator对象包含的交易策略组
self._groups = [] # 交易策略组,所有同时同频运行的策略会被归为同一组
self._group_merge_type = None # 交易策略组的合并方式,默认为None
self.group_timing_table = None # 交易策略组的运行时间表,一个DataFrame,每列代表一个策略组,1表示运行,0不运行
self.group_merge_type = group_merge_type # 交易策略组的合并方式,默认为None
self.group_schedules = {} # 交易策略组的运行时间表,包含每个组的运行时间和频率
# Operator对象存储的历史数据缓存和窗口缓存:
self.data_buffers = {} # Dict——Operator对象的历史数据缓存,缓存所有策略所需的历史数据
self.dynamic_data_buffers = {} # Dict——Operator对象的动态历史数据缓存,缓存所有策略所需的动态历史数据
self.data_window_views = {} # Dict——Operator对象的历史数据滑窗视图,保存所有策略所需的历史数据滑窗
self.data_window_indices = {} # Dict——Operator对象的历史数据滑窗索引,保存所有策略所需的历史数据滑窗索引
# 回测 / 实盘运行过程中由 Backtester / Trader 注入的“交易过程数据”源与时间索引,
# 仅用于在策略中通过 get_data('proc.xxx', ...) 访问过程数据
self._process_data_sources: dict[str, Any] = {}
self._process_time_index = None
# batch模式下生成的交易清单以及交易清单的相关信息
self._op_signals = None # 在batch模式下,Operator生成的交易信号清单
self._op_signal_types = None # Operator交易信号的类型清单,一个list或者ndarray: 表示每一行信号的类型(PT/PS/VS)
self._op_signal_index = None # 生成timing_table后,生成交易信号的index
self._op_signal_shares = {} # Operator交易信号清单的股票代码,一个dict: {share: idx}
# 设置operator的主要关键属性
self.op_type = op_type # 保存operator对象的运行类型,使用property_setter deprecated
self.add_strategies(stg, run_freq=run_freq, run_timing=run_timing) # 添加strategy对象
# 其他相关属性
self._trace_enabled = False
if signal_type:
# change signal_types of all groups to the new signal_type
for group in self._groups:
group.signal_type = signal_type
def __repr__(self):
res = list()
res.append('Operator([')
if self.strategy_count > 0:
res.append(', '.join(self.strategy_ids))
res.append('], ')
res.append(f'name=\'{self.name}\')')
return ''.join(res)
@property
def empty(self):
"""检查operator是否包含任何策略"""
res = (len(self.strategies) == 0)
return res
@property
def strategies(self):
"""以列表的形式返回operator对象的所有Strategy对象"""
all_strategies = []
for group in self._groups:
all_strategies.extend(group.members)
return all_strategies
@property
def strategy_count(self):
"""返回operator对象中的所有Strategy对象的数量"""
return len(self.strategies)
@property
def strategy_ids(self):
"""返回operator对象中所有交易策略对象的ID"""
return [stg.strategy_id for stg in self.strategies]
@property
def op_type(self): # deprecated
""" 返回operator对象的运行类型"""
return self._op_type
@op_type.setter
def op_type(self, op_type): # deprecated
""" 设置operator对象的运行类型"""
if not isinstance(op_type, str):
raise KeyError(f'op_type should be a string, got {type(op_type)} instead.')
op_type = op_type.lower()
if op_type not in AVAILABLE_OP_TYPES:
raise KeyError(f'Invalid op_type ({op_type})')
if op_type in ['s', 'st', 'step', 'stepwise']:
op_type = 'stepwise'
else:
op_type = 'batch'
self._op_type = op_type
@property
def op_data_type_ids(self):
"""返回operator对象所有策略子对象所需历史数据类型的ID"""
d_types = [typ for item in self.strategies for typ in item.data_type_ids]
d_types = list(set(d_types))
return d_types
@property
def op_data_types(self):
"""返回operator对象所有策略子对象所需历史数据类型对象"""
d_types = [typ for item in self.strategies for typ in item.data_types.values()]
d_types = list(set(d_types))
return d_types
@property
def op_data_type_count(self):
""" 返回operator对象生成交易清单所需的历史数据类型数量
"""
return len(self.op_data_types)
@property
def op_signal_index(self) -> Optional[pd.Index]:
""" 返回operator对象生成交易信号的index"""
if self.group_merge_type == 'None':
# 在这种情况下应该生成一个MultiIndex,同时包含时间和策略组信息
stacked_timing_table = self.group_timing_table.stack()
return stacked_timing_table[stacked_timing_table > 0].index
else:
# 在timing_table的index中增加一个level变成MultiIndex,并将Index的第二个level命名为'merged'
# TODO: 返回的index中应该指明Group_id,只有同时出现多个group的时候才标注merged
return pd.MultiIndex.from_product(
[self.group_timing_table.index, ['merged']],
)
@property
def op_data_freq(self) -> dict: # deprecated
"""返回operator对象所有策略子对象所需数据的采样频率
如果所有strategy的data_freq相同时,给出这个值,否则给出一个排序的列表
"""
d_freqs = {dtype.dtype_id: dtype.freq for dtype in self.all_strategy_data_types}
return d_freqs
@property
def group_merge_type(self):
""" 返回operator对象的策略组合并方式"""
return self._group_merge_type
@group_merge_type.setter
def group_merge_type(self, group_merge_type):
""" 设置operator对象的策略组合并方式"""
if group_merge_type is None:
self._group_merge_type = "None"
return
if not isinstance(group_merge_type, str):
raise TypeError(f'group_merge_type should be a string, got {type(group_merge_type)} instead.')
group_merge_type = group_merge_type.capitalize()
if group_merge_type not in ['None', 'And', 'Or']:
raise ValueError(f'Invalid group_merge_type ({group_merge_type})')
self._group_merge_type = group_merge_type
@property
def strategy_groups(self):
"""返回operator的所有策略组,返回以策略组的名称为索引的字典"""
return {g.name: g for g in self._groups}
@property
def groups(self):
"""返回operator的所有策略组,返回以策略组的名称为索引的字典"""
return {g.name: g for g in self._groups}
@property
def groups_by_index(self):
"""返回operator的所有策略组,返回以策略组的序号为索引的字典"""
return {i: g for i, g in enumerate(self._groups)}
@property
def group_ids(self):
"""返回operator对象所有策略子对象的运行时间类型"""
return [g.name for g in self._groups]
@property
def group_names(self):
"""返回operator对象所有策略子对象的运行时间类型"""
return [g.name for g in self._groups]
@property
def all_strategy_data_types(self):
""" 返回operator对象所有策略自对象的回测价格类型和交易清单历史数据类型的集合"""
all_types = set(self.op_data_types)
return list(all_types)
@property
def all_dynamic_dtypes(self):
""" 返回 operator 内“旧式”动态数据类型集合;已废弃,始终返回空。
过程数据现统一通过 get_data('proc.xxx') 访问,不再通过 DataType 声明。
保留本属性仅为兼容调用方,返回空字典。
"""
return {}
@property
def opt_space_par(self):
"""返回参与优化的子策略参数空间信息,用于构造 ``Space``。
遍历各子策略的 ``opt_tag``,汇总其 ``par_range`` / ``par_types`` 等字段,输出
``(ranges, types)`` 二元组,可直接作为 ``Space`` 构造输入。
Returns
-------
tuple[list, list]
``ranges`` 与 ``types`` 两个列表,分别对应各参与优化维度的取值范围与类型标记。
"""
ranges = []
types = []
for stg in self.strategies:
if stg.opt_tag == 0:
pass # 策略参数不参与优化
elif stg.opt_tag == 1:
# 所有的策略参数全部参与优化,且策略的每一个参数作为一个个体参与优化
ranges.extend(stg.par_range.values())
types.extend(stg.par_types.values())
elif stg.opt_tag == 2:
# 所有的策略参数全部参与优化,但策略的所有参数组合作为枚举同时参与优化
ranges.append(stg.par_range.values())
types.extend(['enum'])
return ranges, types
@property
def opt_tags(self):
""" 返回所有策略的优化类型标签
该属性返回值是一个列表,按顺序列出所有交易策略的优化类型标签
"""
return [stg.opt_tag for stg in self.strategies]
@property
def max_window_length(self):
""" 计算并返回operator对象所有子策略中最长的窗口长度。在准备回测或优化历史数据时,以此确保有足够的历史数据供策略形成
Returns
-------
int, operator对象中所有子策略中最长的窗口长度
"""
if self.strategy_count == 0:
return 0
else:
return max(stg.max_window_length for stg in self.strategies)
@property
def strategy_group_count(self):
""" 计算operator对象中所有交易策略组的数量
Returns
-------
int, operator对象中所有交易策略组的数量
"""
return len(self.strategy_groups)
@property
def op_signal_types(self) -> list:
""" 生成的交易清单的signal_types。因为生成的交易信号可能是由不同的策略组生成的
而不同的策略组有自己的信号类型(PT/ST/VT),因此每一行交易信号都有可能有不同的交易
信号类型,本属性返回一个list,其中包含了op_signal_list中每一行的signal_list
Returns
-------
list, 生成的交易清单的price_types,回测交易价格类型
"""
group_signal_type_schedules = {}
for group, schedule in self.group_schedules.items():
group_signal_type_schedules[group] = pd.DataFrame(
data=self.groups[group].signal_type,
index=schedule.index,
columns=['signal_type']
)
signal_type_table = pd.concat(group_signal_type_schedules.values(), axis=1)
if self.group_merge_type == 'None':
# 将signal_type_table stack起来后,直接可以得到每个交易日的所有asset 类型
return signal_type_table.stack().to_list()
else:
# 当group_merge_type不是None时,同时运行的group将会被merge起来,但是如果
# 此时被merge的group有不同的signal type,这时就会产生混乱,到底以哪一个group
# 的signal Type为准?因此这时需要报错给用户提示。
unique_asset_types = signal_type_table.stack().groupby(level=0).unique()
if any(unique_asset_types.apply(len) > 1):
conflict_rows = unique_asset_types[unique_asset_types.apply(len) > 1].head()
err = RuntimeError(f'You are trying to run multiple strategy groups that are generating '
f'different types of signals, which will result in ambiguous results:. \n'
f'{conflict_rows}\n'
f'Please make sure all groups are having same signal type settings or '
f'set operator.group_merge_type = "None"')
raise err
return unique_asset_types.explode().to_list()
@property
def ready(self):
""" 属性,operator.is_ready()的另一种写法"""
return self.is_ready()
[文档] def is_ready(self,
tell_me_why: bool = False,
raise_error: bool = False, ) -> bool:
"""检查 Operator 是否已具备生成交易信号所需的最小条件。
典型就绪条件包括:已挂载策略、各策略组已配置 blender、历史数据类型与运行时间表
等关键字段已就绪(更细的合法性校验在后续流程完成)。
Parameters
----------
tell_me_why: bool, default False
如果Operator对象不满足准备好的条件,是否打印出具体原因, 默认不打印
raise_error: bool, default False
如果Operator对象不满足准备好的条件,是否抛出异常, 默认不抛出
Returns
-------
bool, Operator对象是否已经准备好,可以开始生成交易信号
"""
if self.debug:
return True
message = [f'Operator readiness: ']
is_ready = True
# 确认operator对象中含有交易策略
if self.strategy_count == 0:
message.append(f'No strategy -- add strategies to Operator!\n')
is_ready = False
# 确认operator对象所有策略组都设置了混合器
group_no_blender = [g.name for g in self._groups if g.blender is None]
if len(group_no_blender) > 0:
message.append(f'No blender -- some of the strategy groups ({group_no_blender}) does not have blender '
f'set!\n')
is_ready = False
# 确认operator中每一个策略都已经设置了share_count及share_names属性,且所有share_count与所有len(share_names)相等
for stg in self.strategies:
if (stg.share_count == 0) or (stg.share_names is None):
message.append(f'Strategy ({stg.strategy_id}) share info not set -- '
f'share_count or share_names not set!\n')
is_ready = False
elif stg.share_count != len(stg.share_names):
message.append(f'Strategy ({stg.strategy_id}) share info invalid -- share_count ({stg.share_count}) '
f'not equal to len(share_names) ({len(stg.share_names)})!\n')
is_ready = False
if self.all_strategy_data_types:
# 确认operator对象已经设置了数据缓存
if len(self.data_buffers) == 0 or self.data_buffers is None:
message.append(f'No data buffer -- data buffers are empty!\n')
is_ready = False
# 确认operator对象运行所需的数据窗口已经全部创建好
if len(self.data_window_views) == 0 or self.data_window_views is None:
message.append(f'No data window -- data window views are not created!\n')
is_ready = False
# 确认operator对象运行数据窗口的数据索引已经创建好
if len(self.data_window_indices) == 0 or self.data_window_indices is None:
message.append(f'No data indices -- data window indices are not set!\n')
is_ready = False
# 确认operator对象运行数据窗口的数据索引是否合法
if len(self.data_window_indices) > 0:
for stg, data_window_indices in self.data_window_indices.items():
if not isinstance(data_window_indices, Mapping):
message.append(f'Invalid data indices -- data window indices of strategy {stg} is not a dict!\n')
is_ready = False
for dtype, indices in data_window_indices.items():
if any(index < 0 for index in indices):
message.append(f'Invalid data indices for dtype({dtype}) of stg({stg}) -- '
f'Some data window indices are negative! Normally this means the history '
f'data is not enough to cover start date of backtest!\n')
is_ready = False
# 确认operator对象的运行计划是否已经创建
if self.group_timing_table is None:
message.append(f'No group timing table -- group timing table is not created!\n')
is_ready = False
# 确认operator对象的运行计划是否创建(group_schedules)
if self.group_schedules == {} or self.group_schedules is None:
message.append(f'No group running schedule -- group schedules are not set!\n')
is_ready = False
message.insert(1, f'{"Ready" if is_ready else "Not Ready"}\n')
if (not is_ready) and tell_me_why:
print(''.join(message))
if (not is_ready) and raise_error:
raise RuntimeError(message)
return is_ready
[文档] def reset(self):
""" 重置Operator对象的运行状态,使其可以重新开始生成交易信号
Notes
-----
该方法会重置Operator对象的所有运行数据,包括历史数据缓存、数据窗口视图、数据窗口索引、交易信号等。
重置后,Operator对象可以重新开始生成交易信号。
Examples
--------
>>> op = Operator('dma, macd')
>>> op.reset()
"""
self.debug = False
self._next_stg_index = 0
self.data_buffers = {}
self.data_window_views = {}
self.data_window_indices = {}
self._op_signals = None
self._op_signal_types = None
self._op_signal_shares = {}
raise NotImplementedError
def __getitem__(self, item: Union[str, int]) -> BaseStrategy:
""" 根据策略的名称或序号返回子策略
Parameters
----------
item: int or str
策略的名称或序号
Returns
-------
Strategy, 子策略
Raises
------
TypeError: 当item类型不正确时
KeyError: 当需要返回的item不匹配任何strategy id或超过范围时
Notes
-----
1,当item为int时,返回的是第item个策略
2,当item为str时,返回的是名称为item的策略
3,当item不符合要求时,报错
Examples
--------
>>> op = Operator('dma, macd')
>>> op[0]
RULE-ITER(DMA)
>>> op['dma']
RULE-ITER(DMA)
>>> op[999]
RULE-ITER(MACD)
>>> op['invalid_strategy_name']
UserWarning: No such strategy with ID (invalid_strategy_name)!
See Also
--------
get_stg()
get_strategy_by_id()
"""
item_is_int = isinstance(item, int)
item_is_str = isinstance(item, str)
if not (item_is_int or item_is_str):
err = TypeError(f'strategy id should be either an integer or a string, got {type(item)} instead!')
raise err
strategies = {stg_id: stg for stg_id, stg in zip(self.strategy_ids, self.strategies)}
all_ids = list(strategies.keys())
if item_is_str:
if item not in all_ids:
err = KeyError(f'No such strategy with ID ({item}) in {all_ids}!')
raise err
return strategies[item]
strategy_count = self.strategy_count
if (item > strategy_count - 1) or (item < 0):
err = KeyError(f'Strategy index out of range: {item} out of total {strategy_count} strategies')
raise err
return strategies[all_ids[item]]
[文档] def get_stg(self, stg_id):
""" 获取一个strategy对象, Operator[item]的另一种用法
Parameters
----------
stg_id: int or str
策略的名称或序号
Returns
-------
Strategy, 子策略
Notes
-----
1,当stg_id为int时,返回的是第stg_id个策略
2,当stg_id为str时,返回的是名称为stg_id的策略
3,当stg_id不符合要求时,返回最后一个策略
Examples
--------
>>> op = Operator('dma, macd')
>>> op[0]
RULE-ITER(DMA)
>>> op['dma']
RULE-ITER(DMA)
>>> op[999]
RULE-ITER(MACD)
>>> op['invalid_strategy_name']
UserWarning: No such strategy with ID (invalid_strategy_name)!
See Also
--------
get_strategy_by_id()
"""
return self[stg_id]
[文档] def get_strategy_by_id(self, stg_id):
""" 获取一个strategy对象, Operator[item]的另一种用法
Parameters
----------
stg_id: int or str
策略的名称或序号
Returns
-------
Strategy, 子策略
See Also
--------
get_stg()
"""
return self[stg_id]
[文档] def get_strategy_id_pairs(self):
""" 返回一个generator,包含op中所有strategy和id对:
Returns
-------
generator, 包含op中所有strategy和id对
Examples
--------
>>> op = Operator('dma, macd')
>>> list(op.get_strategy_id_pairs())
[('dma', RULE-ITER(DMA)), ('macd', RULE-ITER(MACD))]
"""
return zip(self.strategy_ids, self.strategies)
[文档] def get_group(self, group_idx: Union[str, int]) -> Group:
""" 根据group_idx获取一个Group对象,等同于get_group_by_id方法
Parameters
----------
group_idx: int or str
策略组的序号
Returns
-------
Group,
Examples
--------
>>> op = Operator('dma, macd')
>>> op.get_group(0)
Group(name=Group_1, members=[RULE-ITER(DMA), RULE-ITER(MACD)])
>>> op.get_group_by_id('Group_1')
Group(name=Group_1, members=[RULE-ITER(DMA), RULE-ITER(MACD)])
"""
return self.get_group_by_id(group_id=group_idx)
[文档] def get_group_by_id(self, group_id: Union[str, int]) -> Group:
""" 根据group_id获取一个Group对象
Parameters
----------
group_id: str or int
策略组的名称ID或序号
Returns
-------
Group,
Notes
-----
1,当group_id为int时,返回的是序号为group_id的策略组
2,当group_id为str时,返回的是ID为group_id的组
Examples
--------
>>> op = Operator('dma, macd')
>>> op.get_group_by_id(0)
Group(name=Group_1, members=[RULE-ITER(DMA), RULE-ITER(MACD)])
>>> op.get_group_by_id('Group_1')
Group(name=Group_1, members=[RULE-ITER(DMA), RULE-ITER(MACD)])
"""
if isinstance(group_id, int):
if group_id < 0 or group_id >= len(self.groups):
raise IndexError(f'group_id index ({group_id}) out of range!')
group_id = self.group_ids[group_id]
return self.groups[group_id]
elif isinstance(group_id, str):
return self.groups[group_id]
else:
raise TypeError(f'group_id should be an integer or a string, got {type(group_id)} instead!')
[文档] def add_strategies(self, strategies: Union[str, list[Union[str, BaseStrategy, type]]],
run_freq: str = 'd',
run_timing: str = 'close',
**kwargs: Any):
""" 添加多个Strategy交易策略到Operator对象中
使用这个方法,不能在添加交易策策略的同时修改交易策略的基本属性
输入参数strategies可以为一个列表或者一个逗号分隔字符串
列表中的元素可以为代表内置策略类型的字符串,或者为一个具体的策略对象
字符串和策略对象可以混合给出
Parameters
----------
strategies: stg or list of str or list of Strategy
交易策略的名称或者交易策略对象
run_freq: str, optional
run_freq为策略的运行频率,可以为None,表示不指定运行频率
run_timing: str, optional
run_timing为策略的运行时机,可以为None,表示不指定运行时机
**kwargs: Any
添加的交易策略所共享的属性,如 run_timing、run_freq、window_length、
use_latest_data_cycle、freq、asset_type、data_type_ids 等(会传入 add_strategy)
Returns
-------
None
Examples
--------
>>> op = Operator()
>>> op.add_strategies(['dma', 'macd'])
>>> op.strategies
[RULE-ITER(DMA), RULE-ITER(MACD)]
"""
if isinstance(strategies, str):
strategies = str_to_list(strategies)
assert isinstance(strategies, list), f'TypeError, the strategies ' \
f'should be a list of string, got {type(strategies)} instead'
for stg in strategies:
if not isinstance(stg, (str, BaseStrategy, type)):
msg = (f'WrongType! some of the items in strategies '
f'can not be added - got {stg}')
raise TypeError(msg)
try:
self.add_strategy(stg, run_freq=run_freq, run_timing=run_timing, **kwargs)
except Exception as e:
import traceback
traceback.print_exc()
raise ValueError(f'Failed to add strategy {stg} to operator - {e}')
[文档] def add_strategy(self, stg: Union[str, BaseStrategy, type, tuple, list, Any],
run_freq: str = 'd',
run_timing: str = 'close',
**kwargs):
""" 添加一个strategy交易策略到operator对象中
如果调用本方法添加一个交易策略到Operator中,可以在添加的时候同时修改或指定交易策略的基本属性
Parameters
----------
stg: str or int or Strategy
需要添加的交易策略,也可以是内置交易策略的策略id或策略名称
run_freq: str, Optional
run_freq为策略的运行频率,可以为None,表示不指定运行频率
run_timing: str, Optional
run_timing为策略的运行时机,可以为None,表示不指定运行时机
kwargs:
任意合法的策略属性,可以在添加策略时直接给该策略属性赋值,
必须明确指定需要修改的属性名称,包含
- pars: dict or tuple, 策略可调参数
- opt_tag: int, 策略优化标签
- stg_type: int, 策略类型
- par_count: int, 策略参数个数
- par_types: list, 策略参数类型
- par_ranges: list, 策略参数范围
- data_freq: str, 策略数据频率
- window_length: int, 策略窗口长度
- run_freq: str, 策略采样频率
- data_types: list, 策略数据类型
- data_type_id: str, 策略数据类型 ID(用于 update_data_types)
- freq: str or list or dict, 数据类型的频率(修改会替换对应 DataType)
- asset_type: str or list or dict, 数据类型的资产类型(修改会替换对应 DataType)
- group: str, 策略运行时机
- use_latest_data_cycle: bool, 策略是否使用最新数据周期
Returns
-------
None
Examples
--------
>>> op = Operator()
>>> op.add_strategy('dma', opt_tag=1, pars=(50, 10, 20))
>>> op.strategies
[RULE-ITER(DMA)]
>>> op.strategies[0].opt_tag
1
>>> op.strategies[0].par_values
(50, 10, 20)
"""
# TODO: 添加策略时应可以设置name属性
# TODO: 添加策略时应可以设置description属性
# TODO: 添加策略时如果有错误,应该删除刚刚添加的strategy
# 如果输入为一个字符串时,检查该字符串是否代表一个内置策略的id或名称,使用.lower()转化为全小写字母
if isinstance(stg, str):
stg_id = stg.lower()
strategy = get_built_in_strategy(stg)
# 当传入的对象是一个strategy对象时,直接添加该策略对象
elif isinstance(stg, BaseStrategy):
stg_type = type(stg)
if stg_type in available_built_in_strategies:
stg_id_index = list(available_built_in_strategies).index(stg_type)
stg_id = list(BUILT_IN_STRATEGIES)[stg_id_index]
else:
stg_id = 'custom'
strategy = stg
elif isinstance(stg, type):
if stg in available_built_in_strategies:
stg_id_index = list(available_built_in_strategies).index(stg)
stg_id = list(BUILT_IN_STRATEGIES)[stg_id_index]
else:
stg_id = 'custom'
strategy = stg()
elif isinstance(stg, (tuple, list)):
err = TypeError(f'Strategy can not be a tuple of a list, only one strategy can be added! \n'
f'To add multiple strategies in the same time, use add_strategies() method instead!')
raise err
else:
err = TypeError(f'The strategy type \'{type(stg)}\' is not supported!')
raise err
if stg in self.strategies:
raise ValueError(f'The strategy {stg} is already in operator, '
f'please add a different strategy or make a copy.')
stg_id = self._next_stg_id(stg_id)
strategy._strategy_id = stg_id
# 规范化并检查 run_freq / run_timing(Group 为唯一来源)
if run_freq is None:
run_freq = 'd'
if run_timing is None:
run_timing = 'close'
if not isinstance(run_freq, str):
raise TypeError(f'run_freq should be a string, got {type(run_freq)} instead!')
if not isinstance(run_timing, str):
raise TypeError(f'run_timing should be a string, got {type(run_timing)} instead!')
# 使用run_freq和run_timing进行策略组匹配,而非从strategy读取(strategy的run_freq/run_timing从group委托)
if len(self._groups) == 0 or not any(
run_timing == group.run_timing and run_freq == group.run_freq
for group in self._groups
): # create a new group if no existing group matches the strategy's timing and frequency
group_id = self._next_group_id()
new_group = Group(name=group_id,
signal_type='PT',
blender=None,
run_freq=run_freq,
run_timing=run_timing, )
# 让 Group 反向持有 Operator 引用,便于策略通过 group 访问 Operator
new_group._operator = self
new_group.add_strategy(strategy)
self._groups.append(new_group)
else: # add the strategy to an existing group
for group in self._groups:
if run_timing == group.run_timing and run_freq == group.run_freq:
group.add_strategy(strategy)
break
# 逐一修改该策略对象的各个参数
try:
self.set_parameter(stg_id=stg_id, **kwargs)
except Exception as e:
self.remove_strategy(stg_id)
raise RuntimeError(f'{e} - strategy {stg_id} addition rolled back.')
def _next_stg_id(self, stg_id: str):
""" 为一个交易策略生成一个新的id"""
all_ids = self.strategy_ids
# 补全stg_id中缺失的序号,主要是将“stg_id”变为"stg_id_0"
all_ids = [ID + '_0' if len(ID.split("_")) == 1 else ID for ID in all_ids]
all_id_names = [ID.split("_")[0] for ID in all_ids if ID.split("_")[0] == stg_id]
if stg_id in all_id_names:
stg_id_stripped = [int(ID.split("_")[1]) for ID in all_ids if ID.split("_")[0] == stg_id]
next_id = stg_id + "_" + str(max(stg_id_stripped) + 1)
return next_id
else:
return stg_id
def _next_group_id(self):
""" 为一个交易策略组生成一个新的id"""
all_ids = self.group_ids
group_id_stripped = [int(ID.split("_")[1]) for ID in all_ids]
next_id = 'Group' + "_" + str(max(group_id_stripped) + 1) if group_id_stripped else 'Group_1'
return next_id
[文档] def remove_strategy(self, id_or_pos: Optional[Union[str, int]] = None) -> None:
"""从Operator对象中移除一个交易策略, 删除时可以给出策略的id或者策略在Operator中的位置"""
if self.strategy_count == 0:
raise IndexError("There's no strategy to be removed from operator")
pos = self.strategy_count
if id_or_pos is None:
pos -= 1
elif isinstance(id_or_pos, int):
if 0 <= id_or_pos < self.strategy_count:
pos = id_or_pos
elif id_or_pos < 0:
pos = max(self.strategy_count + id_or_pos, 0)
else: # id_or_pos >= self.strategy_count
pos = self.strategy_count - 1
elif isinstance(id_or_pos, str):
all_ids = self.strategy_ids
if id_or_pos not in all_ids:
raise ValueError(f'the strategy {id_or_pos} is not in operator')
else:
pos = all_ids.index(id_or_pos)
else: # other wrong types
err = TypeError(f'Must give the position or id of strategy as int or string, got {type(id_or_pos)}')
raise err
# 删除strategy时,通过Group.remove_strategy正确清除strategy._group引用
strategy = self[pos]
group = self.groups[strategy._group_id]
group.remove_strategy(strategy)
# 如果该group中没有其他成员了,则删除该group
if len(group.members) == 0:
self._groups.remove(group)
return
[文档] def clear_strategies(self):
""" 清空Operator对象中的所有交易策略 """
if self.strategy_count > 0:
for group in self._groups:
group.clear_strategies()
del group
self._groups = []
return
[文档] def get_strategies_by_group(self, group_id: str):
"""返回operator对象中的strategy对象, timing为一个可选参数,
如果给出timing时,返回使用该timing的交易策略
Parameters
----------
group_id : str
一个可用的timing, by default None
"""
return self.groups[group_id].members
[文档] def get_strategy_count_by_group(self, group_id: str):
"""返回策略组group_id中的所有策略数量"""
return len(self.get_strategies_by_group(group_id))
[文档] def get_strategy_names_by_group(self, group_id: str):
"""返回策略组group_id中的所有策略名称"""
return [stg.name for stg in self.get_strategies_by_group(group_id)]
[文档] def get_strategy_id_by_group(self, group_id: str):
"""返回策略组group_id中的所有策略ID"""
return [stg.strategy_id for stg in self.get_strategies_by_group(group_id)]
[文档] def get_max_window_length_by_dtype_id(self, dtype: str) -> int:
""" 计算并返回operator对象某个datatype最长的窗口长度。
Parameters
----------
dtype: str
需要查询的历史数据类型的dtype_id
Returns
-------
int, operator对象中所有子策略中某个dtype最长的窗口长度
"""
if not isinstance(dtype, str):
raise TypeError(f'dtype should be a string, got {type(dtype)} instead.')
if dtype not in self.op_data_type_ids:
raise ValueError(f'data type {dtype} is not in operator data types {self.op_data_types}')
if self.strategy_count == 0:
raise ValueError(f'no strategy in operator!')
else:
window_length = [stg.window_lengths[dtype] for stg in self.strategies if dtype in stg.data_type_ids]
if len(window_length) == 0:
raise ValueError(f'no strategy in operator uses data type {dtype}!')
return max(window_length)
[文档] def get_share_idx(self, share): # deprecated
""" 给定一个share(字符串)返回它对应的index
Parameters
----------
share: str
share为一个字符串,表示股票代码
Returns
-------
int
返回一个整数,表示share对应的index
"""
if self._op_signal_shares == {}:
return
return self._op_signal_shares[share]
[文档] def set_opt_par_values(self, par_values):
"""优化器侧入口:将一条参数向量按各子策略 ``opt_tag`` 切片写回。
与 ``set_parameter`` 不同,本方法面向「一条向量对应多个策略子块」的优化流程;
调用方需保证 ``par_values`` 与 ``opt_space_par`` 展平顺序一致。
Parameters
----------
par_values : tuple
优化器当前候选点对应的参数元组。
Returns
-------
None
Notes
-----
内部调用 ``Strategy.update_par_values``,不在此处做参数合法性校验。
"""
s = 0
k = 0
# 依次遍历operator对象中的所有策略:
for stg in self.strategies:
# 优化标记为0:该策略的所有参数在优化中不发生变化
if stg.opt_tag == 0:
pass
# 优化标记为1:该策略参与优化,用于优化的参数组的类型为上下界
elif stg.opt_tag == 1:
k += stg.par_count
stg.update_par_values(*par_values[s:k]) # 使用update_pars更新参数,不检查参数的正确性
s = k
# 优化标记为2:该策略参与优化,用于优化的参数组的类型为枚举
elif stg.opt_tag == 2:
# 在这种情况下,只需要取出参数向量中的一个分量,赋值给策略作为参数即可。因为这一个分量就包含了完整的策略参数tuple
k += 1
stg.update_par_values(*par_values[s]) # 使用update_pars更新参数,不检查参数的正确性
s = k
[文档] def set_blender(self,
blender: Union[str, list[str], dict[str, str]],
group_id: Union[str, None] = None):
""" 统一的blender混合器属性设置入口
Parameters
----------
blender: str or list of str or dict of str, optional
一个合法的交易信号混合表达式当group为None时,可以接受list为参数,
同时为所有的group设置混合表达式
group_id : str, optional
指定策略组 id 时仅更新该组;为 ``None`` 时更新全部策略组。后者在 ``blender`` 为
字符串时表示所有组使用同一表达式;为列表时按组顺序逐项设置,长度不足则复用
最后一项。
Returns
-------
None
Raises
------
TypeError
如果给出的price_type不是正确的类型
Warnings
--------
如果给出的price_type不存在,则给出warning并返回
Examples
--------
>>> op = Operator()
>>> op.add_strategy('dma', run_timing='close') # 添加策略时指定run_timing,自动分配到第一个策略组Group_1
>>> op.add_strategy('trix', run_timing='close') # 添加策略时指定run_timing,自动分配到第一个策略组Group_1
>>> op.add_strategy('macd', run_timing='open') # 添加策略时指定run_timing,自动分配到第二个策略组Group_2
>>> op.add_strategy('bband', run_timing='open') # 添加策略时指定run_timing,自动分配到第二个策略组Group_2
>>> # 设置策略组1的混合模式
>>> op.set_blender('s0+s1', 'Group_1')
>>> op.get_blender()
>>> {'Group_1': ['+', 's1', 's2']}
>>> # 给所有的交易价格策略设置同样的混合表达式
>>> op.set_blender('s0 + s1')
>>> op.get_blender()
>>> {'Group_1': ['+', 's2', 's1'], 'Group_2': ['+', 's2', 's1']}
>>> # 通过一个列表给不同的策略组设置不同的混合表达式(策略组按顺序排列)
>>> op.set_blender(['s1 + s2', 's3*s4'], None)
>>> op.get_blender()
>>> {'close': ['+', 's2', 's1'], 'open': ['*', 's4', 's3']}
"""
if self.strategy_count == 0:
return
if group_id is None:
# 当price_type没有显式给出时,同时为所有的price_type设置blender,此时区分多种情况:
if blender is None:
# price_type和blender都为空,退出
return
if isinstance(blender, str):
# blender为一个普通的字符串,此时将这个blender转化为一个包含该blender的列表,并交由下一步操作
blender = [blender]
if isinstance(blender, list):
# 将列表中的blender补齐数量后,递归调用本函数,分别赋予所有的price_type
if len(blender) == 0:
raise ValueError('Empty blender list!')
if any(not isinstance(b, str) for b in blender):
raise TypeError('All items in blender list should be strings!')
# 如果blender的数量少于price_type的数量,则重复最后一个blender
len_diff = self.strategy_group_count - len(blender)
if len_diff > 0:
blender.extend([blender[-1]] * len_diff)
for bldr, group in zip(blender, self.group_ids):
self.set_blender(blender=bldr, group_id=group)
elif isinstance(blender, dict):
# 如果blender为一个字典,则依次为字典中的每一个price_type赋予相应的blender
for group, bldr in blender.items():
self.set_blender(blender=bldr, group_id=group)
else:
raise TypeError(f'Wrong type of blender, a string or a list of strings should be given,'
f' got {type(blender)} instead')
return
if isinstance(group_id, str):
# 当直接给出price_type时,仅为这个price_type赋予blender
if group_id not in self.group_ids:
msg = f"Strategy group '{group_id}' is not valid in current Operator: {self.group_ids}!"
raise KeyError(msg)
if isinstance(blender, str):
try:
group = self.strategy_groups[group_id]
group.blender_str = blender
except ValueError as e:
raise ValueError(f'Invalid blender expression: "{blender}" - {e}')
else:
# 如果输入的blender类型不正确,则报错
raise TypeError(f'Wrong type of blender, a string should be given, got {type(blender)} instead')
# self._stg_blender_strings[group] = None
# self._stg_blender[group] = []
else:
raise TypeError(f'group should be a string, got {type(group_id)} instead')
return
[文档] def get_blender(self, group_name=None):
"""返回operator对象中的多空蒙板混合器, 如果不指定group_name的话,输出完整的blender字典
Parameters
----------
group_name: str
一个可用的group_name
Returns
-------
blender: dict or list
如果group_name为None,则返回一个字典,其中包含所有的run_timing的blender
如果group_name不为None,则返回一个列表,其中包含该run_timing的blender
"""
if group_name is None:
return {g.name: g.blender for g in self._groups if g.blender is not None}
if group_name not in self.strategy_groups:
return None
return self.groups[group_name].blender
[文档] def view_blender(self, group: Union[str, int] = None) -> Union[dict, str]:
""" 返回operator对象中的多空蒙板混合器的可读版本, 即返回blender的原始字符串的更加可读的
版本,将s0等策略代码替换为策略ID,并进行适当格式化。
如果不给出group参数,则返回所有策略组的blender可读版本
Parameters
----------
group: str
一个可用的group的ID或index
"""
from qteasy.blender import human_blender
if group is None:
all_blenders = {}
for group_id, stg_group in self.groups.items():
stg_ids = self.get_strategy_id_by_group(group_id)
all_blenders[group_id] = human_blender(
stg_group.blender_str,
strategy_ids=stg_ids,
)
return all_blenders
if group not in self.strategy_groups:
raise KeyError(f'No such strategy group with ID ({group})!')
return self.get_group(group).human_blender
[文档] def set_parameter(self,
stg_id: Union[str, int],
pars: Union[Parameter, tuple, list, dict] = None,
opt_tag: int = None,
data_types: Union[DataType, list[DataType], dict[str, DataType]] = None,
data_type_id: str = None,
window_length: Union[int, tuple[int, ...], list[int], dict[str, int]] = None,
use_latest_data_cycle: Union[bool, list[bool], tuple[bool, ...], dict[str, bool]] = None,
freq: Union[str, list[str], tuple[str, ...], dict[str, str]] = None,
asset_type: Union[str, list[str], tuple[str, ...], dict[str, str]] = None,
par_values: Union[tuple, list, dict[str, Any]] = None,
par_range: Union[tuple, list, dict[str, tuple]] = None,
run_freq: str = None,
run_timing: str = None,
**kwargs):
""" 统一的策略参数设置入口,stg_id标识接受参数的具体成员策略,将函数参数中给定的策略参数赋值给相应的策略
Parameters
----------
stg_id: str,
策略的名称(ID),根据ID定位需要修改参数的策略
pars: tuple or dict, optional
可调策略参数,格式为tuple
在创建一个策略的时候,可以设置部分策略参数为"可调参数",这些参数的取值范围可以在策略优化
过程中进行调整,通过调整这些参数的组合,可以找到最优的策略参数组合,从而找到最优的策略
opt_tag: int, optional
优化类型:
0: 不参加优化,在策略优化过程中不调整该策略的可调参数
1: 参加优化,在策略优化过程中根据优化算法主动调整策略参数以寻找最佳参数组合
2: 以枚举类型参加优化,在策略优化过程中仅从给定的参数组合种选取最优的参数组合
data_types: DataType or list of DataType or dict of str, optional
策略计算所需历史数据的数据类型,如果给出,则更新这个数据类型的参数
data_type_id: str, optional
策略计算所需历史数据的数据类型的ID,给出该ID表明更新这个数据类型的参数
window_length: int or list of int or tuple of int,
窗口长度:策略计算的前视窗口长度
use_latest_data_cycle: bool or list of bool or tuple of bool,
是否使用最新的数据周期
freq: str or list of str or tuple of str or dict of str,
数据类型的频率,若给出则用新 DataType 替换对应项(dtype_id 会随之变化)
asset_type: str or list of str or tuple of str or dict of str,
数据类型的资产类型,若给出则用新 DataType 替换对应项(dtype_id 会随之变化)
par_values: tuple or list,
策略参数的具体取值
par_range: tuple or list, or dict of tuples,
策略参数的取值范围
run_freq: str, optional
如果给出该参数,则修改策略的运行频率,修改运行频率会导致将策略从策略组中移除,并重新分配到一个新的策略组中
run_timing: str, optional
如果给出该参数,则修改策略的运行时机,修改运行时机会导致将策略从策略组中移除,并重新分配到一个新的策略组中
kwargs: dict,
其他参数
"""
assert isinstance(stg_id, (int, str)), f'stg_id should be a int or a string, got {type(stg_id)} instead'
# 根据策略的名称或ID获取策略对象
# TODO; 应该允许同时设置多个策略的参数(对于opt_tag这一类参数非常有用)
strategy = self.get_strategy_by_id(stg_id)
if strategy is None:
raise KeyError(f'Specified strategy does not exist or can not be found!')
# 逐一修改该策略对象的各个参数
if pars is not None: # 设置策略参数
if not strategy.set_pars(pars):
raise ValueError(f'parameter setting error: {pars}')
if opt_tag is not None: # 设置策略的优化标记
strategy.set_opt_tag(opt_tag)
if data_types is not None: # 设置策略的数据类型
strategy.set_data_types(
data_types=data_types,
window_length=window_length,
use_latest_data_cycle=use_latest_data_cycle,
)
if (data_type_id is not None) or (window_length is not None) or (use_latest_data_cycle is not None) \
or (freq is not None) or (asset_type is not None):
# 更新策略数据类型的ID或者其参数(含 freq/asset_type 替换)
strategy.update_data_types(
dtype_id=data_type_id,
window_length=window_length,
use_latest_data_cycle=use_latest_data_cycle,
freq=freq,
asset_type=asset_type,
)
if par_values is not None: # 设置策略参数的具体取值
if isinstance(par_values, dict) and isinstance(strategy, RuleIterator):
# 区分:参数名→值(kwargs) vs 股票→参数元组(multi_par)
par_names = set(getattr(strategy, 'par_names', ()))
keys_only_default = set(par_values.keys()) == {'default'}
# kwargs:键均为参数名、无 default,且值均为标量或单元素(非“参数元组”)
is_kwargs = (
par_names
and set(par_values.keys()) <= par_names
and 'default' not in par_values
and all(
not isinstance(v, (tuple, list)) or len(v) <= 1
for v in par_values.values()
)
)
if is_kwargs:
strategy.update_par_values(**par_values)
else:
# multi_par:至少有一个非 default 的 stock_id
if keys_only_default:
raise ValueError(
'multi_par中应该至少有一个不同于default的stock_id'
)
strategy.update_par_values(par_values)
else: # isinstance(par_values, (tuple, list)):
strategy.update_par_values(*par_values)
if par_range is not None: # 设置策略参数的取值范围
if not isinstance(par_range, (list, tuple, dict)):
raise TypeError(f'par_range should be a tuple or dict of tuples, got {type(par_range)} instead!')
if len(par_range) != strategy.par_count:
raise ValueError(f'par_range should have the same length as the number of strategy parameters, '
f'expected {strategy.par_count}, got {len(par_range)} instead!')
if isinstance(par_range, (tuple, list)):
strategy.update_par_ranges(*par_range)
else: # par_range is a dict
strategy.update_par_ranges(**par_range)
if ((run_freq is not None) and (run_freq != strategy.run_freq)) or \
((run_timing is not None) and (run_timing != strategy.run_timing)): # 设置策略的运行频率和运行时机
# 因为涉及到groups的调整,所以只有当run_freq/run_timing与原有不一致时,才重新设置
# run_freq/run_timing仅存储在Group中,Strategy从Group委托读取
old_group_id = strategy._group_id
old_group = self.groups[old_group_id]
new_run_freq = run_freq if run_freq is not None else old_group.run_freq
new_run_timing = run_timing if run_timing is not None else old_group.run_timing
if old_group.strategy_count == 1:
# 当策略单独成组时,先检查是否存在相同 (run_freq, run_timing) 的其他组,可合并
target_group = [
g for g in self._groups
if g is not old_group and g.run_timing == new_run_timing and g.run_freq == new_run_freq
]
if len(target_group) == 1:
# 合并:移出当前组,加入目标组
old_group.remove_strategy(strategy)
self._groups.remove(old_group)
group = target_group[0]
group.add_strategy(strategy)
strategy._group_id = group.name
else:
# 无目标组,原地更新
old_group.run_freq = new_run_freq
old_group.run_timing = new_run_timing
else:
# 当需要修改的策略不是它的group中的唯一一个strategy的时候,需要新建group并移动该strategy
old_group.remove_strategy(strategy)
if len(old_group.members) == 0:
self._groups.remove(old_group)
# 查找或创建目标group
target_group = [
g for g in self._groups if
new_run_timing == g.run_timing and new_run_freq == g.run_freq
]
if len(target_group) == 0:
group_id = self._next_group_id()
new_group = Group(name=group_id,
signal_type='PT',
blender=None,
run_freq=new_run_freq,
run_timing=new_run_timing, )
new_group._operator = self
new_group.add_strategy(strategy)
strategy._group_id = group_id
self._groups.append(new_group)
elif len(target_group) == 1:
group = target_group[0]
group.add_strategy(strategy)
strategy._group_id = group.name
else:
raise RuntimeError(f'more than one target group found for strategy {stg_id} '
f'with run_timing={new_run_timing} and run_freq={new_run_freq}')
# 设置其他自定义参数,如果修改参数时出现错误,则删除刚刚添加的strategy
strategy.set_custom_pars(**kwargs)
[文档] def set_shares(self, shares: list[str]):
""" 设置operator对象的交易标的列表
Parameters
----------
shares: list of str
一个字符串列表,表示交易标的代码列表
Returns
-------
None
"""
if not isinstance(shares, list):
raise TypeError(f'shares should be a list of strings, got {type(shares)} instead')
for share in shares:
if not isinstance(share, str):
raise TypeError(f'share should be a string, got {type(share)} instead')
self._op_signal_shares = {share: idx for idx, share in enumerate(shares)}
for strategy in self.strategies:
strategy.update_shares(share_names=shares)
[文档] def set_group_parameters(self,
group: Union[str, int],
run_timing: str = None,
run_freq: str = None,
signal_type: str = None,
blender_str: str = None,
**kwargs):
""" 设置或修改一个策略组的参数
Parameters
----------
group: str
策略组的ID
run_timing: str, optional
策略组的运行时机,修改运行时机时,修改策略组中所有交易策略的运行时机
run_freq: str, optional
策略组的运行频率,修改运行频率时,修改策略组
signal_type: str, optional
策略组的交易信号类型,默认为'PT',即百分比持仓目标
blender_str: str, optional
策略组的交易信号混合表达式,可以是一个字符串或一个字符串列表
kwargs: dict, optional
其他参数,可以是任意合法的策略组参数,如group_name, run_timing, run_freq等
Returns
-------
None
Raises
------
TypeError
如果group不是字符串或整数类型
ValueError
如果group不存在或无法找到
"""
group = self.get_group_by_id(group)
has_sf = run_freq is not None
has_pt = run_timing is not None
if has_sf or has_pt:
# check if new run_freq and run_timing are not the same as any existing groups
new_run_freq = run_freq if run_freq is not None else group.run_freq
new_run_timing = run_timing if run_timing is not None else group.run_timing
target_group = [
g for g in self._groups if
new_run_freq == g.run_freq and new_run_timing == g.run_timing and g.name != group.name
]
if len(target_group) > 0:
# move all strategies in current group to the target group and remove current group
target_group = target_group[0]
for strategy in list(group.member_strategies):
group.remove_strategy(strategy)
strategy._group_id = target_group.name
target_group.add_strategy(strategy)
self._groups.remove(group)
else:
# update group (strategy的run_freq/run_timing从group委托读取,无需单独更新)
group.run_freq = new_run_freq
group.run_timing = new_run_timing
if signal_type is not None:
group.signal_type = signal_type
if blender_str is not None:
if not isinstance(blender_str, str):
raise TypeError(f'blender should be a string or a list of strings, got {type(blender_str)} instead')
group.blender_str = blender_str
if kwargs:
for key, value in kwargs.items():
if hasattr(group, key):
setattr(group, key, value)
else:
raise ValueError(f'Invalid group parameter: {key}')
def _strategies_use_proc_data(self) -> bool:
"""检测是否有任意策略在 realize() 中使用了 get_data('proc.xxx') 形式的 process data。"""
import inspect
for stg in self.strategies:
try:
src = inspect.getsource(stg.realize)
if "'proc." in src or '"proc.' in src:
return True
except (TypeError, OSError):
continue
return False
[文档] def check_dynamic_data(self):
""" 检查operator对象是否包含动态数据类型(即依赖交易结果的历史数据)以生成交易信号。
若任意策略在 realize() 中使用了 get_data('proc.xxx'),则视为依赖动态过程数据,需走动态回测分支。
"""
if self.op_type == 'stepwise':
return True
return self._strategies_use_proc_data()
# =================================================
# 下面是Operation模块的公有方法:
[文档] def info(self, verbose=False):
""" 打印Operator对象的信息,包括策略组、组内策略,策略混合方式等等信息
如果策略包含更多的信息,还会打印出策略的一些具体信息
Parameters
----------
verbose: bool, Default False
是否打印出策略的详细信息, 如果为True, 则会打印出策略的详细信息
"""
from .utilfuncs import adjust_string_length
from rich import print as rprint
from shutil import get_terminal_size
terminal_width = get_terminal_size().columns
info_width = int(terminal_width * 0.75) if terminal_width > 120 else terminal_width
signal_type_descriptions = {
'pt': 'Position Target, signal represents position holdings in percentage of total value',
'ps': 'Percentage trade signal, represents buy/sell stock in percentage of total value',
'vs': 'Value trade signal, represent tha amount of stocks to be sold/bought'
}
op_type_description = {
'batch': 'All history operation signals are generated before back testing',
'stepwise': 'History op signals are generated one by one, every piece of signal will be back tested before '
'the next signal being generated.'
}
data_freq_name = {
'y': 'year',
'Y': 'year',
'ye': 'year end',
'q': 'quarter',
'Q': 'quarter',
'qe': 'quarter end',
'QE': 'quarter end',
'M': 'month',
'm': 'month',
'ME': 'month end',
'me': 'month end',
'W': 'week',
'w': 'week',
'd': 'days',
'min': 'min',
'1min': 'min',
'5min': '5min',
'15min': '15min',
'30min': '30min',
'h': 'hours',
}
rprint(f'{"Operator Information":=^{info_width}}\n'
f'Name: {self.name}\n'
f'Run Mode: {self.op_type} - {op_type_description[self.op_type]}\n'
f'Groups: {self.strategy_count} Strategy(s) in {self.strategy_group_count} Group(s)\n')
# 依次打印各个Group的信息:
for group_id, group in self.groups.items():
rprint(f'{group_id:-^{info_width}}\n'
f'Signal Type: {group.signal_type} - {signal_type_descriptions[group.signal_type]}\n'
f'Run Timing: {group.run_timing} @ {group.run_freq} - '
f'{data_freq_name.get(group.run_freq, group.run_freq)}\n'
f'Strategies ({group.strategy_count}): {self.get_strategy_id_by_group(group_id)}'
)
if group.blender_str:
rprint(f'Signal blenders: {group.human_blender}\n')
else:
rprint(f'Signal blender not set\n')
# 依次打印各个strategy的基本信息:
if (self.strategy_count > 0) and (not verbose):
id_width = int(info_width * .2)
name_width = int(info_width * .3)
par_width = int(info_width * .5)
rprint(f'{"Strategies in group":-^{info_width}}\n'
f'{"stg_id":<{id_width}}'
f'{"name":<{name_width}}'
f'{"parameters":<{par_width}}\n'
f'{"-" * info_width}')
for stg in self.get_strategies_by_group(group_id=group_id):
from .utilfuncs import parse_freq_string
stg_id = stg.strategy_id
qty, main_freq, sub_freq = parse_freq_string(stg.run_freq)
qty = '' if qty == 1 else qty # to prevent from printing 1x
rprint(f'{adjust_string_length(stg_id, id_width) :<{id_width}}'
f'{adjust_string_length(stg.name, name_width) :<{name_width}}'
f'{adjust_string_length(str(stg.par_values), par_width) :^{par_width}}')
if getattr(stg, 'multi_pars', None):
hint = '[multi_pars] use strategies -d for per-share parameters.'
print(f'{adjust_string_length("", id_width) :<{id_width}}'
f'{adjust_string_length("", name_width) :<{name_width}}'
f'{hint}')
print('=' * info_width)
# 打印每个strategy的详细信息
if (self.strategy_count > 0) and verbose:
print(f'{"Strategy Details":-^{info_width}}')
for stg in self.get_strategies_by_group(group_id=group_id):
from .utilfuncs import parse_freq_string
stg_id = stg.strategy_id
stg.info(stg_id=stg_id, verbose=verbose)
print('=' * info_width)
# Adding functions for the new operator class
[文档] def enable_tracing(self):
""" 启用Operator对象中所有strategy的跟踪功能
启用跟踪功能后,Operator对象在运行过程中会记录更多的调试信息,便于后续分析和调试
Returns
-------
None
"""
self._trace_enabled = True
self._trace_signal_index = 0
max_trace_steps = self.get_signal_count()
for stg in self.strategies:
stg.enable_tracing(max_steps=max_trace_steps)
[文档] def disable_tracing(self):
""" 禁用Operator对象中所有strategy的跟踪功能
禁用跟踪功能后,Operator对象在运行过程中不会记录调试信息,从而提高运行效率
Returns
-------
None
"""
self._trace_enabled = False
for stg in self.strategies:
stg.disable_tracing()
[文档] def prepare_running_schedule(self,
start_date=None,
end_date=None,
**kwargs,
) -> None:
""" Running Schedule也就是策略运行时间表,包含每个策略的运行时间和频率等信息
在运行策略之前,必须先准备好运行时间表,这个时间表根据交易员中每个策略组的运行时机参数确定。
运行时间表包括N行,每一行代表一个时间点,列数为策略组的数量,每个单元格表示该策略组在该时间
点是否运行,0表示不运行,1表示运行。
在这个方法中,将设置以下两个属性的值:
- `group_schedules`: 一个字典,键为策略组名称,值为该组的运行时间表
- `group_timing_table`: 一个DataFrame,包含所有策略组的运行时间表
Parameters
----------
start_date: str or pd.Timestamp, optional
开始日期,默认为None,表示从数据源的起始日期开始
end_date: str or pd.Timestamp, optional
结束日期,默认为None,表示到数据源的结束日期为止
kwargs: dict, optional
其他参数,包括生成交易时间序列过程中所需的参数,如交易开始时间、结束时间等等。
详见qteasy.trading_util.trade_time_index()函数的参数说明
Returns
-------
None
"""
from qteasy.trading_util import trade_time_index as tti
self.group_schedules = {}
for group in self._groups:
if group.run_timing is None or group.run_freq is None:
raise ValueError(f"Group {group.name} has no run timing or frequency defined.")
if group.run_freq in ['1min', '5min', '15min', '30min', 'h', 'H']:
schedule_index = tti(
start=start_date,
end=end_date,
freq=group.run_freq,
**kwargs,
) + pd.Timedelta(hours=0) # Adjust days to datetime,
elif (group.run_freq[0] in ['d', 'W', 'w', 'M', 'Q', 'Y']) or \
(group.run_freq[0:2] in ['ME', 'MS', 'QS', 'QE', 'YS', 'YE']):
# 运行时间设定为15:00 - close 及 09:30 - open
if group.run_timing == 'close':
if 'end_pm' in kwargs:
close_time = kwargs['end_pm'] # market_close_time_pm
time_offset = pd.to_datetime(close_time).strftime("%H:%M")
else:
time_offset = "15:00"
elif group.run_timing == 'open':
if 'start_am' in kwargs:
open_time = kwargs['start_am'] # market_open_time_am
time_offset = pd.to_datetime(open_time).strftime("%H:%M")
else:
time_offset = "9:30"
else:
time_offset = group.run_timing
schedule_index = tti(
start=start_date,
end=end_date,
freq=group.run_freq,
time_offset=time_offset,
**kwargs,
)
else: # for other unexpected cases
raise ValueError(f"Unsupported frequency '{group.run_freq}' for group '{group.name}'.")
self.group_schedules[group.name] = pd.DataFrame(
data=1,
index=schedule_index,
columns=['is_running'],
)
timing_table = pd.concat(self.group_schedules.values(), axis=1)
timing_table.columns = self.group_schedules.keys()
self.group_timing_table = timing_table.fillna(0).astype('int')
return
@staticmethod
def _build_schedule_time_kwargs_from_config(config: Mapping[str, Any]) -> dict:
"""从配置字典构建交易时段参数,统一供 ``prepare_running_schedule()`` 使用。"""
from qteasy.trading_util import build_operator_schedule_time_kwargs
return build_operator_schedule_time_kwargs(
market_open_time_am=config['market_open_time_am'],
market_close_time_am=config['market_close_time_am'],
market_open_time_pm=config['market_open_time_pm'],
market_close_time_pm=config['market_close_time_pm'],
include_start_am=True,
include_end_am=True,
include_start_pm=True,
include_end_pm=True,
)
[文档] def get_signal_count(self, steps=None) -> int:
""" 获取当前运行时间表中所有策略组生成的交易信号数量
Parameters
----------
steps: list of int, optional
如果给出steps,则只计算这些步骤对应的交易信号数量
如果为None,则计算所有步骤的交易信号数量
Returns
-------
int: 交易信号的数量
"""
assert not self.group_timing_table.empty, "Group timing table is empty. Please prepare it first."
if steps is not None:
running_schedule = self.group_timing_table.iloc[steps]
else:
running_schedule = self.group_timing_table
if self.group_merge_type == 'None':
return running_schedule.sum().sum() # same as np.sum(running_schedule.values)
else: # 'OR' or 'AND'
return len(running_schedule)
[文档] def prepare_data_buffer(self, *,
start_date: Union[str, pd.Timestamp],
end_date: Union[str, pd.Timestamp],
data_package: dict) -> None:
""" 准备数据缓冲区,加载所有策略需要的数据
数据缓冲区是一个字典,键为数据类型,值为对应的数据DataFrame,输入参数包括数据包的开始和结束日期,
根据这两个日期从数据包中的每一个DataFrame中切片出相应的时间段,保存到数据缓冲区中。
保存数据缓冲时,还要检查并确保数据有足够的前置量以创建数据滑窗
Parameters
----------
start_date: str or pd.Timestamp
数据的开始日期,默认为None,表示从数据包的起始日期开始
end_date: str or pd.Timestamp
数据的结束日期,默认为None,表示到数据包的结束日期为止
data_package: dict[]
一个字典,包含所有需要的数据,键为数据类型,值为对应的数据DataFrame
例如:{'price': price_df, 'volume': volume_df, ...}
其中每个DataFrame的索引为时间戳,列为不同的标的代码
"""
# 清除原有的 data_buffers
self.data_buffers = {}
# 针对所有 data_type,检查数据包的 key 是否都是 str 且 value 都是 DataFrame 或
for key, data in data_package.items():
if not isinstance(key, str):
raise TypeError(f"Data package keys must be strings, got {type(key)} instead.")
if not isinstance(data, (pd.DataFrame, pd.Series)):
raise TypeError(f"Data package values must be pandas DataFrame or Series, got {type(data)} instead.")
all_hist_data_keys = [key for key, data in data_package.items() if isinstance(data, pd.DataFrame)]
all_ref_data_keys = [key for key, data in data_package.items() if isinstance(data, pd.Series)]
if len(all_hist_data_keys) > 0:
# 针对所有data_type,检查数据框的数据列是否相同且顺序一致(排除ref型数据(只有一列数据且名为'ref'))
data_columns = [data_package[key].columns for key in all_hist_data_keys]
first_cols = data_columns[0]
for cols in data_columns[1:]:
if not first_cols.equals(cols):
raise ValueError("Data package columns must be the same and in the same order for all data types, "
f"got {first_cols} and {cols} instead.")
if len(all_ref_data_keys) > 0:
# 针对所有ref型数据,检查数据索引是否相同且顺序一致
pass
for data_type in self.all_strategy_data_types:
if data_type.dtype_id not in data_package:
raise ValueError(f"Data type '{data_type}' required by strategies is missing in data package.")
dtype_max_window = self.get_max_window_length_by_dtype_id(data_type.dtype_id)
if len(data_package[data_type.dtype_id]) < dtype_max_window:
msg = (f"Not enough data for data type '{data_type}' to create data windows. "
f"Required: {dtype_max_window}, Available: {len(data_package[data_type.dtype_id])}")
raise ValueError(msg)
if data_package[data_type.dtype_id].index[dtype_max_window - 1].date() > pd.to_datetime(
start_date).date():
# 确保数据有足够的前置量
msg = (f"Not enough data for data type '{data_type}' to create data windows. \n"
f"Data package starts on {data_package[data_type.dtype_id].index[0]}, "
f"and start_date is {start_date}, \nbut the first available window starts on "
f" {data_package[data_type.dtype_id].index[dtype_max_window - 1]} (window length: "
f"{dtype_max_window}). ")
# to solve problem of insufficient data when freq = 'Q'
raise ValueError(msg)
# 检查数据索引是否包含所需的时间范围且含有足够的前置数据
self.data_buffers[data_type.dtype_id] = data_package[data_type.dtype_id]
[文档] def prepare_dynamic_data_buffer(self, *,
trade_records: np.ndarray,
trade_prices: np.ndarray,
own_cashes: np.ndarray,
available_cashes: np.ndarray,
holding_positions: np.ndarray,
available_positions: np.ndarray) -> None:
"""预留接口:过程数据已统一通过 proc.* 注入,本方法不再执行逻辑。"""
# 过程数据由 Backtester/Trader 通过 _process_data_sources 注入,策略通过 get_data('proc.xxx') 访问
pass
[文档] def create_data_windows(self):
""" Create data windows for each strategy and its data types.
Also create data window indices for each strategy and its data types.
data window indices are created according to group schedules.
"""
if self.group_timing_table is None:
raise ValueError("Group timing table is not set. Please set it before creating data windows.")
for group in self._groups:
schedule = self.group_timing_table
for strategy in group.members:
self.data_window_views[strategy.strategy_id] = {}
self.data_window_indices[strategy.strategy_id] = {}
for data_type in strategy.data_types:
# DYNAMIC_DATA_TYPES 不需要创建rolling window
if data_type in self.all_dynamic_dtypes:
# dynamic data types 在 prepare_dynamic_data_buffer 中处理
continue
window_length = strategy.data_window_lengths[data_type]
ulc = strategy.data_ulc[data_type]
buffered_data = self.data_buffers.get(data_type, None)
window = rolling_window(buffered_data.values, window=window_length, axis=0)
self.data_window_views[strategy.strategy_id][data_type] = window
total_window_indices = np.arange(len(buffered_data) - window_length + 1) + window_length - 1
running_schedule = schedule.index
window_schedules = buffered_data.index[total_window_indices]
# 如果strategy设置了“use_latest_cycle”,这就表明数据窗口的时间可以等于运行时间。
# 这时应该使用参数side="right"来运行np.searchsorted,使找到的数据窗口时间小于等于运行时间
if ulc:
schedule_indices = np.searchsorted(window_schedules, running_schedule, side='right') - 1
else: # 否则,数据窗口的时间必须严格小于运行时间
schedule_indices = np.searchsorted(window_schedules, running_schedule) - 1
self.data_window_indices[strategy.strategy_id][data_type] = schedule_indices
[文档] def run_strategy(self,
step_index) -> Generator[
Union[tuple[Any, int, Any], tuple[Optional[Any], int, Union[int, Any]]], Any, None]:
""" 运行当前步骤的所有策略组,生成交易信号
本函数是一个生成器函数,返回每个策略组在当前步骤的交易信号。
Parameters
----------
step_index: int
当前步骤的索引,表示在运行时间表中的位置
Returns
-------
generator: (signal_type, step_index, signal)
返回一个生成器,包含每个策略组在当前步骤的交易信号
signal_type: str, 策略组的信号类型
step_index: int, 当前步骤的索引
signal: np.ndarray, 交易信号,一组数字,在不同信号类型模式下表示不同的含义
"""
if self.group_timing_table is None:
raise ValueError("Group timing table is not set. Please set it before running steps.")
# 计算当前步骤对应的“全局 signal 行号”起点(与 op_signal_index 对齐),
# 无论是否启用 tracing,process data 都依赖这一索引。
if self.group_merge_type == 'None':
base_signal_index = int(self.group_timing_table.iloc[:step_index].values.sum())
else:
base_signal_index = step_index
# 对 tracing 来说,保持原有语义:使用全局 signal 行号,与 op_signal_index 一致
if self._trace_enabled:
self._trace_signal_index = base_signal_index
# print(f'taking step index: {step_index} from group_timing_table with shape {self.group_timing_table.shape}')
group_timing = self.group_timing_table.iloc[step_index].values
group_count = len(self.groups)
groups = [self.groups_by_index[i] for i in range(group_count) if group_timing[i]]
signal_type = groups[0].signal_type if groups else None
signal = 0 if self.group_merge_type == 'Or' else 1
# DEBUG:
# print(f'In current op run step, following groups are running: {groups}')
current_index = base_signal_index
for group in groups:
# ----set up data window for each strategy
for strategy in group.members:
strategy.update_running_data_window(
data_windows=self.data_window_views[strategy.strategy_id],
window_indices=self.data_window_indices[strategy.strategy_id],
window_index=step_index,
)
# ---- take care of tracing if enabled(使用全局 signal 行号,与 op_signal_index 一致)
if self._trace_enabled:
for stg in group.members:
stg.update_trace_step(step=self._trace_signal_index)
# ---- end setting up data windows
signal_type = group.signal_type
# 在生成信号前,更新当前全局 signal 行号,供 process data 访问使用
self._current_signal_index = current_index
signals = [stg.generate() for stg in group.members]
if self.group_merge_type == 'None':
signal = group.blend(signals)
yield signal_type, step_index, signal
current_index += 1
if self._trace_enabled:
self._trace_signal_index += 1
elif self.group_merge_type == 'Or':
signal += group.blend(signals)
elif self.group_merge_type == 'And':
signal *= group.blend(signals)
else:
raise ValueError(f'Invalid group merge type: {self.group_merge_type}')
if self.group_merge_type != 'None':
# 对于 AND / OR 合并模式,同一时间步只产生一条合并后的信号
self._current_signal_index = current_index
yield signal_type, step_index, signal
if self._trace_enabled:
self._trace_signal_index += 1
[文档] def run_strategies(self, steps: Iterable) -> Iterable:
"""运行 Operator,返回运行结果;语义接近 ``qt.run(self, ...)`` 传入的关键字参数形式。
Parameters
----------
steps: Iterable
一个可迭代对象,包含需要运行的步骤索引
Yields
------
generator: (signal_type, step_index, signal)
返回一个生成器,包含每个步骤的交易信号
signal_type: str, 策略组的信号类型
step_index: int, 当前步骤的索引
signal: np.ndarray, 交易信号,一组数字,在不同信号类型模式下表示不同的含义
"""
self.is_ready(raise_error=True)
for step in steps:
for result in self.run_strategy(step):
yield result
# ================= High level running functions ===================
[文档] def run(self, config, datasource=None, logger=None):
""" 根据配置参数运行operator,支持实盘和回测两种模式
Parameters
----------
config: dict
运行配置参数字典
datasource: DataSource, optional
数据源对象,默认为None,表示使用全局默认数据源
logger: Logger, optional
日志记录器对象,默认为None,表示不使用日志记录器
Returns
-------
backtest_result: BacktestResult, optional
如果运行模式为回测,则返回回测结果对象,包含回测的各种结果数据
如果运行模式为实盘,则不返回任何结果
"""
run_mode = config['mode']
if run_mode == 0 or run_mode == 'live':
# 进入实盘交易模式,获取实时行情数据,生成交易信号并执行交易
return self.run_live_trade(
config=config,
datasource=datasource,
logger=logger,
)
elif run_mode == 1 or run_mode == 'backtest':
# 进入回测模式,生成历史交易清单,使用真实历史价格回测策略的性能
return self.run_backtest(
config=config,
datasource=datasource,
logger=logger,
)
elif run_mode == 2 or run_mode == 'optimize':
# 进入优化模式,使用真实历史数据或模拟历史数据反复测试策略,寻找并测试最佳参数
return self.run_optimization(
config=config,
datasource=datasource,
logger=logger,
)
elif run_mode == 3 or run_mode == 'predict':
# 进入预测评价模式,使用随机生成的历史数据对策略进行性能预测评价
return self.run_prediction(
config=config,
datasource=datasource,
logger=logger,
)
else:
raise ValueError(f'Invalid run mode: {run_mode}')
[文档] def run_live_trade(self, config, datasource=None, logger=None):
""" 在实盘模式下运行operator"""
from qteasy.config_parser import (
parse_trade_cost_params,
)
# 进入实时信号生成模式:仅允许已验证链路(股票 E、场内基金 FD 等)
_at = config['asset_type']
if _at not in _LIVE_TRADE_SUPPORTED_ASSET_TYPES:
allowed = ', '.join(sorted(_LIVE_TRADE_SUPPORTED_ASSET_TYPES))
raise ValueError(
f'Live trade mode supports asset_type in {{{allowed}}}, got {_at!r} instead. '
f'For example: qt.configure(asset_type="E") or asset_type="FD".'
)
from qteasy.live_config import build_live_trade_config
live_cfg = build_live_trade_config(config)
import qteasy as qt
ds = datasource if datasource is not None else qt.QT_DATA_SOURCE
from qteasy.trade_recording import get_or_create_position, resolve_live_trade_account_id, update_position
init_holdings = config['live_trade_init_holdings']
account_id = resolve_live_trade_account_id(config, data_source=ds)
# if init_holdings is not None then add holdings to account
if init_holdings is not None:
if not isinstance(init_holdings, dict):
err = ValueError(f'init_holdings must be a dict, got {type(init_holdings)} instead.')
raise err
for symbol, amount in init_holdings.items():
pos_id = get_or_create_position(
account_id=account_id,
symbol=symbol,
position_type='long' if amount > 0 else 'short',
data_source=ds,
)
update_position(
position_id=pos_id,
data_source=ds,
**{
'qty_change': abs(amount),
'available_qty_change': abs(amount),
}
)
# if account is ready then create trader and broker
broker_type = live_cfg.live_trade_broker_type
broker_params = (
dict(live_cfg.live_trade_broker_params)
if live_cfg.live_trade_broker_params is not None
else config['live_trade_broker_params']
)
if (broker_type == 'simulator') and (broker_params is None):
broker_params = {
"fee_rate_buy": config['cost_rate_buy'],
"fee_rate_sell": config['cost_rate_sell'],
"fee_min_buy": config['cost_min_buy'],
"fee_min_sell": config['cost_min_sell'],
"slippage": config['cost_slippage'],
"moq_buy": config['trade_batch_size'],
"moq_sell": config['sell_batch_size'],
"delay": 1.0,
"price_deviation": 0.001,
"probabilities": (0.5, 0.45, 0.05), # originally: (0.9, 0.08, 0.02)
}
from qteasy.broker import BrokerFacade, get_broker
from qteasy.trader import Trader
broker = get_broker(broker_type, broker_params)
broker = BrokerFacade(broker)
cost_params = np.array(
list(parse_trade_cost_params(config, asset_type=config.get('asset_type')).values()),
dtype='float',
) # 交易成本参数
# 与回测路径保持一致:在实盘启动前注入标的列表,确保策略 share_count / share_names 正确。
raw_asset_pool = config['asset_pool']
if isinstance(raw_asset_pool, str):
from qteasy.utilfuncs import str_to_list
live_asset_pool = str_to_list(raw_asset_pool)
elif isinstance(raw_asset_pool, list):
live_asset_pool = raw_asset_pool
else:
raise TypeError(
f'asset_pool should be str or list[str], got {type(raw_asset_pool)} instead'
)
self.set_shares(live_asset_pool)
trader = Trader(
operator=self,
account_id=account_id,
broker=broker,
datasource=ds,
asset_pool=live_asset_pool,
asset_type=config['asset_type'],
time_zone=config['time_zone'],
exchange='SSE',
market_open_time_am=config['market_open_time_am'],
market_open_time_pm=config['market_open_time_pm'],
market_close_time_am=config['market_close_time_am'],
market_close_time_pm=config['market_close_time_pm'],
live_price_channel=config['live_price_acquire_channel'],
live_price_freq=config['live_price_acquire_freq'],
live_data_channel=config['live_trade_data_refill_channel'],
live_data_batch_size=config['live_trade_data_refill_batch_size'],
live_data_batch_interval=config['live_trade_data_refill_batch_interval'],
watched_price_refresh_interval=config['watched_price_refresh_interval'],
benchmark_asset=config['benchmark_asset'],
live_sys_logger=None,
cost_params=cost_params,
pt_buy_threshold=config['PT_buy_threshold'],
pt_sell_threshold=config['PT_sell_threshold'],
allow_sell_short=config['allow_sell_short'],
trade_batch_size=config['trade_batch_size'],
sell_batch_size=config['sell_batch_size'],
long_position_limit=config['long_position_limit'],
short_position_limit=config['short_position_limit'],
stock_delivery_period=config['stock_delivery_period'],
cash_delivery_period=config['cash_delivery_period'],
open_close_timing_offset=config['strategy_open_close_timing_offset'],
daily_refill_tables=config['live_trade_daily_refill_tables'],
weekly_refill_tables=config['live_trade_weekly_refill_tables'],
monthly_refill_tables=config['live_trade_monthly_refill_tables'],
debug=False,
live_config=live_cfg,
)
trader.register_broker(debug=trader.debug)
# find out datasource availabilities, refill data source if table data not available
from qteasy.trader import refill_missing_datasource_data
refill_missing_datasource_data(
operator=self,
trader=trader,
datasource=ds,
)
ui_type = live_cfg.live_trade_ui_type
if ui_type.lower() == 'cli':
from .trader_cli import TraderShell
TraderShell(trader).run()
elif ui_type.lower() == 'tui':
from .trader_tui import TraderApp
TraderApp(trader).run()
else:
err = TypeError(f'Invalid ui type: ({ui_type})! use "cli" or "tui" instead.')
raise err
[文档] def run_backtest(self, config, datasource=None, logger=None) -> dict:
""" 在回测模式下运行operator, 使用历史数据进行交易策略的回测并生成回测结果
Parameters
----------
config: dict
回测配置参数字典
datasource: DataSource, optional
数据源对象,默认为None,表示使用全局默认数据源
logger: Logger, optional
日志记录器对象,默认为None,表示不使用日志记录器
Returns
-------
backtest_result: BacktestResult
回测结果对象,包含回测的各种结果数据
"""
from qteasy.config_parser import (
parse_backtest_cash_plan,
parse_backtest_start_end_dates,
parse_trade_cost_params,
parse_signal_parsing_params,
parse_trading_moq_params,
parse_trading_delivery_params,
)
from qteasy.backtest import generate_cash_invest_and_delivery_arrays
# 创建回测交易所需的各种参数和辅助参数,包括现金投入和交割所需数据表
start_date, end_date = parse_backtest_start_end_dates(config=config) # 回测开始和结束日期
# 在生成交易信号之前准备运行计划及历史数据
schedule_time_kwargs = self._build_schedule_time_kwargs_from_config(config)
self.prepare_running_schedule(
start_date=start_date,
end_date=end_date,
**schedule_time_kwargs,
)
# 现金投入和交割数据表
invest_cash_plan = parse_backtest_cash_plan(config)
(cash_investment_array,
cash_inflation_array,
delivery_day_indicators) = generate_cash_invest_and_delivery_arrays(
invest_cash_plan=invest_cash_plan,
group_merge_type=self.group_merge_type,
timing_table=self.group_timing_table,
)
cash_plan = parse_backtest_cash_plan(config) # 资金投入计划
cost_params = np.array(list(parse_trade_cost_params(config).values()), dtype='float') # 交易成本参数
signal_parsing_params = parse_signal_parsing_params(config) # 交易信号解析参数
trading_moq_params = parse_trading_moq_params(config) # 交易最小单位参数
trading_delivery_params = parse_trading_delivery_params(config) # 交易交割参数
data_package = check_and_prepare_backtest_data(
op=self,
backtest_start=start_date,
backtest_end=end_date,
shares=config['asset_pool'],
datasource=datasource,
)
trade_prices = check_and_prepare_trade_prices(
op=self,
shares=config['asset_pool'],
price_adj=config['backtest_price_adj'],
datasource=datasource,
)
# 确保trade_prices包含所有交易时间点,但是不需要填充NaN值价格,当价格为NaN(例如停牌)时,回测引擎会自动跳过交易
trade_prices = trade_prices.reindex(index=self.op_signal_index.get_level_values(0))
# trade_prices.ffill(inplace=True)
# 如果trace_price存在完全为NaN的列,说明有可能数据不完整,或者数据不需要使用
if any(np.isnan(trade_prices)):
# warnings.warn('There are all-NaN columns in trade_prices, that means missing data or '
# 'unnecessary shares', UserWarning)
# trade_prices.fillna(1.0, inplace=True)
pass
hist_benchmark = check_and_prepare_benchmark_data(
op=self,
benchmark_symbol=config['benchmark_asset'],
datasource=datasource,
backtest_start=start_date,
backtest_end=end_date,
)
from qteasy.history import check_and_prepare_evaluate_price_data
evaluate_price_data = check_and_prepare_evaluate_price_data(
op=self,
shares=config['asset_pool'],
datasource=datasource,
backtest_start=start_date,
backtest_end=end_date,
backtest_price_adj=config['backtest_price_adj'],
)
self.prepare_data_buffer(
start_date=start_date,
end_date=end_date,
data_package=data_package,
)
self.create_data_windows()
# 生成交易清单,对交易清单进行回测,对回测的结果进行基本评价
from qteasy.backtest import Backtester
backtested = Backtester(
op=self,
shares=config['asset_pool'],
cash_plan=cash_plan,
benchmark_data=hist_benchmark,
evaluate_price_data=evaluate_price_data,
cash_investment_array=cash_investment_array,
cash_inflation_array=cash_inflation_array,
delivery_day_indicators=delivery_day_indicators,
cost_params=cost_params,
signal_parsing_params=signal_parsing_params,
trading_moq_params=trading_moq_params,
trading_delivery_params=trading_delivery_params,
trade_price_data=trade_prices.values,
enable_tracing=config['trace_log'],
logger=logger,
).run()
# 保存 Backtester 对象到 Operator,便于在测试或后续处理中直接访问回测“金标准”数组
# (own_amounts_array / own_cashes / trade_price_data / trade_cost_array 等)
self.backtested = backtested
# 评价回测结果——根据交易结果生成交易结果的评价结果
backtested.evaluate_result(
indicators=config['test_indicators'],
)
backtest_datetime = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
if config['trade_log']:
from qteasy import QT_TRADE_LOG_PATH
trade_log_file_name = sanitize_filename(
f'trade_log_{self.name}_{backtest_datetime}.csv'
)
trade_log_file_path = os.path.join(QT_TRADE_LOG_PATH, trade_log_file_name)
trade_summary_file_name = sanitize_filename(
f'trade_summary_{self.name}_{backtest_datetime}.csv'
)
trade_summary_file_path = os.path.join(QT_TRADE_LOG_PATH, trade_summary_file_name)
backtested.generate_trade_logs(save_to_file_path=trade_log_file_path)
backtested.generate_trade_summary(save_to_file_path=trade_summary_file_path)
value_curve_file_name = sanitize_filename(
f'value_curve_{self.name}_{backtest_datetime}.csv'
)
value_curve_file_path = os.path.join(QT_TRADE_LOG_PATH, value_curve_file_name)
saved_value_curve = backtested.save_complete_values(save_to_file_path=value_curve_file_path)
if saved_value_curve is not None:
qteasy_log = logger if logger is not None else logging.getLogger('qteasy')
qteasy_log.info(
'value curve (complete_values) saved to %s', saved_value_curve
)
if config['report']:
# 格式化输出回测结果
report = backtested.report_result()
print(report)
if config['visual']:
# 图表输出投资回报历史曲线
plot_title = f'Backtest Report: {self.name} - {backtest_datetime}' if self.name else \
f'Backtest Report - {backtest_datetime}'
backtested.plot_result(
plot_title=plot_title,
buy_sell_markers=config['buy_sell_points'],
show_positions=config['show_positions'],
)
return backtested.backtest_result
[文档] def run_optimization(self, config, datasource=None, logger=None) -> dict:
""" 在优化模式下运行operator"""
from qteasy.config_parser import (
parse_trade_cost_params,
parse_signal_parsing_params,
parse_trading_moq_params,
parse_trading_delivery_params,
parse_optimization_start_end_dates,
parse_optimization_cash_plan,
parse_all_optimization_params
)
# 创建策略优化的优化区间和测试区间的开始和结束日期
opti_start, opti_end, test_start, test_end = parse_optimization_start_end_dates(config=config) # 回测开始和结束日期
opti_cash_plan, test_cash_plan = parse_optimization_cash_plan(config) # 资金投入计划
# 解析交易相关参数
cost_params = np.array(list(parse_trade_cost_params(config).values()), dtype='float') # 交易成本参数
signal_parsing_params = parse_signal_parsing_params(config) # 交易信号解析参数
trading_moq_params = parse_trading_moq_params(config) # 交易最小单位参数
trading_delivery_params = parse_trading_delivery_params(config) # 交易交割参数
# 判断operator对象的策略中是否有可优化的参数,即优化标记opt_tag设置为1,且参数数量不为0
assert self.opt_space_par[0] != [], \
f'ConfigError, none of the strategy parameters is adjustable, set opt_tag to be 1 or 2 to ' \
f'activate optimization in mode 2, and make sure strategy has adjustable parameters'
optimization_config = parse_all_optimization_params(config=config)
opti_data_package = check_and_prepare_backtest_data(
op=self,
backtest_start=opti_start,
backtest_end=opti_end,
shares=config['asset_pool'],
datasource=datasource,
)
test_data_package = check_and_prepare_backtest_data(
op=self,
backtest_start=test_start,
backtest_end=test_end,
shares=config['asset_pool'],
datasource=datasource,
)
from qteasy.optimization import Optimizer
optimizer = Optimizer(
op=self,
method=config['opti_method'],
shares=config['asset_pool'],
benchmark=config['benchmark_asset'],
pool_size=config['opti_output_count'],
opti_target=config['optimize_target'],
opti_direction=config['optimize_direction'],
parallel=config['parallel'],
search_config=optimization_config,
opti_start_date=opti_start,
opti_end_date=opti_end,
test_start_date=test_start,
test_end_date=test_end,
opti_cash_plan=opti_cash_plan,
test_cash_plan=test_cash_plan,
cost_params=cost_params,
signal_parsing_params=signal_parsing_params,
trading_moq_params=trading_moq_params,
trading_delivery_params=trading_delivery_params,
logger=logger,
evaluate_indicators=config['test_indicators'],
test_plot_type=config['indicator_plot_type'],
)
# 准备优化数据
# 生成优化交易运行计划
# debug
schedule_time_kwargs = self._build_schedule_time_kwargs_from_config(config)
print(f'Preparing optimization data from {opti_start} to {opti_end}...')
self.prepare_running_schedule(
start_date=opti_start,
end_date=opti_end,
**schedule_time_kwargs,
)
# debug
# print(f'preparing data buffer...')
self.prepare_data_buffer(
start_date=opti_start,
end_date=opti_end,
data_package=opti_data_package,
)
# print(f'creating data windows...')
self.create_data_windows()
# print(f'Preparing trade prices...')
opti_trade_prices = check_and_prepare_trade_prices(
op=self,
shares=config['asset_pool'],
price_adj=config['backtest_price_adj'],
datasource=datasource,
)
# print(f'Preparing benchmark data')
opti_benchmark = check_and_prepare_benchmark_data(
op=self,
benchmark_symbol=config['benchmark_asset'],
datasource=datasource,
backtest_start=opti_start,
backtest_end=opti_end,
)
# print(f'Starting optimization...')
optimizer.optimize(
benchmark_data=opti_benchmark,
trade_price_data=opti_trade_prices.values,
)
print(f'Optimization finished, best parameters:\n')
if config['report']:
# 输出优化结果报告
print(optimizer.report_result(stage='optimization'))
# ======== 开始校验 =========
# 生成校验交易运行计划
print(f'Preparing test data from {test_start} to {test_end}...')
self.prepare_running_schedule(
start_date=test_start,
end_date=test_end,
**schedule_time_kwargs,
)
print(f'preparing data buffer...')
self.prepare_data_buffer(
start_date=test_start,
end_date=test_end,
data_package=test_data_package,
)
print(f'creating data windows...')
self.create_data_windows()
# 如果operator尚未准备好,is_ready()会检查汇总所有问题点并raise error
self.is_ready(raise_error=True)
print(f'Preparing trade prices for test...')
test_trade_prices = check_and_prepare_trade_prices(
op=self,
shares=config['asset_pool'],
price_adj=config['backtest_price_adj'],
datasource=datasource,
)
print(f'Preparing benchmark data...')
test_benchmark = check_and_prepare_benchmark_data(
op=self,
benchmark_symbol=config['benchmark_asset'],
datasource=datasource,
backtest_start=test_start,
backtest_end=test_end,
)
print(f'Starting validation on test data...')
optimizer.validate(
benchmark_data=test_benchmark,
trade_price_data=test_trade_prices.values,
)
if config['report']:
# 输出优化结果报告
print(optimizer.report_result(stage='validation'))
if config['visual']:
# 图表输出优化结果
optimizer.plot_result()
return optimizer.result_pool
[文档] def run_prediction(self, config, datasource=None, logger=None) -> dict:
""" 在与测模式下运行operator"""
raise NotImplementedError