API 参考
在这里你可以找到所有 Qlib
的接口。
数据
数据提供器
- class qlib.data.data.ProviderBackendMixin
这个辅助类旨在使基于存储后端的提供器更加便捷 如果提供器不依赖后端存储,则无需继承此类
- class qlib.data.data.CalendarProvider
日历提供器基类
提供日历数据。
- calendar(start_time=None, end_time=None, freq='day', future=False)
获取特定市场在给定时间范围内的日历
参数
- start_timestr
时间范围的开始。
- end_timestr
时间范围的结束。
- freqstr
时间频率,可选值:year/quarter/month/week/day。
- futurebool
是否包含未来交易日。
返回
- list
日历列表
- locate_index(start_time: Timestamp | str, end_time: Timestamp | str, freq: str, future: bool = False)
在特定频率的日历中定位开始时间索引和结束时间索引
参数
- start_timepd.Timestamp
时间范围的开始。
- end_timepd.Timestamp
时间范围的结束。
- freqstr
时间频率,可选值:year/quarter/month/week/day。
- futurebool
是否包含未来交易日。
返回
- pd.Timestamp
实际开始时间。
- pd.Timestamp
实际结束时间。
- int
开始时间的索引。
- int
结束时间的索引。
- class qlib.data.data.InstrumentProvider
工具提供器基类
提供工具数据。
- static instruments(market: List | str = 'all', filter_pipe: List | None = None)
获取基础市场的通用配置字典,添加多个动态过滤器。
参数
- marketUnion[List, str]
- str:
市场/行业/指数简称,例如 all/sse/szse/sse50/csi300/csi500。
- list:
["ID1", "ID2"]。股票列表
- filter_pipelist
动态过滤器列表。
返回
- dict: 如果 isinstance(market, str)
股票池配置字典。
{market => 基础市场名称, filter_pipe => 过滤器列表}
示例 :
{'market': 'csi500', 'filter_pipe': [{'filter_type': 'ExpressionDFilter', 'rule_expression': '$open<40', 'filter_start_time': None, 'filter_end_time': None, 'keep': False}, {'filter_type': 'NameDFilter', 'name_rule_re': 'SH[0-9]{4}55', 'filter_start_time': None, 'filter_end_time': None}]}
- list: 如果 isinstance(market, list)
直接返回原始列表。 注意: 这将使工具与更多情况兼容,用户代码将更简单。
- class qlib.data.data.FeatureProvider
特征提供器类
提供特征数据。
- class qlib.data.data.PITProvider
- abstractmethod period_feature(instrument, field, start_index: int, end_index: int, cur_time: Timestamp, period: int | None = None) Series
获取`start_index`和`end_index`之间的历史周期数据序列
参数
- start_index: int
start_index是相对于cur_time最新周期的相对索引
- end_index: int
end_index是相对于cur_time最新周期的相对索引 在大多数情况下,start_index和end_index将是非正值 例如,start_index == -3 end_index == 0且当前周期索引为cur_idx, 则将检索[start_index + cur_idx, end_index + cur_idx]之间的数据。
- period: int
用于查询特定周期。 Qlib中周期用整数表示(例如202001可能表示2020年第一季度) 注意: period`将覆盖`start_index`和`end_index
返回
- pd.Series
索引将是整数,表示数据的周期 典型示例将是 TODO
引发
- FileNotFoundError
如果查询的数据不存在,将引发此异常。
- class qlib.data.data.ExpressionProvider
表达式提供器类
提供表达式数据。
- __init__()
- abstractmethod expression(instrument, field, start_time=None, end_time=None, freq='day') Series
获取表达式数据
expression 方法的职责: - 解析 field 并加载相应数据 - 加载数据时应处理数据的时间依赖性,通常会使用 get_expression_instance 方法
参数
- instrumentstr
特定标的
- fieldstr
特征的特定字段
- start_timestr
时间范围开始
- end_timestr
时间范围结束
- freqstr
时间频率,可选:year/quarter/month/week/day
返回
- pd.Series
特定表达式的数据
数据有两种格式:
带日期时间索引的表达式
带整数索引的表达式
因为日期时间格式不如
- class qlib.data.data.DatasetProvider
数据集提供者类
提供数据集数据
- abstractmethod dataset(instruments, fields, start_time=None, end_time=None, freq='day', inst_processors=[])
获取数据集数据
参数
- instrumentslist or dict
标的列表/字典或股票池配置字典
- fieldslist
特征实例列表
- start_timestr
时间范围开始
- end_timestr
时间范围结束
- freqstr
时间频率
- inst_processors: Iterable[Union[dict, InstProcessor]]
对每个标的执行的操作
返回
- pd.DataFrame
带有<instrument, datetime>索引的pandas数据框
- static get_instruments_d(instruments, freq)
Parse different types of input instruments to output instruments_d Wrong format of input instruments will lead to exception.
- static get_column_names(fields)
Get column names from input fields
- static dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[])
Load and process the data, return the data set. - default using multi-kernel method.
- static inst_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[])
Calculate the expressions for one instrument, return a df result. If the expression has been calculated before, load from cache.
return value: A data frame with index 'datetime' and other data columns.
- class qlib.data.data.LocalCalendarProvider(remote=False, backend={})
Local calendar data provider class
Provide calendar data from local data source.
- __init__(remote=False, backend={})
- load_calendar(freq, future)
Load original calendar timestamp from file.
- 参数:
freq (str) -- frequency of read calendar file.
future (bool)
- 返回类型:
list of timestamps
- class qlib.data.data.LocalInstrumentProvider(backend={})
Local instrument data provider class
Provide instrument data from local data source.
- __init__(backend={}) None
- class qlib.data.data.LocalFeatureProvider(remote=False, backend={})
Local feature data provider class
Provide feature data from local data source.
- __init__(remote=False, backend={})
- class qlib.data.data.LocalPITProvider
- period_feature(instrument, field, start_index, end_index, cur_time, period=None)
获取`start_index`和`end_index`之间的历史周期数据序列
参数
- start_index: int
start_index是相对于cur_time最新周期的相对索引
- end_index: int
end_index是相对于cur_time最新周期的相对索引 在大多数情况下,start_index和end_index将是非正值 例如,start_index == -3 end_index == 0且当前周期索引为cur_idx, 则将检索[start_index + cur_idx, end_index + cur_idx]之间的数据。
- period: int
用于查询特定周期。 Qlib中周期用整数表示(例如202001可能表示2020年第一季度) 注意: period`将覆盖`start_index`和`end_index
返回
- pd.Series
索引将是整数,表示数据的周期 典型示例将是 TODO
引发
- FileNotFoundError
如果查询的数据不存在,将引发此异常。
- class qlib.data.data.LocalExpressionProvider(time2idx=True)
Local expression data provider class
Provide expression data from local data source.
- __init__(time2idx=True)
- expression(instrument, field, start_time=None, end_time=None, freq='day')
获取表达式数据
expression 方法的职责: - 解析 field 并加载相应数据 - 加载数据时应处理数据的时间依赖性,通常会使用 get_expression_instance 方法
参数
- instrumentstr
特定标的
- fieldstr
特征的特定字段
- start_timestr
时间范围开始
- end_timestr
时间范围结束
- freqstr
时间频率,可选:year/quarter/month/week/day
返回
- pd.Series
特定表达式的数据
- class qlib.data.data.LocalDatasetProvider(align_time: bool = True)
Local dataset data provider class
Provide dataset data from local data source.
- __init__(align_time: bool = True)
- 参数:
align_time (bool) --
Will we align the time to calendar the frequency is flexible in some dataset and can't be aligned. For the data with fixed frequency with a shared calendar, the align data to the calendar will provides following benefits
Align queries to the same parameters, so the cache can be shared.
- dataset(instruments, fields, start_time=None, end_time=None, freq='day', inst_processors=[])
获取数据集数据
参数
- instrumentslist or dict
标的列表/字典或股票池配置字典
- fieldslist
特征实例列表
- start_timestr
时间范围开始
- end_timestr
时间范围结束
- freqstr
时间频率
- inst_processors: Iterable[Union[dict, InstProcessor]]
对每个标的执行的操作
返回
- pd.DataFrame
带有<instrument, datetime>索引的pandas数据框
- static multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq='day')
This method is used to prepare the expression cache for the client. Then the client will load the data from expression cache by itself.
- static cache_walker(inst, start_time, end_time, freq, column_names)
If the expressions of one instrument haven't been calculated before, calculate it and write it into expression cache.
- class qlib.data.data.ClientCalendarProvider
Client calendar data provider class
Provide calendar data by requesting data from server as a client.
- __init__()
- class qlib.data.data.ClientInstrumentProvider
Client instrument data provider class
Provide instrument data by requesting data from server as a client.
- __init__()
- class qlib.data.data.ClientDatasetProvider
Client dataset data provider class
Provide dataset data by requesting data from server as a client.
- __init__()
- dataset(instruments, fields, start_time=None, end_time=None, freq='day', disk_cache=0, return_uri=False, inst_processors=[])
获取数据集数据
参数
- instrumentslist or dict
标的列表/字典或股票池配置字典
- fieldslist
特征实例列表
- start_timestr
时间范围开始
- end_timestr
时间范围结束
- freqstr
时间频率
- inst_processors: Iterable[Union[dict, InstProcessor]]
对每个标的执行的操作
返回
- pd.DataFrame
带有<instrument, datetime>索引的pandas数据框
- class qlib.data.data.BaseProvider
Local provider class It is a set of interface that allow users to access data. Because PITD is not exposed publicly to users, so it is not included in the interface.
To keep compatible with old qlib provider.
- features(instruments, fields, start_time=None, end_time=None, freq='day', disk_cache=None, inst_processors=[])
- 参数:
disk_cache (int) -- whether to skip(0)/use(1)/replace(2) disk_cache
This function will try to use cache method which has a keyword disk_cache, and will use provider method if a type error is raised because the DatasetD instance is a provider class.
- class qlib.data.data.LocalProvider
- features_uri(instruments, fields, start_time, end_time, freq, disk_cache=1)
Return the uri of the generated cache of features/dataset
- 参数:
disk_cache
instruments
fields
start_time
end_time
freq
- class qlib.data.data.ClientProvider
Client Provider
Requesting data from server as a client. Can propose requests:
Calendar : Directly respond a list of calendars
Instruments (without filter): Directly respond a list/dict of instruments
Instruments (with filters): Respond a list/dict of instruments
Features : Respond a cache uri
The general workflow is described as follows: When the user use client provider to propose a request, the client provider will connect the server and send the request. The client will start to wait for the response. The response will be made instantly indicating whether the cache is available. The waiting procedure will terminate only when the client get the response saying feature_available is true. BUG : Everytime we make request for certain data we need to connect to the server, wait for the response and disconnect from it. We can't make a sequence of requests within one connection. You can refer to https://python-socketio.readthedocs.io/en/latest/client.html for documentation of python-socketIO client.
- __init__()
- qlib.data.data.register_all_wrappers(C)
过滤器
- class qlib.data.filter.BaseDFilter
动态工具过滤器抽象类
用户可以重写此类来构建自己的过滤器
重写__init__方法以输入过滤规则
重写filter_main方法以使用规则过滤工具
- __init__()
- class qlib.data.filter.SeriesDFilter(fstart_time=None, fend_time=None, keep=False)
动态工具过滤器抽象类,用于过滤特定特征的序列
过滤器应提供以下参数:
过滤开始时间
过滤结束时间
过滤规则
重写__init__方法以分配特定规则来过滤序列。
重写_getFilterSeries方法以使用规则过滤序列并获取{工具 => 序列}的字典,或重写filter_main以实现更高级的序列过滤规则
- __init__(fstart_time=None, fend_time=None, keep=False)
- 过滤器基类的初始化函数。
在fstart_time和fend_time指定的时间段内,根据特定规则过滤一组工具。
参数
- fstart_time: str
过滤规则开始过滤工具的时间。
- fend_time: str
过滤规则停止过滤工具的时间。
- keep: bool
是否保留在过滤时间范围内没有特征数据的工具。
- filter_main(instruments, start_time=None, end_time=None)
Implement this method to filter the instruments.
- 参数:
instruments (dict) -- input instruments to be filtered.
start_time (str) -- start of the time range.
end_time (str) -- end of the time range.
- 返回:
filtered instruments, same structure as input instruments.
- 返回类型:
dict
- class qlib.data.filter.NameDFilter(name_rule_re, fstart_time=None, fend_time=None)
Name dynamic instrument filter
Filter the instruments based on a regulated name format.
A name rule regular expression is required.
- __init__(name_rule_re, fstart_time=None, fend_time=None)
Init function for name filter class
- 参数:
name_rule_re (str) -- regular expression for the name rule.
- class qlib.data.filter.ExpressionDFilter(rule_expression, fstart_time=None, fend_time=None, keep=False)
Expression dynamic instrument filter
Filter the instruments based on a certain expression.
An expression rule indicating a certain feature field is required.
示例
basic features filter : rule_expression = '$close/$open>5'
cross-sectional features filter : rule_expression = '$rank($close)<10'
time-sequence features filter : rule_expression = '$Ref($close, 3)>100'
- __init__(rule_expression, fstart_time=None, fend_time=None, keep=False)
Init function for expression filter class
- 参数:
fstart_time (str) -- filter the feature starting from this time.
fend_time (str) -- filter the feature ending by this time.
rule_expression (str) -- an input expression for the rule.
类
- class qlib.data.base.Expression
表达式基类
表达式用于处理具有以下格式的数据计算 每个工具包含两个维度的数据:
特征(feature)
- 时间(time):可以是观察时间或周期时间
周期时间专为时点数据库设计。例如,周期时间可能是2014Q4,其值可以被多次观察(由于修正,不同时间可能观察到不同值)。
- load(instrument, start_index, end_index, *args)
加载特征 此函数负责基于表达式引擎加载特征/表达式。
- 具体实现分为两部分:
- 缓存数据,处理错误。
这部分由所有表达式共享,并在Expression中实现
- 根据特定表达式处理和计算数据。
这部分在每个表达式中不同,并在每个表达式中实现
表达式引擎由不同数据共享。 不同数据会为`args`提供不同的额外信息。
参数
- instrumentstr
工具代码
- start_indexstr
特征开始索引[在日历中]
- end_indexstr
特征结束索引[在日历中]
- args可能包含以下信息:
- 如果用于基本表达式引擎数据,包含以下参数
- freq: str
特征频率
- 如果用于PIT数据,包含以下参数
- cur_pit:
专为时点数据设计
- period: int
用于查询特定周期 Qlib中周期用整数表示(例如202001可能表示2020年第一季度)
返回
- pd.Series
特征序列:序列的索引是日历索引
- abstractmethod get_longest_back_rolling()
获取特征需要访问的最长历史数据长度
该方法设计用于预先获取计算特定范围内特征所需的数据范围。 但类似Ref(Ref($close, -1), 1)的情况无法正确处理。
因此该方法仅用于检测所需的历史数据长度。
- class qlib.data.base.Feature(name=None)
静态表达式
此类特征将从数据提供者加载数据
- __init__(name=None)
- get_longest_back_rolling()
获取特征需要访问的最长历史数据长度
该方法设计用于预先获取计算特定范围内特征所需的数据范围。 但类似Ref(Ref($close, -1), 1)的情况无法正确处理。
因此该方法仅用于检测所需的历史数据长度。
- class qlib.data.base.PFeature(name=None)
- class qlib.data.base.ExpressionOps
运算符表达式
此类特征将动态使用运算符进行特征构建。
操作符
- class qlib.data.ops.ElemOperator(feature)
元素级运算符
参数
- featureExpression
特征实例
返回
- Expression
特征运算输出
- __init__(feature)
- get_longest_back_rolling()
获取特征需要访问的最长历史数据长度
该方法设计用于预先获取计算特定范围内特征所需的数据范围。 但类似Ref(Ref($close, -1), 1)的情况无法正确处理。
因此该方法仅用于检测所需的历史数据长度。
- class qlib.data.ops.ChangeInstrument(instrument, feature)
更改工具运算符 在某些情况下,计算时可能需要切换到另一个工具,例如计算股票相对于市场指数的贝塔值。 这需要将特征计算从股票(原始工具)更改为指数(参考工具)
参数
- instrument: 要执行下游操作的新工具。
例如,SH000300(沪深300指数)或^GPSC(标普500指数)。
feature: 要为新工具计算的特征。
返回
- Expression
特征运算输出
- __init__(instrument, feature)
- load(instrument, start_index, end_index, *args)
加载特征 此函数负责基于表达式引擎加载特征/表达式。
- 具体实现分为两部分:
- 缓存数据,处理错误。
这部分由所有表达式共享,并在Expression中实现
- 根据特定表达式处理和计算数据。
这部分在每个表达式中不同,并在每个表达式中实现
表达式引擎由不同数据共享。 不同数据会为`args`提供不同的额外信息。
参数
- instrumentstr
工具代码
- start_indexstr
特征开始索引[在日历中]
- end_indexstr
特征结束索引[在日历中]
- args可能包含以下信息:
- 如果用于基本表达式引擎数据,包含以下参数
- freq: str
特征频率
- 如果用于PIT数据,包含以下参数
- cur_pit:
专为时点数据设计
- period: int
用于查询特定周期 Qlib中周期用整数表示(例如202001可能表示2020年第一季度)
返回
- pd.Series
特征序列:序列的索引是日历索引
- class qlib.data.ops.NpElemOperator(feature, func)
Numpy元素级运算符
参数
- featureExpression
特征实例
- funcstr
numpy特征运算方法
返回
- Expression
特征运算输出
- __init__(feature, func)
- class qlib.data.ops.Abs(feature)
特征绝对值
参数
- featureExpression
特征实例
返回
- Expression
具有绝对值输出的特征实例
- __init__(feature)
- class qlib.data.ops.Sign(feature)
特征符号
参数
- featureExpression
特征实例
返回
- Expression
具有符号的特征实例
- __init__(feature)
- class qlib.data.ops.Log(feature)
特征对数
参数
- featureExpression
特征实例
返回
- Expression
具有对数的特征实例
- __init__(feature)
- class qlib.data.ops.Mask(feature, instrument)
特征掩码
参数
- featureExpression
特征实例
- instrumentstr
工具掩码
返回
- Expression
具有掩码工具的特征实例
- __init__(feature, instrument)
- class qlib.data.ops.Not(feature)
Not Operator
- 参数:
feature (Expression) -- feature instance
- 返回:
feature elementwise not output
- 返回类型:
- __init__(feature)
- class qlib.data.ops.PairOperator(feature_left, feature_right)
Pair-wise operator
- 参数:
feature_left (Expression) -- feature instance or numeric value
feature_right (Expression) -- feature instance or numeric value
- 返回:
two features' operation output
- 返回类型:
- __init__(feature_left, feature_right)
- get_longest_back_rolling()
获取特征需要访问的最长历史数据长度
该方法设计用于预先获取计算特定范围内特征所需的数据范围。 但类似Ref(Ref($close, -1), 1)的情况无法正确处理。
因此该方法仅用于检测所需的历史数据长度。
- class qlib.data.ops.NpPairOperator(feature_left, feature_right, func)
Numpy成对运算符
参数
- feature_leftExpression
特征实例或数值
- feature_rightExpression
特征实例或数值
- funcstr
运算符函数
返回
- Feature:
两个特征的运算结果
- __init__(feature_left, feature_right, func)
- class qlib.data.ops.Power(feature_left, feature_right)
幂运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回
- Feature:
feature_left的基数乘以feature_right的指数
- __init__(feature_left, feature_right)
- class qlib.data.ops.Add(feature_left, feature_right)
加法运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回
- Feature:
两个特征的和
- __init__(feature_left, feature_right)
- class qlib.data.ops.Sub(feature_left, feature_right)
减法运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回
- Feature:
两个特征的差
- __init__(feature_left, feature_right)
- class qlib.data.ops.Mul(feature_left, feature_right)
乘法运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回
- Feature:
两个特征的乘积
- __init__(feature_left, feature_right)
- class qlib.data.ops.Div(feature_left, feature_right)
除法运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
两个特征的除法结果
- __init__(feature_left, feature_right)
- class qlib.data.ops.Greater(feature_left, feature_right)
取大运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
从两个输入特征中取较大元素
- __init__(feature_left, feature_right)
- class qlib.data.ops.Less(feature_left, feature_right)
取小运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
从两个输入特征中取较小元素
- __init__(feature_left, feature_right)
- class qlib.data.ops.Gt(feature_left, feature_right)
大于运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
布尔序列,表示`left > right`
- __init__(feature_left, feature_right)
- class qlib.data.ops.Ge(feature_left, feature_right)
大于等于运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
布尔序列,表示`left >= right`
- __init__(feature_left, feature_right)
- class qlib.data.ops.Lt(feature_left, feature_right)
小于运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
布尔序列,表示`left < right`
- __init__(feature_left, feature_right)
- class qlib.data.ops.Le(feature_left, feature_right)
小于等于运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
布尔序列,表示`left <= right`
- __init__(feature_left, feature_right)
- class qlib.data.ops.Eq(feature_left, feature_right)
等于运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
布尔序列,表示`left == right`
- __init__(feature_left, feature_right)
- class qlib.data.ops.Ne(feature_left, feature_right)
不等于运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
布尔序列,表示`left != right`
- __init__(feature_left, feature_right)
- class qlib.data.ops.And(feature_left, feature_right)
与运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
两个特征按行进行与运算的结果
- __init__(feature_left, feature_right)
- class qlib.data.ops.Or(feature_left, feature_right)
或运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
返回值
- Feature:
两个特征按行进行或运算的结果
- __init__(feature_left, feature_right)
- class qlib.data.ops.If(condition, feature_left, feature_right)
条件运算符
参数
- conditionExpression
布尔值特征实例,作为条件
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
- __init__(condition, feature_left, feature_right)
- get_longest_back_rolling()
获取特征需要访问的最长历史数据长度
该方法设计用于预先获取计算特定范围内特征所需的数据范围。 但类似Ref(Ref($close, -1), 1)的情况无法正确处理。
因此该方法仅用于检测所需的历史数据长度。
- class qlib.data.ops.Rolling(feature, N, func)
滚动窗口运算符 此处的滚动(rolling)和扩展(expanding)与pandas中的概念相同。 当窗口大小设置为0时,运算符的行为应遵循`expanding`(扩展窗口) 否则,遵循`rolling`(滚动窗口)
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
- funcstr
滚动计算方法
返回值
- Expression
滚动计算结果
- __init__(feature, N, func)
- get_longest_back_rolling()
获取特征需要访问的最长历史数据长度
该方法设计用于预先获取计算特定范围内特征所需的数据范围。 但类似Ref(Ref($close, -1), 1)的情况无法正确处理。
因此该方法仅用于检测所需的历史数据长度。
- class qlib.data.ops.Ref(feature, N)
特征引用
参数
- featureExpression
特征实例
- Nint
N = 0,获取首个数据;N > 0,获取N个周期前的数据;N < 0,获取未来数据
返回值
- Expression
目标引用的特征实例
- __init__(feature, N)
- get_longest_back_rolling()
获取特征需要访问的最长历史数据长度
该方法设计用于预先获取计算特定范围内特征所需的数据范围。 但类似Ref(Ref($close, -1), 1)的情况无法正确处理。
因此该方法仅用于检测所需的历史数据长度。
- class qlib.data.ops.Mean(feature, N)
滚动平均值(MA)
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动平均值的特征实例
- __init__(feature, N)
- class qlib.data.ops.Sum(feature, N)
滚动总和
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动总和的特征实例
- __init__(feature, N)
- class qlib.data.ops.Std(feature, N)
滚动标准差
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动标准差的特征实例
- __init__(feature, N)
- class qlib.data.ops.Var(feature, N)
滚动方差
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动方差的特征实例
- __init__(feature, N)
- class qlib.data.ops.Skew(feature, N)
滚动偏度
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动偏度的特征实例
- __init__(feature, N)
- class qlib.data.ops.Kurt(feature, N)
滚动峰度
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动峰度的特征实例
- __init__(feature, N)
- class qlib.data.ops.Max(feature, N)
滚动最大值
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动最大值的特征实例
- __init__(feature, N)
- class qlib.data.ops.IdxMax(feature, N)
滚动最大值索引
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动最大值索引的特征实例
- __init__(feature, N)
- class qlib.data.ops.Min(feature, N)
滚动最小值
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动最小值的特征实例
- __init__(feature, N)
- class qlib.data.ops.IdxMin(feature, N)
滚动最小值索引
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动最小值索引的特征实例
- __init__(feature, N)
- class qlib.data.ops.Quantile(feature, N, qscore)
滚动分位数
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动分位数的特征实例
- __init__(feature, N, qscore)
- class qlib.data.ops.Med(feature, N)
滚动中位数
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动中位数的特征实例
- __init__(feature, N)
- class qlib.data.ops.Mad(feature, N)
滚动平均绝对偏差
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动平均绝对偏差的特征实例
- __init__(feature, N)
- class qlib.data.ops.Rank(feature, N)
滚动排名(百分位)
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动排名的特征实例
- __init__(feature, N)
- class qlib.data.ops.Count(feature, N)
滚动计数
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动非NaN元素数量的特征实例
- __init__(feature, N)
- class qlib.data.ops.Delta(feature, N)
滚动差值
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
滚动窗口内末尾值减起始值的特征实例
- __init__(feature, N)
- class qlib.data.ops.Slope(feature, N)
滚动斜率 该运算符计算`idx`和`feature`之间的斜率。 (例如:[<feature_t1>, <feature_t2>, <feature_t3>] 和 [1, 2, 3])
使用示例: - "Slope($close, %d)/$close"
# TODO: # 有些用户可能需要成对滚动,如`Slope(A, B, N)`
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
给定窗口的线性回归斜率的特征实例
- __init__(feature, N)
- class qlib.data.ops.Rsquare(feature, N)
滚动R值平方
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
给定窗口的线性回归R值平方的特征实例
- __init__(feature, N)
- class qlib.data.ops.Resi(feature, N)
滚动回归残差
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
给定窗口的回归残差的特征实例
- __init__(feature, N)
- class qlib.data.ops.WMA(feature, N)
滚动加权移动平均
参数
- featureExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
加权移动平均输出的特征实例
- __init__(feature, N)
- class qlib.data.ops.EMA(feature, N)
滚动指数移动平均 (EMA)
参数
- featureExpression
特征实例
- Nint, float
滚动窗口大小
返回值
- Expression
指数移动平均输出的特征实例
- __init__(feature, N)
- class qlib.data.ops.PairRolling(feature_left, feature_right, N, func)
成对滚动运算符
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
两个输入特征的滚动计算结果的特征实例
- __init__(feature_left, feature_right, N, func)
- get_longest_back_rolling()
获取特征需要访问的最长历史数据长度
该方法设计用于预先获取计算特定范围内特征所需的数据范围。 但类似Ref(Ref($close, -1), 1)的情况无法正确处理。
因此该方法仅用于检测所需的历史数据长度。
- class qlib.data.ops.Corr(feature_left, feature_right, N)
滚动相关系数
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
两个输入特征的滚动相关系数的特征实例
- __init__(feature_left, feature_right, N)
- class qlib.data.ops.Cov(feature_left, feature_right, N)
滚动协方差
参数
- feature_leftExpression
特征实例
- feature_rightExpression
特征实例
- Nint
滚动窗口大小
返回值
- Expression
两个输入特征的滚动协方差的特征实例
- __init__(feature_left, feature_right, N)
- class qlib.data.ops.TResample(feature, freq, func)
- class qlib.data.ops.OpsWrapper
运算符包装器
- __init__()
- register(ops_list: List[Type[ExpressionOps] | dict])
注册运算符
参数
- ops_listList[Union[Type[ExpressionOps], dict]]
如果类型是List[Type[ExpressionOps]],列表中的每个元素表示运算符类,该类必须是`ExpressionOps`的子类
如果类型是List[dict],列表中的每个元素表示运算符的配置,格式如下:
{ "class": "类名", "module_path": "路径" }
注意
- qlib.data.ops.register_all_ops(C)
注册所有运算符
缓存
- class qlib.data.cache.MemCacheUnit(*args, **kwargs)
内存缓存单元。
- __init__(*args, **kwargs)
- property limited
内存缓存是否有限制
- class qlib.data.cache.MemCache(mem_cache_size_limit=None, limit_type='length')
内存缓存。
- class qlib.data.cache.ExpressionCache(provider)
Expression cache mechanism base class.
This class is used to wrap expression provider with self-defined expression cache mechanism.
备注
Override the _uri and _expression method to create your own expression cache mechanism.
- expression(instrument, field, start_time, end_time, freq)
Get expression data.
备注
Same interface as expression method in expression provider
- update(cache_uri: str | Path, freq: str = 'day')
Update expression cache to latest calendar.
Override this method to define how to update expression cache corresponding to users' own cache mechanism.
- 参数:
cache_uri (str or Path) -- the complete uri of expression cache file (include dir path).
freq (str)
- 返回:
0(successful update)/ 1(no need to update)/ 2(update failure).
- 返回类型:
int
- class qlib.data.cache.DatasetCache(provider)
Dataset cache mechanism base class.
This class is used to wrap dataset provider with self-defined dataset cache mechanism.
备注
Override the _uri and _dataset method to create your own dataset cache mechanism.
- dataset(instruments, fields, start_time=None, end_time=None, freq='day', disk_cache=1, inst_processors=[])
Get feature dataset.
备注
Same interface as dataset method in dataset provider
备注
The server use redis_lock to make sure read-write conflicts will not be triggered but client readers are not considered.
- update(cache_uri: str | Path, freq: str = 'day')
Update dataset cache to latest calendar.
Override this method to define how to update dataset cache corresponding to users' own cache mechanism.
- 参数:
cache_uri (str or Path) -- the complete uri of dataset cache file (include dir path).
freq (str)
- 返回:
0(successful update)/ 1(no need to update)/ 2(update failure)
- 返回类型:
int
- static cache_to_origin_data(data, fields)
cache data to origin data
- 参数:
data -- pd.DataFrame, cache data.
fields -- feature fields.
- 返回:
pd.DataFrame.
- static normalize_uri_args(instruments, fields, freq)
normalize uri args
- class qlib.data.cache.DiskExpressionCache(provider, **kwargs)
Prepared cache mechanism for server.
- __init__(provider, **kwargs)
- gen_expression_cache(expression_data, cache_path, instrument, field, freq, last_update)
use bin file to save like feature-data.
- update(sid, cache_uri, freq: str = 'day')
Update expression cache to latest calendar.
Override this method to define how to update expression cache corresponding to users' own cache mechanism.
- 参数:
cache_uri (str or Path) -- the complete uri of expression cache file (include dir path).
freq (str)
- 返回:
0(successful update)/ 1(no need to update)/ 2(update failure).
- 返回类型:
int
- class qlib.data.cache.DiskDatasetCache(provider, **kwargs)
Prepared cache mechanism for server.
- __init__(provider, **kwargs)
- classmethod read_data_from_cache(cache_path: str | Path, start_time, end_time, fields)
read_cache_from
This function can read data from the disk cache dataset
- 参数:
cache_path
start_time
end_time
fields -- The fields order of the dataset cache is sorted. So rearrange the columns to make it consistent.
- 返回:
- class IndexManager(cache_path: str | Path)
The lock is not considered in the class. Please consider the lock outside the code. This class is the proxy of the disk data.
- __init__(cache_path: str | Path)
- gen_dataset_cache(cache_path: str | Path, instruments, fields, freq, inst_processors=[])
备注
This function does not consider the cache read write lock. Please acquire the lock outside this function
The format the cache contains 3 parts(followed by typical filename).
index : cache/d41366901e25de3ec47297f12e2ba11d.index
The content of the file may be in following format(pandas.Series)
start end 1999-11-10 00:00:00 0 1 1999-11-11 00:00:00 1 2 1999-11-12 00:00:00 2 3 ...
备注
The start is closed. The end is open!!!!!
Each line contains two element <start_index, end_index> with a timestamp as its index.
It indicates the start_index (included) and end_index (excluded) of the data for timestamp
meta data: cache/d41366901e25de3ec47297f12e2ba11d.meta
data : cache/d41366901e25de3ec47297f12e2ba11d
This is a hdf file sorted by datetime
- 参数:
cache_path -- The path to store the cache.
instruments -- The instruments to store the cache.
fields -- The fields to store the cache.
freq -- The freq to store the cache.
inst_processors -- Instrument processors.
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
- update(cache_uri, freq: str = 'day')
Update dataset cache to latest calendar.
Override this method to define how to update dataset cache corresponding to users' own cache mechanism.
- 参数:
cache_uri (str or Path) -- the complete uri of dataset cache file (include dir path).
freq (str)
- 返回:
0(successful update)/ 1(no need to update)/ 2(update failure)
- 返回类型:
int
存储
- class qlib.data.storage.storage.BaseStorage
存储基类,提供存储名称的基本实现
- property storage_name: str
获取存储名称
- 返回:
str: 存储名称(小写)
- class qlib.data.storage.storage.CalendarStorage(freq: str, future: bool, **kwargs)
日历存储类,其方法行为与同名的List方法保持一致
- __init__(freq: str, future: bool, **kwargs)
初始化日历存储
- 参数:
freq: 频率字符串 future: 是否包含未来数据 **kwargs: 其他关键字参数
- property data: Iterable[str]
获取所有日历数据
- 返回:
Iterable[CalVT]: 日历数据的可迭代对象
异常:
- ValueError
如果数据(存储)不存在,则引发ValueError
- extend(iterable: Iterable[str]) None
扩展日历存储数据
- 参数:
iterable: 包含日历数据的可迭代对象
异常:
NotImplementedError: 子类必须实现此方法
- class qlib.data.storage.storage.InstrumentStorage(market: str, freq: str, **kwargs)
证券工具存储类,用于管理证券工具的相关数据
- class qlib.data.storage.storage.FeatureStorage(instrument: str, field: str, freq: str, **kwargs)
特征存储类,用于管理证券特征数据
- __init__(instrument: str, field: str, freq: str, **kwargs)
初始化特征存储
- 参数:
instrument: 证券工具代码 field: 特征字段名称 freq: 频率字符串 **kwargs: 其他关键字参数
- property data: Series
获取所有特征数据
- 返回:
pd.Series: 特征数据序列
注意:
如果数据(存储)不存在,返回空的pd.Series: return pd.Series(dtype=np.float32)
- property start_index: int | None
获取特征存储的起始索引
- 返回:
Union[int, None]: 起始索引,如果数据不存在则返回None
注意:
如果数据(存储)不存在,返回None
- property end_index: int | None
获取特征存储的结束索引
- 返回:
Union[int, None]: 结束索引,如果数据不存在则返回None
注意:
数据范围的右索引(闭区间)
下一个数据追加点为 end_index + 1
如果数据(存储)不存在,返回None
- write(data_array: List | ndarray | Tuple, index: int = None)
将数据数组写入特征存储,从指定索引开始
- 参数:
data_array: 要写入的数据数组,可以是列表、numpy数组或元组 index: 起始索引,如果为None则追加数据
注意:
如果index为None,则将data_array追加到特征数据末尾
如果data_array长度为0,则直接返回
如果(index - self.end_index) >= 1,则self[end_index+1: index]区间将填充np.nan
示例:
特征数据: 3 4 4 5 5 6
>>> self.write([6, 7], index=6)
- 特征数据:
3 4 4 5 5 6 6 6 7 7
>>> self.write([8], index=9)
- 特征数据:
3 4 4 5 5 6 6 6 7 7 8 np.nan 9 8
>>> self.write([1, np.nan], index=3)
- 特征数据:
3 1 4 np.nan 5 6 6 6 7 7 8 np.nan 9 8
异常:
NotImplementedError: 子类必须实现此方法
- rebase(start_index: int = None, end_index: int = None)
重新设置特征存储的起始索引和结束索引
- 参数:
start_index: 新的起始索引,默认为None(使用当前起始索引) end_index: 新的结束索引,默认为None(使用当前结束索引)
注意:
start_index和end_index构成闭区间: [start_index, end_index]
示例:
特征数据: 3 4 4 5 5 6 >>> self.rebase(start_index=4) 特征数据: 4 5 5 6 >>> self.rebase(start_index=3) 特征数据: 3 np.nan 4 5 5 6 >>> self.write([3], index=3) 特征数据: 3 3 4 5 5 6 >>> self.rebase(end_index=4) 特征数据: 3 3 4 5 >>> self.write([6, 7, 8], index=4) 特征数据: 3 3 4 6 5 7 6 8 >>> self.rebase(start_index=4, end_index=5) 特征数据: 4 6 5 7
异常:
ValueError: 如果storage.start_index或storage.end_index为None(存储可能不存在)
- class qlib.data.storage.file_storage.FileStorageMixin
文件存储混合类,适用于FileXXXStorage 子类需要具有provider_uri、freq、storage_name、file_name属性
- check()
检查self.uri
- 抛出:
ValueError --
- class qlib.data.storage.file_storage.FileCalendarStorage(freq: str, future: bool, provider_uri: dict = None, **kwargs)
- __init__(freq: str, future: bool, provider_uri: dict = None, **kwargs)
- property data: List[str]
获取所有日历数据
- 返回:
Iterable[CalVT]: 日历数据的可迭代对象
异常:
- ValueError
如果数据(存储)不存在,则引发ValueError
- extend(values: Iterable[str]) None
扩展日历存储数据
- 参数:
iterable: 包含日历数据的可迭代对象
异常:
NotImplementedError: 子类必须实现此方法
- class qlib.data.storage.file_storage.FileInstrumentStorage(market: str, freq: str, provider_uri: dict = None, **kwargs)
- __init__(market: str, freq: str, provider_uri: dict = None, **kwargs)
- class qlib.data.storage.file_storage.FileFeatureStorage(instrument: str, field: str, freq: str, provider_uri: dict = None, **kwargs)
- __init__(instrument: str, field: str, freq: str, provider_uri: dict = None, **kwargs)
- property data: Series
获取所有特征数据
- 返回:
pd.Series: 特征数据序列
注意:
如果数据(存储)不存在,返回空的pd.Series: return pd.Series(dtype=np.float32)
- write(data_array: List | ndarray, index: int = None) None
将数据数组写入特征存储,从指定索引开始
- 参数:
data_array: 要写入的数据数组,可以是列表、numpy数组或元组 index: 起始索引,如果为None则追加数据
注意:
如果index为None,则将data_array追加到特征数据末尾
如果data_array长度为0,则直接返回
如果(index - self.end_index) >= 1,则self[end_index+1: index]区间将填充np.nan
示例:
特征数据: 3 4 4 5 5 6
>>> self.write([6, 7], index=6)
- 特征数据:
3 4 4 5 5 6 6 6 7 7
>>> self.write([8], index=9)
- 特征数据:
3 4 4 5 5 6 6 6 7 7 8 np.nan 9 8
>>> self.write([1, np.nan], index=3)
- 特征数据:
3 1 4 np.nan 5 6 6 6 7 7 8 np.nan 9 8
异常:
NotImplementedError: 子类必须实现此方法
数据集
数据集类
- class qlib.data.dataset.__init__.Dataset(**kwargs)
为模型训练和推理准备数据。
- __init__(**kwargs)
初始化旨在完成以下步骤:
- 初始化子实例和数据集的状态(准备数据所需的信息)
用于准备数据的基本状态名称不应以'_'开头,以便在序列化时可以保存到磁盘。
- 设置数据
数据相关属性的名称应以'_'开头,以便在序列化时不会保存到磁盘。
数据可以指定计算准备所需基本数据的信息
- config(**kwargs)
config is designed to configure and parameters that cannot be learned from the data
- setup_data(**kwargs)
Setup the data.
We split the setup_data function for following situation:
User have a Dataset object with learned status on disk.
User load the Dataset object from the disk.
User call setup_data to load new data.
User prepare data for model based on previous status.
- class qlib.data.dataset.__init__.DatasetH(handler: Dict | DataHandler, segments: Dict[str, Tuple], fetch_kwargs: Dict = {}, **kwargs)
带有数据处理器(DataHandler)的数据集。
用户应尝试将数据预处理函数放入处理器中。 只有以下数据处理函数应放在数据集中: - 与特定模型相关的处理 - 与数据拆分相关的处理
- __init__(handler: Dict | DataHandler, segments: Dict[str, Tuple], fetch_kwargs: Dict = {}, **kwargs)
Setup the underlying data.
- 参数:
handler (Union[dict, DataHandler]) --
处理器可以是:
`DataHandler`的实例
DataHandler`的配置。请参考`DataHandler
segments (dict) --
描述数据分段的选项。 以下是一些示例:
1) 'segments': { 'train': ("2008-01-01", "2014-12-31"), 'valid': ("2017-01-01", "2020-08-01",), 'test': ("2015-01-01", "2016-12-31",), } 2) 'segments': { 'insample': ("2008-01-01", "2014-12-31"), 'outsample': ("2017-01-01", "2020-08-01",), }
- config(handler_kwargs: dict = None, **kwargs)
Initialize the DatasetH
- 参数:
handler_kwargs (dict) --
Config of DataHandler, which could include the following arguments:
arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
kwargs (dict) --
Config of DatasetH, such as
- segmentsdict
Config of segments which is same as 'segments' in self.__init__
- setup_data(handler_kwargs: dict = None, **kwargs)
Setup the Data
- 参数:
handler_kwargs (dict) --
init arguments of DataHandler, which could include the following arguments:
init_type : Init Type of Handler
enable_cache : whether to enable cache
- prepare(segments: List[str] | Tuple[str] | str | slice | Index, col_set='__all', data_key='infer', **kwargs) List[DataFrame] | DataFrame
Prepare the data for learning and inference.
- 参数:
segments (Union[List[Text], Tuple[Text], Text, slice]) --
Describe the scope of the data to be prepared Here are some examples:
'train'
['train', 'valid']
col_set (str) --
The col_set will be passed to self.handler when fetching data. TODO: make it automatic:
select DK_I for test data
select DK_L for training data.
data_key (str) -- The data to fetch: DK_* Default is DK_I, which indicate fetching data for inference.
kwargs --
- The parameters that kwargs may contain:
- flt_colstr
It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data. This parameter is only supported when it is an instance of TSDatasetH.
- 返回类型:
Union[List[pd.DataFrame], pd.DataFrame]
- 抛出:
NotImplementedError: --
数据加载器
- class qlib.data.dataset.loader.DataLoader
DataLoader用于从原始数据源加载原始数据。
- abstractmethod load(instruments, start_time=None, end_time=None) DataFrame
以pd.DataFrame格式加载数据。
数据示例(列的多级索引是可选的):
feature label $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 datetime instrument 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
参数
- instrumentsstr或dict
可以是市场名称或由InstrumentProvider生成的标的配置文件。 如果instruments的值为None,则表示不进行过滤。
- start_timestr
时间范围的开始。
- end_timestr
时间范围的结束。
返回
- pd.DataFrame:
从底层数据源加载的数据
异常
- KeyError:
如果不支持标的过滤,将引发KeyError
- class qlib.data.dataset.loader.DLWParser(config: list | tuple | dict)
(D)ata(L)oader (W)ith (P)arser(带解析器的数据加载器),用于处理特征和名称
提取此类以便QlibDataLoader和其他数据加载器(如QdbDataLoader)可以共享字段。
- __init__(config: list | tuple | dict)
参数
- configUnion[list, tuple, dict]
用于描述字段和列名的配置
<config> := { "group_name1": <fields_info1> "group_name2": <fields_info2> } 或 <config> := <fields_info> <fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...]) # 注意:列表或元组在解析时将被视为上述结构
- abstractmethod load_group_df(instruments, exprs: list, names: list, start_time: str | Timestamp = None, end_time: str | Timestamp = None, gp_name: str = None) DataFrame
load the dataframe for specific group
- 参数:
instruments -- the instruments.
exprs (list) -- the expressions to describe the content of the data.
names (list) -- the name of the data.
- 返回:
the queried dataframe.
- 返回类型:
pd.DataFrame
- load(instruments=None, start_time=None, end_time=None) DataFrame
以pd.DataFrame格式加载数据。
数据示例(列的多级索引是可选的):
feature label $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 datetime instrument 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
参数
- instrumentsstr或dict
可以是市场名称或由InstrumentProvider生成的标的配置文件。 如果instruments的值为None,则表示不进行过滤。
- start_timestr
时间范围的开始。
- end_timestr
时间范围的结束。
返回
- pd.DataFrame:
从底层数据源加载的数据
异常
- KeyError:
如果不支持标的过滤,将引发KeyError
- class qlib.data.dataset.loader.QlibDataLoader(config: Tuple[list, tuple, dict], filter_pipe: List = None, swap_level: bool = True, freq: str | dict = 'day', inst_processors: dict | list = None)
与QlibDataLoader相同。可以通过配置定义字段。
- __init__(config: Tuple[list, tuple, dict], filter_pipe: List = None, swap_level: bool = True, freq: str | dict = 'day', inst_processors: dict | list = None)
参数
- configTuple[list, tuple, dict]
请参考DLWParser的文档
- filter_pipe :
标的过滤管道
- swap_level :
是否交换多级索引的级别
- freq: dict或str
如果type(config) == dict且type(freq) == str,使用freq加载配置数据。 如果type(config) == dict且type(freq) == dict,使用freq[<group_name>]加载config[<group_name>]数据
- inst_processors: dict | list
如果inst_processors不为None且type(config) == dict;使用inst_processors[<group_name>]加载config[<group_name>]数据 如果inst_processors是列表,则将应用于所有组。
- load_group_df(instruments, exprs: list, names: list, start_time: str | Timestamp = None, end_time: str | Timestamp = None, gp_name: str = None) DataFrame
load the dataframe for specific group
- 参数:
instruments -- the instruments.
exprs (list) -- the expressions to describe the content of the data.
names (list) -- the name of the data.
- 返回:
the queried dataframe.
- 返回类型:
pd.DataFrame
- class qlib.data.dataset.loader.StaticDataLoader(config: dict | str | DataFrame, join='outer')
DataLoader that supports loading data from file or as provided.
- __init__(config: dict | str | DataFrame, join='outer')
- 参数:
config (dict) -- {fields_group: <path or object>}
join (str) -- How to align different dataframes
- load(instruments=None, start_time=None, end_time=None) DataFrame
以pd.DataFrame格式加载数据。
数据示例(列的多级索引是可选的):
feature label $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 datetime instrument 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
参数
- instrumentsstr或dict
可以是市场名称或由InstrumentProvider生成的标的配置文件。 如果instruments的值为None,则表示不进行过滤。
- start_timestr
时间范围的开始。
- end_timestr
时间范围的结束。
返回
- pd.DataFrame:
从底层数据源加载的数据
异常
- KeyError:
如果不支持标的过滤,将引发KeyError
- class qlib.data.dataset.loader.NestedDataLoader(dataloader_l: List[Dict], join='left')
We have multiple DataLoader, we can use this class to combine them.
- __init__(dataloader_l: List[Dict], join='left') None
- 参数:
dataloader_l (list[dict]) --
A list of dataloader, for exmaple
nd = NestedDataLoader( dataloader_l=[ { "class": "qlib.contrib.data.loader.Alpha158DL", }, { "class": "qlib.contrib.data.loader.Alpha360DL", "kwargs": { "config": { "label": ( ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]) } } } ] )
join -- it will pass to pd.concat when merging it.
- load(instruments=None, start_time=None, end_time=None) DataFrame
以pd.DataFrame格式加载数据。
数据示例(列的多级索引是可选的):
feature label $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 datetime instrument 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
参数
- instrumentsstr或dict
可以是市场名称或由InstrumentProvider生成的标的配置文件。 如果instruments的值为None,则表示不进行过滤。
- start_timestr
时间范围的开始。
- end_timestr
时间范围的结束。
返回
- pd.DataFrame:
从底层数据源加载的数据
异常
- KeyError:
如果不支持标的过滤,将引发KeyError
- class qlib.data.dataset.loader.DataLoaderDH(handler_config: dict, fetch_kwargs: dict = {}, is_group=False)
DataLoader based on (D)ata (H)andler It is designed to load multiple data from data handler - If you just want to load data from single datahandler, you can write them in single data handler
TODO: What make this module not that easy to use.
For online scenario
The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
- __init__(handler_config: dict, fetch_kwargs: dict = {}, is_group=False)
- 参数:
handler_config (dict) --
handler_config will be used to describe the handlers
<handler_config> := { "group_name1": <handler> "group_name2": <handler> } or <handler_config> := <handler> <handler> := DataHandler Instance | DataHandler Config
fetch_kwargs (dict) -- fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.
is_group (bool) -- is_group will be used to describe whether the key of handler_config is group
- load(instruments=None, start_time=None, end_time=None) DataFrame
以pd.DataFrame格式加载数据。
数据示例(列的多级索引是可选的):
feature label $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 datetime instrument 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
参数
- instrumentsstr或dict
可以是市场名称或由InstrumentProvider生成的标的配置文件。 如果instruments的值为None,则表示不进行过滤。
- start_timestr
时间范围的开始。
- end_timestr
时间范围的结束。
返回
- pd.DataFrame:
从底层数据源加载的数据
异常
- KeyError:
如果不支持标的过滤,将引发KeyError
数据处理器
- class qlib.data.dataset.handler.DataHandlerABC(*args, **kwargs)
数据处理器接口。
此类不假设数据处理器的内部数据结构。 它仅为外部用户定义接口(使用DataFrame作为内部数据结构)。
未来,数据处理器的更详细实现应进行重构。以下是一些指导原则:
它包含几个组件:
[数据加载器] -> 数据的内部表示 -> 数据预处理 -> 获取接口的适配器
组合所有组件的工作流程: 工作流程可能非常复杂。DataHandlerLP是其中一种实践,但无法满足所有需求。 因此,为用户提供实现工作流程的灵活性是更合理的选择。
- __init__(*args, **kwargs)
我们应该定义如何为数据获取做好准备。
- class qlib.data.dataset.handler.DataHandler(instruments=None, start_time=None, end_time=None, data_loader: dict | str | DataLoader = None, init_data=True, fetch_orig=True)
DataHandler的设计动机:
- 它提供了BaseDataHandler的一种实现,具体包括:
使用内部加载的DataFrame处理响应
DataFrame由数据加载器加载
使用处理器的步骤: 1. 初始化数据处理器(通过`init`调用)。 2. 使用数据。
数据处理器尝试维护具有两级索引的处理器: `datetime`(日期时间)和`instruments`(标的)。
支持任何顺序的索引级别(顺序将由数据暗示)。 当数据框索引名称缺失时,将使用<datetime, instruments>顺序。
数据示例: 列的多级索引是可选的。
feature label $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0 datetime instrument 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032 SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
Tips for improving the performance of datahandler - Fetching data with col_set=CS_RAW will return the raw data and may avoid pandas from copying the data when calling loc
- __init__(instruments=None, start_time=None, end_time=None, data_loader: dict | str | DataLoader = None, init_data=True, fetch_orig=True)
参数
- instruments :
要检索的股票列表。
- start_time :
原始数据的开始时间。
- end_time :
原始数据的结束时间。
- data_loaderUnion[dict, str, DataLoader]
用于加载数据的数据加载器。
- init_data :
在构造函数中初始化原始数据。
- fetch_origbool
如果可能,返回原始数据而不是副本。
- config(**kwargs)
数据配置。 # 从数据源加载哪些数据
此方法将在从数据集加载序列化的处理器时使用。 数据将使用不同的时间范围进行初始化。
- setup_data(enable_cache: bool = False)
设置数据,以防多次运行初始化
负责维护以下变量: 1) self._data
参数
- enable_cachebool
默认值为false:
- 如果`enable_cache` == True:
处理后的数据将保存到磁盘,下次调用`init`时处理器将直接从磁盘加载缓存数据
- fetch(selector: Timestamp | slice | str | Index = slice(None, None, None), level: str | int = 'datetime', col_set: str | List[str] = '__all', data_key: Literal['raw', 'infer', 'learn'] = 'infer', squeeze: bool = False, proc_func: Callable | None = None) DataFrame
从底层数据源获取数据
设计动机: - 为底层数据提供统一接口 - 潜在地使接口更友好 - 用户可以在此额外层中提高数据获取性能
参数
- selectorUnion[pd.Timestamp, slice, str]
描述如何按索引选择数据 可以分为以下几类:
获取单个索引
获取索引范围
切片范围
特定索引的pd.Index
可能出现以下冲突:
["20200101", "20210101"]是表示选择此切片还是这两天?
切片具有更高优先级
- levelUnion[str, int]
选择数据的索引级别
col_set : Union[str, List[str]]
如果是str类型:
选择一组有意义的pd.Index列(例如特征、列)
如果col_set == CS_RAW:
将返回原始数据集
如果是List[str]类型:
选择几组有意义的列,返回的数据具有多级索引
proc_func: Callable
提供在获取数据前处理数据的钩子
解释此钩子必要性的示例:
数据集学习了一些与数据分割相关的处理器
每次准备数据时都会应用它们
学习到的处理器要求数据框在拟合和应用时保持相同格式
然而数据格式会根据参数变化
因此处理器应应用于底层数据
- squeezebool
是否压缩列和索引
返回
pd.DataFrame.
- get_cols(col_set='__all') list
get the column names
- 参数:
col_set (str) -- select a set of meaningful columns.(e.g. features, columns)
- 返回:
list of column names
- 返回类型:
list
- get_range_selector(cur_date: Timestamp | str, periods: int) slice
get range selector by number of periods
- 参数:
cur_date (pd.Timestamp or str) -- current date
periods (int) -- number of periods
- get_range_iterator(periods: int, min_periods: int | None = None, **kwargs) Iterator[Tuple[Timestamp, DataFrame]]
get an iterator of sliced data with given periods
- 参数:
periods (int) -- number of periods.
min_periods (int) -- minimum periods for sliced dataframe.
kwargs (dict) -- will be passed to self.fetch.
- class qlib.data.dataset.handler.DataHandlerLP(instruments=None, start_time=None, end_time=None, data_loader: dict | str | DataLoader = None, infer_processors: List = [], learn_processors: List = [], shared_processors: List = [], process_type='append', drop_raw=False, **kwargs)
Motivation: - For the case that we hope using different processor workflows for learning and inference;
DataHandler with (L)earnable (P)rocessor
This handler will produce three pieces of data in pd.DataFrame format.
DK_R / self._data: the raw data loaded from the loader
DK_I / self._infer: the data processed for inference
DK_L / self._learn: the data processed for learning model.
The motivation of using different processor workflows for learning and inference Here are some examples.
The instrument universe for learning and inference may be different.
The processing of some samples may rely on label (for example, some samples hit the limit may need extra processing or be dropped).
These processors only apply to the learning phase.
Tips for data handler
To reduce the memory cost
drop_raw=True: this will modify the data inplace on raw data;
Please note processed data like self._infer or self._learn are concepts different from segments in Qlib's Dataset like "train" and "test"
Processed data like self._infer or self._learn are underlying data processed with different processors
segments in Qlib's Dataset like "train" and "test" are simply the time segmentations when querying data("train" are often before "test" in time-series).
For example, you can query data._infer processed by infer_processors in the "train" time segmentation.
- __init__(instruments=None, start_time=None, end_time=None, data_loader: dict | str | DataLoader = None, infer_processors: List = [], learn_processors: List = [], shared_processors: List = [], process_type='append', drop_raw=False, **kwargs)
- 参数:
infer_processors (list) --
list of <description info> of processors to generate data for inference
example of <description info>:
1) classname & kwargs: { "class": "MinMaxNorm", "kwargs": { "fit_start_time": "20080101", "fit_end_time": "20121231" } } 2) Only classname: "DropnaFeature" 3) object instance of Processor
learn_processors (list) -- similar to infer_processors, but for generating data for learning models
process_type (str) --
PTYPE_I = 'independent'
self._infer will be processed by infer_processors
self._learn will be processed by learn_processors
PTYPE_A = 'append'
self._infer will be processed by infer_processors
self._learn will be processed by infer_processors + learn_processors
(e.g. self._infer processed by learn_processors )
drop_raw (bool) -- Whether to drop the raw data
- fit()
fit data without processing the data
- fit_process_data()
fit and process data
The input of the fit will be the output of the previous processor
- process_data(with_fit: bool = False)
process_data data. Fun processor.fit if necessary
Notation: (data) [processor]
# data processing flow of self.process_type == DataHandlerLP.PTYPE_I
(self._data)-[shared_processors]-(_shared_df)-[learn_processors]-(_learn_df) \ -[infer_processors]-(_infer_df)
# data processing flow of self.process_type == DataHandlerLP.PTYPE_A
(self._data)-[shared_processors]-(_shared_df)-[infer_processors]-(_infer_df)-[learn_processors]-(_learn_df)
- 参数:
with_fit (bool) -- The input of the fit will be the output of the previous processor
- config(processor_kwargs: dict = None, **kwargs)
configuration of data. # what data to be loaded from data source
This method will be used when loading pickled handler from dataset. The data will be initialized with different time range.
- setup_data(init_type: str = 'fit_seq', **kwargs)
Set up the data in case of running initialization for multiple time
- 参数:
init_type (str) -- The type IT_* listed above.
enable_cache (bool) --
default value is false:
if enable_cache == True:
the processed data will be saved on disk, and handler will load the cached data from the disk directly when we call init next time
- fetch(selector: Timestamp | slice | str = slice(None, None, None), level: str | int = 'datetime', col_set='__all', data_key: Literal['raw', 'infer', 'learn'] = 'infer', squeeze: bool = False, proc_func: Callable = None) DataFrame
fetch data from underlying data source
- 参数:
selector (Union[pd.Timestamp, slice, str]) -- describe how to select data by index.
level (Union[str, int]) -- which index level to select the data.
col_set (str) -- select a set of meaningful columns.(e.g. features, columns).
data_key (str) -- the data to fetch: DK_*.
proc_func (Callable) -- please refer to the doc of DataHandler.fetch
- 返回类型:
pd.DataFrame
- get_cols(col_set='__all', data_key: Literal['raw', 'infer', 'learn'] = 'infer') list
get the column names
- 参数:
col_set (str) -- select a set of meaningful columns.(e.g. features, columns).
data_key (DATA_KEY_TYPE) -- the data to fetch: DK_*.
- 返回:
list of column names
- 返回类型:
list
- classmethod cast(handler: DataHandlerLP) DataHandlerLP
Motivation
A user creates a datahandler in his customized package. Then he wants to share the processed handler to other users without introduce the package dependency and complicated data processing logic.
This class make it possible by casting the class to DataHandlerLP and only keep the processed data
- 参数:
handler (DataHandlerLP) -- A subclass of DataHandlerLP
- 返回:
the converted processed data
- 返回类型:
- classmethod from_df(df: DataFrame) DataHandlerLP
Motivation: - When user want to get a quick data handler.
The created data handler will have only one shared Dataframe without processors. After creating the handler, user may often want to dump the handler for reuse Here is a typical use case
from qlib.data.dataset import DataHandlerLP dh = DataHandlerLP.from_df(df) dh.to_pickle(fname, dump_all=True)
TODO: - The StaticDataLoader is quite slow. It don't have to copy the data again...
处理器
- qlib.data.dataset.processor.get_group_columns(df: DataFrame, group: str | None)
从多级索引列DataFrame中获取一组列
参数
- dfpd.DataFrame
具有多级列的DataFrame。
- groupstr
特征组的名称,即组索引的第一级值。
- class qlib.data.dataset.processor.Processor
-
- readonly() bool
处理器在处理时是否将输入数据视为只读(即不写入输入数据)
了解只读信息有助于处理器避免不必要的复制
- class qlib.data.dataset.processor.DropnaProcessor(fields_group=None)
-
- readonly()
处理器在处理时是否将输入数据视为只读(即不写入输入数据)
了解只读信息有助于处理器避免不必要的复制
- class qlib.data.dataset.processor.DropnaLabel(fields_group='label')
-
- is_for_infer() bool
根据标签删除样本,因此不可用于推理
- class qlib.data.dataset.processor.DropCol(col_list=[])
-
- readonly()
处理器在处理时是否将输入数据视为只读(即不写入输入数据)
了解只读信息有助于处理器避免不必要的复制
- class qlib.data.dataset.processor.FilterCol(fields_group='feature', col_list=[])
- __init__(fields_group='feature', col_list=[])
参数
- fields_groupstr, 默认"feature"
要筛选的字段组名称。
- col_listlist, 默认[]
要保留的列名列表。
- readonly()
处理器在处理时是否将输入数据视为只读(即不写入输入数据)
了解只读信息有助于处理器避免不必要的复制
- class qlib.data.dataset.processor.TanhProcess
使用tanh处理噪声数据
- class qlib.data.dataset.processor.ProcessInf
处理无穷大值
- class qlib.data.dataset.processor.Fillna(fields_group=None, fill_value=0)
处理NaN值
- class qlib.data.dataset.processor.MinMaxNorm(fit_start_time, fit_end_time, fields_group=None)
- class qlib.data.dataset.processor.ZScoreNorm(fit_start_time, fit_end_time, fields_group=None)
ZScore标准化
对数据进行ZScore标准化处理:(x - mean) / std
- __init__(fit_start_time, fit_end_time, fields_group=None)
- class qlib.data.dataset.processor.RobustZScoreNorm(fit_start_time, fit_end_time, fields_group=None, clip_outlier=True)
鲁棒ZScore标准化
- 使用鲁棒统计量进行ZScore标准化:
均值(x) = 中位数(x) 标准差(x) = 中位数绝对偏差(MAD) * 1.4826
- 参考:
- __init__(fit_start_time, fit_end_time, fields_group=None, clip_outlier=True)
- class qlib.data.dataset.processor.CSZScoreNorm(fields_group=None, method='zscore')
横截面ZScore标准化
对每个时间点的横截面数据进行ZScore标准化处理
- __init__(fields_group=None, method='zscore')
- class qlib.data.dataset.processor.CSRankNorm(fields_group=None)
横截面排名标准化
"横截面"通常用于描述数据操作。 对不同股票的操作通常称为横截面操作。
例如,CSRankNorm是按天分组并对每天的所有股票进行排名操作。
关于3.46和0.5的解释:
import numpy as np import pandas as pd x = np.random.random(10000) # 任意变量 x_rank = pd.Series(x).rank(pct=True) # 转换为排名后将是均匀分布 x_rank_norm = (x_rank - x_rank.mean()) / x_rank.std() # 通常我们会将其标准化为正态分布 x_rank.mean() # 对应0.5 1 / x_rank.std() # 对应3.46
- __init__(fields_group=None)
- class qlib.data.dataset.processor.CSZFillna(fields_group=None)
横截面填充缺失值
对每个时间点的横截面数据填充缺失值为该时间点的均值
- __init__(fields_group=None)
- class qlib.data.dataset.processor.HashStockFormat
将数据框处理为哈希股票存储格式
将DataFrame转换为HashingStockStorage格式
- class qlib.data.dataset.processor.TimeRangeFlt(start_time: Timestamp | str | None = None, end_time: Timestamp | str | None = None, freq: str = 'day')
股票时间范围过滤器
只保留从start_time到end_time存在的数据(不检查中间是否存在) 警告:可能导致数据泄露!!!
扩展
模型
- class qlib.model.base.Model
可学习模型
- fit(dataset: Dataset, reweighter: Reweighter)
从基础模型学习模型
备注
学习模型的属性名称不应以'_'开头,以便模型可以序列化到磁盘。
以下代码示例展示如何从dataset获取`x_train`、y_train`和`w_train:
# 获取特征和标签 df_train, df_valid = dataset.prepare( ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] # 获取权重 try: wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L) w_train, w_valid = wdf_train["weight"], wdf_valid["weight"] except KeyError as e: w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index) w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
参数
- datasetDataset
数据集将生成模型训练所需的处理后的数据
- class qlib.model.base.ModelFT
可微调模型
- abstractmethod finetune(dataset: Dataset)
基于给定数据集微调模型
使用qlib.workflow.R微调模型的典型用例:
# 开始实验训练初始模型 with R.start(experiment_name="init models"): model.fit(dataset) R.save_objects(init_model=model) rid = R.get_recorder().id # 基于之前训练的模型进行微调 with R.start(experiment_name="finetune model"): recorder = R.get_recorder(recorder_id=rid, experiment_name="init models") model = recorder.load_object("init_model") model.finetune(dataset, num_boost_round=10)
参数
- datasetDataset
数据集将生成模型训练所需的处理后的数据
策略
- class qlib.contrib.strategy.TopkDropoutStrategy(*, topk, n_drop, method_sell='bottom', method_buy='top', hold_thresh=1, only_tradable=False, forbid_all_trade_at_limit=True, **kwargs)
- __init__(*, topk, n_drop, method_sell='bottom', method_buy='top', hold_thresh=1, only_tradable=False, forbid_all_trade_at_limit=True, **kwargs)
参数
- topkint
投资组合中的股票数量
- n_dropint
每个交易日要替换的股票数量
- method_sellstr
卖出方法,random/bottom
- method_buystr
买入方法,random/top
- hold_threshint
最小持有天数 在卖出股票前,会检查current.get_stock_count(order.stock_id) >= self.hold_thresh
- only_tradablebool
策略是否只考虑可交易股票
如果only_tradable为True:
策略将根据股票的可交易状态做出决策,避免买卖不可交易股票
否则:
策略将在不考虑股票可交易状态的情况下做出买卖决策
- forbid_all_trade_at_limitbool
是否在达到涨跌停时禁止所有交易
如果forbid_all_trade_at_limit为True:
当价格达到涨跌停时,策略不会进行任何交易,即使在现实中允许在涨停卖出和在跌停买入
否则:
策略将在涨停卖出和在跌停买入
- class qlib.contrib.strategy.WeightStrategyBase(*, order_generator_cls_or_obj=<class 'qlib.contrib.strategy.order_generator.OrderGenWOInteract'>, **kwargs)
- __init__(*, order_generator_cls_or_obj=<class 'qlib.contrib.strategy.order_generator.OrderGenWOInteract'>, **kwargs)
- signal :
描述信号的信息。请参考`qlib.backtest.signal.create_signal_from`的文档 策略的决策将基于给定的信号
- trade_exchangeExchange
提供市场信息的交易所,用于处理订单和生成报告
如果`trade_exchange`为None,self.trade_exchange将通过common_infra设置
允许在不同的执行中使用不同的trade_exchange
例如:
在日频执行中,日频和分钟频交易所都可用,但推荐使用日频交易所,因为它运行更快
在分钟频执行中,日频交易所不可用,只推荐使用分钟频交易所
- class qlib.contrib.strategy.EnhancedIndexingStrategy(*, riskmodel_root, market='csi500', turn_limit=None, name_mapping={}, optimizer_kwargs={}, verbose=False, **kwargs)
增强指数策略
增强指数策略结合了主动管理和被动管理的艺术, 旨在控制风险敞口(又称跟踪误差)的同时超越基准指数(如标普500)的投资组合回报。
用户需要准备如下的风险模型数据:
├── /path/to/riskmodel ├──── 20210101 ├────── factor_exp.{csv|pkl|h5} ├────── factor_cov.{csv|pkl|h5} ├────── specific_risk.{csv|pkl|h5} ├────── blacklist.{csv|pkl|h5} # 可选
风险模型数据可以从风险数据提供商处获取。你也可以使用 `qlib.model.riskmodel.structured.StructuredCovEstimator`来准备这些数据。
- 参数:
riskmodel_path (str): 风险模型路径 name_mapping (dict): 替代文件名
- __init__(*, riskmodel_root, market='csi500', turn_limit=None, name_mapping={}, optimizer_kwargs={}, verbose=False, **kwargs)
- signal :
描述信号的信息。请参考`qlib.backtest.signal.create_signal_from`的文档 策略的决策将基于给定的信号
- trade_exchangeExchange
提供市场信息的交易所,用于处理订单和生成报告
如果`trade_exchange`为None,self.trade_exchange将通过common_infra设置
允许在不同的执行中使用不同的trade_exchange
例如:
在日频执行中,日频和分钟频交易所都可用,但推荐使用日频交易所,因为它运行更快
在分钟频执行中,日频交易所不可用,只推荐使用分钟频交易所
- generate_target_weight_position(score, current, trade_start_time, trade_end_time)
生成目标持仓权重
- 参数:
score: 预测分数 current: 当前持仓 trade_start_time: 交易开始时间 trade_end_time: 交易结束时间
- 返回:
目标持仓权重字典
- get_risk_data(date)
获取指定日期的风险数据
- 参数:
date: 日期
- 返回:
因子暴露矩阵、因子协方差矩阵、特定风险向量、股票列表、黑名单
- class qlib.contrib.strategy.TWAPStrategy(outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, trade_exchange: Exchange = None)
TWAP交易策略
- 注意:
此TWAP策略在交易时会向上取整。当总交易单位数量小于交易步骤时,这将使TWAP交易策略更早生成订单
- class qlib.contrib.strategy.SBBStrategyBase(outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, trade_exchange: Exchange = None)
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
- generate_trade_decision(execute_result=None)
在每个交易bar生成交易决策
参数
- execute_resultList[object], optional
交易决策的执行结果,默认为None
首次调用generate_trade_decision时,`execute_result`可能为None
- reset(outer_trade_decision: BaseTradeDecision = None, **kwargs)
- 参数:
outer_trade_decision (BaseTradeDecision, optional)
- class qlib.contrib.strategy.SBBStrategyEMA(outer_trade_decision: BaseTradeDecision = None, instruments: List | str = 'csi300', freq: str = 'day', trade_exchange: Exchange = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, **kwargs)
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA) signal.
- __init__(outer_trade_decision: BaseTradeDecision = None, instruments: List | str = 'csi300', freq: str = 'day', trade_exchange: Exchange = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, **kwargs)
- 参数:
instruments (Union[List, str], optional) -- instruments of EMA signal, by default "csi300"
freq (str, optional) -- freq of EMA signal, by default "day" Note: freq may be different from time_per_step
- reset_level_infra(level_infra)
reset level-shared infra - After reset the trade calendar, the signal will be changed
- class qlib.contrib.strategy.SoftTopkStrategy(model, dataset, topk, order_generator_cls_or_obj=<class 'qlib.contrib.strategy.order_generator.OrderGenWInteract'>, max_sold_weight=1.0, risk_degree=0.95, buy_method='first_fill', trade_exchange=None, level_infra=None, common_infra=None, **kwargs)
- __init__(model, dataset, topk, order_generator_cls_or_obj=<class 'qlib.contrib.strategy.order_generator.OrderGenWInteract'>, max_sold_weight=1.0, risk_degree=0.95, buy_method='first_fill', trade_exchange=None, level_infra=None, common_infra=None, **kwargs)
参数
- topkint
要购买的前N只股票
- risk_degreefloat
- 总价值的持仓百分比。buy_method 选项:
rank_fill: 优先为排名靠前的股票分配权重(最大1/topk) average_fill: 为排名靠前的股票平均分配权重。
- generate_target_weight_position(score, current, trade_start_time, trade_end_time)
参数
- score:
该交易日的预测分数,pd.Series类型,索引为股票ID,包含'score'列
- current:
当前持仓,使用Position()类
- trade_date:
交易日
根据当日分数和当前持仓生成目标持仓。 持仓中未考虑缓存。
- get_risk_degree(trade_step=None)
返回将用于投资的总价值比例。动态调整risk_degree将导致市场择时。
评估
- qlib.contrib.evaluate.risk_analysis(r, N: int = None, freq: str = 'day', mode: Literal['sum', 'product'] = 'sum')
风险分析 注意: 年化收益率的计算方式与年化收益率的定义有所不同。 这是有意为之的设计实现。 Qlib尝试通过求和而非乘积来累积收益,以避免累积曲线呈指数型扭曲。 Qlib中所有年化收益率的计算均遵循此原则。
参数
- qlib.contrib.evaluate.indicator_analysis(df, method='mean')
analyze statistical time-series indicators of trading
- 参数:
df (pandas.DataFrame) --
- columns: like ['pa', 'pos', 'ffr', 'deal_amount', 'value'].
- Necessary fields:
'pa' is the price advantage in trade indicators
'pos' is the positive rate in trade indicators
'ffr' is the fulfill rate in trade indicators
- Optional fields:
'deal_amount' is the total deal deal_amount, only necessary when method is 'amount_weighted'
'value' is the total trade value, only necessary when method is 'value_weighted'
index: Index(datetime)
method (str, optional) --
statistics method of pa/ffr, by default "mean"
if method is 'mean', count the mean statistical value of each trade indicator
if method is 'amount_weighted', count the deal_amount weighted mean statistical value of each trade indicator
if method is 'value_weighted', count the value weighted mean statistical value of each trade indicator
Note: statistics method of pos is always "mean"
- 返回:
statistical value of each trade indicators
- 返回类型:
pd.DataFrame
- qlib.contrib.evaluate.backtest_daily(start_time: str | Timestamp, end_time: str | Timestamp, strategy: str | dict | BaseStrategy, executor: str | dict | BaseExecutor = None, account: float | int | Position = 100000000.0, benchmark: str = 'SH000300', exchange_kwargs: dict = None, pos_type: str = 'Position')
initialize the strategy and executor, then executor the backtest of daily frequency
- 参数:
start_time (Union[str, pd.Timestamp]) -- closed start time for backtest NOTE: This will be applied to the outmost executor's calendar.
end_time (Union[str, pd.Timestamp]) -- closed end time for backtest NOTE: This will be applied to the outmost executor's calendar. E.g. Executor[day](Executor[1min]), setting end_time == 20XX0301 will include all the minutes on 20XX0301
strategy (Union[str, dict, BaseStrategy]) --
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.
E.g.
# dict strategy = { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { "signal": (model, dataset), "topk": 50, "n_drop": 5, }, } # BaseStrategy pred_score = pd.read_pickle("score.pkl")["score"] STRATEGY_CONFIG = { "topk": 50, "n_drop": 5, "signal": pred_score, } strategy = TopkDropoutStrategy(**STRATEGY_CONFIG) # str example. # 1) specify a pickle object # - path like 'file:///<path to pickle file>/obj.pkl' # 2) specify a class name # - "ClassName": getattr(module, "ClassName")() will be used. # 3) specify module path with class name # - "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
executor (Union[str, dict, BaseExecutor]) -- for initializing the outermost executor.
benchmark (str) -- the benchmark for reporting.
account (Union[float, int, Position]) --
information for describing how to creating the account
For float or int:
Using Account with only initial cash
For Position:
Using Account with a Position
exchange_kwargs (dict) --
the kwargs for initializing Exchange E.g.
exchange_kwargs = { "freq": freq, "limit_threshold": None, # limit_threshold is None, using C.limit_threshold "deal_price": None, # deal_price is None, using C.deal_price "open_cost": 0.0005, "close_cost": 0.0015, "min_cost": 5, }
pos_type (str) -- the type of Position.
- 返回:
report_normal (pd.DataFrame) -- backtest report
positions_normal (pd.DataFrame) -- backtest positions
- qlib.contrib.evaluate.long_short_backtest(pred, topk=50, deal_price=None, shift=1, open_cost=0, close_cost=0, trade_unit=None, limit_threshold=None, min_cost=5, subscribe_fields=[], extract_codes=False)
A backtest for long-short strategy
- 参数:
pred -- The trading signal produced on day T.
topk -- The short topk securities and long topk securities.
deal_price -- The price to deal the trading.
shift -- Whether to shift prediction by one day. The trading day will be T+1 if shift==1.
open_cost -- open transaction cost.
close_cost -- close transaction cost.
trade_unit -- 100 for China A.
limit_threshold -- limit move 0.1 (10%) for example, long and short with same limit.
min_cost -- min transaction cost.
subscribe_fields -- subscribe fields.
extract_codes -- bool. will we pass the codes extracted from the pred to the exchange. NOTE: This will be faster with offline qlib.
- 返回:
The result of backtest, it is represented by a dict. { "long": long_returns(excess), "short": short_returns(excess), "long_short": long_short_returns}
报告
- qlib.contrib.report.analysis_position.report.report_graph(report_df: DataFrame, show_notebook: bool = True) [<class 'list'>, <class 'tuple'>]
display backtest report
Example:
import qlib import pandas as pd from qlib.utils.time import Freq from qlib.utils import flatten_dict from qlib.backtest import backtest, executor from qlib.contrib.evaluate import risk_analysis from qlib.contrib.strategy import TopkDropoutStrategy # init qlib qlib.init(provider_uri=<qlib data dir>) CSI300_BENCH = "SH000300" FREQ = "day" STRATEGY_CONFIG = { "topk": 50, "n_drop": 5, # pred_score, pd.Series "signal": pred_score, } EXECUTOR_CONFIG = { "time_per_step": "day", "generate_portfolio_metrics": True, } backtest_config = { "start_time": "2017-01-01", "end_time": "2020-08-01", "account": 100000000, "benchmark": CSI300_BENCH, "exchange_kwargs": { "freq": FREQ, "limit_threshold": 0.095, "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, "min_cost": 5, }, } # strategy object strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG) # executor object executor_obj = executor.SimulatorExecutor(**EXECUTOR_CONFIG) # backtest portfolio_metric_dict, indicator_dict = backtest(executor=executor_obj, strategy=strategy_obj, **backtest_config) analysis_freq = "{0}{1}".format(*Freq.parse(FREQ)) # backtest info report_normal_df, positions_normal = portfolio_metric_dict.get(analysis_freq) qcr.analysis_position.report_graph(report_normal_df)
- 参数:
report_df --
df.index.name must be date, df.columns must contain return, turnover, cost, bench.
return cost bench turnover date 2017-01-04 0.003421 0.000864 0.011693 0.576325 2017-01-05 0.000508 0.000447 0.000721 0.227882 2017-01-06 -0.003321 0.000212 -0.004322 0.102765 2017-01-09 0.006753 0.000212 0.006874 0.105864 2017-01-10 -0.000416 0.000440 -0.003350 0.208396
show_notebook -- whether to display graphics in notebook, the default is True.
- 返回:
if show_notebook is True, display in notebook; else return plotly.graph_objs.Figure list.
- qlib.contrib.report.analysis_position.score_ic.score_ic_graph(pred_label: DataFrame, show_notebook: bool = True, **kwargs) [<class 'list'>, <class 'tuple'>]
分数IC图表
示例:
from qlib.data import D from qlib.contrib.report import analysis_position pred_df_dates = pred_df.index.get_level_values(level='datetime') features_df = D.features(D.instruments('csi500'), ['Ref($close, -2)/Ref($close, -1)-1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] pred_label = pd.concat([features_df, pred], axis=1, sort=True).reindex(features_df.index) analysis_position.score_ic_graph(pred_label)
- 参数:
pred_label --
索引为**pd.MultiIndex**,索引名称为**[instrument, datetime]**;列名为**[score, label]**.
instrument datetime score label SH600004 2017-12-11 -0.013502 -0.013502 2017-12-12 -0.072367 -0.072367 2017-12-13 -0.068605 -0.068605 2017-12-14 0.012440 0.012440 2017-12-15 -0.102778 -0.102778
show_notebook -- whether to display graphics in notebook, the default is True.
- 返回:
if show_notebook is True, display in notebook; else return plotly.graph_objs.Figure list.
- qlib.contrib.report.analysis_position.cumulative_return.cumulative_return_graph(position: dict, report_normal: DataFrame, label_data: DataFrame, show_notebook=True, start_date=None, end_date=None) Iterable[Figure]
Backtest buy, sell, and holding cumulative return graph
Example:
from qlib.data import D from qlib.contrib.evaluate import risk_analysis, backtest, long_short_backtest from qlib.contrib.strategy import TopkDropoutStrategy # backtest parameters bparas = {} bparas['limit_threshold'] = 0.095 bparas['account'] = 1000000000 sparas = {} sparas['topk'] = 50 sparas['n_drop'] = 5 strategy = TopkDropoutStrategy(**sparas) report_normal_df, positions = backtest(pred_df, strategy, **bparas) pred_df_dates = pred_df.index.get_level_values(level='datetime') features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)
Graph desc:
Axis X: Trading day.
Axis Y:
Above axis Y: (((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum().
Below axis Y: Daily weight sum.
In the sell graph, y < 0 stands for profit; in other cases, y > 0 stands for profit.
In the buy_minus_sell graph, the y value of the weight graph at the bottom is buy_weight + sell_weight.
In each graph, the red line in the histogram on the right represents the average.
- 参数:
position -- position data
report_normal --
return cost bench turnover date 2017-01-04 0.003421 0.000864 0.011693 0.576325 2017-01-05 0.000508 0.000447 0.000721 0.227882 2017-01-06 -0.003321 0.000212 -0.004322 0.102765 2017-01-09 0.006753 0.000212 0.006874 0.105864 2017-01-10 -0.000416 0.000440 -0.003350 0.208396
label_data --
D.features result; index is pd.MultiIndex, index name is [instrument, datetime]; columns names is [label].
The label T is the change from T to T+1, it is recommended to use
close
, example: D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])label instrument datetime SH600004 2017-12-11 -0.013502 2017-12-12 -0.072367 2017-12-13 -0.068605 2017-12-14 0.012440 2017-12-15 -0.102778
show_notebook -- True or False. If True, show graph in notebook, else return figures
start_date -- start date
end_date -- end date
- 返回:
- qlib.contrib.report.analysis_position.risk_analysis.risk_analysis_graph(analysis_df: DataFrame = None, report_normal_df: DataFrame = None, report_long_short_df: DataFrame = None, show_notebook: bool = True) Iterable[Figure]
生成分析图表和月度分析
示例:
import qlib import pandas as pd from qlib.utils.time import Freq from qlib.utils import flatten_dict from qlib.backtest import backtest, executor from qlib.contrib.evaluate import risk_analysis from qlib.contrib.strategy import TopkDropoutStrategy # 初始化qlib qlib.init(provider_uri=<qlib数据目录>) CSI300_BENCH = "SH000300" FREQ = "day" STRATEGY_CONFIG = { "topk": 50, "n_drop": 5, # pred_score, pd.Series "signal": pred_score, } EXECUTOR_CONFIG = { "time_per_step": "day", "generate_portfolio_metrics": True, } backtest_config = { "start_time": "2017-01-01", "end_time": "2020-08-01", "account": 100000000, "benchmark": CSI300_BENCH, "exchange_kwargs": { "freq": FREQ, "limit_threshold": 0.095, "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, "min_cost": 5, }, } # 策略对象 strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG) # 执行器对象 executor_obj = executor.SimulatorExecutor(** EXECUTOR_CONFIG) # 回测 portfolio_metric_dict, indicator_dict = backtest(executor=executor_obj, strategy=strategy_obj, **backtest_config) analysis_freq = "{0}{1}".format(*Freq.parse(FREQ)) # 回测信息 report_normal_df, positions_normal = portfolio_metric_dict.get(analysis_freq) analysis = dict() analysis["excess_return_without_cost"] = risk_analysis( report_normal_df["return"] - report_normal_df["bench"], freq=analysis_freq ) analysis["excess_return_with_cost"] = risk_analysis( report_normal_df["return"] - report_normal_df["bench"] - report_normal_df["cost"], freq=analysis_freq ) analysis_df = pd.concat(analysis) # type: pd.DataFrame analysis_position.risk_analysis_graph(analysis_df, report_normal_df)
- 参数:
analysis_df --
分析数据,索引为**pd.MultiIndex**;列名为**[risk]**.
risk excess_return_without_cost mean 0.000692 std 0.005374 annualized_return 0.174495 information_ratio 2.045576 max_drawdown -0.079103 excess_return_with_cost mean 0.000499 std 0.005372 annualized_return 0.125625 information_ratio 1.473152 max_drawdown -0.088263
report_normal_df --
df.index.name**必须为**date,df.columns必须包含**return**、turnover、cost、bench.
return cost bench turnover date 2017-01-04 0.003421 0.000864 0.011693 0.576325 2017-01-05 0.000508 0.000447 0.000721 0.227882 2017-01-06 -0.003321 0.000212 -0.004322 0.102765 2017-01-09 0.006753 0.000212 0.006874 0.105864 2017-01-10 -0.000416 0.000440 -0.003350 0.208396
report_long_short_df --
df.index.name**必须为**date,df.columns包含**long**、short、long_short.
long short long_short date 2017-01-04 -0.001360 0.001394 0.000034 2017-01-05 0.002456 0.000058 0.002514 2017-01-06 0.000120 0.002739 0.002859 2017-01-09 0.001436 0.001838 0.003273 2017-01-10 0.000824 -0.001944 -0.001120
show_notebook -- 是否在notebook中显示图形,默认为**True**. 若为True,在notebook中显示图形 若为False,返回图形对象
- 返回:
图形对象列表
- qlib.contrib.report.analysis_position.rank_label.rank_label_graph(position: dict, label_data: DataFrame, start_date=None, end_date=None, show_notebook=True) Iterable[Figure]
Ranking percentage of stocks buy, sell, and holding on the trading day. Average rank-ratio(similar to sell_df['label'].rank(ascending=False) / len(sell_df)) of daily trading
Example:
from qlib.data import D from qlib.contrib.evaluate import backtest from qlib.contrib.strategy import TopkDropoutStrategy # backtest parameters bparas = {} bparas['limit_threshold'] = 0.095 bparas['account'] = 1000000000 sparas = {} sparas['topk'] = 50 sparas['n_drop'] = 230 strategy = TopkDropoutStrategy(**sparas) _, positions = backtest(pred_df, strategy, **bparas) pred_df_dates = pred_df.index.get_level_values(level='datetime') features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
- 参数:
position -- position data; qlib.backtest.backtest result.
label_data --
D.features result; index is pd.MultiIndex, index name is [instrument, datetime]; columns names is [label].
The label T is the change from T to T+1, it is recommended to use
close
, example: D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1']).label instrument datetime SH600004 2017-12-11 -0.013502 2017-12-12 -0.072367 2017-12-13 -0.068605 2017-12-14 0.012440 2017-12-15 -0.102778
start_date -- start date
end_date -- end_date
show_notebook -- True or False. If True, show graph in notebook, else return figures.
- 返回:
- qlib.contrib.report.analysis_model.analysis_model_performance.ic_figure(ic_df: DataFrame, show_nature_day=True, **kwargs) Figure
信息系数(IC)图表
- 参数:
ic_df -- IC数据框
show_nature_day -- 是否显示非交易日的横坐标
**kwargs -- 包含控制Plotly图表样式的参数,目前支持 - rangebreaks: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays
- 返回:
plotly.graph_objs.Figure对象
- qlib.contrib.report.analysis_model.analysis_model_performance.model_performance_graph(pred_label: DataFrame, lag: int = 1, N: int = 5, reverse=False, rank=False, graph_names: list = ['group_return', 'pred_ic', 'pred_autocorr'], show_notebook: bool = True, show_nature_day: bool = False, **kwargs) [<class 'list'>, <class 'tuple'>]
Model performance
- 参数:
pred_label --
index is pd.MultiIndex, index name is [instrument, datetime]; columns names is [score, label]. It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1").
instrument datetime score label SH600004 2017-12-11 -0.013502 -0.013502 2017-12-12 -0.072367 -0.072367 2017-12-13 -0.068605 -0.068605 2017-12-14 0.012440 0.012440 2017-12-15 -0.102778 -0.102778
lag -- pred.groupby(level='instrument', group_keys=False)['score'].shift(lag). It will be only used in the auto-correlation computing.
N -- group number, default 5.
reverse -- if True, pred['score'] *= -1.
rank -- if True, calculate rank ic.
graph_names -- graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover'].
show_notebook -- whether to display graphics in notebook, the default is True.
show_nature_day -- whether to display the abscissa of non-trading day.
**kwargs -- contains some parameters to control plot style in plotly. Currently, supports - rangebreaks: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays
- 返回:
if show_notebook is True, display in notebook; else return plotly.graph_objs.Figure list.
工作流
实验管理器
- class qlib.workflow.expm.ExpManager(uri: str, default_exp_name: str | None)
这是用于管理实验的`ExpManager`类。API设计类似于mlflow。 (链接: https://mlflow.org/docs/latest/python_api/mlflow.html)
ExpManager`预期是一个单例(同时,我们可以有多个具有不同uri的`Experiment。用户可以从不同的uri获取不同的实验,然后比较它们的记录)。全局配置(即`C`)也是一个单例。
因此我们尝试将它们对齐。它们共享同一个变量,称为**默认uri**。有关变量共享的详细信息,请参阅`ExpManager.default_uri`。
当用户开始一个实验时,用户可能希望将uri设置为特定的uri(在此期间它将覆盖**默认uri**),然后取消设置**特定uri**并回退到**默认uri**。`ExpManager._active_exp_uri`就是那个**特定uri**。
- __init__(uri: str, default_exp_name: str | None)
- start_exp(*, experiment_id: str | None = None, experiment_name: str | None = None, recorder_id: str | None = None, recorder_name: str | None = None, uri: str | None = None, resume: bool = False, **kwargs) Experiment
启动一个实验。该方法首先获取或创建一个实验,然后将其设置为活动状态。
`_active_exp_uri`的维护包含在start_exp中,剩余实现应包含在子类的_end_exp中
参数
- experiment_idstr
活动实验的ID
- experiment_namestr
活动实验的名称
- recorder_idstr
要启动的记录器ID
- recorder_namestr
要启动的记录器名称
- uristr
当前跟踪URI
- resumeboolean
是否恢复实验和记录器
返回
一个活动实验对象
- end_exp(recorder_status: str = 'SCHEDULED', **kwargs)
结束一个活动实验
`_active_exp_uri`的维护包含在end_exp中,剩余实现应包含在子类的_end_exp中
参数
- experiment_namestr
活动实验的名称
- recorder_statusstr
实验活动记录器的状态
- create_exp(experiment_name: str | None = None)
创建一个实验
参数
- experiment_namestr
实验名称,必须唯一
返回
一个实验对象
- raises ExpAlreadyExistError:
当实验已存在时抛出
- search_records(experiment_ids=None, **kwargs)
获取符合实验搜索条件的记录DataFrame 输入为用户想要应用的搜索条件
返回
一个pandas.DataFrame记录,其中每个指标、参数和标签 分别展开到名为metrics.*、params.*和tags.*的列中 对于没有特定指标、参数或标签的记录,它们的值将分别为(NumPy)Nan、None或None
- get_exp(*, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False)
检索一个实验。该方法包括获取活动实验,以及获取或创建特定实验
当用户指定实验ID和名称时,方法将尝试返回特定实验 当用户未提供记录器ID或名称时,方法将尝试返回当前活动实验 `create`参数决定如果实验尚未创建,方法是否根据用户规范自动创建新实验
如果`create`为True:
如果`活动实验`存在:
未指定ID或名称,返回活动实验
如果指定了ID或名称,返回指定实验。如果未找到,则使用给定ID或名称创建新实验。如果`start`设为True,实验将被设置为活动状态
如果`活动实验`不存在:
未指定ID或名称,创建默认实验
如果指定了ID或名称,返回指定实验。如果未找到,则使用给定ID或名称创建新实验。如果`start`设为True,实验将被设置为活动状态
如果`create`为False:
如果`活动实验`存在:
未指定ID或名称,返回活动实验
如果指定了ID或名称,返回指定实验。如果未找到,抛出错误
如果`活动实验`不存在:
未指定ID或名称。如果默认实验存在则返回,否则抛出错误
如果指定了ID或名称,返回指定实验。如果未找到,抛出错误
参数
- experiment_idstr
要返回的实验ID
- experiment_namestr
要返回的实验名称
- createboolean
如果实验尚未创建,是否创建它
- startboolean
如果创建了新实验,是否启动它
返回
一个实验对象
- delete_exp(experiment_id=None, experiment_name=None)
删除一个实验
参数
- experiment_idstr
实验ID
- experiment_namestr
实验名称
- property default_uri
从qlib.config.C获取默认跟踪URI
实验
- class qlib.workflow.exp.Experiment(id, name)
这是用于运行每个实验的`Experiment`类。API设计类似于mlflow。 (链接: https://mlflow.org/docs/latest/python_api/mlflow.html)
- __init__(id, name)
- start(*, recorder_id=None, recorder_name=None, resume=False)
开始实验并设置为活动状态。此方法还将启动一个新的记录器。
参数
- recorder_idstr
要创建的记录器ID
- recorder_namestr
要创建的记录器名称
- resumebool
是否恢复第一个记录器
返回
一个活动的记录器。
- end(recorder_status='SCHEDULED')
结束实验。
参数
- recorder_statusstr
结束时记录器要设置的状态(SCHEDULED, RUNNING, FINISHED, FAILED)。
- create_recorder(recorder_name=None)
为每个实验创建记录器
- 参数:
recorder_name (str) -- 要创建记录器的名称
- 返回类型:
记录器对象
- search_records(**kwargs)
获取符合实验搜索条件的记录DataFrame 输入为用户想要应用的搜索条件
- 返回:
包含记录的pandas.DataFrame,其中每个指标、参数和标签
都被展开到名为metrics.、params.*和tags.*的列中*
对于没有特定指标、参数或标签的记录,它们的值将分别为(NumPy)Nan、None或None
- get_recorder(recorder_id=None, recorder_name=None, create: bool = True, start: bool = False) Recorder
为用户检索记录器。当用户指定记录器ID和名称时,该方法会尝试返回特定的记录器。 当用户未提供记录器ID或名称时,该方法会尝试返回当前活动记录器。 `create`参数决定如果记录器尚未创建,该方法是否会根据用户规范自动创建新记录器。
如果`create`为True:
如果`活动记录器`存在:
未指定ID或名称,返回活动记录器。
如果指定了ID或名称,返回指定的记录器。如果未找到对应实验,则使用给定的ID或名称创建新记录器。如果`start`设为True,则将记录器设为活动状态。
如果`活动记录器`不存在:
未指定ID或名称,创建新记录器。
如果指定了ID或名称,返回指定的实验。如果未找到对应实验,则使用给定的ID或名称创建新记录器。如果`start`设为True,则将记录器设为活动状态。
如果`create`为False:
如果`活动记录器`存在:
未指定ID或名称,返回活动记录器。
如果指定了ID或名称,返回指定的记录器。如果未找到对应实验,则抛出错误。
如果`活动记录器`不存在:
未指定ID或名称,抛出错误。
如果指定了ID或名称,返回指定的记录器。如果未找到对应实验,则抛出错误。
参数
- recorder_idstr
要删除的记录器ID。
- recorder_namestr
要删除的记录器名称。
- createboolean
如果记录器尚未创建,则创建它。
- startboolean
如果创建了新记录器,则启动它。
返回
记录器对象。
- list_recorders(rtype: Literal['dict', 'list'] = 'dict', **flt_kwargs) List[Recorder] | Dict[str, Recorder]
列出本实验的所有现有记录器。调用此方法前请先获取实验实例。 如果想使用`R.list_recorders()`方法,请参考`QlibRecorder`中的相关API文档。
- flt_kwargsdict
按条件过滤记录器 例如:list_recorders(status=Recorder.STATUS_FI)
返回
- 返回类型取决于`rtype`
- 如果`rtype` == "dict":
存储的记录器信息的字典(id -> recorder)。
- 如果`rtype` == "list":
记录器列表。
记录器
- class qlib.workflow.recorder.Recorder(experiment_id, name)
This is the Recorder class for experiment recording, with API design similar to mlflow. (Link: https://mlflow.org/docs/latest/python_api/mlflow.html)
The recorder's status can be SCHEDULED, RUNNING, FINISHED or FAILED.
- __init__(experiment_id, name)
- save_objects(local_path=None, artifact_path=None, **kwargs)
保存对象如预测文件或模型检查点到artifact URI。用户 可以通过关键字参数(name:value)保存对象。
请参考qlib.workflow:R.save_objects的文档
参数
- local_pathstr
如果提供,则将文件或目录保存到artifact URI。
- artifact_path=Nonestr
存储在URI中的artifact的相对路径。
- start_run()
启动或恢复记录器。返回值可用作`with`块中的上下文管理器; 否则必须调用end_run()来终止当前运行。(参见mlflow中的`ActiveRun`类)
返回
一个活动运行对象(例如mlflow.ActiveRun对象)。
- end_run()
结束一个活动的记录器。
- log_artifact(local_path: str, artifact_path: str | None = None)
将本地文件或目录记录为当前活动运行的artifact。
参数
- local_pathstr
要写入的文件路径。
- artifact_pathOptional[str]
如果提供,则写入到``artifact_uri``中的目录。
- list_artifacts(artifact_path: str = None)
列出记录器的所有artifacts。
参数
- artifact_pathstr
artifact存储在URI中的相对路径。
返回
存储的artifacts信息列表(名称、路径等)。
记录模板
- class qlib.workflow.record_temp.RecordTemp(recorder)
这是记录模板类,用于生成特定格式的实验结果,例如IC和回测。
- save(**kwargs)
功能与self.recorder.save_objects相同,但提供了更简单的接口,用户无需关心`get_path`和`artifact_path`
- __init__(recorder)
- generate(**kwargs)
生成特定记录如IC、回测等,并保存它们。
- 参数:
kwargs
- load(name: str, parents: bool = True)
功能与self.recorder.load_object相同,但提供了更简单的接口,用户无需关心`get_path`和`artifact_path`
- 参数:
name (str) -- the name for the file to be load.
parents (bool) -- Each recorder has different artifact_path. So parents recursively find the path in parents Sub classes has higher priority
- 返回类型:
The stored records.
- list()
列出支持的工件,用户无需考虑self.get_path
- 返回类型:
A list of all the supported artifacts.
- check(include_self: bool = False, parents: bool = True)
检查记录是否正确生成和保存。在以下情况下很有用:
在生成新内容前检查依赖文件是否完整
检查最终文件是否完成
- 参数:
include_self (bool) -- is the file generated by self included
parents (bool) -- will we check parents
- 抛出:
FileNotFoundError -- whether the records are stored properly.
- class qlib.workflow.record_temp.SignalRecord(model=None, dataset=None, recorder=None)
信号记录类,用于生成信号预测。继承自``RecordTemp``类。
- __init__(model=None, dataset=None, recorder=None)
- generate(**kwargs)
生成特定记录如IC、回测等,并保存它们。
- 参数:
kwargs
- list()
列出支持的工件,用户无需考虑self.get_path
- 返回类型:
A list of all the supported artifacts.
- class qlib.workflow.record_temp.ACRecordTemp(recorder, skip_existing=False)
Automatically checking record template
- __init__(recorder, skip_existing=False)
- generate(*args, **kwargs)
automatically checking the files and then run the concrete generating task
- class qlib.workflow.record_temp.HFSignalRecord(recorder, **kwargs)
信号分析记录类,用于生成IC和IR等分析结果。继承自``RecordTemp``类。
- depend_cls
SignalRecord
的别名
- __init__(recorder, **kwargs)
- generate()
生成特定记录如IC、回测等,并保存它们。
- 参数:
kwargs
- list()
列出支持的工件,用户无需考虑self.get_path
- 返回类型:
A list of all the supported artifacts.
- class qlib.workflow.record_temp.SigAnaRecord(recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False)
信号分析记录类,用于生成IC和IR等分析结果。 继承自``RecordTemp``类。
- depend_cls
SignalRecord
的别名
- __init__(recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False)
- list()
列出支持的工件,用户无需考虑self.get_path
- 返回类型:
A list of all the supported artifacts.
- class qlib.workflow.record_temp.PortAnaRecord(recorder, config=None, risk_analysis_freq: List | str = None, indicator_analysis_freq: List | str = None, indicator_analysis_method=None, skip_existing=False, **kwargs)
组合分析记录类,用于生成回测等分析结果。继承自``RecordTemp``类。
以下文件将被存储在记录器中:
report_normal.pkl & positions_normal.pkl:
回测的收益报告和详细持仓,由`qlib/contrib/evaluate.py:backtest`返回
port_analysis.pkl : 组合风险分析,由`qlib/contrib/evaluate.py:risk_analysis`返回
- depend_cls
SignalRecord
的别名
- __init__(recorder, config=None, risk_analysis_freq: List | str = None, indicator_analysis_freq: List | str = None, indicator_analysis_method=None, skip_existing=False, **kwargs)
- config["strategy"]dict
定义策略类及其参数
- config["executor"]dict
定义执行器类及其参数
- config["backtest"]dict
定义回测参数
- risk_analysis_freqstr|List[str]
风险分析报告频率
- indicator_analysis_freqstr|List[str]
指标分析报告频率
- indicator_analysis_methodstr, optional, default by None
可选值包括'mean', 'amount_weighted', 'value_weighted'
- list()
列出支持的工件,用户无需考虑self.get_path
- 返回类型:
A list of all the supported artifacts.
- class qlib.workflow.record_temp.MultiPassPortAnaRecord(recorder, pass_num=10, shuffle_init_score=True, **kwargs)
多遍组合分析记录类,多次运行回测并生成分析结果。继承自``PortAnaRecord``类。
如果启用shuffle_init_score,将随机打乱第一个回测日期的预测分数,使初始持仓随机化。 shuffle_init_score仅在信号使用<PRED>占位符时有效。占位符将被记录器中保存的pred.pkl替换。
参数
- recorderRecorder
用于保存回测结果的记录器
- pass_numint
回测遍数
- shuffle_init_scorebool
是否打乱第一个回测日期的预测分数
- depend_cls
SignalRecord
的别名
- __init__(recorder, pass_num=10, shuffle_init_score=True, **kwargs)
- 参数:
recorder (Recorder) -- 用于保存回测结果的记录器
pass_num (int) -- 回测次数
shuffle_init_score (bool) -- 是否打乱第一个回测日期的预测分数
- list()
列出支持的工件,用户无需考虑self.get_path
- 返回类型:
A list of all the supported artifacts.
任务管理
任务生成器
TaskGenerator模块可以根据TaskGen和一些任务模板生成多个任务
- qlib.workflow.task.gen.task_generator(tasks, generators) list
使用TaskGen列表和任务模板列表生成不同的任务
例如:
有3个任务模板a,b,c和2个TaskGen A,B。A会从每个模板生成2个任务,B会生成3个任务 task_generator([a, b, c], [A, B])最终会生成3*2*3 = 18个任务
参数
- tasksList[dict] or dict
任务模板列表或单个任务
- generatorsList[TaskGen] or TaskGen
TaskGen列表或单个TaskGen
返回
- list
任务列表
- class qlib.workflow.task.gen.TaskGen
生成不同任务的基类
示例1:
输入: 一个特定的任务模板和滚动步骤
输出: 任务的滚动版本
示例2:
输入: 一个特定的任务模板和损失列表
输出: 一组具有不同损失的任务
- abstractmethod generate(task: dict) List[dict]
Generate different tasks based on a task template
- 参数:
task (dict) -- a task template
- 返回:
A list of tasks
- 返回类型:
List[dict]
- qlib.workflow.task.gen.handler_mod(task: dict, rolling_gen)
在使用RollingGen时帮助修改handler的结束时间 尝试处理以下情况:
Handler的数据end_time早于dataset的test_data的segments
为此需要扩展handler数据的end_time
如果handler的end_time为None,则无需更改其结束时间
- 参数:
task (dict): 任务模板 rg (RollingGen): RollingGen实例
- qlib.workflow.task.gen.trunc_segments(ta: TimeAdjuster, segments: Dict[str, Timestamp], days, test_key='test')
为避免未来信息泄露,应根据测试开始时间截断segments
- 注意:
此函数将**原地**修改segments
- class qlib.workflow.task.gen.RollingGen(step: int = 40, rtype: str = 'expanding', ds_extra_mod_func: None | ~typing.Callable = <function handler_mod>, test_key='test', train_key='train', trunc_days: int = None, task_copy_func: ~typing.Callable = <function deepcopy>)
- __init__(step: int = 40, rtype: str = 'expanding', ds_extra_mod_func: None | ~typing.Callable = <function handler_mod>, test_key='test', train_key='train', trunc_days: int = None, task_copy_func: ~typing.Callable = <function deepcopy>)
Generate tasks for rolling
- 参数:
step (int) -- step to rolling
rtype (str) -- rolling type (expanding, sliding)
ds_extra_mod_func (Callable) -- A method like: handler_mod(task: dict, rg: RollingGen) Do some extra action after generating a task. For example, use
handler_mod
to modify the end time of the handler of a dataset.trunc_days (int) -- trunc some data to avoid future information leakage
task_copy_func (Callable) -- the function to copy entire task. This is very useful when user want to share something between tasks
- gen_following_tasks(task: dict, test_end: Timestamp) List[dict]
generating following rolling tasks for task until test_end
- 参数:
task (dict) -- Qlib task format
test_end (pd.Timestamp) -- the latest rolling task includes test_end
- 返回:
the following tasks of task`(`task itself is excluded)
- 返回类型:
List[dict]
- generate(task: dict) List[dict]
Converting the task into a rolling task.
- 参数:
task (dict) --
A dict describing a task. For example.
DEFAULT_TASK = { "model": { "class": "LGBModel", "module_path": "qlib.contrib.model.gbdt", }, "dataset": { "class": "DatasetH", "module_path": "qlib.data.dataset", "kwargs": { "handler": { "class": "Alpha158", "module_path": "qlib.contrib.data.handler", "kwargs": { "start_time": "2008-01-01", "end_time": "2020-08-01", "fit_start_time": "2008-01-01", "fit_end_time": "2014-12-31", "instruments": "csi100", }, }, "segments": { "train": ("2008-01-01", "2014-12-31"), "valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation "test": ("2017-01-01", "2020-08-01"), }, }, }, "record": [ { "class": "SignalRecord", "module_path": "qlib.workflow.record_temp", }, ] }
- 返回:
List[dict]
- 返回类型:
a list of tasks
- class qlib.workflow.task.gen.MultiHorizonGenBase(horizon: List[int] = [5], label_leak_n=2)
- __init__(horizon: List[int] = [5], label_leak_n=2)
基于现有任务为不同周期生成任务
参数
- horizonList[int]
任务可能的周期列表
- label_leak_nint
预测日后需要多少天才能获得完整标签 例如: - 用户在`T`日(收盘后)进行预测 - 标签是在`T + 1`日买入并在`T + 2`日卖出的收益 - `label_leak_n`将为2(即需要泄露2天的信息来利用该样本)
- generate(task: dict)
Generate different tasks based on a task template
- 参数:
task (dict) -- a task template
- 返回:
A list of tasks
- 返回类型:
List[dict]
任务管理器
TaskManager可以自动获取未使用的任务,并通过错误处理管理一组任务的生命周期。 这些功能可以并发运行任务,并确保每个任务只被使用一次。 Task Manager会将所有任务存储在`MongoDB <https://www.mongodb.com/>`_中。 使用此模块时,用户**必须**完成`MongoDB <https://www.mongodb.com/>`_的配置。
TaskManager中的任务由三部分组成 - 任务描述:定义任务内容 - 任务状态:任务的当前状态 - 任务结果:用户可以通过任务描述和任务结果获取任务
- class qlib.workflow.task.manage.TaskManager(task_pool: str)
以下是TaskManager创建的任务示例:
{ 'def': 'pickle序列化的任务定义,使用pickle更方便', 'filter': '类JSON数据,用于过滤任务', 'status': 'waiting' | 'running' | 'done', 'res': 'pickle序列化的任务结果' }
任务管理器假设您只会更新已获取的任务。 MongoDB的获取和更新操作确保数据更新安全。
此类可作为命令行工具使用。以下是几个示例: 查看manage模块帮助的命令: python -m qlib.workflow.task.manage -h # 显示manage模块CLI手册 python -m qlib.workflow.task.manage wait -h # 显示wait命令手册
python -m qlib.workflow.task.manage -t <pool_name> wait python -m qlib.workflow.task.manage -t <pool_name> task_stat
备注
假设:MongoDB中的数据会被编码,取出的数据会被解码
四种状态说明:
STATUS_WAITING: 等待训练
STATUS_RUNNING: 训练中
STATUS_PART_DONE: 已完成部分步骤,等待下一步
STATUS_DONE: 全部工作完成
- __init__(task_pool: str)
Init Task Manager, remember to make the statement of MongoDB url and database name firstly. A TaskManager instance serves a specific task pool. The static method of this module serves the whole MongoDB.
- 参数:
task_pool (str) -- the name of Collection in MongoDB
- static list() list
列出数据库中所有集合(任务池)。
- 返回:
list
- replace_task(task, new_task)
Use a new task to replace a old one
- 参数:
task -- old task
new_task -- new task
- insert_task(task)
Insert a task.
- 参数:
task -- the task waiting for insert
- 返回:
pymongo.results.InsertOneResult
- insert_task_def(task_def)
Insert a task to task_pool
- 参数:
task_def (dict) -- the task definition
- 返回类型:
pymongo.results.InsertOneResult
- create_task(task_def_l, dry_run=False, print_nt=False) List[str]
如果task_def_l中的任务是新的,则插入新任务到任务池并记录inserted_id。 如果任务已存在,则只查询其_id。
参数
- task_def_l: list
任务列表
- dry_run: bool
是否实际插入新任务到任务池
- print_nt: bool
是否打印新任务
返回
- List[str]
task_def_l中各任务的_id列表
- fetch_task(query={}, status='waiting') dict
使用查询获取任务。
- 参数:
query (dict, optional): 查询字典,默认为{} status (str, optional): 任务状态,默认为STATUS_WAITING
- 返回:
dict: 解码后的任务(集合中的文档)
- safe_fetch_task(query={}, status='waiting')
使用contextmanager从任务池中获取任务
参数
- query: dict
查询字典
返回
dict: 解码后的任务(集合中的文档)
- query(query={}, decode=True)
查询集合中的任务。 如果迭代生成器耗时过长,此函数可能抛出异常`pymongo.errors.CursorNotFound: cursor id not found`
- 示例:
python -m qlib.workflow.task.manage -t <your task pool> query '{"_id": "615498be837d0053acbc5d58"}'
参数
- query: dict
查询字典
- decode: bool
是否解码结果
返回
dict: 解码后的任务(集合中的文档)
- re_query(_id) dict
使用_id查询任务。
- 参数:
_id (str): 文档的_id
- 返回:
dict: 解码后的任务(集合中的文档)
- commit_task_res(task, res, status='done')
提交结果到task['res']。
- 参数:
task ([type]): 任务 res (object): 要保存的结果 status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE。默认为STATUS_DONE。
- return_task(task, status='waiting')
Return a task to status. Always using in error handling.
- 参数:
task ([type]) -- [description]
status (str, optional) -- STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING.
- remove(query={})
Remove the task using query
- 参数:
query (dict) -- the dict of query
- task_stat(query={}) dict
Count the tasks in every status.
- 参数:
query (dict, optional) -- the query dict. Defaults to {}.
- 返回:
dict
- reset_waiting(query={})
将所有运行中的任务重置为等待状态。可用于某些任务意外退出的情况。
- 参数:
query (dict, optional): 查询字典,默认为{}
- prioritize(task, priority: int)
Set priority for task
- 参数:
task (dict) -- The task query from the database
priority (int) -- the target priority
- wait(query={})
在多进程环境下,主进程可能因为仍有任务在运行而无法从TaskManager获取任务。 因此主进程应等待其他进程或机器完成所有任务。
- 参数:
query (dict, optional): 查询字典,默认为{}
- qlib.workflow.task.manage.run_task(task_func: Callable, task_pool: str, query: dict = {}, force_release: bool = False, before_status: str = 'waiting', after_status: str = 'done', **kwargs)
当任务池不为空(有WAITING状态任务)时,使用task_func获取并运行任务池中的任务
运行此方法后,有以下4种情况(before_status -> after_status):
STATUS_WAITING -> STATUS_DONE: 使用task["def"]作为`task_func`参数,表示任务尚未开始
STATUS_WAITING -> STATUS_PART_DONE: 使用task["def"]作为`task_func`参数
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as task_func param, it means that the task has been started but not completed
STATUS_PART_DONE -> STATUS_DONE: use task["res"] as task_func param
- 参数:
task_func (Callable) --
def (task_def, **kwargs) -> <res which will be committed>
the function to run the task
task_pool (str) -- the name of the task pool (Collection in MongoDB)
query (dict) -- will use this dict to query task_pool when fetching task
force_release (bool) -- will the program force to release the resource
before_status (str:) -- the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
after_status (str:) -- the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
kwargs -- the params for task_func
训练器
训练器(Trainer)用于训练一系列任务并返回模型记录器列表。 每个训练器包含两个步骤: 1. train - 创建模型记录器 2. end_train - 修改模型记录器
DelayTrainer是一种特殊训练器,可用于在线模拟并行训练: - 第一步仅保存必要信息到记录器 - 第二步在最后执行并发耗时操作(如模型拟合)
Qlib提供两种训练器实现: 1. TrainerR - 基础训练器 2. TrainerRM - 基于TaskManager自动管理任务生命周期
- qlib.model.trainer.begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) Recorder
开始任务训练,创建记录器并保存任务配置。
- 参数:
task_config (dict): 任务配置 experiment_name (str): 实验名称 recorder_name (str): 记录器名称,None表示使用rid
- 返回:
Recorder: 模型记录器
- qlib.model.trainer.end_task_train(rec: Recorder, experiment_name: str) Recorder
完成任务训练,执行实际的模型拟合和保存。
- 参数:
rec (Recorder): 需要恢复的记录器 experiment_name (str): 实验名称
- 返回:
Recorder: 模型记录器
- qlib.model.trainer.task_train(task_config: dict, experiment_name: str, recorder_name: str = None) Recorder
基于任务的训练,分为两个步骤执行
参数
- task_configdict
任务配置
- experiment_name: str
实验名称
- recorder_name: str
记录器名称
返回
Recorder: 记录器实例
- class qlib.model.trainer.Trainer
训练器用于训练模型列表 Trainer和DelayTrainer的区别在于完成实际训练的时机不同
- __init__()
- train(tasks: list, *args, **kwargs) list
给定任务定义列表,开始训练并返回模型。
对于Trainer,此方法完成实际训练。 对于DelayTrainer,此方法仅做准备工作。
- 参数:
tasks (list): 任务定义列表
- 返回:
list: 模型列表
- 注意:
对于`Trainer`,此方法将直接训练模型
对于`DelayTrainer`,此方法仅做训练准备
- end_train(models: list, *args, **kwargs) list
给定模型列表,在训练结束时完成必要操作 模型可能是记录器、文本文件、数据库等
对于Trainer,该方法做一些收尾工作 对于DelayTrainer,该方法完成实际训练
- 参数:
models: 模型列表
- 返回:
list: 模型列表
- is_delay() bool
判断训练器是否会延迟完成`end_train`
- 返回:
bool: 是否为DelayTrainer
- worker()
启动工作进程
- 异常:
NotImplementedError: 如果不支持工作进程
- class qlib.model.trainer.TrainerR(experiment_name: str | None = None, train_func: ~typing.Callable = <function task_train>, call_in_subproc: bool = False, default_rec_name: str | None = None)
基于记录器(R)的训练器 以线性方式训练任务列表并返回模型记录器列表
假设:模型由`task`定义,结果将保存到`Recorder`
- __init__(experiment_name: str | None = None, train_func: ~typing.Callable = <function task_train>, call_in_subproc: bool = False, default_rec_name: str | None = None)
初始化TrainerR
- 参数:
experiment_name (str, optional): 默认实验名称 train_func (Callable, optional): 默认训练方法,默认为`task_train` call_in_subproc (bool): 在子进程中调用以强制释放内存
- train(tasks: list, train_func: Callable | None = None, experiment_name: str | None = None, **kwargs) List[Recorder]
给定任务列表并返回训练好的记录器列表,顺序可以保证。
- 参数:
tasks (list): 基于任务字典的定义列表 train_func (Callable): 训练方法,至少需要tasks和experiment_name参数。None表示使用默认训练方法。 experiment_name (str): 实验名称,None表示使用默认名称。 kwargs: train_func的参数。
- 返回:
List[Recorder]: 记录器列表
- class qlib.model.trainer.DelayTrainerR(experiment_name: str = None, train_func=<function begin_task_train>, end_train_func=<function end_task_train>, **kwargs)
基于TrainerR的延迟实现,意味着`train`方法可能只做准备工作,而`end_train`方法完成实际的模型拟合
- __init__(experiment_name: str = None, train_func=<function begin_task_train>, end_train_func=<function end_task_train>, **kwargs)
初始化TrainerRM
- 参数:
experiment_name (str): 默认实验名称 train_func (Callable, optional): 默认训练方法,默认为`begin_task_train` end_train_func (Callable, optional): 默认结束训练方法,默认为`end_task_train`
- end_train(models, end_train_func=None, experiment_name: str = None, **kwargs) List[Recorder]
给定记录器列表并返回训练好的记录器列表 该类将完成实际的数据加载和模型拟合
- 参数:
models (list): 记录器列表,任务已保存到其中 end_train_func (Callable, optional): 结束训练方法,至少需要`recorders`和`experiment_name`参数,默认为None表示使用self.end_train_func experiment_name (str): 实验名称,None表示使用默认名称 kwargs: end_train_func的参数
- 返回:
List[Recorder]: 记录器列表
- class qlib.model.trainer.TrainerRM(experiment_name: str = None, task_pool: str = None, train_func=<function task_train>, skip_run_task: bool = False, default_rec_name: str | None = None)
基于记录器(R)和任务管理器(M)的训练器 可以以多进程方式训练任务列表并返回模型记录器列表
假设:`task`将保存到TaskManager,并且`task`将从TaskManager获取并训练
- __init__(experiment_name: str = None, task_pool: str = None, train_func=<function task_train>, skip_run_task: bool = False, default_rec_name: str | None = None)
初始化TrainerR
- 参数:
experiment_name (str): 默认实验名称 task_pool (str): TaskManager中的任务池名称,None表示使用与experiment_name相同的名称 train_func (Callable, optional): 默认训练方法,默认为`task_train` skip_run_task (bool):
如果skip_run_task == True: 仅在worker中运行run_task,否则跳过run_task
- train(tasks: list, train_func: Callable = None, experiment_name: str = None, before_status: str = 'waiting', after_status: str = 'done', default_rec_name: str | None = None, **kwargs) List[Recorder]
给定任务列表并返回训练好的记录器列表,顺序可以保证。
此方法默认为单进程,但TaskManager提供了并行训练的强大方式。 用户可以自定义train_func实现多进程甚至多机器训练。
- 参数:
tasks (list): 基于任务字典的定义列表 train_func (Callable): 训练方法,至少需要tasks和experiment_name参数。None表示使用默认训练方法。 experiment_name (str): 实验名称,None表示使用默认名称。 before_status (str): 处于before_status状态的任务将被获取并训练。可以是STATUS_WAITING, STATUS_PART_DONE。 after_status (str): 训练后的任务将变为after_status状态。可以是STATUS_WAITING, STATUS_PART_DONE。 kwargs: train_func的参数。
- 返回:
List[Recorder]: 记录器列表
- end_train(recs: list, **kwargs) List[Recorder]
为记录器设置STATUS_END标签。
- 参数:
recs (list): 训练好的记录器列表。
- 返回:
List[Recorder]: 与参数相同的列表。
- class qlib.model.trainer.DelayTrainerRM(experiment_name: str = None, task_pool: str = None, train_func=<function begin_task_train>, end_train_func=<function end_task_train>, skip_run_task: bool = False, **kwargs)
基于TrainerRM的延迟实现,意味着`train`方法可能只做准备工作,而`end_train`方法完成实际模型拟合。
- __init__(experiment_name: str = None, task_pool: str = None, train_func=<function begin_task_train>, end_train_func=<function end_task_train>, skip_run_task: bool = False, **kwargs)
初始化DelayTrainerRM。
- 参数:
experiment_name (str): 默认实验名称。 task_pool (str): TaskManager中的任务池名称。None表示使用与experiment_name相同的名称。 train_func (Callable, optional): 默认训练方法。默认为`begin_task_train`。 end_train_func (Callable, optional): 默认结束训练方法。默认为`end_task_train`。 skip_run_task (bool):
如果skip_run_task == True: 仅在worker中运行run_task。否则跳过run_task。 例如:在CPU虚拟机上启动训练器,然后等待任务在GPU虚拟机上完成。
- train(tasks: list, train_func=None, experiment_name: str = None, **kwargs) List[Recorder]
与TrainerRM的`train`方法相同,after_status将为STATUS_PART_DONE。
- 参数:
tasks (list): 基于任务字典的定义列表 train_func (Callable): 训练方法,至少需要tasks和experiment_name参数。None表示使用self.train_func。 experiment_name (str): 实验名称,None表示使用默认名称。
- 返回:
List[Recorder]: 记录器列表
- end_train(recs, end_train_func=None, experiment_name: str = None, **kwargs) List[Recorder]
给定记录器列表并返回训练好的记录器列表。 此类将完成实际数据加载和模型拟合。
- 参数:
recs (list): 记录器列表,任务已保存到其中。 end_train_func (Callable, optional): 结束训练方法,至少需要recorders和experiment_name参数。None表示使用self.end_train_func。 experiment_name (str): 实验名称,None表示使用默认名称。 kwargs: end_train_func的参数。
- 返回:
List[Recorder]: 记录器列表
收集器
收集器模块可以从各处收集对象并进行处理,例如合并、分组、平均等操作。
- class qlib.workflow.task.collect.Collector(process_list=[])
用于收集不同结果的收集器
- __init__(process_list=[])
初始化收集器
- 参数:
process_list (list or Callable): 处理字典的处理器列表或单个处理器实例
- collect() dict
收集结果并返回一个类似{key: 值}的字典
- 返回:
dict: 收集后的字典
例如:
{"prediction": pd.Series}
{"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
...
- static process_collect(collected_dict, process_list=[], *args, **kwargs) dict
对collect返回的字典进行一系列处理并返回一个类似{key: 值}的字典 例如可以进行分组和集成
- 参数:
collected_dict (dict): collect方法返回的字典 process_list (list or Callable): 处理字典的处理器列表或单个处理器实例
处理器顺序与列表顺序相同 例如: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
- 返回:
dict: 处理后的字典
- class qlib.workflow.task.collect.MergeCollector(collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None)
用于收集其他收集器结果的收集器
例如:
我们有两个收集器A和B A可以收集{"prediction": pd.Series},B可以收集{"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}} 经过本类收集后,我们可以收集{"A_prediction": pd.Series, "B_IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
...
- __init__(collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None)
初始化MergeCollector
- 参数:
collector_dict (Dict[str,Collector]): 类似{collector_key, Collector}的字典 process_list (List[Callable]): 处理字典的处理器列表或单个处理器实例 merge_func (Callable): 生成最外层键的方法。参数是collector_dict中的``collector_key``和每个收集器收集后的``key``
如果为None则使用元组连接它们,例如"ABC"+("a","b") -> ("ABC", ("a","b"))
- collect() dict
收集collector_dict中的所有结果并将最外层键改为重组后的键
- 返回:
dict: 收集后的字典
- class qlib.workflow.task.collect.RecorderCollector(experiment, process_list=[], rec_key_func=None, rec_filter_func=None, artifacts_path={'pred': 'pred.pkl'}, artifacts_key=None, list_kwargs={}, status: Iterable = {'FINISHED'})
- __init__(experiment, process_list=[], rec_key_func=None, rec_filter_func=None, artifacts_path={'pred': 'pred.pkl'}, artifacts_key=None, list_kwargs={}, status: Iterable = {'FINISHED'})
初始化RecorderCollector。
- 参数:
- experiment:
(Experiment或str): Experiment实例或Experiment名称 (Callable): 可调用函数,返回实验列表
process_list (list或Callable): 处理器列表或处理字典的处理器实例 rec_key_func (Callable): 获取记录器键的函数。如果为None,则使用记录器ID rec_filter_func (Callable, 可选): 通过返回True或False过滤记录器。默认为None artifacts_path (dict, 可选): 记录器中工件名称及其路径。默认为{"pred": "pred.pkl"} artifacts_key (str或List, 可选): 要获取的工件键。如果为None,则获取所有工件 list_kwargs (str): list_recorders函数的参数 status (Iterable): 仅收集具有特定状态的记录器。None表示收集所有记录器
- collect(artifacts_key=None, rec_filter_func=None, only_exist=True) dict
基于过滤后的记录器收集不同的工件。
- 参数:
artifacts_key (str或List, 可选): 要获取的工件键。如果为None,则使用默认值 rec_filter_func (Callable, 可选): 通过返回True或False过滤记录器。如果为None,则使用默认值 only_exist (bool, 可选): 是否仅当记录器确实拥有时才收集工件。
如果为True,加载时出现异常的记录器将不会被收集。但如果为False,则会引发异常
- 返回:
dict: 收集后的字典,格式为{artifact: {rec_key: object}}
- get_exp_name() str
获取实验名称
- 返回:
str: 实验名称
分组
Group可以根据`group_func`对一组对象进行分组并将其转换为字典。 分组后,我们提供了一种方法来归约它们。
例如:
group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}} reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
- class qlib.model.ens.group.Group(group_func=None, ens: Ensemble = None)
Group the objects based on dict
- __init__(group_func=None, ens: Ensemble = None)
初始化Group。
- 参数:
group_func (Callable, optional): 给定一个字典并返回分组键和其中一个分组元素。
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
默认为None。
ens (Ensemble, optional): 如果不为None,则在分组后对分组值进行集成。
- group(*args, **kwargs) dict
将一组对象分组并转换为字典。
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
- 返回:
dict: 分组后的字典
- reduce(*args, **kwargs) dict
归约分组后的字典。
For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
- 返回:
dict: 归约后的字典
- class qlib.model.ens.group.RollingGroup(ens=<qlib.model.ens.ensemble.RollingEnsemble object>)
滚动字典分组
- group(rolling_dict: dict) dict
给定一个滚动字典如{(A,B,R): things},返回分组后的字典如{(A,B): {R:things}}
注意:这里假设滚动键在键元组的末尾,因为滚动结果通常需要先进行集成。
- 参数:
rolling_dict (dict): 滚动字典。如果键不是元组,则不进行任何操作。
- 返回:
dict: 分组后的字典
- __init__(ens=<qlib.model.ens.ensemble.RollingEnsemble object>)
初始化Group。
- 参数:
group_func (Callable, optional): 给定一个字典并返回分组键和其中一个分组元素。
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
默认为None。
ens (Ensemble, optional): 如果不为None,则在分组后对分组值进行集成。
集成
集成模块可以合并Ensemble中的对象。例如,如果有多个子模型预测,我们可能需要将它们合并为一个集成预测。
- class qlib.model.ens.ensemble.Ensemble
将ensemble_dict合并为一个集成对象。
例如: {Rollinga_b: 对象, Rollingb_c: 对象} -> 对象
当调用此类时:
- Args:
ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging
- Returns:
object: the ensemble object
- class qlib.model.ens.ensemble.SingleKeyEnsemble
如果字典中只有一个键值对,则提取该对象使结果更易读。 {唯一键: 唯一值} -> 唯一值
如果有超过1个键或少于1个键,则不进行任何操作。 甚至可以递归运行以使字典更易读。
注意:默认递归运行。
当调用此类时:
- 参数:
ensemble_dict (dict): 字典。字典的键将被忽略。
- 返回:
dict: 更易读的字典。
- class qlib.model.ens.ensemble.RollingEnsemble
将类似`prediction`或`IC`的滚动数据字典合并为一个集成。
注意:字典的值必须是pd.DataFrame,并且具有"datetime"索引。
当调用此类时:
- 参数:
ensemble_dict (dict): 类似{"A": pd.DataFrame, "B": pd.DataFrame}的字典。 字典的键将被忽略。
- 返回:
pd.DataFrame: 滚动的完整结果。
- class qlib.model.ens.ensemble.AverageEnsemble
将相同形状的数据字典(如`prediction`或`IC`)进行平均和标准化,合并为一个集成。
注意:字典的值必须是pd.DataFrame,并且具有"datetime"索引。如果是嵌套字典,则将其展平。
当调用此类时:
- 参数:
ensemble_dict (dict): 类似{"A": pd.DataFrame, "B": pd.DataFrame}的字典。 字典的键将被忽略。
- 返回:
pd.DataFrame: 平均和标准化的完整结果。
工具
Some tools for task management.
- qlib.workflow.task.utils.get_mongodb() Database
获取MongoDB数据库实例,使用前需要先声明数据库地址和名称。
示例:
使用qlib.init():
mongo_conf = { "task_url": task_url, # MongoDB地址 "task_db_name": task_db_name, # 数据库名称 } qlib.init(..., mongo=mongo_conf)
在qlib.init()之后:
C["mongo"] = { "task_url" : "mongodb://localhost:27017/", "task_db_name" : "rolling_db" }
- 返回:
Database: 数据库实例
- qlib.workflow.task.utils.list_recorders(experiment, rec_filter_func=None)
列出实验中通过过滤器的所有记录器。
- 参数:
experiment (str or Experiment): 实验名称或实例 rec_filter_func (Callable, optional): 返回True以保留给定记录器。默认为None。
- 返回:
dict: 过滤后的字典{rid: recorder}
- class qlib.workflow.task.utils.TimeAdjuster(future=True, end_time=None)
Find appropriate date and adjust date.
- __init__(future=True, end_time=None)
- set_end_time(end_time=None)
Set end time. None for use calendar's end time.
- 参数:
end_time
- max() Timestamp
返回日历中的最大日期时间
- align_idx(time_point, tp_type='start') int
对齐日历中时间点的索引。
参数
time_point tp_type : str
返回
index : int
- cal_interval(time_point_A, time_point_B) int
Calculate the trading day interval (time_point_A - time_point_B)
- 参数:
time_point_A -- time_point_A
time_point_B -- time_point_B (is the past of time_point_A)
- 返回:
the interval between A and B
- 返回类型:
int
- align_time(time_point, tp_type='start') Timestamp
Align time_point to trade date of calendar
- 参数:
time_point -- Time point
tp_type -- str time point type ("start", "end")
- 返回:
pd.Timestamp
- align_seg(segment: dict | tuple) dict | tuple
Align the given date to the trade date
for example:
input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')} output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')), 'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')), 'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))}
- 参数:
segment
- 返回:
Union[dict, tuple]
- 返回类型:
the start and end trade date (pd.Timestamp) between the given start and end date.
- truncate(segment: tuple, test_start, days: int) tuple
Truncate the segment based on the test_start date
- 参数:
segment (tuple) -- time segment
test_start
days (int) -- The trading days to be truncated the data in this segment may need 'days' data days are based on the test_start. For example, if the label contains the information of 2 days in the near future, the prediction horizon 1 day. (e.g. the prediction target is Ref($close, -2)/Ref($close, -1) - 1) the days should be 2 + 1 == 3 days.
- 返回:
tuple
- 返回类型:
new segment
- shift(seg: tuple, step: int, rtype='sliding') tuple
Shift the datetime of segment
If there are None (which indicates unbounded index) in the segment, this method will return None.
- 参数:
seg -- datetime segment
step (int) -- rolling step
rtype (str) -- rolling type ("sliding" or "expanding")
- 返回:
tuple
- 返回类型:
new segment
- 抛出:
KeyError: -- shift will raise error if the index(both start and end) is out of self.cal
- qlib.workflow.task.utils.replace_task_handler_with_cache(task: dict, cache_dir: str | Path = '.') dict
Replace the handler in task with a cache handler. It will automatically cache the file and save it in cache_dir.
>>> import qlib >>> qlib.auto_init() >>> import datetime >>> # it is simplified task >>> task = {"dataset": {"kwargs":{'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': datetime.date(2008, 1, 1), 'end_time': datetime.date(2020, 8, 1), 'fit_start_time': datetime.date(2008, 1, 1), 'fit_end_time': datetime.date(2014, 12, 31), 'instruments': 'CSI300'}}}}} >>> new_task = replace_task_handler_with_cache(task) >>> print(new_task) {'dataset': {'kwargs': {'handler': 'file...Alpha158.3584f5f8b4.pkl'}}}
在线服务
在线管理器
OnlineManager可以管理一组`Online Strategy <#Online Strategy>`_并动态运行它们。
随着时间的推移,决策模型也会发生变化。在本模块中,我们将这些贡献模型称为`online`模型。 在每个例行程序(如每天或每分钟)中,`online`模型可能会发生变化,需要更新它们的预测。 因此本模块提供了一系列方法来控制这个过程。
本模块还提供了一种在历史中模拟`Online Strategy <#Online Strategy>`_的方法。 这意味着您可以验证您的策略或找到更好的策略。
在不同情况下使用不同训练器共有4种情况:
以下是一些伪代码,展示了每种情况的工作流程
- 为简单起见
策略中只使用一个策略
`update_online_pred`仅在在线模式下调用并被忽略
在线+训练器
tasks = first_train()
models = trainer.train(tasks)
trainer.end_train(models)
for day in online_trading_days:
# OnlineManager.routine
models = trainer.train(strategy.prepare_tasks()) # 对每个策略
strategy.prepare_online_models(models) # 对每个策略
trainer.end_train(models)
prepare_signals() # 每日准备交易信号
在线+延迟训练器: 工作流程与`在线+训练器`相同。
模拟+延迟训练器
# 模拟
tasks = first_train()
models = trainer.train(tasks)
for day in historical_calendars:
# OnlineManager.routine
models = trainer.train(strategy.prepare_tasks()) # 对每个策略
strategy.prepare_online_models(models) # 对每个策略
# delay_prepare()
# FIXME: 目前delay_prepare没有以正确的方式实现。
trainer.end_train(<for all previous models>)
prepare_signals()
# 我们可以简化当前的工作流程吗?
可以减少任务的状态数量吗?
对于每个任务,我们有三个阶段(即任务、部分训练的任务、最终训练的任务)
- class qlib.workflow.online.manager.OnlineManager(strategies: OnlineStrategy | List[OnlineStrategy], trainer: Trainer = None, begin_time: str | Timestamp = None, freq='day')
OnlineManager可以通过`Online Strategy <#Online Strategy>`_管理在线模型。 它还提供了哪些模型在什么时间在线的历史记录。
- __init__(strategies: OnlineStrategy | List[OnlineStrategy], trainer: Trainer = None, begin_time: str | Timestamp = None, freq='day')
初始化OnlineManager。 一个OnlineManager必须至少有一个OnlineStrategy。
- 参数:
strategies (Union[OnlineStrategy, List[OnlineStrategy]]): OnlineStrategy实例或OnlineStrategy列表 begin_time (Union[str,pd.Timestamp], 可选): OnlineManager将在此时间开始。默认为None表示使用最新日期。 trainer (qlib.model.trainer.Trainer): 用于训练任务的训练器。None表示使用TrainerR。 freq (str, 可选): 数据频率。默认为"day"。
- first_train(strategies: List[OnlineStrategy] = None, model_kwargs: dict = {})
从每个策略的first_tasks方法获取任务并训练它们。 如果使用DelayTrainer,它可以在每个策略的first_tasks之后一起完成所有训练。
- 参数:
strategies (List[OnlineStrategy]): 策略列表(添加策略时需要此参数)。None表示使用默认策略。 model_kwargs (dict): `prepare_online_models`的参数
- routine(cur_time: str | Timestamp = None, task_kwargs: dict = {}, model_kwargs: dict = {}, signal_kwargs: dict = {})
每个策略的典型更新过程并记录在线历史。
例行程序(如逐日或逐月)后的典型更新过程。 过程是: 更新预测 -> 准备任务 -> 准备在线模型 -> 准备信号。
如果使用DelayTrainer,它可以在每个策略的prepare_tasks之后一起完成所有训练。
- get_collector(**kwargs) MergeCollector
获取`Collector <../advanced/task_management.html#Task Collecting>`_实例以收集每个策略的结果。 此收集器可以作为信号准备的基础。
- 参数:
**kwargs: get_collector的参数。
- 返回:
MergeCollector: 用于合并其他收集器的收集器。
- add_strategy(strategies: OnlineStrategy | List[OnlineStrategy])
向OnlineManager添加一些新策略。
- 参数:
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): OnlineStrategy列表
- prepare_signals(prepare_func: ~typing.Callable = <qlib.model.ens.ensemble.AverageEnsemble object>, over_write=False)
在准备完最后一个例行程序(箱线图中的一个框)的数据后,这意味着例行程序的结束,我们可以为下一个例行程序准备交易信号。
注意: 给定一组预测,这些预测结束时间之前的所有信号都将准备好。
即使最新的信号已经存在,最新的计算结果也将被覆盖。
备注
给定某个时间的预测,此时间之前的所有信号都将准备好。
- 参数:
prepare_func (Callable, 可选): 从收集后的字典中获取信号。默认为AverageEnsemble(),由MergeCollector收集的结果必须是{xxx:pred}。 over_write (bool, 可选): 如果为True,新信号将覆盖。如果为False,新信号将附加到信号末尾。默认为False。
- 返回:
pd.DataFrame: 信号。
- get_signals() Series | DataFrame
获取准备好的在线信号。
- 返回:
Union[pd.Series, pd.DataFrame]: pd.Series表示每个日期时间只有一个信号。 pd.DataFrame表示多个信号,例如买卖操作使用不同的交易信号。
- simulate(end_time=None, frequency='day', task_kwargs={}, model_kwargs={}, signal_kwargs={}) Series | DataFrame
从当前时间开始,此方法将模拟OnlineManager中的每个例行程序,直到结束时间。
考虑到并行训练,模型和信号可以在所有例行程序模拟后准备。
延迟训练方式可以是``DelayTrainer``,延迟准备信号方式可以是``delay_prepare``。
在线策略
OnlineStrategy模块是在线服务的一个组件。
- class qlib.workflow.online.strategy.OnlineStrategy(name_id: str)
OnlineStrategy与`Online Manager <#Online Manager>`_配合使用,负责处理任务生成、模型更新和信号准备的方式。
- __init__(name_id: str)
初始化OnlineStrategy。 此模块**必须**使用`Trainer <../reference/api.html#qlib.model.trainer.Trainer>`_来完成模型训练。
- 参数:
name_id (str): 唯一的名称或ID。 trainer (qlib.model.trainer.Trainer, 可选): Trainer的实例。默认为None。
- prepare_tasks(cur_time, **kwargs) List[dict]
在例行程序结束后,根据当前时间(None表示最新)检查是否需要准备和训练一些新任务。 返回等待训练的新任务。
您可以通过OnlineTool.online_models找到最后的在线模型。
- prepare_online_models(trained_models, cur_time=None) List[object]
从训练好的模型中选择一些模型并将它们设置为在线模型。 这是一个将所有训练好的模型设为在线的典型实现,您可以重写它来实现更复杂的方法。 如果仍需要,可以通过OnlineTool.online_models找到最后的在线模型。
注意:将所有在线模型重置为训练好的模型。如果没有训练好的模型,则不执行任何操作。
- 注意:
当前实现非常简单。以下是一个更接近实际场景的复杂情况: 1. 在`test_start`前一天(时间戳`T`)训练新模型 2. 在`test_start`时(通常是时间戳`T + 1`)切换模型
- 参数:
models (list): 模型列表。 cur_time (pd.Dataframe): 来自OnlineManger的当前时间。None表示最新。
- 返回:
List[object]: 在线模型列表。
- first_tasks() List[dict]
首先生成一系列任务并返回它们。
- class qlib.workflow.online.strategy.RollingStrategy(name_id: str, task_template: dict | List[dict], rolling_gen: RollingGen)
此示例策略始终使用最新的滚动模型作为在线模型。
- __init__(name_id: str, task_template: dict | List[dict], rolling_gen: RollingGen)
初始化RollingStrategy。
假设:name_id的字符串、实验名称和训练器的实验名称相同。
- 参数:
name_id (str): 唯一的名称或ID。也将作为实验的名称。 task_template (Union[dict, List[dict]]): 任务模板列表或单个模板,将用于通过rolling_gen生成多个任务。 rolling_gen (RollingGen): RollingGen的实例
- get_collector(process_list=[<qlib.model.ens.group.RollingGroup object>], rec_key_func=None, rec_filter_func=None, artifacts_key=None)
获取`Collector <../advanced/task_management.html#Task Collecting>`_实例以收集结果。返回的收集器必须能够区分不同模型的结果。
假设:可以根据模型名称和滚动测试段来区分模型。 如果不希望此假设,请实现您自己的方法或使用其他rec_key_func。
- 参数:
rec_key_func (Callable): 获取记录器键的函数。如果为None,则使用记录器ID。 rec_filter_func (Callable, 可选): 通过返回True或False来过滤记录器。默认为None。 artifacts_key (List[str], 可选): 要获取的工件键。如果为None,则获取所有工件。
- first_tasks() List[dict]
使用rolling_gen基于task_template生成不同的任务
- 返回:
List[dict]: 任务列表
- prepare_tasks(cur_time) List[dict]
根据当前时间(最新为None)准备新任务
可以通过OnlineToolR.online_models查找最新的在线模型
- 返回:
List[dict]: 新任务列表
在线工具
OnlineTool是一个用于设置和取消设置一系列`online`模型的模块。 `online`模型是在某些时间点的决定性模型,可以随时间变化而改变。 这使我们能够使用高效的子模型来适应市场风格的变化。
- class qlib.workflow.online.utils.OnlineTool
OnlineTool将管理包含模型记录器的实验中的`online`模型。
- __init__()
初始化OnlineTool。
- set_online_tag(tag, recorder: list | object)
设置模型的`tag`标记其是否为在线状态。
- 参数:
tag (str): `ONLINE_TAG`或`OFFLINE_TAG`中的标签 recorder (Union[list,object]): 模型的记录器
- get_online_tag(recorder: object) str
给定模型记录器,返回其在线标签。
- 参数:
recorder (Object): 模型的记录器
- 返回:
str: 在线标签
- reset_online_tag(recorder: list | object)
将所有模型下线并将指定记录器设置为'online'。
- 参数:
- recorder (Union[list,object]):
要重置为'online'的记录器
- update_online_pred(to_date=None)
将`online`模型的预测更新到to_date。
- 参数:
to_date (pd.Timestamp): 更新此日期之前的预测。None表示更新到最新。
- class qlib.workflow.online.utils.OnlineToolR(default_exp_name: str = None)
基于记录器(R)的OnlineTool实现。
- __init__(default_exp_name: str = None)
初始化OnlineToolR。
- 参数:
default_exp_name (str): 默认实验名称
- set_online_tag(tag, recorder: Recorder | List)
设置模型记录器的`tag`标记其是否为在线状态。
- 参数:
tag (str): ONLINE_TAG、`NEXT_ONLINE_TAG`或`OFFLINE_TAG`中的标签 recorder (Union[Recorder, List]): 记录器列表或单个记录器实例
- get_online_tag(recorder: Recorder) str
给定模型记录器,返回其在线标签。
- 参数:
recorder (Recorder): 记录器实例
- 返回:
str: 在线标签
- reset_online_tag(recorder: Recorder | List, exp_name: str = None)
将所有模型下线并将指定记录器设置为'online'。
- 参数:
- recorder (Union[Recorder, List]):
要重置为'online'的记录器
exp_name (str): 实验名称。如果为None则使用default_exp_name
- online_models(exp_name: str = None) list
获取当前`online`模型
- 参数:
exp_name (str): 实验名称。如果为None则使用default_exp_name
- 返回:
list: `online`模型列表
- update_online_pred(to_date=None, from_date=None, exp_name: str = None)
将在线模型的预测更新到to_date。
- 参数:
to_date (pd.Timestamp): 更新此日期之前的预测。None表示更新到日历中的最新时间 exp_name (str): 实验名称。如果为None则使用default_exp_name
记录更新器
更新器模块,用于在股票数据更新时更新预测等artifact。
- class qlib.workflow.online.update.RMDLoader(rec: Recorder)
Recorder Model Dataset Loader
- get_dataset(start_time, end_time, segments=None, unprepared_dataset: DatasetH | None = None) DatasetH
加载、配置和设置数据集
该数据集用于推理
- 参数:
- start_time :
基础数据的开始时间
- end_time :
基础数据的结束时间
- segmentsdict
数据集的分段配置 对于时间序列数据集(TSDatasetH),测试段可能与开始时间和结束时间不同
- unprepared_dataset: Optional[DatasetH]
如果用户不想从记录器加载数据集,请指定用户的数据集
- 返回:
DatasetH: DatasetH实例
- class qlib.workflow.online.update.RecordUpdater(record: Recorder, *args, **kwargs)
Update a specific recorders
- abstractmethod update(*args, **kwargs)
Update info for specific recorder
- class qlib.workflow.online.update.DSBasedUpdater(record: ~qlib.workflow.recorder.Recorder, to_date=None, from_date=None, hist_ref: int | None = None, freq='day', fname='pred.pkl', loader_cls: type = <class 'qlib.workflow.online.update.RMDLoader'>)
基于数据集的更新器
提供基于Qlib数据集更新数据的功能
假设条件
基于Qlib数据集
要更新的数据是多级索引的pd.DataFrame,例如标签、预测
LABEL0 datetime instrument 2021-05-10 SH600000 0.006965 SH600004 0.003407 ... ... 2021-05-28 SZ300498 0.015748 SZ300676 -0.001321
- __init__(record: ~qlib.workflow.recorder.Recorder, to_date=None, from_date=None, hist_ref: int | None = None, freq='day', fname='pred.pkl', loader_cls: type = <class 'qlib.workflow.online.update.RMDLoader'>)
初始化预测更新器
在以下情况下的预期行为:
如果`to_date`大于日历中的最大日期,数据将更新到最新日期
如果有数据在`from_date`之前或`to_date`之后,只有`from_date`和`to_date`之间的数据会受到影响
- 参数:
- recordRecorder
记录器
- to_date :
更新预测到`to_date`
如果to_date为None:
数据将更新到最新日期
- from_date :
更新将从`from_date`开始
如果from_date为None:
更新将在历史数据中最新数据的下一个时间点进行
- hist_refint
有时数据集会有历史依赖 将历史依赖长度的问题留给用户设置 如果用户不指定此参数,更新器将尝试加载数据集自动确定hist_ref
备注
start_time不包含在`hist_ref`中;因此`hist_ref`在大多数情况下会是`step_len - 1`
- loader_clstype
加载模型和数据集的类
- prepare_data(unprepared_dataset: DatasetH | None = None) DatasetH
加载数据集 - 如果指定了unprepared_dataset,则直接准备数据集 - 否则
分离此函数将使重用数据集更容易
- 返回:
DatasetH: DatasetH实例
- class qlib.workflow.online.update.PredUpdater(record: ~qlib.workflow.recorder.Recorder, to_date=None, from_date=None, hist_ref: int | None = None, freq='day', fname='pred.pkl', loader_cls: type = <class 'qlib.workflow.online.update.RMDLoader'>)
更新记录器中的预测
- class qlib.workflow.online.update.LabelUpdater(record: Recorder, to_date=None, **kwargs)
更新记录器中的标签
假设条件 - 标签由record_temp.SignalRecord生成
- __init__(record: Recorder, to_date=None, **kwargs)
初始化预测更新器
在以下情况下的预期行为:
如果`to_date`大于日历中的最大日期,数据将更新到最新日期
如果有数据在`from_date`之前或`to_date`之后,只有`from_date`和`to_date`之间的数据会受到影响
- 参数:
- recordRecorder
记录器
- to_date :
更新预测到`to_date`
如果to_date为None:
数据将更新到最新日期
- from_date :
更新将从`from_date`开始
如果from_date为None:
更新将在历史数据中最新数据的下一个时间点进行
- hist_refint
有时数据集会有历史依赖 将历史依赖长度的问题留给用户设置 如果用户不指定此参数,更新器将尝试加载数据集自动确定hist_ref
备注
start_time不包含在`hist_ref`中;因此`hist_ref`在大多数情况下会是`step_len - 1`
- loader_clstype
加载模型和数据集的类
工具
可序列化
- class qlib.utils.serial.Serializable
Serializable类将改变pickle的行为。
判断属性在dump时是否保留的规则(优先级从高到低): - 在config属性列表中 -> 总是丢弃 - 在include属性列表中 -> 总是保留 - 在exclude属性列表中 -> 总是丢弃 - 不以`_`开头的属性名 -> 保留 - 以`_`开头的属性名 -> 如果`dump_all`为true则保留,否则丢弃
它提供了一种语法糖来区分用户不希望保存的属性。 - 例如,一个可学习的Datahandler在dump到磁盘时只想保存参数而不保存数据
- __init__()
- property dump_all
对象是否dump所有属性
- config(recursive=False, **kwargs)
配置可序列化对象
参数
kwargs可能包含以下键:
- dump_allbool
对象是否dump所有属性
- excludelist
不被dump的属性列表
- includelist
被dump的属性列表
- recursivebool
是否递归配置
- to_pickle(path: Path | str, **kwargs)
将对象dump到pickle文件
path (Union[Path, str]): dump文件路径
kwargs可能包含以下键:
- dump_allbool
对象是否dump所有属性
- excludelist
不被dump的属性列表
- includelist
被dump的属性列表
- classmethod load(filepath)
从文件路径加载可序列化类
- 参数:
filepath (str): 文件路径
- 异常:
TypeError: pickle文件必须是`type(cls)`类型
- 返回:
type(cls): `type(cls)`的实例
- classmethod get_backend()
返回Serializable类的真实后端。pickle_backend值可以是"pickle"或"dill"
- 返回:
module: 基于pickle_backend的pickle或dill模块
强化学习
基础组件
- class qlib.rl.Interpreter
解释器是模拟器产生的状态与强化学习策略所需状态之间的媒介。 解释器是双向的:
从模拟器状态到策略状态(又称观测),参见:class:StateInterpreter。
从策略动作到模拟器接受的动作,参见:class:ActionInterpreter。
通过继承这两个子类之一来定义自己的解释器。 此基类仅用于类型检查。
建议解释器设计为无状态,即在解释器中使用``self.xxx``存储临时信息是反模式。未来可能支持通过调用``self.env.register_state()``注册解释器相关状态,但第一版暂不支持。
- class qlib.rl.StateInterpreter
状态解释器,将qlib执行器的执行结果解释为强化学习环境状态
- interpret(simulator_state: StateType) ObsType
解释模拟器的状态。
参数
- simulator_state
通过``simulator.get_state()``获取的模拟器状态。
返回
策略所需的状态,应符合``observation_space``中定义的状态空间。
- validate(obs: ObsType) None
Validate whether an observation belongs to the pre-defined observation space.
- class qlib.rl.ActionInterpreter
动作解释器,将强化学习智能体的动作解释为qlib订单
- interpret(simulator_state: StateType, action: PolicyActType) ActType
将策略动作转换为模拟器动作。
参数
- simulator_state
通过``simulator.get_state()``获取的模拟器状态。
- action
策略给出的原始动作。
返回
模拟器所需的动作。
- validate(action: PolicyActType) None
验证动作是否属于预定义的动作空间。
- class qlib.rl.Reward
奖励计算组件,接受单个参数:模拟器状态。返回一个实数:奖励。
子类应实现``reward(simulator_state)``来实现自定义奖励计算逻辑。
- log(name: str, value: Any) None
记录奖励相关指标到日志。
- reward(simulator_state: SimulatorState) float
实现此方法以定义自定义奖励。
- class qlib.rl.RewardCombination(rewards: Dict[str, Tuple[Reward, float]])
多个奖励的组合。
- class qlib.rl.Simulator(initial: InitialStateType, **kwargs: Any)
通过``__init__``重置并通过``step(action)``转换的模拟器。
为了使数据流清晰,我们对模拟器做了以下限制:
修改模拟器内部状态的唯一方法是使用``step(action)``。
外部模块可以通过``simulator.get_state()``*读取*模拟器的状态, 并通过调用``simulator.done()``检查模拟器是否处于结束状态。
模拟器被定义为与三种类型绑定:
InitialStateType:用于创建模拟器的数据类型
StateType:模拟器**状态**的类型
ActType:**动作**的类型,即每一步接收的输入
不同的模拟器可能共享相同的StateType。例如,当它们处理相同的任务但使用不同的模拟实现时。 通过相同的类型,它们可以安全地共享MDP中的其他组件。
模拟器是短暂的。模拟器的生命周期从初始状态开始,到轨迹结束为止。 换句话说,当轨迹结束时,模拟器会被回收。 如果模拟器之间需要共享上下文(例如为了加速), 可以通过访问环境包装器的弱引用来实现。
属性
- env
环境包装器的引用,在某些特殊情况下可能有用。 不建议模拟器使用此属性,因为它容易引发错误。
- __init__(initial: InitialStateType, **kwargs: Any) None
- done() bool
检查模拟器是否处于"done"状态。当模拟器处于"done"状态时, 它不应再接收任何``step``请求。由于模拟器是短暂的,要重置模拟器,应销毁旧的模拟器并创建一个新的。
- step(action: ActType) None
接收一个ActType类型的动作。
模拟器应更新其内部状态,并返回None。 更新后的状态可以通过``simulator.get_state()``获取。
策略
- class qlib.rl.strategy.SingleOrderStrategy(order: Order, trade_range: TradeRange | None = None)
用于生成只包含一个订单的交易决策的策略。
- __init__(order: Order, trade_range: TradeRange | None = None) None
参数
- outer_trade_decisionBaseTradeDecision, optional
本策略依赖的外部策略交易决策,将在[start_time, end_time]区间内交易,默认为None
如果策略用于拆分交易决策,将会使用此参数
如果策略用于投资组合管理,可以忽略此参数
- level_infraLevelInfrastructure, optional
回测共享的层级基础设施,包括交易日历等
- common_infraCommonInfrastructure, optional
回测共享的公共基础设施,包括交易账户、交易交易所等
- trade_exchangeExchange
提供市场信息的交易所,用于处理订单和生成报告
如果`trade_exchange`为None,self.trade_exchange将从common_infra中获取
允许在不同的执行中使用不同的交易所
例如:
在日线执行中,日线交易所和分钟线交易所都可用,但推荐使用日线交易所,因为它运行更快
在分钟线执行中,日线交易所不可用,只能使用分钟线交易所
训练器
Train, test, inference utilities.
- class qlib.rl.trainer.Trainer(*, max_iters: int | None = None, val_every_n_iters: int | None = None, loggers: LogWriter | List[LogWriter] | None = None, callbacks: List[Callback] | None = None, finite_env_type: Literal['dummy', 'subproc', 'shmem'] = 'subproc', concurrency: int = 2, fast_dev_run: int | None = None)
用于在特定任务上训练策略的工具。
与传统深度学习训练器不同,此训练器的迭代单位是"收集"(collect), 而非"epoch"或"mini-batch"。 每次收集时,:class:`Collector`会收集一定数量的策略-环境交互数据, 并累积到回放缓冲区中。此缓冲区用作训练策略的"数据"。 每次收集结束时,策略会被*更新*多次。
API与`PyTorch Lightning <https://pytorch-lightning.readthedocs.io/>`__有些相似, 但由于此训练器专为RL应用构建,大多数配置都在RL上下文中, 因此本质上不同。 我们仍在寻找整合现有训练器库的方法,因为构建与这些库同样强大的训练器 需要大量工作,且这不是我们的主要目标。
与`tianshou的内置训练器 <https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html>`__ 也完全不同,因为此实现要复杂得多。
参数
- max_iters
停止前的最大迭代次数。
- val_every_n_iters
每n次迭代(即训练收集)执行一次验证。
- logger
记录回测结果的日志记录器。必须提供日志记录器, 否则所有信息都将丢失。
- finite_env_type
有限环境实现类型。
- concurrency
并行工作进程数。
- fast_dev_run
创建用于调试的子集。 具体实现取决于训练容器的实现方式。 对于:class:~qlib.rl.vessel.TrainingVessel,如果大于零, 将使用大小为``fast_dev_run``的随机子集 替代``train_initial_states``和``val_initial_states``。
- __init__(*, max_iters: int | None = None, val_every_n_iters: int | None = None, loggers: LogWriter | List[LogWriter] | None = None, callbacks: List[Callback] | None = None, finite_env_type: Literal['dummy', 'subproc', 'shmem'] = 'subproc', concurrency: int = 2, fast_dev_run: int | None = None)
- fit(vessel: TrainingVesselBase, ckpt_path: Path | None = None) None
在定义的模拟器上训练RL策略。
参数
- vessel
训练中使用的所有元素的集合。
- ckpt_path
加载预pre-trained / paused的训练检查点。
- initialize()
初始化整个训练过程。
此处的状态应与state_dict保持同步。
- initialize_iter()
初始化一次迭代/收集。
- load_state_dict(state_dict: dict) None
将所有状态加载到当前训练器中。
- named_callbacks() Dict[str, Callback]
获取带有名称的回调函数集合。 在保存检查点时很有用。
- state_dict() dict
尽可能将当前训练的所有状态存入字典。
它不会尝试处理一次训练收集中间可能出现的所有状态类型。 对于大多数情况,在每次迭代结束时,结果通常是正确的。
注意,收集器中回放缓冲区数据的丢失是预期行为。
- test(vessel: TrainingVesselBase) None
在模拟器上测试RL策略。
模拟器将使用``test_seed_iterator``生成的数据。
参数
- vessel
所有相关元素的集合。
- venv_from_iterator(iterator: Iterable[InitialStateType]) FiniteVectorEnv
从迭代器和训练容器创建向量化环境。
- should_stop: bool
设置为true可停止训练。
- metrics: dict
训练/验证/测试中产生的数值指标。 在训练/验证过程中,指标来自最新的一轮(episode)。 当每次训练/验证迭代完成时,指标将是该迭代中 所有轮次(episode)的聚合结果。
每次新训练迭代开始时会被清空。
在fit过程中,验证指标会以``val/``为前缀。
- current_iter: int
当前训练迭代(收集)次数。
- class qlib.rl.trainer.TrainingVessel(*, simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]], state_interpreter: StateInterpreter[StateType, ObsType], action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], policy: BasePolicy, reward: Reward, train_initial_states: Sequence[InitialStateType] | None = None, val_initial_states: Sequence[InitialStateType] | None = None, test_initial_states: Sequence[InitialStateType] | None = None, buffer_size: int = 20000, episode_per_iter: int = 1000, update_kwargs: Dict[str, Any] = None)
训练容器的默认实现。
__init__``接受初始状态序列以便创建迭代器。 ``train
、validate
、``test``各执行一次收集(训练中还包含更新)。 默认情况下,训练初始状态会在训练期间无限重复, 收集器会控制每次迭代的轮次(episode)数量。 在验证和测试中,验证/测试初始状态将仅使用一次。额外超参数(仅用于训练)包括:
buffer_size
: 回放缓冲区大小episode_per_iter
: 每次训练收集的轮次数量。可被快速开发模式覆盖。update_kwargs
: 传递给``policy.update``的关键字参数。 例如``dict(repeat=10, batch_size=64)``。
- __init__(*, simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]], state_interpreter: StateInterpreter[StateType, ObsType], action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], policy: BasePolicy, reward: Reward, train_initial_states: Sequence[InitialStateType] | None = None, val_initial_states: Sequence[InitialStateType] | None = None, test_initial_states: Sequence[InitialStateType] | None = None, buffer_size: int = 20000, episode_per_iter: int = 1000, update_kwargs: Dict[str, Any] = None)
- test(vector_env: FiniteVectorEnv) Dict[str, Any]
实现此方法以在测试环境中评估策略一次。
- test_seed_iterator() ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]
重写此方法以创建测试用的种子迭代器。
- train(vector_env: FiniteVectorEnv) Dict[str, Any]
创建收集器并收集``episode_per_iter``轮次(episodes)。 在收集的回放缓冲区上更新策略。
- train_seed_iterator() ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]
重写此方法以创建训练用的种子迭代器。 如果可迭代对象是上下文管理器,整个训练将在with块中调用, 并且迭代器会在训练完成后自动关闭。
- val_seed_iterator() ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]
重写此方法以创建验证用的种子迭代器。
- validate(vector_env: FiniteVectorEnv) Dict[str, Any]
实现此方法以对策略进行一次验证。
- class qlib.rl.trainer.TrainingVesselBase
包含模拟器、解释器和策略的容器,将被发送给训练器。 此类控制训练中与算法相关的部分,而训练器负责运行时部分。
该容器还定义了核心训练部分最重要的逻辑, 以及(可选)一些回调函数用于在特定事件插入自定义逻辑。
- load_state_dict(state_dict: Dict) None
从之前保存的状态字典恢复检查点。
- state_dict() Dict
返回当前容器状态的检查点。
- test(vector_env: FiniteVectorEnv) Dict[str, Any]
实现此方法以在测试环境中评估策略一次。
- test_seed_iterator() ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]
重写此方法以创建测试用的种子迭代器。
- train(vector_env: BaseVectorEnv) Dict[str, Any]
实现此方法以进行一次训练迭代。在RL中,一次迭代通常指一次收集。
- train_seed_iterator() ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]
重写此方法以创建训练用的种子迭代器。 如果可迭代对象是上下文管理器,整个训练将在with块中调用, 并且迭代器会在训练完成后自动关闭。
- val_seed_iterator() ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]
重写此方法以创建验证用的种子迭代器。
- validate(vector_env: FiniteVectorEnv) Dict[str, Any]
实现此方法以对策略进行一次验证。
- class qlib.rl.trainer.Checkpoint(dirpath: Path, filename: str = '{iter:03d}.pth', save_latest: Literal['link', 'copy'] | None = 'link', every_n_iters: int | None = None, time_interval: int | None = None, save_on_fit_end: bool = True)
定期保存检查点以实现持久化和恢复。
参数
- dirpath
保存检查点文件的目录。
- filename
检查点文件名。可包含命名格式化选项自动填充。 例如:
{iter:03d}-{reward:.2f}.pth
。 支持的参数名有:iter (整数)
trainer.metrics中的指标
时间字符串,格式为``%Y%m%d%H%M%S``
- save_latest
在``latest.pth``中保存最新检查点。 如果为``link``,
latest.pth``将创建为软链接。 如果为``copy
,``latest.pth``将保存为独立副本。 设为none可禁用此功能。- every_n_iters
每n次训练迭代结束时保存检查点, 如果有验证则在验证后保存。
- time_interval
再次保存检查点的最大时间间隔(秒)。
- save_on_fit_end
在训练结束时保存最后一个检查点。 如果该位置已有检查点则不执行任何操作。
- __init__(dirpath: Path, filename: str = '{iter:03d}.pth', save_latest: Literal['link', 'copy'] | None = 'link', every_n_iters: int | None = None, time_interval: int | None = None, save_on_fit_end: bool = True)
- on_fit_end(trainer: Trainer, vessel: TrainingVesselBase) None
在整个训练过程结束后调用。
- on_iter_end(trainer: Trainer, vessel: TrainingVesselBase) None
每次迭代结束时调用。 在``current_iter``递增**之后**调用, 表示前一次迭代已完成。
- class qlib.rl.trainer.EarlyStopping(monitor: str = 'reward', min_delta: float = 0.0, patience: int = 0, mode: Literal['min', 'max'] = 'max', baseline: float | None = None, restore_best_weights: bool = False)
当监控指标停止改善时停止训练。
每次验证结束时触发早停回调。 它会检查验证产生的指标, 获取名为``monitor``的指标(默认为``reward``), 判断其是否不再增加/减少。 根据``min_delta``和``patience``参数决定是否停止。 如果发现指标不再改善, 则设置``trainer.should_stop``为true, 终止训练过程。
实现参考: https://github.com/keras-team/keras/blob/v2.9.0/keras/callbacks.py#L1744-L1893
- __init__(monitor: str = 'reward', min_delta: float = 0.0, patience: int = 0, mode: Literal['min', 'max'] = 'max', baseline: float | None = None, restore_best_weights: bool = False)
- load_state_dict(state_dict: dict) None
从保存的状态字典恢复回调函数。
- on_fit_start(trainer: Trainer, vessel: TrainingVesselBase) None
在整个训练过程开始前调用。
- on_validate_end(trainer: Trainer, vessel: TrainingVesselBase) None
验证结束时调用。
- state_dict() dict
获取回调函数的状态字典用于暂停和恢复。
- class qlib.rl.trainer.MetricsWriter(dirpath: Path)
将训练指标写入文件。
- __init__(dirpath: Path) None
- on_train_end(trainer: Trainer, vessel: TrainingVesselBase) None
训练结束时调用。 要访问训练期间产生的所有输出,可在trainer或vessel中缓存数据, 并在此钩子中进行后处理。
- on_validate_end(trainer: Trainer, vessel: TrainingVesselBase) None
验证结束时调用。
- qlib.rl.trainer.train(simulator_fn: Callable[[InitialStateType], Simulator], state_interpreter: StateInterpreter, action_interpreter: ActionInterpreter, initial_states: Sequence[InitialStateType], policy: BasePolicy, reward: Reward, vessel_kwargs: Dict[str, Any], trainer_kwargs: Dict[str, Any]) None
使用RL框架提供的并行能力训练策略。
实验性API,参数可能会变更。
参数
- simulator_fn
接收初始种子并返回模拟器的可调用对象。
- state_interpreter
解释模拟器状态。
- action_interpreter
解释策略动作。
- initial_states
初始状态集合,每个状态将恰好运行一次。
- policy
待训练的策略。
- reward
奖励函数。
- vessel_kwargs
传递给:class:TrainingVessel`的关键字参数,如``episode_per_iter`。
- trainer_kwargs
传递给:class:Trainer`的关键字参数,如``finite_env_type`,
concurrency
。
- qlib.rl.trainer.backtest(simulator_fn: Callable[[InitialStateType], Simulator], state_interpreter: StateInterpreter, action_interpreter: ActionInterpreter, initial_states: Sequence[InitialStateType], policy: BasePolicy, logger: LogWriter | List[LogWriter], reward: Reward | None = None, finite_env_type: Literal['dummy', 'subproc', 'shmem'] = 'subproc', concurrency: int = 2) None
使用RL框架提供的并行能力进行回测。
实验性API,参数可能会变更。
参数
- simulator_fn
接收初始种子并返回模拟器的可调用对象。
- state_interpreter
解释模拟器状态。
- action_interpreter
解释策略动作。
- initial_states
初始状态集合,每个状态将恰好运行一次。
- policy
待测试的策略。
- logger
记录回测结果的日志器。必须提供日志器,否则所有信息都将丢失。
- reward
可选的奖励函数。对于回测,仅用于测试和记录奖励。
- finite_env_type
有限环境实现类型。
- concurrency
并行工作线程数。
订单执行
当前支持单一资产订单执行。 多资产支持正在开发中。
- class qlib.rl.order_execution.FullHistoryStateInterpreter(max_step: int, data_ticks: int, data_dim: int, processed_data_provider: dict | ProcessedDataProvider)
包含所有历史数据的观察值,包括今天(直到当前时刻)和昨天的数据。
参数
- max_step
总步数(上限估计)。例如390分钟/每步30分钟=13步。
- data_ticks
等于总记录数。例如在每分钟的SAOE中, 总ticks数就是一天中的分钟数。
- data_dim
数据的维度数。
- processed_data_provider
处理数据的提供者。
- __init__(max_step: int, data_ticks: int, data_dim: int, processed_data_provider: dict | ProcessedDataProvider) None
- class qlib.rl.order_execution.CurrentStepStateInterpreter(max_step: int)
当前步骤的观察值。
用于策略仅依赖于最新状态而不依赖历史的情况。 键列表不完整,如果您的策略需要更多信息可以添加。
- __init__(max_step: int) None
- class qlib.rl.order_execution.CategoricalActionInterpreter(values: int | List[float], max_step: int | None = None)
将离散策略动作转换为连续动作,然后乘以``order.amount``。
参数
- values
可以是长度为$L$的列表:$[a_1, a_2, ldots, a_L]$。 当策略给出决策$x$时,输出为$a_x$乘以订单数量。 也可以是一个整数$n$,此时会自动生成长度为$n+1$的列表, 即$[0, 1/n, 2/n, ldots, n/n]$。
- max_step
总步数(上限估计)。例如390分钟/每步30分钟=13步。
- __init__(values: int | List[float], max_step: int | None = None) None
- class qlib.rl.order_execution.TwapRelativeActionInterpreter
将连续比率转换为交易数量。
该比率相对于当天剩余时间的TWAP策略。 例如,剩余5个步骤,剩余头寸为300。 使用TWAP策略,每个步骤应交易60。 当此解释器收到动作$a$时,其输出为$60 cdot a$。
- class qlib.rl.order_execution.Recurrent(obs_space: FullHistoryObs, hidden_dim: int = 64, output_dim: int = 32, rnn_type: Literal['rnn', 'lstm', 'gru'] = 'gru', rnn_num_layers: int = 1)
`OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_中提出的网络架构。
在每个时间步,策略网络的输入分为两部分: 公共变量和私有变量,分别由本网络中的``raw_rnn``和``pri_rnn``处理。
一个小的区别是,在此实现中,我们不假设方向是固定的。 因此添加了另一个``dire_fc``来生成额外的方向相关特征。
- __init__(obs_space: FullHistoryObs, hidden_dim: int = 64, output_dim: int = 32, rnn_type: Literal['rnn', 'lstm', 'gru'] = 'gru', rnn_num_layers: int = 1) None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(batch: Batch) Tensor
输入应该是一个至少包含以下内容的字典:
data_processed: [N, T, C]
cur_step: [N] (int)
cur_time: [N] (int)
position_history: [N, S] (S is number of steps)
target: [N]
num_step: [N] (int)
acquiring: [N] (0 or 1)
- class qlib.rl.order_execution.AllOne(obs_space: Space, action_space: Space, fill_value: float | int = 1.0)
前向传播返回全为1的批次。
在实现某些基线(如TWAP)时很有用。
- __init__(obs_space: Space, action_space: Space, fill_value: float | int = 1.0) None
- forward(batch: Batch, state: dict | Batch | ndarray = None, **kwargs: Any) Batch
Compute action over the given batch data.
- 返回:
A
Batch
which MUST have the following keys:act
an numpy.ndarray or a torch.Tensor, the action over given batch data.state
a dict, an numpy.ndarray or a torch.Tensor, the internal state of the policy,None
as default.
Other keys are user-defined. It depends on the algorithm. For example,
# some code return Batch(logits=..., act=..., state=None, dist=...)
The keyword
policy
is reserved and the corresponding data will be stored into the replay buffer. For instance,# some code return Batch(..., policy=Batch(log_prob=dist.log_prob(act))) # and in the sampled data batch, you can directly use # batch.policy.log_prob to get your data.
备注
In continuous action space, you should do another step "map_action" to get the real action:
act = policy(batch).act # doesn't map to the target action range act = policy.map_action(act, batch)
- class qlib.rl.order_execution.PPO(network: Module, obs_space: Space, action_space: Space, lr: float, weight_decay: float = 0.0, discount_factor: float = 1.0, max_grad_norm: float = 100.0, reward_normalization: bool = True, eps_clip: float = 0.3, value_clip: bool = True, vf_coef: float = 1.0, gae_lambda: float = 1.0, max_batch_size: int = 256, deterministic_eval: bool = True, weight_file: Path | None = None)
tianshou PPOPolicy的包装器。
区别:
自动创建actor和critic网络。仅支持离散动作空间。
去除actor网络和critic网络之间的重复参数 (不确定最新版tianshou是否已包含此功能)。
支持加载检查点的``weight_file``参数。
某些参数的默认值与原始版本不同。
- __init__(network: Module, obs_space: Space, action_space: Space, lr: float, weight_decay: float = 0.0, discount_factor: float = 1.0, max_grad_norm: float = 100.0, reward_normalization: bool = True, eps_clip: float = 0.3, value_clip: bool = True, vf_coef: float = 1.0, gae_lambda: float = 1.0, max_batch_size: int = 256, deterministic_eval: bool = True, weight_file: Path | None = None) None
- class qlib.rl.order_execution.PAPenaltyReward(penalty: float = 100.0, scale: float = 1.0)
鼓励更高的PA(价格优势),但对在短时间内堆积所有交易量进行惩罚。 形式上,每个时间步的奖励是 \((PA_t * vol_t / target - vol_t^2 * penalty)\)。
参数
- penalty
短时间内大交易量的惩罚系数。
- scale
用于放大或缩小奖励的权重。
- __init__(penalty: float = 100.0, scale: float = 1.0) None
- class qlib.rl.order_execution.SingleAssetOrderExecutionSimple(order: Order, data_dir: Path, feature_columns_today: List[str] = [], feature_columns_yesterday: List[str] = [], data_granularity: int = 1, ticks_per_step: int = 30, vol_threshold: float | None = None)
单资产订单执行(SAOE)模拟器。
由于简单模拟器中没有"日历",使用tick进行交易。 一个tick是pickle格式数据文件中的一条记录。 每个tick被视为一个独立的交易机会。 如果不需要如此精细的粒度,可以使用``ticks_per_step``来 延长每个步骤的tick数量。
在每个步骤中,交易量被"平均"分配到每个tick, 然后受限于最大执行量(即``vol_threshold``), 如果是最后一步,则尝试确保执行全部数量。
参数
- order
启动SAOE模拟器的种子订单。
- data_dir
加载回测数据的路径。
- feature_columns_today
今日特征列。
- feature_columns_yesterday
昨日特征列。
- data_granularity
连续数据条目之间的tick数量。
- ticks_per_step
每个步骤包含多少tick。
- vol_threshold
最大执行量(除以市场执行量)。
- __init__(order: Order, data_dir: Path, feature_columns_today: List[str] = [], feature_columns_yesterday: List[str] = [], data_granularity: int = 1, ticks_per_step: int = 30, vol_threshold: float | None = None) None
- done() bool
检查模拟器是否处于"done"状态。当模拟器处于"done"状态时, 它不应再接收任何``step``请求。由于模拟器是短暂的,要重置模拟器,应销毁旧的模拟器并创建一个新的。
- step(amount: float) None
Execute one step or SAOE.
- 参数:
amount -- The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.
- history_exec: pd.DataFrame
所有可能时间点的执行历史记录。可用列参见:class:SAOEMetrics。 索引为``datetime``。
- history_steps: pd.DataFrame
每个步骤的仓位。第一步之前的仓位也会被记录。 可用列参见:class:SAOEMetrics。 索引为``datetime``,即每个步骤的**开始**时间。
- metrics: SAOEMetrics | None
指标。仅在完成后可用。
- twap_price: float
用于计算价格优势的价格。 定义为订单开始时间到结束时间期间的平均价格。
- ticks_index: pd.DatetimeIndex
当日所有可用的tick(不限于订单)。
- ticks_for_order: pd.DatetimeIndex
可用于交易的tick(按订单切片)。
- class qlib.rl.order_execution.SAOEStateAdapter(order: Order, trade_decision: BaseTradeDecision, executor: BaseExecutor, exchange: Exchange, ticks_per_step: int, backtest_data: IntradayBacktestData, data_granularity: int = 1)
Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state according to the execution results with additional information acquired from executors & exchange. For example, it gets the dealt order amount from execution results, and get the corresponding market price / volume from exchange.
Example usage:
adapter = SAOEStateAdapter(...) adapter.update(...) state = adapter.saoe_state
- __init__(order: Order, trade_decision: BaseTradeDecision, executor: BaseExecutor, exchange: Exchange, ticks_per_step: int, backtest_data: IntradayBacktestData, data_granularity: int = 1) None
- generate_metrics_after_done() None
Generate metrics once the upper level execution is done
- class qlib.rl.order_execution.SAOEMetrics
SAOE(单资产订单执行)的指标数据,可累计计算一个"周期"内的指标。 可以按天累计,或按时间段(如30分钟)累计,或每分钟单独计算。
警告
类型提示是针对单个元素的,但很多时候这些指标可以是向量化的。 例如,``market_volume``可以是一个浮点数列表(或ndarray)而不仅是单个浮点数。
- stock_id: str
该记录的股票ID。
- datetime: Timestamp | DatetimeIndex
该记录的时间戳(在数据框中作为索引)。
- direction: int
订单方向。0表示卖出,1表示买入。
- market_volume: ndarray | float
该时间段内的(总)市场成交量。
- market_price: ndarray | float
成交价格。如果是时间段,则为该时间段内的平均市场成交价。
- amount: ndarray | float
策略计划交易的总量(成交量)。
- inner_amount: ndarray | float
下层策略计划交易的总量 (可能大于amount,例如为了确保ffr)。
- deal_amount: ndarray | float
实际生效的交易量(必须小于inner_amount)。
- trade_price: ndarray | float
该策略的平均成交价格。
- trade_value: ndarray | float
交易总价值。在简单模拟中,trade_value = deal_amount * price。
- position: ndarray | float
该"周期"后剩余的持仓量。
- ffr: ndarray | float
已完成每日订单的百分比。
- pa: ndarray | float
与基准(即使用基准市场价格交易)相比的价格优势。 基准是使用TWAP策略执行该订单时的交易价格。 请注意这里可能存在数据泄漏。 单位为BP(基点,1/10000)。
- class qlib.rl.order_execution.SAOEState(order: Order, cur_time: pd.Timestamp, cur_step: int, position: float, history_exec: pd.DataFrame, history_steps: pd.DataFrame, metrics: SAOEMetrics | None, backtest_data: BaseIntradayBacktestData, ticks_per_step: int, ticks_index: pd.DatetimeIndex, ticks_for_order: pd.DatetimeIndex)
SAOE(单资产订单执行)模拟器的状态数据结构。
- backtest_data: BaseIntradayBacktestData
状态中包含回测数据。 实际上,目前只需要该数据的时间索引。 包含完整数据是为了支持依赖原始数据的算法(如VWAP)的实现。 解释器可以按需使用这些数据,但应注意避免泄漏未来数据。
- cur_step: int
当前步骤,例如0。
- cur_time: pd.Timestamp
30。
- Type:
当前时间,例如9
- history_exec: pd.DataFrame
参见
SingleAssetOrderExecution.history_exec
。
- history_steps: pd.DataFrame
参见
SingleAssetOrderExecution.history_steps
。
- metrics: SAOEMetrics | None
每日指标,仅在交易处于"完成"状态时可用。
- order: Order
正在处理的订单。
- position: float
当前剩余待执行的交易量。
- ticks_for_order: pd.DatetimeIndex
46, ..., 14:44]。
- Type:
按订单切片的交易tick,例如[9
- Type:
45, 9
- ticks_index: pd.DatetimeIndex
31, ..., 14:59]。
- Type:
全天的交易tick,未按订单切片(在数据中定义)。例如[9
- Type:
30, 9
- ticks_per_step: int
每个步骤包含多少个tick。
- class qlib.rl.order_execution.SAOEStrategy(policy: BasePolicy, outer_trade_decision: BaseTradeDecision | None = None, level_infra: LevelInfrastructure | None = None, common_infra: CommonInfrastructure | None = None, data_granularity: int = 1, **kwargs: Any)
RL-based strategies that use SAOEState as state.
- __init__(policy: BasePolicy, outer_trade_decision: BaseTradeDecision | None = None, level_infra: LevelInfrastructure | None = None, common_infra: CommonInfrastructure | None = None, data_granularity: int = 1, **kwargs: Any) None
- 参数:
policy -- RL policy for generate action
- generate_trade_decision(execute_result: list | None = None) BaseTradeDecision | Generator[Any, Any, BaseTradeDecision]
For SAOEStrategy, we need to update the self._last_step_range every time a decision is generated. This operation should be invisible to developers, so we implement it in generate_trade_decision() The concrete logic to generate decisions should be implemented in _generate_trade_decision(). In other words, all subclass of SAOEStrategy should overwrite _generate_trade_decision() instead of generate_trade_decision().
- post_exe_step(execute_result: list | None) None
A hook for doing sth after the corresponding executor finished its execution.
- 参数:
execute_result -- the execution result
- post_upper_level_exe_step() None
A hook for doing sth after the upper level executor finished its execution (for example, finalize the metrics collection).
- class qlib.rl.order_execution.ProxySAOEStrategy(outer_trade_decision: BaseTradeDecision | None = None, level_infra: LevelInfrastructure | None = None, common_infra: CommonInfrastructure | None = None, **kwargs: Any)
Proxy strategy that uses SAOEState. It is called a 'proxy' strategy because it does not make any decisions by itself. Instead, when the strategy is required to generate a decision, it will yield the environment's information and let the outside agents to make the decision. Please refer to _generate_trade_decision for more details.
- __init__(outer_trade_decision: BaseTradeDecision | None = None, level_infra: LevelInfrastructure | None = None, common_infra: CommonInfrastructure | None = None, **kwargs: Any) None
- 参数:
policy -- RL policy for generate action
- class qlib.rl.order_execution.SAOEIntStrategy(policy: dict | BasePolicy, state_interpreter: dict | StateInterpreter, action_interpreter: dict | ActionInterpreter, network: dict | Module | None = None, outer_trade_decision: BaseTradeDecision | None = None, level_infra: LevelInfrastructure | None = None, common_infra: CommonInfrastructure | None = None, **kwargs: Any)
(SAOE)state based strategy with (Int)preters.
- __init__(policy: dict | BasePolicy, state_interpreter: dict | StateInterpreter, action_interpreter: dict | ActionInterpreter, network: dict | Module | None = None, outer_trade_decision: BaseTradeDecision | None = None, level_infra: LevelInfrastructure | None = None, common_infra: CommonInfrastructure | None = None, **kwargs: Any) None
- 参数:
policy -- RL policy for generate action
工具
- class qlib.rl.utils.LogLevel(value)
强化学习训练的日志级别。 每个日志级别的处理行为取决于:class:`LogWriter`的实现。
- DEBUG = 10
仅在调试模式下查看指标。
- PERIODIC = 20
定期查看指标。
- INFO = 30
重要日志消息。
- CRITICAL = 40
LogWriter应始终处理CRITICAL消息
- class qlib.rl.utils.DataQueue(dataset: Sequence[T], repeat: int = 1, shuffle: bool = True, producer_num_workers: int = 0, queue_maxsize: int = 0)
主进程(生产者)生成数据并存储在队列中。 子进程(消费者)可以从队列中获取数据点。 数据点通过从``dataset``读取项生成。
:class:`DataQueue`是临时性的。当``repeat``耗尽时, 必须创建一个新的DataQueue。
更多背景请参阅:class:`qlib.rl.utils.FiniteVectorEnv`文档。
参数
- dataset
从中读取数据的数据集。必须实现``__len__``和``__getitem__``。
- repeat
数据点迭代次数。使用``-1``表示无限迭代。
- shuffle
如果为True,项将以随机顺序读取。
- producer_num_workers
数据加载的并发工作线程数。
- queue_maxsize
队列阻塞前可放入的最大项数。
示例
>>> data_queue = DataQueue(my_dataset) >>> with data_queue: ... ...
在工作进程中:
>>> for data in data_queue: ... print(data)
- __init__(dataset: Sequence[T], repeat: int = 1, shuffle: bool = True, producer_num_workers: int = 0, queue_maxsize: int = 0) None
- class qlib.rl.utils.EnvWrapper(simulator_fn: Callable[[...], Simulator[InitialStateType, StateType, ActType]], state_interpreter: StateInterpreter[StateType, ObsType], action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], seed_iterator: Iterable[InitialStateType] | None, reward_fn: Reward | None = None, aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None, logger: LogCollector | None = None)
基于Qlib的RL环境,继承自``gym.Env``。 组件包装器,包含模拟器、状态解释器、动作解释器和奖励函数。
这是RL训练中模拟器-解释器-策略框架的体现。 除策略外的所有组件需要组装成一个称为"环境"的对象。 "环境"被复制到多个工作进程,在tianshou实现中, 单个策略(agent)与一批环境交互。
参数
- simulator_fn
模拟器工厂函数。 当``seed_iterator``存在时,工厂函数接受一个参数(种子/初始状态), 否则不接受参数。
- state_interpreter
状态-观测转换器。
- action_interpreter
策略-模拟器动作转换器。
- seed_iterator
种子迭代器。借助:class:qlib.rl.utils.DataQueue, 不同进程的环境工作器可以共享一个``seed_iterator``。
- reward_fn
接受StateType并返回浮点数的可调用对象(至少单智能体情况下)。
- aux_info_collector
收集辅助信息,在MARL中可能有用。
- logger
日志收集器,收集的日志通过``env.step()``返回值传回主进程。
属性
- statusEnvWrapperStatus
状态指示器,所有术语采用*RL语言*。 当用户关心RL侧数据时可以使用。 没有轨迹时可能为None。
- __init__(simulator_fn: Callable[[...], Simulator[InitialStateType, StateType, ActType]], state_interpreter: StateInterpreter[StateType, ObsType], action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], seed_iterator: Iterable[InitialStateType] | None, reward_fn: Reward | None = None, aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None, logger: LogCollector | None = None) None
- render(mode: str = 'human') None
Compute the render frames as specified by render_mode attribute during initialization of the environment.
The set of supported modes varies per environment. (And some third-party environments may not support rendering at all.) By convention, if render_mode is:
None (default): no render is computed.
human: render return None. The environment is continuously rendered in the current display or terminal. Usually for human consumption.
rgb_array: return a single frame representing the current state of the environment. A frame is a numpy.ndarray with shape (x, y, 3) representing RGB values for an x-by-y pixel image.
rgb_array_list: return a list of frames representing the states of the environment since the last reset. Each frame is a numpy.ndarray with shape (x, y, 3), as with rgb_array.
ansi: Return a strings (str) or StringIO.StringIO containing a terminal-style text representation for each time step. The text can include newlines and ANSI escape sequences (e.g. for colors).
备注
Make sure that your class's metadata 'render_modes' key includes the list of supported modes. It's recommended to call super() in implementations to use the functionality of this method.
- reset(**kwargs: Any) ObsType
尝试从状态队列获取状态,并用此状态初始化模拟器。 如果队列耗尽,则生成无效(nan)观测值。
- step(policy_action: PolicyActType, **kwargs: Any) Tuple[ObsType, float, bool, InfoDict]
环境步骤。
结合代码和注释查看此处发生的事件序列。
- class qlib.rl.utils.LogCollector(min_loglevel: int | LogLevel = LogLevel.PERIODIC)
日志首先在每个环境工作进程中收集,然后在向量环境的中央线程中聚合流。
在:class:`LogCollector`中,每个指标都被添加到一个字典中,需要在每步调用``reset()``清空。 该字典通过``env.step()``中的``info``发送,并由向量环境中的:class:`LogWriter`解码。
``min_loglevel``用于优化目的:避免网络/管道中的过多流量。
- add_any(name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) None
记录任意类型的对象。
由于是"任意"对象,唯一能接受它的LogWriter是pickle。 因此,pickle必须能够序列化它。
- add_array(name: str, array: ndarray | DataFrame | Series, loglevel: int | LogLevel = LogLevel.PERIODIC) None
添加带名称的数组到日志中。
- add_scalar(name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) None
添加带名称的标量到日志内容中。 标量将被转换为浮点数。
- add_string(name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) None
添加带名称的字符串到日志内容中。
- reset() None
清除所有已收集的内容。
- class qlib.rl.utils.LogWriter(loglevel: int | LogLevel = LogLevel.PERIODIC)
日志写入器基类,由有限环境在每次重置和步骤时触发。
如何处理特定日志取决于子类化:class:`LogWriter`的实现。 一般原则是,它应该处理高于其日志级别(包括)的日志,并丢弃不可接受的日志。例如,控制台日志器显然无法处理图像。
- static aggregation(array: Sequence[Any], name: str | None = None) Any
Aggregation function from step-wise to episode-wise.
If it's a sequence of float, take the mean. Otherwise, take the first element.
If a name is specified and,
if it's
reward
, the reduction will be sum.
- clear()
Clear all the metrics for a fresh start. To make the logger instance reusable.
- load_state_dict(state_dict: dict) None
Load the states of current logger from a dict.
- log_episode(length: int, rewards: List[float], contents: List[Dict[str, Any]]) None
在每个轨迹结束时触发。
参数
- length
此轨迹的长度。
- rewards
本回合每步奖励的列表。
- contents
每步的日志内容。
- on_env_all_done() None
所有操作完成,进行清理工作。
- on_env_all_ready() None
当所有环境准备就绪可以运行时调用。 通常,日志器应在此处重置。
- on_env_reset(env_id: int, _: ObsType) None
有限环境的回调函数。
重置回合统计信息。由于tianshou的限制<https://github.com/thu-ml/tianshou/issues/605>,此处不记录任何任务特定信息。
- on_env_step(env_id: int, obs: ObsType, rew: float, done: bool, info: InfoDict) None
Callback for finite env, on each step.
- state_dict() dict
Save the states of the logger to a dict.
- episode_count: int
回合计数器。
- step_count: int
步数计数器。
- global_step: int
全局步数计数器。在``clear``中不会被清除。
- global_episode: int
全局回合计数器。在``clear``中不会被清除。
- active_env_ids: Set[int]
向量环境中活跃的环境ID集合。
- episode_lengths: Dict[int, int]
从环境ID到回合长度的映射。
- episode_rewards: Dict[int, List[float]]
从环境ID到回合总奖励的映射。
- episode_logs: Dict[int, list]
从环境ID到回合日志的映射。
- qlib.rl.utils.vectorize_env(env_factory: Callable[[...], Env], env_type: Literal['dummy', 'subproc', 'shmem'], concurrency: int, logger: LogWriter | List[LogWriter]) FiniteVectorEnv
创建向量环境的辅助函数。可用于替换常规的VectorEnv。
例如,如果你曾经这样写::
DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])
现在你可以替换为::
finite_env_factory(lambda: gym.make(task), "dummy", env_num, my_logger)
通过这样的替换,你将获得两个额外的功能(与普通VectorEnv相比):
向量环境将检查NaN观测值,并在发现时终止工作进程。 有关原因,请参见:class:FiniteVectorEnv。
一个显式收集环境工作进程日志的日志记录器。
参数
- env_factory
用于实例化单个``gym.Env``的可调用对象。 所有并发工作进程将使用相同的``env_factory``。
- env_type
dummy或subproc或shmem。对应于 tianshou中的并行方式。
- concurrency
并发环境工作进程数。
- logger
日志记录器。
警告
请勿在此处为``env_factory``使用lambda表达式,因为这可能会创建不正确共享的实例。
不要这样做::
vectorize_env(lambda: EnvWrapper(...), ...)
请这样做::
def env_factory(): ... vectorize_env(env_factory, ...)
- class qlib.rl.utils.ConsoleWriter(log_every_n_episode: int = 20, total_episodes: int | None = None, float_format: str = ':.4f', counter_format: str = ':4d', loglevel: int | LogLevel = LogLevel.PERIODIC)
定期将日志消息写入控制台。
它为每个指标跟踪一个平均计量器,即从上次``clear()``到现在的平均值。 每个指标的显示格式为``<名称> <最新值> (<平均值>)``。
非单一数字指标会自动跳过。
- prefix: str
Prefix can be set via
writer.prefix
.
- __init__(log_every_n_episode: int = 20, total_episodes: int | None = None, float_format: str = ':.4f', counter_format: str = ':4d', loglevel: int | LogLevel = LogLevel.PERIODIC) None
- clear() None
Clear all the metrics for a fresh start. To make the logger instance reusable.
- class qlib.rl.utils.CsvWriter(output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC)
将所有回合指标转储到``result.csv``文件中。
这不是正确的实现,仅用于第一次迭代。
- clear() None
Clear all the metrics for a fresh start. To make the logger instance reusable.
- log_episode(length: int, rewards: List[float], contents: List[Dict[str, Any]]) None
在每个轨迹结束时触发。
参数
- length
此轨迹的长度。
- rewards
本回合每步奖励的列表。
- contents
每步的日志内容。
- on_env_all_done() None
所有操作完成,进行清理工作。
- class qlib.rl.utils.EnvWrapperStatus
EnvWrapper中使用的状态数据结构。 这里的字段采用RL语义。 例如,``obs``表示输入策略的观测值, ``action``表示策略返回的原始动作。
- class qlib.rl.utils.LogBuffer(callback: Callable[[bool, bool, LogBuffer], None], loglevel: int | LogLevel = LogLevel.PERIODIC)
将所有数字保存在内存中。
无法聚合的对象(如字符串、张量、图像)不能存储在缓冲区中。 要持久化它们,请使用:class:PickleWriter。
每次日志缓冲区收到新指标时,都会触发回调, 这在训练器内部跟踪指标时非常有用。
参数
- callback
接收三个参数的回调函数:
on_episode: 是否在回合结束时调用
on_collect: 是否在收集结束时调用
log_buffer: :class:`LogBbuffer`对象
不需要返回值。
- __init__(callback: Callable[[bool, bool, LogBuffer], None], loglevel: int | LogLevel = LogLevel.PERIODIC)
- clear()
Clear all the metrics for a fresh start. To make the logger instance reusable.
- collect_metrics() dict[str, float]
Retrieve the aggregated metrics of the latest collect.
- episode_metrics() dict[str, float]
Retrieve the numeric metrics of the latest episode.
- load_state_dict(state_dict: dict) None
Load the states of current logger from a dict.
- log_episode(length: int, rewards: list[float], contents: list[dict[str, Any]]) None
在每个轨迹结束时触发。
参数
- length
此轨迹的长度。
- rewards
本回合每步奖励的列表。
- contents
每步的日志内容。
- on_env_all_done() None
所有操作完成,进行清理工作。
- state_dict() dict
Save the states of the logger to a dict.