# coding=utf-8
# ======================================
# File: strategy.py
# Author: Jackie PENG
# Contact: jackie.pengzhao@gmail.com
# Created: 2020-09-27
# Desc:
# Strategy Base Classes and its derived
# Classes.
# ======================================
import numpy as np
import pandas as pd
from abc import abstractmethod, ABCMeta
from typing import Union, List, Tuple, Dict, Any, Callable, Literal, Iterable
import warnings
from qteasy.utilfuncs import (
TIME_FREQ_STRINGS,
input_to_list, str_to_list,
)
from qteasy.datatypes import (
DataType,
StgData,
)
from qteasy.parameter import Parameter
def _dict_par_format_is_valid(par_name: str, pars, value_type, key_type):
"""检查字典参数的键和值是否符合约定格式。
Parameters
----------
par_name : str
参数名,用于错误信息提示。
pars : Any
待检查的对象,期望为 dict。
value_type : type
dict 的 value 期望类型。
key_type : str
期望 value 对象上存在的属性名,该属性需为 str,用于与 dict 的 key 对齐。
Returns
-------
bool
检查通过返回 True。
Raises
------
AssertionError
当 ``pars`` 非 dict、value 类型不匹配、或 key 与 value 对象属性不一致时抛出(英文)。
"""
assert isinstance(pars, dict), f'parameter "{par_name}" is invalid, please check your input'
assert all(isinstance(dtype, value_type) for dtype in pars.values()), \
f'parameter "{par_name}" should be a dict of {value_type} objects, got {pars} instead'
assert all(isinstance(dtype.__getattribute__(key_type), str) for dtype in pars.values()), \
f'parameter "{par_name}" should be a dict of {value_type} objects with {key_type} as key, ' \
f'got {pars} instead'
assert all(key == dtype.__getattribute__(key_type) for key, dtype in pars.items()), \
f'parameter "{par_name}" should be a dict of {value_type} objects with {key_type} as key, ' \
f'got {pars} instead'
return True
[文档]class BaseStrategy:
"""量化投资策略的抽象基类,所有具体策略都应从本类继承并实现交易信号生成逻辑。
一个完整的策略通常由三部分组成:可调参数(pars)、所需历史数据的声明
(data_types + window_length 等)以及基于这些数据生成信号的 ``realize()`` 逻辑。
在策略运行过程中,可通过 ``get_pars()`` 和 ``get_data()`` 访问参数与数据,输出
的实数信号再由 Operator 按 PT/PS/VS 信号语义解析为实际交易指令。关于自定义策略
实现步骤与 PT/PS/VS 的详细说明,见文档「交易策略与 BaseStrategy」相关章节。
Examples
--------
BaseStrategy 为抽象基类,通常通过继承实现自定义策略。下面示例展示其类型信息(稳定输出):
>>> import qteasy as qt
>>> qt.BaseStrategy.__name__
'BaseStrategy'
"""
__metaclass__ = ABCMeta
AVAILABLE_STG_RUN_TIMING = ['open', 'close', 'unit_nav', 'accum_nav']
def __init__(
self,
*,
name: str = '',
description: str = '',
stg_type: str = 'BASE',
pars: Union[Parameter, List[Parameter], Dict[str, Parameter]] = None,
data_types: Union[DataType, List[DataType], Dict[str, DataType]] = None,
use_latest_data_cycle: Union[bool, List[bool], Dict[str, bool]] = False,
window_length: Union[int, List[int], Dict[str, int]] = 270,
opt_tag: int = 0,
par_values: Union[Tuple[Any], List[Any]] = None,
):
""" 初始化策略
Parameters
----------
name: str
策略名称,用户自定义策略的名称,用于区分不同的策略
description: str
策略描述,用户自定义策略的描述,用于区分不同的策略
stg_type: str
策略类型,用户自定义,用于区分不同的策略,例如均线策略、趋势跟随策略等
pars: Parameter, list of Parameter, dict {str: Parameter}, default None
策略可调参数,Parameter对象,确定策略的可调参数,参数类型以及取值范围
data_types: DataType, list of DataTypes, dict{str: DataType}
策略使用的数据类型,每个数据类型一个类型名
use_latest_data_cycle: bool, list of bool, dict{str: bool}, default True
是否使用最新的数据周期生成交易信号,默认True
如果为True: 默认值
在实盘运行时,会尝试下载当前周期的最新数据,或尝试使用最近的实时数据估算当前周期的数据,此时应该注意避免出现未来函数,
如运行时间点为开盘时,这时就不能使用收盘价/最高价/最低价生成交易信号,会导致策略运行失真。
在回测交易时,会使用回测当前时间点的最新数据生成交易信号。此时应该注意避免出现未来函数,如回测时间点为
开盘时,但是使用当前周期的收盘价生成交易信号,会导致策略运行失真。
如果为False:
在回测或实盘运行时都仅使用当前已经获得的上一周期的已知数据生成交易信号,在运行频率较低时,可能会导致
交易信号的滞后,但是可以避免未来函数的出现。
window_length: int, list of int, dict{str: int}, default 30
策略使用的数据窗口长度,即策略使用的历史数据的长度
opt_tag: int {0, 1}
策略的优化标签,0表示不参与优化,1表示参与优化
Returns
-------
None
"""
# 检查策略参数是否合法:
from qteasy import logger_core
logger_core.info(f'initializing new Strategy: type: {stg_type}, name: {name}, text: {description}')
self._stg_name = str(name)
self._stg_description = str(description)
self._pars = None
self.set_pars(pars) # 设置策略参数,使用set_pars()函数同时检查参数的合法性
self._opt_tag = None
self.set_opt_tag(opt_tag) # 策略的优化标记,
self._stg_type = stg_type # 策略类型
if par_values:
self.update_par_values(*par_values)
logger_core.info(f'Strategy created with basic parameters set, pars={pars}, par_count={self.par_count},'
f' par_types={self.par_types}, par_range={self.par_range}')
self._data_types = None
self._data_ids = None # 一个list,保存所有数据类型的data_id
self._data_ULC = None # 一个dict,保存每个data的最新周期使用标志
self._data_WL = None # 一个dict,保存每个data的窗口长度
self.set_data_types(data_types, use_latest_data_cycle, window_length)
logger_core.info(f'Strategy data types set:\n'
f'data_types={self.data_types}, data_ids={self._data_ids}, ')
# 以下是策略运行时产生的动态参数
self._share_names = None
self._strategy_id = None # 策略的唯一ID,在策略运行时由系统分配
self._group_id = None # deprecate: 策略所在的策略组ID,策略组是一个策略的集合,策略组可以包含多个策略
self._group = None # 策略所在的策略组,使用self.assign_group()方法分配
# 交易策略追踪相关参数 -- 用户可以在realize()方法中定义变量追踪,将变量值存储在_trace_data中
self._trace_enabled = False # 设定是否启用追踪
self._trace_data = {} # 一个dict,记录所有的追踪数据,用于创建追踪信息DataFrame,key为trace变量名,value为记录的变量值
self._trace_max_steps = 0 # 在一次运行中可能产生的最大运行次数,此数字与当次运行的run_schedule的行数相同
self._trace_step = 0
# 每一次记录数据时的step_index,在运行多个策略组时,本策略不是所有时间点都运行的,因此
# trace_step对应了该策略运行时间点在整个schedule中的位置,确保记录的变量值与正确的运
# 行时间点对齐
# 策略的其他可设置参数
self.debug = False # 是否开启调试模式
self.trace_mode = False # 是否开启追踪模式
self.logger = None # 策略的日志记录器
# 运行时由 Operator 设置,用于在策略内部访问 Operator(例如 process data)
self._operator = None
@property
def name(self):
"""策略名称,打印策略信息的时候策略名称会被打印出来"""
return self._stg_name
@name.setter
def name(self, name: str):
self._stg_name = name
@property
def strategy_id(self):
return self._strategy_id
@property
def group(self):
return self._group
@property
def group_id(self):
"""策略所属策略组的ID,策略加入Operator后由系统分配"""
return self.group.name if self._group is not None else None
@property
def description(self):
"""策略说明文本,对策略的实现方法和功能进行简要介绍"""
return self._stg_description
@description.setter
def description(self, description: str):
self._stg_description = str(description)
@property
def stg_type(self):
"""策略类型,表明策略的基类,即:
- GeneralStg: GENERAL
- FactorSorter: FACTOR
- RuleIterator: RULE-ITER
"""
return self._stg_type
@property
def has_pars(self) -> bool:
"""返回True如果策略有可调参数,否则返回False"""
return self.pars != {}
@property
def pars(self):
"""策略参数,是一个列表,列表中的每个元素都是一个参数值"""
return self._pars
@pars.setter
def pars(self, new_pars):
"""设置策略参数,参数的合法性检查在这里进行"""
self.set_pars(new_pars)
@property
def par_count(self):
"""策略的参数数量"""
return len(self.par_values)
@property
def par_values(self) -> tuple:
"""策略参数,元组
Return
-------
tuple: 策略参数的值,元组中的每个元素都是一个参数值,如果没有设置参数,则返回None
"""
return tuple(par.value for par in self._pars.values()) if self._pars is not None else None
@par_values.setter
def par_values(self, pars: tuple):
"""设置策略参数,参数的合法性检查在这里进行"""
self.update_par_values(*pars)
@property
def par_names(self):
"""策略的参数名称列表"""
return [par.name for par in self.pars.values()] if self.par_values is not None else []
@property
def par_types(self):
"""策略的参数类型,由策略参数类的par_type属性给出"""
return {name: par.par_type for name, par in self.pars.items()}
@property
def par_range(self):
"""策略的参数取值范围,用来定义参数空间用于参数优化"""
return {name: par.par_range for name, par in self.pars.items()}
@property
def opt_tag(self):
"""策略的优化类型"""
return self._opt_tag
@opt_tag.setter
def opt_tag(self, opt_tag):
self.set_opt_tag(opt_tag=opt_tag)
@property
def run_freq(self):
"""策略生成的采样频率,从所属Group读取。策略需先加入Operator才能访问此属性"""
if self._group is None:
raise AttributeError(
"Strategy must be added to an Operator (and thus a Group) before run_freq can be accessed. "
"Use op.add_strategy(stg, run_freq='d', run_timing='close') to add strategy to operator."
)
return self._group.run_freq
@property
def run_timing(self):
"""策略的运行时机,从所属Group读取。策略运行时机决定了live运行时策略的运行时间,以及回测时策略的价格类型。
策略需先加入Operator才能访问此属性"""
if self._group is None:
raise AttributeError(
"Strategy must be added to an Operator (and thus a Group) before run_timing can be accessed. "
"Use op.add_strategy(stg, run_freq='d', run_timing='close') to add strategy to operator."
)
return self._group.run_timing
@property
def data_type_count(self):
"""策略依赖的历史数据类型的数量"""
return len(self.data_types)
@property
def data_types(self):
"""策略依赖的历史数据类型"""
return self._data_types
@data_types.setter
def data_types(self, data_types: Union[DataType, List[DataType], Dict[str, DataType]]):
"""设置策略依赖的历史数据类型"""
if isinstance(data_types, StgData):
use_latest_data_cycle = data_types.use_latest_data_cycle
window_length = data_types.window_length
else:
use_latest_data_cycle = False
window_length = 30
self.set_data_types(
data_types,
use_latest_data_cycle=use_latest_data_cycle,
window_length=window_length,
)
@property
def data_type_ids(self):
"""策略依赖的历史数据类型的ID"""
return self._data_ids
@property
def data_ids(self):
"""策略依赖的历史数据类型的名称"""
return self._data_ids
@property
def data_names(self):
"""策略依赖的历史数据类型的名称"""
return {dtype_id: dtype.name for dtype_id, dtype in self._data_types.items()}
@property
def data_freqs(self):
"""策略依赖的历史数据类型的频率"""
return {dtype_id: dtype.freq for dtype_id, dtype in self._data_types.items()}
@property
def data_ulc(self):
"""策略依赖的历史数据类型的最新周期使用标志"""
return self._data_ULC
@property
def data_window_lengths(self):
"""策略依赖的历史数据类型的窗口长度"""
return self._data_WL
@property
def window_lengths(self):
"""策略依赖的历史数据类型的窗口长度"""
return self._data_WL
@property
def max_window_length(self):
""" 策略所有历史数据种类中,最大的窗口长度"""
return max(self._data_WL.values()) if self._data_WL else 0
@property
def share_count(self):
"""运行时参数,策略运行时的股票数量,只有运行后才能确定"""
if self._share_names is None:
return 0
return len(self._share_names)
@property
def share_names(self):
"""运行时参数,策略运行时的股票名称列表,只有运行后才能确定"""
if self._share_names is None:
warnings.warn('share_names is not set, please initialize the strategy first')
return []
return self._share_names
def get_use_latest_data_cycle(self, data_type: str = None) -> bool:
""" 根据dtype_id获取历史数据的最新周期使用参数"""
return self._data_ULC[data_type]
def get_data_ulc(self, data_type: str = None) -> bool:
""" 根据dtype_id获取历史数据的最新周期使用参数"""
return self._data_ULC[data_type]
def get_window_length(self, data_type: str = None) -> int:
"""根据dtype_id获取历史数据窗口长度"""
return self._data_WL[data_type]
def get_data_name(self, data_type: str = None) -> str:
""" 根据dtype_id获取数据类型的名称"""
return self.data_names[data_type]
def __str__(self):
"""返回交易策略的主要信息"""
return f'Strategy {self.stg_type}({self.name})'
def __repr__(self):
""" 打印对象的代表信息
Returns
-------
str
"""
return f'{self._stg_type}({self.name}, {self.par_values})'
[文档] def info(self, verbose: bool = False, status: bool = False, stg_id: str = None, extra_info:str = None) -> None:
"""打印所有相关信息和主要属性
Parameters
----------
verbose: bool, default False
是否打印更多的信息
status: bool, default False
是否打印策略的运行状态
stg_id: str, default None
策略的ID,如果为None,则打印策略的名称,否则打印策略的ID
extra_info: str, default None
额外的信息,可以是任何字符串,会被打印在策略主信息之后,参数和数据之前
Returns
-------
None
"""
from rich import print as rprint
from shutil import get_terminal_size
from .utilfuncs import adjust_string_length
if stg_id is None:
stg_id = self.name
term_width = get_terminal_size().columns
info_width = int(term_width * 0.75) if term_width > 120 else term_width
key_width = max(24, int(info_width * 0.3))
value_width = max(7, info_width - key_width)
stg_title = f' Strategy: {stg_id} '
rprint(f'{stg_title:=^{info_width}}')
if verbose:
rprint(f'{self.__str__()}: {self.description}')
else:
rprint(self.__str__())
if verbose:
# 打印额外信息
if extra_info:
rprint(extra_info)
# 打印所有策略可调参数相关信息
par_name_width = int(info_width * .1)
par_type_width = int(info_width * .2)
par_range_width = int(info_width * .2)
par_value_width = int(info_width * .5)
rprint(f'{" Parameters ":-^{info_width}}\n'
f'{"name":<{par_name_width}}'
f'{"type":<{par_type_width}}'
f'{"range":<{par_range_width}}'
f'{"value":<{par_value_width}}')
for par_name, par in self.pars.items():
rprint(
f'{adjust_string_length(par_name, par_name_width) :<{par_name_width}}'
f'{adjust_string_length(par.par_type, par_type_width) :<{par_type_width}}'
f'{adjust_string_length(str(par.par_range), par_range_width) :^{par_range_width}}'
f'{adjust_string_length(str(par.value), par_value_width) :^{par_value_width}}'
)
# 打印所有策略数据类型相关信息
dtype_id_width = int(info_width * .2)
window_width = int(info_width * .1)
ulc_width = int(info_width * .2)
description_width = int(info_width * .5)
rprint(f'{" Data Types ":-^{info_width}}\n'
f'{"id":<{dtype_id_width}}'
f'{"window":<{window_width}}'
f'{"use latest":<{ulc_width}}'
f'{"description":<{description_width}}')
for dtype_id, dtype in self.data_types.items():
rprint(
f'{adjust_string_length(dtype_id, dtype_id_width) :<{dtype_id_width}}'
f'{adjust_string_length(str(self.get_window_length(dtype_id)), window_width) :<{window_width}}'
f'{adjust_string_length(str(self.get_data_ulc(dtype_id)), ulc_width) :^{ulc_width}}'
f'{adjust_string_length(str(dtype.description), description_width, hans_aware=True) :^{description_width}}'
)
else:
par_info = f'{self.par_names} = {self.par_values}'
dtype_info = ', '.join([f'{dtype} x {window}' for dtype, window in self.window_lengths.items()])
rprint(f'Parameters: {par_info:<{value_width}}')
rprint(f'Date Types: {dtype_info:<{value_width}}')
# 打印额外信息
if extra_info:
rprint(extra_info)
def set_pars(self, pars: Union[Parameter, List[Parameter], Dict[str, Parameter]]) -> bool:
""" 设置交易策略的可调参数,不设定参数的值,设置成功返回True,否则返回False或Raise
Parameters
----------
pars: Parameter, list of Parameters, tuple of Parameters, dict{str: Parameter}
需要设置的参数字典,key为参数名、value为参数
Returns
-------
True: 当参数设置成功时
"""
if pars is None:
pars = {}
# 如果给出了pars且为tuple时,pars必须是Parameter对象,
elif isinstance(pars, Parameter):
pars = {pars.name: pars}
elif isinstance(pars, (list, tuple)):
if not all(isinstance(par, Parameter) for par in pars):
raise TypeError(f'pars should be a list of Parameter objects, got {type(pars)} instead')
pars = {par.name: par for par in pars}
elif isinstance(pars, dict):
# 确保每一个par.name与key一致
for key, par in pars.items():
if not isinstance(par, Parameter):
raise TypeError(f'pars should be a dict of Parameter objects, got {type(pars)} instead')
if key != par.name:
par.name = key
else:
raise TypeError(f'pars should be a list or a dict of Parameter Objects! got {type(pars)} instead')
if not _dict_par_format_is_valid('pars', pars, Parameter, 'name'):
raise ValueError(f'pars is invalid! ({pars})')
self._pars = {name: par for name, par in pars.items()}
for name, par in pars.items():
par.name = name
self.__setattr__(name, par.value)
return True
[文档] def get_pars(self, *par_names):
"""get the value of parameter by its name or id, alias as operator.par_name
multiple parameters can be got at one time"""
return self._get_pars_or_data(*par_names)
def _get_pars_or_data(self, *names: str):
"""get the value of parameter or data by its name or id, alias as operator.par_name or operator.dtype_id
multiple parameters or data can be got at one time"""
undefined_names = [name for name in names if name not in self.par_names and name not in self.data_type_ids]
if undefined_names:
raise KeyError(f'names {undefined_names} not defined in strategy {self}')
if len(names) > 1:
return tuple(self.__getattribute__(name) for name in names)
else:
return self.__getattribute__(names[0])
def set_data_types(self,
data_types: Union[DataType, List[DataType], Dict[str, DataType]],
use_latest_data_cycle,
window_length) -> None:
""" 设置策略参数
Parameters
----------
data_types: DataType,list of DataType, dict {str: DataType}
需要设置的参数字典,key为参数名、value为参数
use_latest_data_cycle: bool, list of bool, dict {str: bool}
是否使用最新的数据周期生成交易信号,默认仅使用截止到上一周期的数据生成交易信号
window_length: int, list of int, dict {str: int}
策略使用的数据窗口长度,即策略使用的历史数据的长度
Returns
-------
int: 1: 设置成功,0: 设置失败
"""
if data_types is None:
data_types = {}
elif isinstance(data_types, DataType):
data_types = {data_types.dtype_id: data_types}
elif isinstance(data_types, (list, tuple)):
data_types = {dtype.dtype_id: dtype for dtype in data_types}
elif isinstance(data_types, dict):
# set up dtype_id for each DataType object
for key, dtype in data_types.items():
if not isinstance(dtype, DataType):
raise TypeError(f'pars should be a dict of DataType objects, got {type(data_types)} instead')
if key != dtype.dtype_id:
dtype._dtype_id = key
else:
raise TypeError(f'pars is invalid! ({data_types})')
if not _dict_par_format_is_valid('data_types', data_types, DataType, 'dtype_id'):
raise ValueError(f'pars is invalid! ({data_types})')
self._data_types = data_types
self._data_ids = [dtype_name for dtype_name in data_types]
self._data_types = data_types
# 设置ULC
if isinstance(use_latest_data_cycle, bool):
self._data_ULC = {d_name: use_latest_data_cycle for d_name in self.data_types}
elif isinstance(use_latest_data_cycle, (list, tuple)):
ULCs = input_to_list(use_latest_data_cycle, len(self.data_types), False)
self._data_ULC = {self._data_ids[i]: ULCs[i] for i in range(len(ULCs))}
elif isinstance(use_latest_data_cycle, dict):
self._data_ULC = {d_name: False for d_name in self._data_ids}
self._data_ULC.update(use_latest_data_cycle)
elif use_latest_data_cycle is None:
self._data_ULC = {d_name: False for d_name in self._data_ids}
else:
raise TypeError(f'parameter "use_latest_data_cycles" is invalid ({use_latest_data_cycle}), '
f'please check your input')
# 如果DataType中给出了ULC,则更新相应的ULC值
for dtype_id, dtype in data_types.items():
if not isinstance(dtype, StgData):
continue
if dtype.use_latest_data_cycle is not None:
if not isinstance(dtype.use_latest_data_cycle, bool):
raise TypeError(f'use_latest_data_cycle should be a boolean, '
f'got {dtype.use_latest_data_cycle} instead')
self._data_ULC[dtype_id] = dtype.use_latest_data_cycle
# 设置window lengths
if isinstance(window_length, (int, float)):
if window_length <= 0:
raise ValueError(f'window_length should be a positive integer, got {window_length} instead')
window_length = int(window_length)
self._data_WL = {d_name: window_length for d_name in self.data_types}
elif isinstance(window_length, (list, tuple)):
WLs = input_to_list(window_length, len(self.data_types), 20)
self._data_WL = {self._data_ids[i]: WLs[i] for i in range(len(WLs))}
elif isinstance(window_length, dict):
self._data_WL = {d_name: 20 for d_name in self._data_ids}
self._data_WL.update(window_length)
elif window_length is None:
self._data_WL = {d_name: 20 for d_name in self._data_ids}
else:
raise TypeError(f'parameter "window_length" is invalid ({window_length}), please check your input')
# 如果DataType中给出了window_length,则更新相应的window_length值
for dtype_id, dtype in data_types.items():
if not isinstance(dtype, StgData):
continue
if dtype.window_length is not None:
if not isinstance(dtype.window_length, int) or dtype.window_length <= 0:
raise ValueError(f'window_length should be a positive integer, got '
f'{dtype.window_length}({type(dtype.window_length)}) instead')
self._data_WL[dtype_id] = dtype.window_length
for dtype_id in data_types:
self.__setattr__(dtype_id, None)
def get_data(self,
*dtype_id: str,
lag: Union[int, str, None] = None,
window: Union[str, None] = None):
"""通过dtype_id获取历史数据或交易过程数据,可以获取多个数据类型的数据
对于普通历史数据(无 ``proc.`` 前缀):
- 保持现有行为,仅支持按 dtype_id 批量获取数据,不支持 ``lag`` / ``window`` 参数;
对于交易过程数据(以 ``proc.`` 开头):
- 支持 ``lag`` / ``window`` 两类定位参数,用于按运行步或时间窗口获取账户 / 持仓 / 成交历史。
"""
if not dtype_id:
raise ValueError('at least one data type id must be provided')
names = list(dtype_id)
proc_names = [n for n in names if isinstance(n, str) and n.startswith('proc.')]
static_names = [n for n in names if n not in proc_names]
# 同时包含静态数据和过程数据时直接报错,提示用户拆分调用
if proc_names and static_names:
raise ValueError(
'get_data() cannot mix static data and process data in one call. '
'Please call get_data() separately for static sources and proc.* sources.'
)
# 仅静态数据:保持现有行为,不支持 lag/window
if static_names and not proc_names:
if lag is not None or window is not None:
raise ValueError('lag/window parameters are only supported for proc.* (process) data.')
return self._get_pars_or_data(*static_names)
# 仅过程数据:通过 proc.* 接口访问 Backtester / Trader 注入的交易过程数据。
# 根据约定,一次调用只允许访问一个 proc.* 字段,避免在返回结构和参数语义上引入歧义。
if not proc_names:
raise ValueError('no valid process data ids (proc.*) provided to get_data()')
if len(proc_names) != 1:
raise ValueError(
'get_data() only supports one proc.* (process data) field per call. '
'Please call get_data() separately for each proc.* source.'
)
if lag is not None and window is not None:
raise ValueError('lag and window cannot be used at the same time for proc.* data.')
return self._get_process_data_single(proc_names[0], lag=lag, window=window)
def _get_process_data_single(self,
name: str,
*,
lag: Union[int, str, None],
window: Union[str, None]) -> np.ndarray:
"""内部工具:根据名称和定位参数获取单个 proc.* 过程数据字段"""
# 1,获取 Operator 与过程数据源
group = getattr(self, '_group', None)
op = getattr(group, '_operator', None) if group is not None else None
if op is None:
raise RuntimeError('Process data is only available when strategy is managed by an Operator.')
sources = getattr(op, '_process_data_sources', None)
if not sources:
raise RuntimeError('Process data sources are not initialized; proc.* data is not available.')
current_idx = getattr(op, '_current_signal_index', None)
if current_idx is None:
raise RuntimeError('Current signal index is not available for process data access.')
time_index = getattr(op, '_process_time_index', None)
# 2,构造“截至当前可见”的完整历史序列(不包含当前尚未成交的这一 signal 的结果)
def _slice_until_now(arr: np.ndarray, offset: int = 0) -> np.ndarray:
"""根据当前全局 signal 行号裁剪数组,offset 用于处理 like own_cashes 这类多一行的情况。"""
# current_idx 表示“当前正在生成的 signal 行号”,已经完成的 signal 数量为 current_idx
if arr.ndim == 1:
stop = min(current_idx + offset, arr.shape[0])
return arr[:stop].copy()
stop = min(current_idx + offset, arr.shape[0])
return arr[:stop, :].copy()
base: np.ndarray
if name == 'proc.own_cash':
# own_cashes 形状为 (n_signals + 1,),索引 0 为初始状态,索引 i 为第 i-1 条信号执行后的结果
base = _slice_until_now(sources['own_cashes'], offset=1)
elif name == 'proc.available_cash':
base = _slice_until_now(sources['available_cashes'], offset=1)
elif name == 'proc.own_amounts':
base = _slice_until_now(sources['own_amounts'], offset=1)
elif name == 'proc.available_amounts':
base = _slice_until_now(sources['available_amounts'], offset=1)
elif name == 'proc.trade_records':
base = _slice_until_now(sources['trade_records'], offset=0)
elif name == 'proc.trade_cost':
base = _slice_until_now(sources['trade_costs'], offset=0)
elif name == 'proc.trade_price':
base = _slice_until_now(sources['trade_prices'], offset=0)
elif name in ('proc.position_value', 'proc.total_value'):
# 使用 price_data 对持仓进行估值
own_amounts = _slice_until_now(sources['own_amounts'], offset=1)
cashes = _slice_until_now(sources['own_cashes'], offset=1)
price_data = sources.get('price_data', None)
if price_data is None:
raise RuntimeError('price_data is not available for computing position_value/total_value.')
# price_data 形状约为 (n_signals, share_count),own_amounts 为 (<=n_signals, share_count)
steps = own_amounts.shape[0]
prices = price_data[:steps, :]
position_values = (prices * own_amounts).sum(axis=1)
if name == 'proc.position_value':
base = position_values
else: # total_value = position_value + cash
base = position_values + cashes
else:
raise KeyError(f'Unknown process data field "{name}".')
# 3,根据 lag / window 进行二次裁剪
def _apply_int_lag(arr: np.ndarray, k: int) -> np.ndarray:
if k < 0:
raise ValueError('lag must be non-negative when using integer lag.')
if arr.ndim == 1:
if arr.size == 0:
return arr
if k >= arr.size:
# 超出历史长度时,返回最早一条
return arr[0:1]
return arr[-(k + 1):-(k)] if k != 0 else arr[-1:]
# 2D:在时间轴上取对应一行
if arr.shape[0] == 0:
return arr
if k >= arr.shape[0]:
return arr[0:1, :]
idx = arr.shape[0] - k - 1
return arr[idx:idx + 1, :]
def _parse_time_delta(spec: str) -> np.timedelta64:
if not isinstance(spec, str) or len(spec) < 2:
raise ValueError(f'invalid time lag/window spec "{spec}", expected like "1d" or "8h".')
unit = spec[-1].lower()
try:
val = int(spec[:-1])
except Exception:
raise ValueError(f'invalid time lag/window spec "{spec}", expected like "1d" or "8h".')
if unit == 'd':
return np.timedelta64(val, 'D')
if unit == 'h':
return np.timedelta64(val, 'h')
raise ValueError(f'unsupported time unit in lag/window spec "{spec}", only "d"/"h" are supported.')
def _apply_time_lag(arr: np.ndarray, spec: str) -> np.ndarray:
if time_index is None:
raise RuntimeError('time index is not available for time-based lag/window on process data.')
if arr.shape[0] == 0:
return arr
delta = _parse_time_delta(spec)
# 对应 arr 的时间索引是 process_time_index 的前 arr.shape[0] 个元素
ts = np.asarray(time_index[:arr.shape[0]])
current_t = ts[-1]
cutoff = current_t - delta
mask = ts <= cutoff
if not mask.any():
# 如果没有足够长的历史,退化为返回第一条
return arr[0:1] if arr.ndim == 1 else arr[0:1, :]
idx = np.nonzero(mask)[0][-1]
return arr[idx:idx + 1] if arr.ndim == 1 else arr[idx:idx + 1, :]
def _apply_time_window(arr: np.ndarray, spec: str) -> np.ndarray:
if time_index is None:
raise RuntimeError('time index is not available for time-based lag/window on process data.')
if arr.shape[0] == 0:
return arr
delta = _parse_time_delta(spec)
ts = np.asarray(time_index[:arr.shape[0]])
current_t = ts[-1]
cutoff = current_t - delta
mask = ts > cutoff
if not mask.any():
# 没有任何点落在窗口内,返回空数组,保持形状兼容
if arr.ndim == 1:
return arr[:0]
return arr[:0, :]
idx0 = np.nonzero(mask)[0][0]
return arr[idx0:] # [idx0, ..., -1]
if lag is None and window is None:
return base
if isinstance(lag, int):
return _apply_int_lag(base, lag)
if isinstance(lag, str):
return _apply_time_lag(base, lag)
if isinstance(window, str):
return _apply_time_window(base, window)
raise ValueError('invalid lag/window specification for proc.* data.')
def update_shares(self,
share_count: int = None,
share_names: Union[str, list[str]] = None):
""" 更新策略的股票名称列表,或者仅给出share_count并自动生成虚拟股票名称列表
如果给出了share_names,则忽略share_count,生成股票列表
如果没有给出share_names,则必须给出share_count,并生成虚拟名称列表
Parameters
----------
share_count: int, optional
股票数量
share_names: list of str, optional
股票名称列表,如果没有提供,则使用默认的股票名称列表
Returns
-------
None
"""
if share_names is not None:
if isinstance(share_names, str):
share_names = str_to_list(share_names)
if not isinstance(share_names, (str, list)):
raise TypeError(f'share_names should be a string or list of str, got{type(share_names)}')
self._share_names = share_names
else:
if share_count is None:
raise ValueError('Either share_names or share_count should be given')
if not isinstance(share_count, int):
raise TypeError(f'share count should be a integer, got {type(share_count)} instead')
if share_count <= 0:
raise ValueError(f'share count should be given and be larger than 0, got {share_count}')
self._share_names = [f'Share_{i+1}' for i in range(share_count)]
def update_data_types(self,
dtype_id=None,
*,
use_latest_data_cycle: Union[bool, tuple[bool], list[bool], dict[str, bool]] = None,
window_length: Union[int, tuple[int], list[int], dict[str, int]] = None,
freq: Union[str, tuple[str], list[str], dict[str, str]] = None,
asset_type: Union[str, tuple[str], list[str], dict[str, str]] = None) -> None:
""" 更新交易策略的数据参数,可以更新单个数据类型的参数,也可以更新多个数据类型的参数
如果给出dtype_id,则更新单个参数,否则更新所有参数。
支持更新 window_length、use_latest_data_cycle、freq、asset_type;
修改 freq 或 asset_type 会替换为新的 DataType 实例(因 dtype_id 由 name/freq/asset_type 派生)。
"""
# 1) 先更新 window_length 和 use_latest_data_cycle(在当前 dtype_id 上)
if dtype_id is not None:
if dtype_id not in self.data_types:
raise KeyError(f'data type {dtype_id} is not defined in the strategy')
if use_latest_data_cycle is not None:
assert isinstance(use_latest_data_cycle, bool), \
f'use_latest_data_cycle should be a boolean, got {type(use_latest_data_cycle)} instead'
self._data_ULC[dtype_id] = use_latest_data_cycle
if window_length is not None:
assert isinstance(window_length, int) and window_length > 0, \
f'window_length should be a positive integer, got {window_length} instead'
self._data_WL[dtype_id] = window_length
else: # 如果没有给出dtype_id,则更新所有参数或按照dict更新参数
if use_latest_data_cycle is not None:
if isinstance(use_latest_data_cycle, bool):
self._data_ULC = {d_name: use_latest_data_cycle for d_name in self.data_types}
elif isinstance(use_latest_data_cycle, (list, tuple)):
if len(use_latest_data_cycle) != len(self.data_types):
raise ValueError(f'Length of use_latest_data_cycle should be {len(self.data_types)}, '
f'got {len(use_latest_data_cycle)} instead')
ULCs = input_to_list(use_latest_data_cycle, len(self.data_types), False)
self._data_ULC = {self._data_ids[i]: ULCs[i] for i in range(len(ULCs))}
elif isinstance(use_latest_data_cycle, dict):
assert all(isinstance(v, bool) for v in use_latest_data_cycle.values()), \
f'All use_latest_data_cycle should be boolean, got {use_latest_data_cycle} instead'
self._data_ULC.update(use_latest_data_cycle)
else:
raise TypeError(f'Only one "use_latest_data_cycles" should be given when dtype_id is None, ')
if window_length is not None:
if isinstance(window_length, (int, float)):
if window_length <= 0:
raise ValueError(f'window_length should be a positive integer, got {window_length} instead')
window_length = int(window_length)
self._data_WL = {d_name: window_length for d_name in self.data_types}
elif isinstance(window_length, (list, tuple)):
if len(window_length) != len(self.data_types):
raise ValueError(f'Length of window_length should be {len(self.data_types)}, '
f'got {len(window_length)} instead')
WLs = input_to_list(window_length, len(self.data_types), 20)
self._data_WL = {self._data_ids[i]: WLs[i] for i in range(len(WLs))}
elif isinstance(window_length, dict):
assert all(isinstance(v, int) for v in window_length.values()), \
f'All window lengths should be positive integers, got {window_length} instead'
self._data_WL.update(window_length)
else:
raise TypeError(f'parameter "window_length" is invalid ({window_length}), please check your input')
# 2) 再处理 freq / asset_type:用新 DataType 替换旧项,保持 _data_ids 顺序
if freq is None and asset_type is None:
return
def _resolve_freq_asset(did: str):
dt = self._data_types[did]
new_f = freq
new_a = asset_type
if dtype_id is not None:
# 单条更新:freq/asset_type 为标量 str
if did != dtype_id:
return None, None
if freq is not None:
assert isinstance(freq, str), f'freq should be str when dtype_id is given, got {type(freq)}'
new_f = freq
else:
new_f = dt.freq
if asset_type is not None:
assert isinstance(asset_type, str), f'asset_type should be str when dtype_id is given, got {type(asset_type)}'
new_a = asset_type
else:
new_a = dt.asset_type
else:
# 批量:freq/asset_type 可为标量、list、dict
if freq is not None:
if isinstance(freq, str):
new_f = freq
elif isinstance(freq, (list, tuple)):
idx = self._data_ids.index(did)
new_f = freq[idx] if idx < len(freq) else dt.freq
elif isinstance(freq, dict):
new_f = freq.get(did, dt.freq)
else:
raise TypeError(f'freq should be str, list, tuple or dict, got {type(freq)}')
else:
new_f = dt.freq
if asset_type is not None:
if isinstance(asset_type, str):
new_a = asset_type
elif isinstance(asset_type, (list, tuple)):
idx = self._data_ids.index(did)
new_a = asset_type[idx] if idx < len(asset_type) else dt.asset_type
elif isinstance(asset_type, dict):
new_a = asset_type.get(did, dt.asset_type)
else:
raise TypeError(f'asset_type should be str, list, tuple or dict, got {type(asset_type)}')
else:
new_a = dt.asset_type
return new_f, new_a
replacements = [] # (old_id, new_dtype, new_id)
for did in list(self._data_ids):
new_f, new_a = _resolve_freq_asset(did)
if new_f is None:
continue
dt = self._data_types[did]
if (new_f, new_a) == (dt.freq, dt.asset_type):
continue
try:
new_dtype = DataType(dt.name, freq=new_f, asset_type=new_a)
except Exception as e:
raise ValueError(f'Failed to create DataType({dt.name!r}, freq={new_f!r}, asset_type={new_a!r}): {e}')
new_id = new_dtype.dtype_id
replacements.append((did, new_dtype, new_id))
for old_id, new_dtype, new_id in replacements:
ulc_val = self._data_ULC[old_id]
wl_val = self._data_WL[old_id]
del self._data_types[old_id]
del self._data_ULC[old_id]
del self._data_WL[old_id]
if new_id not in self._data_types:
self._data_types[new_id] = new_dtype
self._data_ULC[new_id] = ulc_val
self._data_WL[new_id] = wl_val
self._data_ids = [new_id if x == old_id else x for x in self._data_ids]
self.__setattr__(new_id, None)
else:
# new_id 已存在:仅从列表中移除 old_id,并更新已有条目的 ULC/WL
self._data_ids = [x for x in self._data_ids if x != old_id]
self._data_ULC[new_id] = ulc_val
self._data_WL[new_id] = wl_val
[文档] def update_par_values(self, *par_values: Any, **kwargs: Any) -> None:
""" 快速更新策略的参数值
Parameters
----------
par_values: tuple, optional
策略参数的值,元组中的每个元素是按顺序排列的所有参数值,如果
没有设置参数,则必须传入kwargs参数
kwargs: dict
以字典形式传入具体需要更新的参数值,键为参数名,值为参数值
Returns
-------
None
"""
# allow updating partial parameter values, thus length check is not needed
if par_values != ():
if len(par_values) > self.par_count:
raise ValueError(f'Number of par_values should not exceed {self.par_count}, '
f'got {len(par_values)} instead')
for par_name, par_value in zip(self.par_names, par_values):
self._pars[par_name].value = par_value
self.__setattr__(par_name, par_value)
elif self.par_count == 0:
pass
else: # 如果没有传入par_values,则必须传入kwargs参数
if not kwargs:
raise ValueError('par_values is None, please provide par_values or kwargs to update parameters')
for par_name, par_value in kwargs.items():
if par_name not in self.par_names:
raise KeyError(f'parameter {par_name} is not defined in the strategy')
self._pars[par_name].value = par_value
self.__setattr__(par_name, par_value)
def update_par_ranges(self, *par_ranges: Any, **kwargs) -> None:
""" 快速更新策略的参数取值范围
Parameters
----------
par_ranges: tuple of dict, optional
策略参数的取值范围,元组中的每个元素是按顺序排列的所有参数取值范围的字典,
如果没有设置参数,则必须传入kwargs参数
Returns
-------
None
"""
# allow updating partial parameter ranges, thus length check is not needed
if par_ranges != ():
if len(par_ranges) > self.par_count:
raise ValueError(f'Number of par_ranges should not exceed {self.par_count}, '
f'got {len(par_ranges)} instead')
for par_name, par_range in zip(self.par_names, par_ranges):
self._pars[par_name].update_par_range(new_range=par_range)
else: # 如果没有传入par_ranges,则必须传入kwargs参数
if not kwargs:
raise ValueError('par_ranges is None, please provide par_ranges or kwargs to update parameter ranges')
for par_name, par_range in kwargs.items():
if par_name not in self.par_names:
raise KeyError(f'parameter {par_name} is not defined in the strategy')
self._pars[par_name].update_par_range(new_range=par_range)
def set_opt_tag(self, opt_tag: int) -> int:
""" 设置策略的优化类型"""
assert isinstance(opt_tag, int), f'optimization tag should be an integer, got {type(opt_tag)} instead'
assert 0 <= opt_tag <= 2, f'ValueError, optimization tag should be between 0 and 2, got {opt_tag} instead'
self._opt_tag = opt_tag
return opt_tag
def update_running_data_window(self, data_windows:dict, window_indices:dict, window_index:int):
""" 将策略的历史数据更新为window_index指定的历史数据"""
for dtype_name in self.data_types:
data_window = data_windows[dtype_name][window_indices[dtype_name][window_index]]
setattr(self, dtype_name, data_window)
def set_custom_pars(self, **kwargs):
"""如果还有其他策略参数或用户自定义参数,在这里设置"""
for k, v in zip(kwargs.keys(), kwargs.values()):
if k in self.__dict__:
setattr(self, k, v)
else:
raise KeyError(f'The strategy does not have property \'{k}\'')
def disable_tracing(self):
""" 禁用最交易策略追踪功能"""
self._trace_enabled = False
self._trace_data = {}
self._trace_max_steps = 0
self._trace_step = 0
def enable_tracing(self, max_steps: int):
""" 启用最交易策略追踪功能"""
if not isinstance(max_steps, (int, np.integer, float, np.floating)) or (max_steps <= 0):
raise ValueError(f'max_steps should be a positive integer, got {max_steps}({type(max_steps)}) instead')
self._trace_enabled = True
self._trace_data = {}
self._trace_max_steps = int(max_steps)
self._trace_step = 0
def update_trace_step(self, step: int):
""" 前进到下一个追踪步骤"""
if 0 <= step <= self._trace_max_steps:
self._trace_step = step
def trace(self, name: str, var: Union[int, float, bool, str]) -> None:
""" 在策略的realize()方法中调用此方法可以追踪策略中间变量的值,或者可以记录一条备注信息
Parameters
----------
name: str
追踪记录的名称,用于区分不同的追踪变量
var: any
需要追踪的变量或者需要记录的备注信息,可以是int/float/bool类型的数据,也可以是一条字符串备注信息
"""
# TODO: 需要考虑FactorSorter和GeneralStg两种情况下的trace问题:
# 因为在这两种情况下stg.realize中使用的变量都是nd_array向量,无法
# 记录,因此需要特殊处理
if self._trace_enabled:
if name is None:
err = RuntimeError(f'When trace variables, name must be given and can not be None!')
raise err
if name not in self._trace_data:
# 根据变量类型选择最优数据类型
if isinstance(var, (int, np.integer)):
dtype = np.int64
elif isinstance(var, (float, np.floating)):
dtype = np.float64
elif isinstance(var, (bool, np.bool_)):
dtype = np.bool_
elif isinstance(var, (str, np.str_)):
dtype = object
else:
err = TypeError(f'When trace variables, only int, float, bool, str types are supported, '
f'got {type(var)} instead!')
raise err
# 预分配数组
self._trace_data[name] = {
'values': np.empty(self._trace_max_steps, dtype=dtype),
}
trace_info = self._trace_data[name]
idx = self._trace_step
if idx < self._trace_max_steps:
trace_info['values'][idx] = var
def get_trace_data(self):
"""获取实际追踪数据的DataFrame形式"""
result = pd.DataFrame(index=range(self._trace_max_steps))
for name, data in self._trace_data.items():
col_name = self.strategy_id + '_' + name
result[col_name] = data['values']
return result
@abstractmethod
def generate(self):
"""策略类的抽象方法,接受输入历史数据并根据参数生成策略输出
Parameters
----------
Returns
-------
stg_signal: np.ndarray
策略运行的输出,包括交易信号、交易指令等
"""
pass
[文档]class GeneralStg(BaseStrategy):
""" 通用交易策略类,用户需要完整定义策略的所有交易逻辑,并在realize()方法中定义策略的信号输出。
关于GeneralStg类的更详细说明,请参见qteasy的文档。
"""
__metaclass__ = ABCMeta
# 设置Selecting策略类的标准默认参数,继承Selecting类的具体类如果沿用同样的静态参数,不需要重复定义
def __init__(self,
name: str = 'General',
description: str = 'description of General strategy',
**kwargs):
super().__init__(stg_type='GENERAL',
name=name,
description=description,
**kwargs)
def generate(self):
""" 通用交易策略的所有策略代码全部都在realize中实现
"""
return self.realize()
@abstractmethod
def realize(self):
""" h_seg和ref_seg都是用于生成交易信号的一段窗口数据,根据这一段窗口数据
生成一条交易信号
交易信号的格式必须为1D 的numpy数组,数据类型为float
"""
pass
[文档]class FactorSorter(BaseStrategy):
""" 因子排序选股策略,根据用户定义的选股因子筛选排序后确定每个股票的选股权重(请注意,FactorSorter策略
生成的交易信号在0到1之间,推荐设置signal_type为"PT")
这类策略要求用户从历史数据中提取一个选股因子,并根据选股因子的大小排序后确定投资组合中股票的交易信号
用户需要在realize()方法中计算选股因子,计算出选股因子后,接下来的排序和选股逻辑都不需要用户自行定义。
策略会根据预设的条件,从中筛选出符合标准的因子,并将剩下的因子排序,从中选择特定数量的股票,最后根据它
们的因子值分配权重或信号值。关于Strategy类的更详细说明,请参见qteasy的文档。
"""
__metaclass__ = ABCMeta
# 设置Selecting策略类的标准默认参数,继承Selecting类的具体类如果沿用同样的静态参数,不需要重复定义
def __init__(self,
name: str = 'Factor',
description: str = 'description of factor sorter strategy',
max_sel_count: float = 0.5,
condition: str = 'any',
lbound: float = -np.inf,
ubound: float = np.inf,
sort_ascending: bool = False,
weighting: str = 'even',
**kwargs):
super().__init__(stg_type='FACTOR',
name=name,
description=description,
**kwargs)
self.max_sel_count = max_sel_count
self.condition = condition
self.lbound = lbound
self.ubound = ubound
self.sort_ascending = sort_ascending
self.weighting = weighting
def info(self, verbose: bool = False, stg_id=None, **kwargs):
""" display more FactorSorter-specific properties
Parameters
----------
verbose: bool
if True, display more properties
stg_id: str
strategy id, if None, use self.name
**kwargs:
other parameters
"""
from .utilfuncs import adjust_string_length
from shutil import get_terminal_size
term_width = get_terminal_size().columns
info_width = int(term_width * 0.75) if term_width > 120 else term_width
key_width = max(24, int(info_width * 0.3))
value_width = max(7, info_width - key_width)
extra_info = f'{" Selection Properties ":-^{info_width}}\n'
if self.max_sel_count > 1:
extra_info += f'{"Max select count":<{key_width}}{int(self.max_sel_count)}\n'
else:
extra_info += f'{"Max select count":<{key_width}}{self.max_sel_count:.1%}\n'
extra_info += f'{"Sort Ascending":<{key_width}}{self.sort_ascending}\n' \
f'{"Weighting":<{key_width}}{adjust_string_length(self.weighting, value_width)}\n' \
f'{"Filter Condition":<{key_width}}{adjust_string_length(self.condition, value_width)}\n' \
f'{"Filter ubound":<{key_width}}{self.ubound}\n' \
f'{"Filter lbound":<{key_width}}{self.lbound}'
super().info(verbose=verbose, stg_id=stg_id, extra_info=extra_info)
def generate(self):
"""处理从_realize()方法传递过来的选股因子
选出符合condition的因子,并将这些因子排序,根据次序确定所有因子相应股票的选股权重
将选股权重传递到generate()方法中,生成最终的选股交易信号
Parameters
----------
Returns
-------
chosen: numpy.ndarray
一个一维向量,代表一个周期内股票的投资组合权重,所有权重的和为1
"""
pct = self.max_sel_count
condition = self.condition
lbound = self.lbound
ubound = self.ubound
sort_ascending = self.sort_ascending # True: 选择最小的,Fals: 选择最大的
weighting = self.weighting
# 获取realize()方法计算得到的选股因子
factors = self.realize()
share_count = factors.shape[0]
if pct < 1:
# pct 参数小于1时,代表目标投资组合在所有投资产品中所占的比例,如0.5代表需要选中50%的投资产品
pct = int(share_count * pct)
else: # pct 参数大于1时,取整后代表目标投资组合中投资产品的数量,如5代表需要选中5只投资产品
pct = int(pct)
if pct < 1:
pct = 1
# factors必须是一维向量,如果因子是二维向量,允许shape为(N, 1)型,此时将其转换为一维向量,否则报错
if factors.ndim == 2:
factors = factors.flatten()
chosen = np.zeros_like(factors)
# 筛选出不符合要求的指标,将他们设置为nan值
if condition == 'any':
pass
elif condition == 'greater':
factors[np.where(factors < ubound)] = np.nan
elif condition == 'less':
factors[np.where(factors > lbound)] = np.nan
elif condition == 'between':
factors[np.where((factors < lbound) | (factors > ubound))] = np.nan
elif condition == 'not_between':
factors[np.where((factors > lbound) & (factors < ubound))] = np.nan
else:
raise ValueError(f'invalid selection condition \'{condition}\''
f'should be one of ["any", "greater", "less", "between", "not_between"]')
nan_count = np.isnan(factors).astype('int').sum() # 清点数据,获取nan值的数量
if nan_count == share_count: # 当indices全部为nan,导致没有有意义的参数可选,此时直接返回全0值
return chosen
if not sort_ascending:
# 选择分数最高的部分个股,由于np排序时会把NaN值与最大值排到一起,总数需要加上NaN值的数量
pos = max(share_count - pct - nan_count, 0)
else: # 选择分数最低的部分个股
pos = pct
# 对数据进行排序,并把排位靠前者的序号存储在arg_found中
if weighting == 'even':
# 仅当投资比例为均匀分配时,才可以使用速度更快的argpartition方法进行粗略排序
if not sort_ascending:
share_found = factors.argpartition(pos)[pos:]
else:
share_found = factors.argpartition(pos)[:pos]
else: # 如果采用其他投资比例分配方式时,必须使用较慢的全排序
if not sort_ascending:
share_found = factors.argsort()[pos:]
else:
share_found = factors.argsort()[:pos]
# nan值数据的序号存储在arg_nan中
share_nan = np.where(np.isnan(factors))[0]
# 使用集合操作从arg_found中剔除arg_nan,使用assume_unique参数可以提高效率
args = np.setdiff1d(share_found, share_nan, assume_unique=True)
# 构造输出向量,初始值为全0
arg_count = len(args)
# 如果符合条件的选项数量为0,则直接返回全0
if arg_count == 0:
return chosen
# 根据投资组合比例分配方式,确定被选中产品的权重
# ones:全1分配,所有中选股票在组合中权重相同且全部为1
if weighting == 'ones':
chosen[args] = 1.
# linear 线性比例分配,将所有分值排序后,股票的比例呈线性分布
elif weighting == 'linear':
dist = np.arange(1, 3, 2. / arg_count) # 生成一个线性序列,最大值为最小值的约三倍
chosen[args] = dist / dist.sum() # 将比率填入输出向量中
# distance:距离分配,权重与其分值距离成正比,分值最低者获得一个基础比例,其余股票的比例
# 与其分值的距离成正比,分值的距离为它与最低分之间的差值,因此不管分值是否大于0,股票都能
# 获取比例分配
elif weighting == 'distance':
dist = factors[args]
d_max = dist[-1]
d_min = dist[0]
d = d_max - d_min
if not sort_ascending:
dist = dist - d_min + d / 10.
else:
dist = d_max - dist + d / 10.
d_sum = dist.sum()
if ~np.any(dist): # if all distances are zero
chosen[args] = 1 / len(dist)
elif d_sum == 0: # if not all distances are zero but sum is zero
chosen[args] = dist / len(dist)
else:
chosen[args] = dist / d_sum
# proportion:比例分配,权重与其分值成正比,分值为0或小于0者比例为0
elif weighting == 'proportion':
f = factors[args]
f = np.where(f < 0, 0, f) # np.where 比 np.clip(0) 速度快得多
chosen[args] = f / f.sum()
# even:均匀分配,所有中选股票在组合中权重相同
elif weighting == 'even':
chosen[args] = 1. / arg_count
else:
raise KeyError(f'invalid weighting type: "{weighting}". '
f'should be one of ["ones", "linear", "distance", "proportion", "even"]')
return chosen
@abstractmethod
def realize(self):
""" realize strategy here"""
pass
[文档]class RuleIterator(BaseStrategy):
""" 规则迭代策略类。这一类策略不考虑每一只股票的区别,将同一套规则同时迭代应用到所有的股票上。
RuleIterator策略类的特殊功能是可以对同一套交易规则,将不同的参数应用到投资组合中的不同股票上。
例如,用户可以设计一个均线交叉策略,并将其应用到投资组合中的所有股票上,同时可以为每只股票
设定不同的均线周期参数。关于Strategy类的更详细说明,请参见qteasy的文档。
**多标的且每股参数不同**:须通过 ``update_par_values({股票代码: (p1, p2, ...), ...})`` 传入与
``share_names`` 一致的股票代码键;若仅用位置元组初始化,多标的时全体共享一套运行时参数(见
用户文档《三种策略基类》中 RuleIterator 一节)。字典中可用键名 ``default`` 为未单独列出的标的
提供同一套默认初值;不支持 ``others`` 等其它保留键名。
"""
__metaclass__ = ABCMeta
def __init__(self,
name: str = 'Rule-Iterator',
description: str = 'description of rule iterator strategy',
allow_multi_par: bool = True,
**kwargs):
super().__init__(name=name,
description=description,
stg_type='RULE-ITER',
**kwargs)
self._data_windows = {}
self.allow_multi_par = allow_multi_par # 设置为True,表示策略可以对不同的股票使用不同的参数
self.multi_pars = None
# _update_multi_pars 内部循环会调用 update_par_values(tuple),此时不应触发 multi_pars 与标量脱钩告警
self._updating_multi_pars_internal: bool = False
def info(self, verbose: bool = False, stg_id=None, **kwargs):
""" display more FactorSorter-specific properties
Parameters
----------
verbose: bool
if True, display more properties
stg_id: str
strategy id, if None, use self.name
**kwargs:
other parameters
"""
from shutil import get_terminal_size
term_width = get_terminal_size().columns
info_width = int(term_width * 0.75) if term_width > 120 else term_width
key_width = max(24, int(info_width * 0.3))
extra_info = f'{" Iteration Properties ":-^{info_width}}\n'
extra_info += f'{"Allow multi pars":<{key_width}}{self.allow_multi_par}'
if self.allow_multi_par:
if not self.multi_pars:
extra_info += f'\n{"Multi-parameter not set":<{info_width}}'
elif verbose: # print out complete multi_pars
lines = [
f'{sn}: {mp}'
for sn, mp in zip(self.share_names, self.multi_pars)
]
multi_par_str = '\n'.join(lines)
extra_info += f'\n{"Multi-parameter":<{info_width}}\n{multi_par_str}'
else: # print out brief multi_pars info
extra_info += (
f'\n{"Multi-parameter (first share only; full list: strategies -d in Trader CLI)":<{info_width}}\n'
f'{self.multi_pars[0]!r} ... ({len(self.multi_pars)} shares)\n'
)
super().info(verbose=verbose, stg_id=stg_id, extra_info=extra_info)
def get_multi_pars(self) -> Union[tuple, None]:
"""返回当前按标的顺序排列的 multi_pars 元组;未设置时为 None。
Parameters
----------
无
Returns
-------
tuple or None
长度为 ``share_count`` 的元组,每项为该股对应的一组参数值;未启用或未设置 multi 时为 None。
"""
return self.multi_pars
def get_pars_for_share(self, share: Union[str, int]) -> Tuple[Any, ...]:
"""按股票代码或标的索引读取 multi_pars 中该股的一组参数;无 multi_pars 时退回 ``par_values``。
Parameters
----------
share : str or int
股票代码(须与 ``share_names`` 中一致)或标的在池中的整数下标。
Returns
-------
tuple
该股参数元组;无 multi_pars 时返回当前 ``par_values``(全体标的共用一套参数时的语义)。
"""
if self.multi_pars is None:
return self.par_values
if isinstance(share, int):
if share < 0 or share >= len(self.multi_pars):
raise IndexError(f'share index {share} out of range for multi_pars length {len(self.multi_pars)}')
return tuple(self.multi_pars[share])
if not isinstance(share, str):
raise TypeError(f'share should be str or int, got {type(share)}')
if share not in self.share_names:
raise KeyError(f'share {share!r} not in share_names {self.share_names}')
idx = self.share_names.index(share)
return tuple(self.multi_pars[idx])
@staticmethod
def _overlay_tuple_prefix_on_row(row: Tuple[Any, ...], prefix: Tuple[Any, ...]) -> Tuple[Any, ...]:
"""将 prefix 按顺序覆盖 row 的前 len(prefix) 个槽位,用于位置参数部分更新。"""
out = list(row)
for i, v in enumerate(prefix):
if i >= len(out):
break
out[i] = v
return tuple(out)
def _merge_row_with_kwargs(self, row: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Any, ...]:
"""将 kwargs 按参数名合并到一行元组中,键必须属于 ``par_names``。"""
out = list(row)
for par_name, par_value in kwargs.items():
if par_name not in self.par_names:
raise KeyError(f'parameter {par_name} is not defined in the strategy')
j = self.par_names.index(par_name)
out[j] = par_value
return tuple(out)
def commit_share_par_values(self, *values: Any) -> None:
"""将当前标的的一整组参数写回 ``multi_pars`` 并与 ``_pars`` 同步。
推荐在 ``realize()`` 中直接使用 ``update_par_values(*values)``;本方法保留兼容。
Parameters
----------
*values : Any
与 ``par_count`` 等长的一组参数值,顺序与 ``par_names`` 一致。
Returns
-------
None
"""
if len(values) != self.par_count:
raise ValueError(
f'commit_share_par_values expected {self.par_count} values, got {len(values)}'
)
idx = getattr(self, '_generate_share_index', None)
if idx is None:
raise RuntimeError(
'commit_share_par_values() must be called from realize() while '
'RuleIterator.generate() is running (missing _generate_share_index).'
)
if not (self.allow_multi_par and self.multi_pars is not None):
raise RuntimeError(
'commit_share_par_values() requires allow_multi_par=True and non-None multi_pars.'
)
self.update_par_values(*values)
def update_par_values(self, *par_values: Any, **kwargs: Any) -> None:
"""快速更新策略参数值;``multi_pars`` 与 ``_pars`` 的同步规则如下。
- 传入 ``dict``(按股票代码映射到参数元组)时,走 ``_update_multi_pars`` 全量重绑(与原先一致)。
- 当 ``allow_multi_par`` 且 ``multi_pars`` 已存在、且不在 ``_update_multi_pars`` 内部循环时:
- ``_generate_share_index is not None``(由 ``generate()`` 在调用 ``realize()`` 前设置):位置参数
与 kwargs 仅作用于**当前标的**对应的一行 ``multi_pars``;位置参数长度等于 ``par_count`` 时整行替换,
否则按顺序覆盖当前行的前若干槽位;kwargs 与当前行按参数名合并。
- ``_generate_share_index is None``:无「当前标的」语义;位置参数长度等于 ``par_count`` 时将**同一行**
广播到所有标的,否则对每一行做相同的前缀覆盖;kwargs 按名合并到**每一行**;最后 ``_pars`` 与
``multi_pars[0]`` 对齐。
- 无 ``multi_pars`` 或 ``allow_multi_par`` 为 False 时,行为与 ``BaseStrategy.update_par_values`` 一致。
Parameters
----------
par_values: tuple, optional
位置参数;首元素为 dict 时表示 multi_par 全量设置。
kwargs: dict
按参数名的部分更新。
Returns
-------
None
"""
if not par_values:
if (self.allow_multi_par and self.multi_pars is not None
and not self._updating_multi_pars_internal and kwargs):
gen_idx = getattr(self, '_generate_share_index', None)
if gen_idx is not None:
base_row = tuple(self.multi_pars[gen_idx])
merged = self._merge_row_with_kwargs(base_row, kwargs)
mp = list(self.multi_pars)
mp[gen_idx] = merged
self.multi_pars = tuple(mp)
super().update_par_values(*merged)
else:
new_mp = []
for i in range(len(self.multi_pars)):
merged = self._merge_row_with_kwargs(tuple(self.multi_pars[i]), kwargs)
new_mp.append(merged)
self.multi_pars = tuple(new_mp)
super().update_par_values(*self.multi_pars[0])
else:
super().update_par_values(**kwargs)
return
# 检测常见误用:update_par_values(dict_value,) 由于尾随逗号导致 par_values == ((dict,),)
if (len(par_values) == 1
and isinstance(par_values[0], (tuple, list))
and len(par_values[0]) == 1
and isinstance(par_values[0][0], dict)):
raise TypeError(
f'Expected a dict for multi_par, but got a length-1 '
f'{type(par_values[0]).__name__} wrapping a dict. '
f'This usually happens when the dict literal has a trailing '
f'comma (e.g. "par_values = {{...}},"). '
f'Remove the trailing comma and call '
f'update_par_values(par_values) again.'
)
# 只有当第一个参数是 dict 时,才可能是 multi_par
# tuple/list 形式的 par_values 不应被当作 multi_par
if isinstance(par_values[0], dict):
# par values中有multi_par,更新multi_par
par_values_dict = par_values[0]
self._update_multi_pars(par_values_dict)
# 将第一个参数值写入par_values
par_values = self.multi_pars[0]
super().update_par_values(*par_values)
elif self._updating_multi_pars_internal:
super().update_par_values(*par_values)
elif self.allow_multi_par and self.multi_pars is not None:
gen_idx = getattr(self, '_generate_share_index', None)
tail = tuple(par_values)
if gen_idx is not None:
base_row = tuple(self.multi_pars[gen_idx])
if len(tail) == self.par_count:
new_row = tail
else:
new_row = self._overlay_tuple_prefix_on_row(base_row, tail)
mp = list(self.multi_pars)
mp[gen_idx] = new_row
self.multi_pars = tuple(mp)
super().update_par_values(*new_row)
else:
new_mp = []
for i in range(len(self.multi_pars)):
base_row = tuple(self.multi_pars[i])
if len(tail) == self.par_count:
new_row = tail
else:
new_row = self._overlay_tuple_prefix_on_row(base_row, tail)
new_mp.append(new_row)
self.multi_pars = tuple(new_mp)
super().update_par_values(*self.multi_pars[0])
else:
super().update_par_values(*par_values)
def _update_multi_pars(self, multi_pars):
""" 设置多参数的函数,允许用户为每只股票设置不同的参数
Parameters
----------
multi_pars: dict {str: tuple, list, ndarray}
策略参数字典:键为与 ``share_names`` 完全一致的股票代码,值为该股的一组参数(长度须等于
``par_count``)。未单独给出参数的标的,若字典中含键 ``"default"``,则使用该键对应元组
作为初值(等价于「除已列股票外其余标的共用同一套初值」);**不存在**名为 ``others`` 等
其它保留键,请勿使用。
Returns
-------
None
"""
if not self.allow_multi_par:
# 如果不允许多参数,则直接返回一个空元组
self.multi_pars = None
raise ValueError('multi_pars is not allowed, you need to set allow_multi_par to True first')
# multi_par 仅接受 dict 形式
if not isinstance(multi_pars, dict):
raise TypeError(f'multi_pars should be a dict, not {type(multi_pars)}')
prev_flag = self._updating_multi_pars_internal
self._updating_multi_pars_internal = True
try:
# 检查 share_count:如果为 0,说明还未 set_shares,应报错
if self.share_count == 0:
raise ValueError(
'无法设置 multi_par:share_count 为 0,请先调用 set_shares() 设置股票列表'
)
# 获取 default 值和策略默认参数
default_value = multi_pars.get('default', None)
# 按 share_names 顺序解析,而非按 dict 的 key 顺序
result = []
for share_name in self.share_names:
if share_name in multi_pars:
# 如果 dict 中存在该 share_id,使用对应值
par_tuple = multi_pars[share_name]
elif default_value is not None:
# 如果不存在但存在 'default',使用 default 值
par_tuple = default_value
else:
# 如果都不存在,使用策略默认 par_values
raise KeyError(f'par_value of {share_name} is not provided in multi_pars and default '
f'value is not given, please provide par_value for {share_name} or '
f'set a default value like "default": (par1, par2, ...) in multi_pars')
# 确保 par_tuple 是 tuple 或 list
if not isinstance(par_tuple, (tuple, list)):
raise TypeError(
f'par values must be tuple or list,got {type(par_tuple)} for share {share_name}'
)
# 转换为 tuple
par_tuple = tuple(par_tuple)
# 校验参数元组长度必须等于 par_count
if len(par_tuple) != self.par_count:
raise ValueError(
f'par values count should be equal to par_count ({self.par_count}),'
f'got {len(par_tuple)} for share {share_name}'
)
# Make sure that par_tuple is valid by setting the first par_tuple to the strategy parameters
self.update_par_values(*par_tuple)
result.append(par_tuple)
# 校验解析后的 multi_pars 长度必须等于 share_count
if len(result) != self.share_count:
raise ValueError(
f'length of multi_pars must equal to share_count ({self.share_count}),'
f'got {len(result)}'
)
self.multi_pars = tuple(result)
finally:
self._updating_multi_pars_internal = prev_flag
def update_running_data_window(self, data_windows:dict, window_indices:dict, window_index:int):
""" 将策略的历史数据更新为window_index指定的历史数据,对Rule_iterator来说数据不能直接保存到"""
for dtype_name in self.data_types:
data_window = data_windows[dtype_name][window_indices[dtype_name][window_index]]
self._data_windows[dtype_name] = data_window
def generate(self):
""" 中间构造函数,将历史数据模块传递过来的单只股票历史数据去除nan值,并进行滚动展开
对于滚动展开后的矩阵,使用map函数循环调用generate_one函数生成整个历史区间的
循环回测结果(结果为1维向量, 长度为hist_length - _window_length + 1)
Parameters
----------
Returns
-------
signal: np.ndarray
一维向量。根据策略,在历史上产生的仓位信号或交易信号,具体信号的含义取决于策略类型
"""
# 生成iterators, 将参数送入realize_no_nan中逐个迭代后返回结果
signal = np.empty(self.share_count, dtype=float)
try:
for i in range(self.share_count):
# 供 realize() 在多参数模式下写回 per-share 状态(如更新 multi_pars 中第 i 组参数)
self._generate_share_index = i
if self.allow_multi_par and self.multi_pars:
# 如果允许多参数,则为每个股票使用不同的参数
par = self.multi_pars[i]
self.update_par_values(*par)
# 更新股票使用的数据
for dtype_name in self.data_types:
setattr(self, dtype_name, self._data_windows[dtype_name][:, i])
signal[i] = self.realize()
finally:
self._generate_share_index = None
return signal
@abstractmethod
def realize(self):
""" realize strategy here"""
pass