qteasy.database 源代码

# coding=utf-8
# ======================================
# File:     database.py
# Author:   Jackie PENG
# Contact:  jackie.pengzhao@gmail.com
# Created:  2020-11-29
# Desc:
#   Definition of DataSource class, managing
# local data storage and acquiring.
# ======================================

import os
import pandas as pd
import numpy as np
import warnings

from tqdm import tqdm
from os import path
from typing import Union, Callable, Any

from functools import lru_cache

from .datatables import (
    AVAILABLE_DATA_FILE_TYPES,
    TABLE_MASTERS,
    get_table_master,
)

from .utilfuncs import (
    str_to_list,
    regulate_date_format,
    human_file_size,
    human_units,
    date_to_month_format,
    date_to_quarter_format,
    sanitize_filename,
)

from .datatables import (
    get_built_in_table_schema,
    set_primary_key_index,
    set_primary_key_frame,
)


[文档]class DataSource: """管理本地历史数据存储(文件或数据库)的统一入口对象。 DataSource 负责与本地文件或数据库交互,统一管理历史数据表的读取、写入与概览, 并保证生成的数据结构可以被 HistoryPanel 与上层 API 正确消费;当某些表缺失时, DataSource 本身不会自动下载数据,而是配合 ``refill_data_source`` 等函数完成维护。 支持的文件类型与数据库种类及更多初始化细节,见文档「DataSource 与本地数据源」 相关章节。 Examples -------- 下面示例展示 DataSource 类名(稳定输出);实际使用时需要结合你的本地数据目录或数据库连接参数:\n >>> import qteasy as qt >>> qt.DataSource.__name__ 'DataSource' """ def __init__(self, source_type: str = 'file', file_type: str = 'csv', file_loc: str = 'data/', host: str = 'localhost', port: int = 3306, user: str = None, password: str = None, db_name: str = 'qt_db', allow_drop_table: bool = False): """ 创建一个DataSource 对象 创建对象时确定本地数据存储方式,确定文件存储位置、文件类型,或者建立数据库的连接 Parameters ---------- source_type: str, Default: file 数据源类型: - db/database: 数据存储在mysql数据库中 - file: 数据存储在本地文件中 file_type: str, {'csv', 'hdf', 'hdf5', 'feather', 'fth'}, Default: csv 如果数据源为file时,数据文件类型: - csv: 简单的纯文本文件格式,可以用Excel打开,但是占用空间大,读取速度慢 - hdf/hdf5: 基于pytables的数据表文件,速度较快,需要安装pytables - feather/fth: 轻量级数据文件,速度较快,占用空间小,需要安装pyarrow file_loc: str, Default: data/ 用于存储本地数据文件的路径 host: str, default: localhost 如果数据源为database时,数据库的host port: int, Default: 3306 如果数据源为database时,数据库的port,默认3306 user: str, Default: None 如果数据源为database时,数据库的user name password: str, Default: None 如果数据源为database时,数据库的passwrod db_name: str, Default: 'qt_db' 如果数据源为database时,数据库的名称,默认值qt_db Raises ------ ImportError 部分文件格式以及数据类型需要optional dependency,如果缺乏这些package时,会提示安装 SystemError 数据类型为file时,在本地创建数据文件夹失败时会抛出该异常 Returns ------- None """ if not isinstance(source_type, str): err = TypeError(f'source type should be a string, got {type(source_type)} instead.') raise err if source_type.lower() not in ['file', 'database', 'db']: err = ValueError(f'invalid source_type') raise err self._table_list = set() if source_type.lower() in ['db', 'database']: # try to create pymysql connections self.source_type = 'db' try: # optional packages to be imported import pymysql from dbutils.pooled_db import PooledDB # set up connection parameters to database assert isinstance(port, int), f'port should be int type, got {type(port)} instead!' assert user is not None, f'Missing user name for database connection' assert password is not None, f'Missing password for database connection' self.pool = PooledDB( creator=pymysql, # 使用链接数据库的模块 mincached=3, # 初始化时,链接池中至少创建的链接,0表示不创建 maxconnections=5, # 连接池允许的最大连接数,0和None表示不限制连接数 blocking=True, # 连接池中如果没有可用连接后,是否阻塞等待。True,等待;False,不等待并报错 host=host, port=port, user=user, password=password, database=db_name, ) self.connection_type = f'mysql://{host}@{port}/{db_name}' self.host = host self.port = port self.db_name = db_name self.file_type = None self.file_path = None self.__user__ = user self.__password__ = password except ImportError as e: msg = f'{str(e)}' \ f'Install pymysql or DButils: \n$ pip install pymysql\n$ pip install dbutils\n' \ f'Can not set data source type to "db", will fall back to csv file' warnings.warn(msg, RuntimeWarning, stacklevel=2) source_type = 'file' file_type = 'csv' except AssertionError as e: msg = f'Failed setting up mysql connection: {str(e)}\n' \ f'Can not set data source type to "db", will fall back to csv file' warnings.warn(msg, RuntimeWarning, stacklevel=2) source_type = 'file' file_type = 'csv' except Exception as e: msg = f'Mysql connection failed: {str(e)}\n' \ f'Can not set data source type to "db", will fall back to csv file' warnings.warn(msg, RuntimeWarning, stacklevel=2) source_type = 'file' file_type = 'csv' finally: pass if source_type.lower() == 'file': from qteasy import QT_ROOT_PATH # 规范化用户传入的 file_type,同时保留“原始请求值”以便提示信息更准确。 # 注意:qteasy 物理实现仅覆盖 hdf 与 fth;hdf5 视为 hdf 的别名。 file_type_original = file_type if not isinstance(file_type, str): err = TypeError(f'file type should be a string, got {type(file_type)} instead!') raise err file_type = file_type.lower() assert file_type in AVAILABLE_DATA_FILE_TYPES, ( f'Wrong file type! supported file types are {AVAILABLE_DATA_FILE_TYPES}' ) normalized_file_type = file_type if normalized_file_type == 'hdf5': normalized_file_type = 'hdf' if normalized_file_type in ['feather', 'fth']: normalized_file_type = 'fth' # 对文件型数据源,先确保目录可用: # 即使可选依赖缺失并回退到 csv,也不能遗漏 file_path 的初始化与目录创建。 try: self.file_path = path.join(QT_ROOT_PATH, file_loc) os.makedirs(self.file_path, exist_ok=True) # 确保数据目录存在 except Exception as e: err = SystemError( f'{str(e)}, Failed creating data directory \'{file_loc}\' in qt root path, please check your input.' ) raise err # 检查可选依赖;缺失则回退到 csv。 try: if normalized_file_type == 'hdf': import tables # 可选依赖 file_type = 'hdf' elif normalized_file_type == 'fth': import pyarrow # 可选依赖 file_type = 'fth' else: file_type = 'csv' except ImportError: if normalized_file_type == 'hdf': msg = ( f"Missing optional dependency 'tables' for datasource file type '{file_type_original}'. " f"Please install pytables: $ conda install pytables\n" f'Fallback to csv is applied.' ) elif normalized_file_type == 'fth': msg = ( f"Missing optional dependency 'pyarrow' for datasource file type '{file_type_original}'. " f"Please install pyarrow: $ conda install pyarrow\n" f'Fallback to csv is applied.' ) else: msg = ( f"Missing optional dependency for datasource file type '{file_type_original}'. " f'Fallback to csv is applied.' ) warnings.warn(msg, RuntimeWarning, stacklevel=2) file_type = 'csv' self.source_type = 'file' self.file_type = file_type self.file_loc = file_loc self.connection_type = f'file://{file_type}@qt_root/{file_loc}' self.host = None self.port = None self.db_name = None self.__user__ = None self.__password__ = None self._allow_drop_table = allow_drop_table self._active_tx = None @property def tables(self) -> list: """ 所有已经建立的tables的清单""" return list(self._table_list) @property def all_tables(self) -> list: """ 获取所有数据表的清单""" return get_table_master().index.to_list() @property def all_sys_tables(self) -> list: """ 获取所有系统数据表的清单""" tables = get_table_master() return tables[tables['table_usage'] == 'sys'].index.to_list()
[文档] def none_sys_tables(self) -> list: """ 获取所有非系统数据表的清单""" tables = get_table_master() return tables[tables['table_usage'] != 'sys'].index.to_list()
@property def all_data_tables(self) -> list: """ 获取所有历史数据表(不包括调整表)的清单""" tables = get_table_master() return tables[~tables['table_usage'].isin(['sys', 'adj'])].index.to_list() @property def all_basic_tables(self) -> list: """ 获取所有基础数据表的清单""" tables = get_table_master() return tables[tables['table_usage'] == 'basics'].index.to_list() @property def allow_drop_table(self) -> bool: """ 获取是否允许删除数据表""" return self._allow_drop_table @allow_drop_table.setter def allow_drop_table(self, value: bool): """ 设置是否允许删除数据表""" if not isinstance(value, bool): err = TypeError(f'allow_drop_table should be a boolean, got {type(value)} instead!') raise err self._allow_drop_table = value def __repr__(self): if self.source_type == 'db': return f'DataSource(\'db\', \'{self.host}\', {self.port})' elif self.source_type == 'file': return f'DataSource(\'file\', \'{self.file_type}\', \'{self.file_loc}\')' else: return def __str__(self): return self.connection_type
[文档] def info(self): """ 格式化打印database对象的各种主要信息 Returns ------- """ if self.source_type == 'file': print(f'DataSource Info: \n' f'{"=" * 40}\n' f'{"Source Type":<20}: {self.source_type}\n' f'{"File Type":<20}: {self.file_type}\n' f'{"File Location":<20}: {self.file_loc}\n') elif self.source_type == 'db': print(f'DataSource Info: \n' f'{"=" * 40}\n' f'{"Source Type":<20}: {self.source_type}\n' f'{"Host":<20}: {self.host}\n' f'{"Port":<20}: {self.port}\n' f'{"User":<20}: {self.__user__}\n' f'{"Database":<20}: {self.db_name}\n')
[文档] def overview(self, tables=None, print_out=True, include_sys_tables=False) -> pd.DataFrame: """ 以表格形式列出所有数据表的当前数据状态 Parameters ---------- tables: str or list of str, Default None 指定要列出的数据表,如果为None则列出所有数据表 print_out: bool, Default True 是否打印数据表总揽 include_sys_tables: bool, Default False 是否包含系统表 Returns ------- pd.DataFrame, 包含所有数据表的数据状态 """ all_tables = get_table_master() if not include_sys_tables: all_tables = all_tables[all_tables['table_usage'] != 'sys'] all_table_names = all_tables.index if tables is not None: if isinstance(tables, str): tables = str_to_list(tables) if not isinstance(tables, list): err = TypeError(f'tables should be a list of str, got {type(tables)} instead!') raise err all_table_names = [table_name for table_name in all_table_names if table_name in tables] all_info = [] print('Analyzing local data source tables... depending on size of tables, it may take a few minutes') total_table_count = len(all_table_names) completed_reading_count = 0 for table_name in tqdm(all_table_names, desc='Analyzing tables', total=total_table_count): all_info.append(self.get_table_info(table_name, verbose=False, print_info=False, human=True).values()) completed_reading_count += 1 print(f'Analyzing completed!') all_info = pd.DataFrame(all_info, columns=['table', 'has_data', 'size', 'records', 'pk1', 'records1', 'min1', 'max1', 'pk2', 'records2', 'min2', 'max2']) all_info.index = all_info['table'] all_info.drop(columns=['table'], inplace=True) if print_out: info_to_print = all_info.loc[all_info.has_data == True][['has_data', 'size', 'records', 'min2', 'max2']] print(f'\nFinished analyzing datasource: \n{self}\n' f'{len(info_to_print)} table(s) out of {len(all_info)} contain local data as summary below, ' f'to view complete list, print returned DataFrame\n' f'{"tables with local data":=^84}') print(info_to_print.to_string(columns=['has_data', 'size', 'records', 'min2', 'max2'], header=['Has_data', 'Size_on_disk', 'Record_count', 'Record_start', 'Record_end'], justify='center' ) ) return all_info
# 文件操作层函数,只操作文件,不修改数据 def _get_file_path_name(self, file_name): """获取完整文件路径名(文件名经跨平台安全规范化)。""" if self.source_type == 'db': err = RuntimeError('can not check file system while source type is "db"') raise err if not isinstance(file_name, str): err = TypeError(f'file_name name must be a string, {file_name} is not a valid input!') raise err safe_filename = sanitize_filename(file_name + '.' + self.file_type) file_path_name = path.join(self.file_path, safe_filename) return file_path_name def _file_exists(self, file_name): """ 检查文件是否已存在 Parameters ---------- file_name: 需要检查的文件名(不含扩展名) Returns ------- Boolean: 文件存在时返回真,否则返回假 """ file_path_name = self._get_file_path_name(file_name) return path.exists(file_path_name) def _write_file(self, df, file_name): """ 将df写入本地文件,在把文件写入文件之前,需要将primary key写入index,使用 set_primary_key_index()函数 Parameters ---------- df: 待写入文件的DataFrame,primary key 为index file_name: 本地文件名(不含扩展名) Returns ------- str: file_name 如果数据保存成功,返回完整文件路径名称 """ file_path_name = self._get_file_path_name(file_name) if self.file_type == 'csv': df.to_csv(file_path_name, encoding='utf-8') elif self.file_type == 'fth': df.reset_index().to_feather(file_path_name) elif self.file_type == 'hdf': df.to_hdf(file_path_name, key='df') else: # for some unexpected cases err = TypeError(f'Invalid file type: {self.file_type}') raise err return len(df) def _read_file(self, file_name, primary_key, pk_dtypes, share_like_pk=None, shares=None, date_like_pk=None, start=None, end=None, chunk_size=50000): """ 从文件中读取DataFrame,当文件类型为csv时,支持分块读取且完成数据筛选 Parameters ---------- file_name: str 文件名 primary_key: list of str 用于生成primary_key index 的主键 pk_dtypes: list of str primary_key的数据类型 share_like_pk: str 用于按值筛选数据的主键 shares: list of str 用于筛选数据的主键的值 date_like_pk: str 用于按日期筛选数据的主键 start: datetime-like 用于按日期筛选数据的起始日期 end: datetime-like 用于按日期筛选数据的结束日期 chunk_size: int 分块读取csv大文件时的分块大小 Returns ------- DataFrame:从文件中读取的DataFrame,如果数据有主键,将主键设置为df的index """ # TODO: 历史数据表的规模较大,如果数据存储在数据库中,读取和存储时 # 没有问题,但是如果数据存储在文件中,需要优化存储和读取过程 # ,以便提高效率。目前优化了csv文件的读取,通过分块读取提高 # csv文件的读取效率,其他文件系统的读取还需要进一步优化 file_path_name = self._get_file_path_name(file_name) if not self._file_exists(file_name): # 如果文件不存在,则返回空的DataFrame return pd.DataFrame() if date_like_pk is not None: start_dt = pd.to_datetime(start) end_dt = pd.to_datetime(end) pk_idx = primary_key.index(date_like_pk) if date_like_pk in primary_key else None is_datetime_pk = (pk_idx is not None) and (pk_dtypes[pk_idx] == 'datetime') if is_datetime_pk: # datetime 主键应保留时分秒;若仅给了日期,则自动补齐到当天收尾时刻 if end_dt.hour == 0 and end_dt.minute == 0 and end_dt.second == 0: end_dt = end_dt.replace(hour=23, minute=59, second=59) start = start_dt.strftime('%Y-%m-%d %H:%M:%S') end = end_dt.strftime('%Y-%m-%d %H:%M:%S') else: start = start_dt.strftime('%Y-%m-%d') end = end_dt.strftime('%Y-%m-%d') if self.file_type == 'csv': # 这里针对csv文件进行了优化,通过分块读取文件,避免当文件过大时导致读取异常 try: df_reader = pd.read_csv(file_path_name, chunksize=chunk_size) except FileNotFoundError: raise FileNotFoundError(f'File {file_name} not found!') except FileExistsError: raise FileExistsError(f'File {file_name} exists but can not be read!') except Exception as e: err = RuntimeError(f'{e}, file reading error encountered.') raise err df_picker = (chunk for chunk in df_reader) if (share_like_pk is not None) and (date_like_pk is not None): df_picker = (chunk.loc[(chunk[share_like_pk].isin(shares)) & (chunk[date_like_pk] >= start) & (chunk[date_like_pk] <= end)] for chunk in df_reader) elif (share_like_pk is None) and (date_like_pk is not None): df_picker = (chunk.loc[(chunk[date_like_pk] >= start) & (chunk[date_like_pk] <= end)] for chunk in df_reader) elif (share_like_pk is not None) and (date_like_pk is None): df_picker = (chunk.loc[(chunk[share_like_pk].isin(shares))] for chunk in df_reader) df = pd.concat(df_picker) set_primary_key_index(df, primary_key=primary_key, pk_dtypes=pk_dtypes) return df if self.file_type == 'hdf': # TODO: hdf5/feather的大文件读取尚未优化 try: df = pd.read_hdf(file_path_name, 'df') except ValueError as e: if 'pickle protocol: 5' in e.__str__(): # check when the file is written in a higher pickle protocol err = EnvironmentError(f'File {file_name} is written in a higher version of python which uses ' f'pickle protocol 5, to avoid this error, install pickle5 package and ' f're-save the file.') raise err else: err = RuntimeError(f'{e}, file reading error encountered.') raise err except Exception as e: err = RuntimeError(f'{e}, file reading error encountered.') raise err df = set_primary_key_frame(df, primary_key=primary_key, pk_dtypes=pk_dtypes) elif self.file_type == 'fth': # TODO: feather大文件读取尚未优化 try: df = pd.read_feather(file_path_name) except Exception as e: err = RuntimeError(f'{e}, file reading error encountered.') raise err else: # for some unexpected cases err = TypeError(f'Invalid file type: {self.file_type}') raise err try: # 如果self.file_type 为 hdf/fth,那么需要筛选数据 if (share_like_pk is not None) and (date_like_pk is not None): df = df.loc[(df[share_like_pk].isin(shares)) & (df[date_like_pk] >= start) & (df[date_like_pk] <= end)] elif (share_like_pk is None) and (date_like_pk is not None): df = df.loc[(df[date_like_pk] >= start) & (df[date_like_pk] <= end)] elif (share_like_pk is not None) and (date_like_pk is None): df = df.loc[(df[share_like_pk].isin(shares))] except Exception as e: import traceback traceback.print_exc() raise e set_primary_key_index(df, primary_key=primary_key, pk_dtypes=pk_dtypes) return df def _delete_file_records(self, file_name, primary_key, record_ids) -> int: """ 从文件中删除指定的记录 Parameters ---------- file_name: str 文件名 primary_key: list of str 主键 record_ids: list of int or tuple of int 待删除的记录的主键值 Returns ------- rows_deleted: int 删除的记录数 """ # check that all record_ids are integers if not all(isinstance(record_id, int) for record_id in record_ids): err = TypeError(f'All record_ids must be integers, got {record_ids} instead!') raise err # read all data from file into a dataframe primary_key = [primary_key] if isinstance(primary_key, str) else primary_key df = self._read_file(file_name, primary_key, pk_dtypes=['int']) # check if the record_ids are in the dataframe, remove them if they are rows_deleted = 0 for record_id in record_ids: if record_id in df.index: df.drop(record_id, inplace=True) rows_deleted += 1 # write the updated dataframe back to the file, make sure primary keys are in index df = set_primary_key_frame(df, primary_key=primary_key, pk_dtypes=['int']) set_primary_key_index(df, primary_key=primary_key, pk_dtypes=['int']) self._write_file(df, file_name) return rows_deleted def _get_file_table_coverage(self, table, column, primary_key, pk_dtypes, min_max_only): """ 检查数据表文件关键列的内容,去重后返回该列的内容清单 Parameters ---------- table: str 数据表名 column: str 关键列名 primary_key: list of str 数据表的主键名称列表 pk_dtypes: list of str 数据表的主键数据类型列表 min_max_only: bool 为True时仅输出最小、最大以及总数量,False输出完整列表 Returns ------- list of str 数据表中存储的数据关键列的清单 """ if not self._file_exists(table): return list() df = self._read_file(table, primary_key, pk_dtypes) if df.empty: return list() if column in list(df.index.names): extracted_val = df.index.get_level_values(column).unique() else: extracted_val = df[column].unique() if isinstance(extracted_val[0], pd.Timestamp): extracted_val = extracted_val.strftime('%Y%m%d') res = list() if min_max_only: res.append(extracted_val.min()) res.append(extracted_val.max()) res.append(len(extracted_val)) else: res.extend(extracted_val) return list(res) def _drop_file(self, file_name): """ 删除本地文件 Parameters ---------- file_name: str 将被删除的文件名 Returns ------- None """ import os if self._file_exists(file_name): file_path_name = self._get_file_path_name(file_name) os.remove(file_path_name) def _get_file_size(self, file_name): """ 获取文件大小,输出 Parameters ---------- file_name: str 文件名 Returns ------- str representing file size """ import os file_path_name = self._get_file_path_name(file_name) try: file_size = os.path.getsize(file_path_name) return file_size except FileNotFoundError: return -1 except Exception as e: err = RuntimeError(f'{e}, unknown error encountered.') raise err def _get_file_rows(self, file_name): """获取csv、hdf、feather文件中数据的行数""" file_path_name = self._get_file_path_name(file_name) if self.file_type == 'csv': with open(file_path_name, 'r', encoding='utf-8') as fp: line_count = None for line_count, line in enumerate(fp): pass return line_count elif self.file_type == 'hdf': df = pd.read_hdf(file_path_name, 'df') return len(df) elif self.file_type == 'fth': df = pd.read_feather(file_path_name) return len(df) # 数据库操作层函数,只操作具体的数据表,不操作数据 def _db_open_connection(self): """从数据连接池中获取数据连接,返回con和cursor对象""" try: conn = self.pool.connection() cursor = conn.cursor() # 表示读取的数据为字典类型 return conn, cursor except Exception as e: err = RuntimeError(f'{e}, error in opening database connection') return None, None @staticmethod def _db_close_connection(conn, cursor): """关闭数据库连接""" cursor.close() conn.close() def _db_execute_one(self, sql, data=None, *, fetch_and_return=True, return_cursor=False, rollback=False) -> any: """从mysql连接池获取一个新的连接,执行一条sql语句, 返回结果并关闭连接 Parameters ---------- sql: str 需要执行的sql语句 data: tuple, optional, 需要执行的数据 fetch_and_return: bool, Default True 是否需要返回执行sql语句的结果 return_cursor: bool, Default False 是否需要同时返回cursor对象,默认不需要,如果fetch_and_return为False,则无效 rollback: bool, Default False 是否在执行sql语句出错时回滚 Returns ------- result: any 执行sql语句的结果 """ use_existing_tx = self._active_tx is not None if use_existing_tx: conn, cursor = self._active_tx else: conn, cursor = self._db_open_connection() try: rows_affected = cursor.execute(sql, data) if not use_existing_tx: conn.commit() if fetch_and_return: result = cursor.fetchall() if return_cursor: return result, cursor else: return result return rows_affected except Exception as e: if rollback and not use_existing_tx: conn.rollback() err = RuntimeError(f'{e}, error in executing sql: {sql}') raise err finally: if not use_existing_tx: self._db_close_connection(conn, cursor) def _db_execute_many(self, sql, data, rollback=False) -> any: """从mysql连接池获取一个新的连接,执行包含多条数据的sql语句,返回执行的结果并关闭连接 Parameters ---------- sql: str 需要执行的sql语句 data: tuple of tuple 需要执行的数据 rollback: bool, Default False 是否在执行sql语句出错时回滚 Returns ------- rows_affected: int: 执行sql语句的结果 """ use_existing_tx = self._active_tx is not None if use_existing_tx: conn, cursor = self._active_tx else: conn, cursor = self._db_open_connection() try: rows_affected = cursor.executemany(sql, data) if not use_existing_tx: conn.commit() return rows_affected except Exception as e: if rollback and not use_existing_tx: conn.rollback() err = RuntimeError(f'{e}, error in executing sql: {sql}') raise err finally: if not use_existing_tx: self._db_close_connection(conn, cursor)
[文档] def db_run_in_transaction(self, action: Callable[..., Any], *args, **kwargs) -> Any: """在 DataSource 上执行最小事务封装(DB 真事务,file 为 no-op)。 Parameters ---------- action: Callable[..., Any] 需要在事务中执行的可调用对象 *args: 传递给 ``action`` 的位置参数 **kwargs: 传递给 ``action`` 的关键字参数 Returns ------- Any ``action`` 的返回值 """ if not callable(action): err = TypeError(f'action should be callable, got {type(action)} instead') raise err # 对 file/csv/hdf/fth 提供 no-op 包装,保持接口统一但不承诺原子性 if self.source_type != 'db': return action(*args, **kwargs) # 支持嵌套调用:若已有外层事务,直接复用外层事务上下文 if self._active_tx is not None: return action(*args, **kwargs) conn, cursor = self._db_open_connection() if conn is None or cursor is None: raise RuntimeError('failed to open db connection for transaction') self._active_tx = (conn, cursor) try: result = action(*args, **kwargs) conn.commit() return result except Exception: conn.rollback() raise finally: self._active_tx = None self._db_close_connection(conn, cursor)
def _read_database(self, db_table, share_like_pk=None, shares=None, date_like_pk=None, start=None, end=None): """ 从一张数据库表中读取数据,读取时根据share(ts_code)和dates筛选 具体筛选的字段通过share_like_pk和date_like_pk两个字段给出 Parameters ---------- db_table: str 需要读取数据的数据表 share_like_pk: str 用于筛选证券代码的字段名,不同的表中字段名可能不同,用这个字段筛选不同的证券、如股票、基金、指数等 当这个参数给出时,必须给出shares参数 shares: str, 如果给出shares,则按照"WHERE share_like_pk IN shares"筛选 date_like_pk: str 用于筛选日期的主键字段名,不同的表中字段名可能不同,用这个字段筛选需要的记录的时间段 当这个参数给出时,必须给出start和end参数 start: datetime like, 如果给出start同时又给出end,按照"WHERE date_like_pk BETWEEN start AND end"的条件筛选 end: datetime like, 当没有给出start时,单独给出end无效 Returns ------- DataFrame,从数据库中读取的DataFrame """ if not self._db_table_exists(db_table): return pd.DataFrame() ts_code_filter = '' has_ts_code_filter = False date_filter = '' has_date_filter = False if shares is not None: has_ts_code_filter = True share_count = len(shares) if share_count > 1: ts_code_filter = f'{share_like_pk} in {tuple(shares)}' else: ts_code_filter = f'{share_like_pk} = "{shares[0]}"' if (start is not None) and (end is not None): # assert start and end are date-like has_date_filter = True date_filter = f'{date_like_pk} BETWEEN "{start}" AND "{end}"' sql = f'SELECT * ' \ f'FROM {db_table}\n' if not (has_ts_code_filter or has_date_filter): # No WHERE clause pass elif has_ts_code_filter and has_date_filter: # both WHERE clause for ts_code and date sql += f'WHERE {ts_code_filter}' \ f' AND {date_filter}\n' elif has_ts_code_filter and not has_date_filter: # only one WHERE clause for ts_code sql += f'WHERE {ts_code_filter}\n' elif not has_ts_code_filter and has_date_filter: # only one WHERE clause for date sql += f'WHERE {date_filter}' sql += '' res, cursor = self._db_execute_one(sql, return_cursor=True) df = pd.DataFrame(res, columns=[i[0] for i in cursor.description]) return df def _write_database(self, df, db_table, primary_key): """ 将DataFrame中的数据添加到数据库表末尾,如果表不存在,则 新建一张数据库表,并设置primary_key(如果给出) 假定df的列与db_table的schema相同且顺序也相同 Parameter --------- df: pd.DataFrame 需要添加的DataFrame db_table: str 需要添加数据的数据库表 primary_key: tuple 数据表的primary_key,必须定义在数据表中 Returns ------- int: 返回写入的记录数 Note ---- 调用update_database()执行任务,设置参数ignore_duplicate=True """ # if table does not exist, create a new table without primary key info if not self._db_table_exists(db_table): dtype_mapping = {'object': 'varchar(255)', 'datetime64[ns]': 'datetime', 'int64': 'int', 'float32': 'float', 'float64': 'double', } columns = df.columns dtypes = df.dtypes.tolist() dtypes = [dtype_mapping.get(str(dtype.name), 'varchar(255)') for dtype in dtypes] # create a new table self._new_db_table( db_table=db_table, columns=columns, dtypes=dtypes, primary_key=primary_key) tbl_columns = tuple(self._get_db_table_schema(db_table).keys()) # TODO: # 实际上,下面的代码与update_database()中的代码几乎一样 # 应该将这一大坨代码抽象出来,作为一个单独的函数,统一调用 if (len(df.columns) != len(tbl_columns)) or (any(i_d != i_t for i_d, i_t in zip(df.columns, tbl_columns))): raise KeyError(f'df columns {df.columns.to_list()} does not fit table schema {list(tbl_columns)}') df = df.where(pd.notna(df), None) # where-fill None in dataframe result in filling np.nan since pandas v2.0 pd_version = pd.__version__ if pd_version >= '2.0': df.replace(np.nan, None, inplace=True) df_tuple = tuple(df.itertuples(index=False, name=None)) sql = f"INSERT IGNORE INTO " sql += f"`{db_table}` (" for col in tbl_columns[:-1]: sql += f"`{col}`, " sql += f"`{tbl_columns[-1]}`)\nVALUES\n(" for val in tbl_columns[:-1]: sql += "%s, " sql += "%s)\n" rows_affected = self._db_execute_many(sql, df_tuple) return rows_affected def _update_database(self, df, db_table, primary_key): """ 用DataFrame中的数据更新数据表中的数据记录 假定df的列与db_table的列相同且顺序也相同 在插入数据之前,必须确保表的primary_key已经正确设定 如果写入记录的键值存在冲突时,更新数据库中的记录 Parameters ---------- df: pd.DataFrame 用于更新数据表的数据DataFrame db_table: str 需要更新的数据表 primary_key: tuple 数据表的primary_key,必须定义在数据表中,如果数据库表没有primary_key,将append所有数据 Returns ------- int: rows affected """ tbl_columns = tuple(self._get_db_table_schema(db_table).keys()) update_cols = [item for item in tbl_columns if item not in primary_key] # 确保df的列与数据库表的列相同 if (len(df.columns) != len(tbl_columns)) or (any(i_d != i_t for i_d, i_t in zip(df.columns, tbl_columns))): raise KeyError(f'df columns {df.columns.to_list()} does not fit table schema {list(tbl_columns)}') df = df.where(pd.notna(df), None) # fill None in Dataframe will result in filling Nan since pandas v2.0 pd_version = pd.__version__ if pd_version >= '2.0': df.replace(np.nan, None, inplace=True) # TODO, 在某些情况下将数据写入数据库时仍然会发生'nan can't be written to mysql'的错误 # 这个问题需要进一步解,复现代码如下: # op=qt.Operator('dma') # op.run(mode=0, live_trade_account_id=1, asset_type='IDX') df_tuple = tuple(df.itertuples(index=False, name=None)) sql = f"INSERT INTO " sql += f"`{db_table}` (" for col in tbl_columns[:-1]: sql += f"`{col}`, " sql += f"`{tbl_columns[-1]}`)\nVALUES\n(" for val in tbl_columns[:-1]: sql += "%s, " sql += "%s)\n" \ "ON DUPLICATE KEY UPDATE\n" for col in update_cols[:-1]: sql += f"`{col}`=VALUES(`{col}`),\n" sql += f"`{update_cols[-1]}`=VALUES(`{update_cols[-1]}`)" rows_affected = self._db_execute_many(sql, df_tuple) return rows_affected def _delete_database_records(self, db_table, primary_key, record_ids): """ 从数据库表中删除数据 必须给出数据表的主键名,以及需要删除的记录的主键值 Parameters ---------- db_table: str 数据表名 primary_key: str 数据表的主键名称列表 record_ids: list of str or tuple of str 需要删除的记录的主键值 Returns ------- int: rows affected """ # 如果没有记录需要删除,则直接返回 if not record_ids: return 0 # 生成删除记录的SQL语句 sql = f"DELETE FROM `{db_table}` WHERE " # 设置删除的条件 if len(record_ids) > 1: sql += f"`{primary_key}` IN {tuple(record_ids)}" elif len(record_ids) == 1: sql += f"`{primary_key}` = {record_ids[0]}" rows_affected = self._db_execute_one(sql, fetch_and_return=False) return rows_affected def _get_db_table_coverage(self, db_table, column): """ 检查数据库表关键列的内容,去重后返回该列的内容清单 Parameters ---------- db_table: str 数据表名 column: str 数据表的字段名 Returns ------- """ import datetime if not self._db_table_exists(db_table): return list() sql = f'SELECT DISTINCT `{column}`' \ f'FROM `{db_table}`' \ f'ORDER BY `{column}`' res = self._db_execute_one(sql) res = [item[0] for item in res] if isinstance(res[0], datetime.datetime): res = list(pd.to_datetime(res).strftime('%Y%m%d')) return res def _get_db_table_minmax(self, db_table, column, with_count=False): """ 检查数据库表关键列的内容,获取最小值和最大值和总数量 Parameters ---------- db_table: str 数据表名 column: str 数据表的字段名 with_count: bool, default False 是否返回关键列值的数量,可能非常耗时 Returns ------- list: [min, max, count] """ import datetime if not self._db_table_exists(db_table): return list() if with_count: add_sql = f', COUNT(DISTINCT(`{column}`))' else: add_sql = '' sql = f'SELECT MIN(`{column}`), MAX(`{column}`){add_sql} ' sql += f'FROM `{db_table}`' res = self._db_execute_one(sql)[0] res = list(res) if isinstance(res[0], datetime.datetime): res = list(pd.to_datetime(res).strftime('%Y%m%d')) return res def _db_table_exists(self, db_table): """ 检查数据库中是否存在db_table这张表 Parameters ---------- db_table: str 数据表名 Returns ------- bool """ sql = f"SHOW TABLES LIKE '{db_table}'" res = self._db_execute_one(sql) if res is not None: return len(res) > 0 else: return False def _new_db_table(self, db_table, columns, dtypes, primary_key: [str], auto_increment_id: bool = False, index_col: [str] = None, partition_by: [str] = None, partitions: int = None) -> None: """ 在数据库中新建一个数据表(如果该表不存在),并且确保数据表的schema与设置相同, 并创建正确的index Parameters ---------- db_table: str 数据表名 columns: list of str 数据表的所有字段名 dtypes: list of str {'varchar', 'float', 'int', 'datetime', 'text'} 数据表所有字段的数据类型 primary_key: list of str 数据表的所有primary_key auto_increment_id: bool, Default: False 是否使用自增主键 index_col: list of str, Default: None 数据表的索引列 partition_by: list of str, Default: None 数据表的分区列 partitions: int, Default: None 数据表的分区数 Returns ------- None """ sql = f"CREATE TABLE IF NOT EXISTS `{db_table}` (\n" for col_name, dtype in zip(columns, dtypes): sql += f"`{col_name}` {dtype}" if col_name in primary_key: sql += " NOT NULL" sql += " AUTO_INCREMENT,\n" if auto_increment_id else ",\n" else: sql += " DEFAULT NULL,\n" # 如果有primary key则添加primary key if primary_key is not None: sql += f"PRIMARY KEY (`{'`, `'.join(primary_key)}`)" # 如果primary key多于一个,则创建KEY INDEX if len(primary_key) > 1: sql += ",\nKEY (`" + '`),\nKEY (`'.join(primary_key[1:]) + "`)" sql += '\n)' # 如果设置了partition则添加partition by KEY() if partition_by is not None: sql += f"PARTITION BY KEY(`{partition_by}`) PARTITIONS {partitions}" # 执行sql语句 self._db_execute_one(sql, fetch_and_return=False) # 如果设置了额外的index则添加index: # TODO: 使用KEY关键字添加额外的index # if index_col is not None: # sql = f"CREATE INDEX `{db_table}_idx` ON `{db_table}` (`{index_col}`)" # # # 执行sql语句 # self._db_execute_one(sql, fetch_and_return=False) # ============== # 特殊数据库操作层函数,当数据表结构发生变化时用于调整数据库表结构,建立索引或执行分区等操作 def _get_db_table_schema(self, db_table): """ 获取数据库表的列名称和数据类型 Parameters ---------- db_table: str 需要获取列名的数据库表 Returns ------- dict: 一个包含列名和数据类型的Dict: {column1: dtype1, column2: dtype2, ...} """ sql = f"SELECT COLUMN_NAME, DATA_TYPE " \ f"FROM INFORMATION_SCHEMA.COLUMNS " \ f"WHERE TABLE_SCHEMA = Database() " \ f"AND table_name = '{db_table}' " \ f"ORDER BY ordinal_position;" result = self._db_execute_one(sql) columns = {} if len(result[0]) == 0: return columns for col, typ in result: columns[col] = typ return columns def _drop_db_table(self, db_table): """ 修改优化db_table的schema,建立index,从而提升数据库的查询速度提升效能 Parameters ---------- db_table: str 数据表名 Returns ------- None """ sql = f"DROP TABLE IF EXISTS {db_table};" self._db_execute_one(sql, fetch_and_return=False) def _get_db_table_size(self, db_table): """ 获取数据库表的占用磁盘空间 Parameters ---------- db_table: str 数据库表名称 Returns ------- rows: int """ if not self._db_table_exists(db_table): return -1 sql = "SELECT table_rows, data_length + index_length " \ "FROM INFORMATION_SCHEMA.tables " \ "WHERE table_schema = %s " \ "AND table_name = %s;" rows, size = self._db_execute_one(sql, (self.db_name, db_table), fetch_and_return=True)[0] return rows, size def _alter_db_table(self, db_table, columns, dtypes, primary_key, auto_increment_id=False): """ 修改数据库表的schema,建立index,从而提升数据库的查询速度提升效能 """ raise NotImplementedError def _alter_db_table_index(self, db_table, index_name, columns, unique=False): """ 修改数据库表的schema,建立index,从而提升数据库的查询速度提升效能 """ raise NotImplementedError def _alter_db_table_partition(self, db_table, partition_name, partition_type, partition_key): """ 修改数据库表的schema,建立index,从而提升数据库的查询速度提升效能 """ raise NotImplementedError # ============== # (逻辑)数据表操作层函数,只在逻辑表层面读取或写入数据,调用文件操作函数或数据库函数存储数据
[文档] def table_data_exists(self, table): """ 逻辑层函数,判断数据表是否存在 Parameters ---------- table: 数据表名称 Returns ------- bool: True if table exists, False otherwise """ if self.source_type == 'db': return self._db_table_exists(db_table=table) elif self.source_type == 'file': return self._file_exists(table) else: raise KeyError(f'invalid source_type: {self.source_type}')
[文档] @lru_cache(maxsize=32) def read_cached_table_data( self, table: str, *, shares: str = None, start: str = None, end: str = None, primary_key_in_index: bool = True, ) -> pd.DataFrame: """ 缓存数据表数据以缩短读取速度,这个函数用于加快本地提取数据时加速 在用户使用DataType对象大量读取数据时,通常需要重复从同一张数据表中以同样的参数获取数据 为了提升读取速度,可以将数据表的数据缓存到内存中,以减少读取时间,但在正常的数据表操作中 并不适合使用缓存,因为数据表通常需要实时刷新,因此本函数仅供DataType对象读取数据使用 """ return self.read_table_data( table, shares=shares, start=start, end=end, primary_key_in_index=primary_key_in_index, )
[文档] def read_table_data(self, table, *, shares: Union[str, list] = None, start: str = None, end: str = None, primary_key_in_index: bool = True, ) -> pd.DataFrame: """ 从本地数据表中读取数据并返回DataFrame,不修改数据格式,primary_key为DataFrame的index 在读取数据表时读取所有的列,但是返回值筛选ts_code以及trade_date between start 和 end Parameters ---------- table: str 数据表名称 shares: str or list of str, ts_code筛选条件,逗号分隔字符串,为空时给出所有记录 start: str, YYYYMMDD格式日期,为空时不筛选 end: str, YYYYMMDD格式日期,当start不为空时有效,筛选日期范围 primary_key_in_index: bool, default True 是否将primary key设置为DataFrame的index 如果为False,primary key将作为普通的列返回,此时DataFrame的index为默认的整数index Returns ------- pd.DataFrame 返回数据表中的数据 """ if not isinstance(table, str): err = TypeError(f'table name should be a string, got {type(table)} instead.') raise err if table not in TABLE_MASTERS.keys(): raise KeyError(f'Invalid table name: {table}.') if shares is not None: if isinstance(shares, str): shares = str_to_list(shares) if (start is not None) and (end is not None): start = regulate_date_format(start) end = regulate_date_format(end) assert pd.to_datetime(start) <= pd.to_datetime(end) columns, dtypes, primary_key, pk_dtypes = get_built_in_table_schema(table) # 识别primary key中的证券代码列名和日期类型列名,确认是否需要筛选证券代码及日期 share_like_pk = None date_like_pk = None if shares is not None: try: varchar_like_dtype = [item for item in pk_dtypes if item[:7] == 'varchar'][0] share_like_pk = primary_key[pk_dtypes.index(varchar_like_dtype)] except: msg = f'can not find share-like primary key in the table {table}!\n' \ f'passed argument shares will be ignored!' warnings.warn(msg, RuntimeWarning, stacklevel=2) share_like_pk = None shares = None # 识别Primary key中的日期型字段,并确认是否需要筛选日期型pk if (start is not None) and (end is not None) and ("month" in primary_key): # case 1: 识别到了month字段,且start和end都不为空 date_like_pk = 'month' start = date_to_month_format(start) end = date_to_month_format(end) elif (start is not None) and (end is not None) and ("quarter" in primary_key): # case 2: 识别到了quarter字段,且start和end都不为空 date_like_pk = 'quarter' start = date_to_quarter_format(start) end = date_to_quarter_format(end) elif (start is not None) and (end is not None): # case 3: 未识别到quarter或者month字段,根据字段类型检查是否有date-like字段 try: date_like_dtype = [item for item in pk_dtypes if item in ['date', 'datetime']][0] date_like_pk = primary_key[pk_dtypes.index(date_like_dtype)] except Exception as e: msg = f'{e}\ncan not find date-like primary key in the table {table}!\n' \ f'passed start({start}) and end({end}) arguments will be ignored!' warnings.warn(msg, RuntimeWarning, stacklevel=2) date_like_pk = None start = None end = None if self.source_type == 'file': # 读取table数据, 从本地文件中读取的DataFrame已经设置好了primary_key index # 但是并未按shares和start/end进行筛选,需要手动筛选 df = self._read_file(file_name=table, primary_key=primary_key, pk_dtypes=pk_dtypes, share_like_pk=share_like_pk, shares=shares, date_like_pk=date_like_pk, start=start, end=end) if df.empty: return df # TODO: 这里对所有读取的文件都进行筛选,需要考虑是否在read_table_data还需要筛选? # 也就是说,在read_table_data级别筛选数据还是在read_file/read_database级别 # 筛选数据?目前看在read_file级别筛选即可,暂时注释掉这里的筛选代码,可以提高速度 # if share_like_pk is not None: # df = df.loc[df.index.isin(shares, level=share_like_pk)] # if date_like_pk is not None: # # 两种方法实现筛选,分别是df.query 以及 df.index.get_level_values() # # 第一种方法, df.query # # df = df.query(f"{date_like_pk} >= {start} and {date_like_pk} <= {end}") # # 第二种方法:df.index.get_level_values() # m1 = df.index.get_level_values(date_like_pk) >= start # m2 = df.index.get_level_values(date_like_pk) <= end # df = df[m1 & m2] elif self.source_type == 'db': # TODO: 下面这部分代码应该都不需要,如果数据库表不存在时就新建一张表,则跟文件操作 # 结果不一致,另外shares和start / end的判断跟前面重复了,且跟文件操作也不一样 # 暂时先注释掉,后续再根据实际情况修改 """ if not self._db_table_exists(db_table=table): # 如果数据库中不存在该表,则创建表 self._new_db_table(db_table=table, columns=columns, dtypes=dtypes, primary_key=primary_key) if share_like_pk is None: shares = None if date_like_pk is None: start = None end = None""" # 读取数据库表,从数据库表中读取的DataFrame并未设置primary_key index,因此 # 需要手动设置index,但是读取的数据已经按shares/start/end筛选,无需手动筛选 df = self._read_database(db_table=table, share_like_pk=share_like_pk, shares=shares, date_like_pk=date_like_pk, start=start, end=end) if df.empty: return df set_primary_key_index(df, primary_key, pk_dtypes) else: # for unexpected cases: err = TypeError(f'Invalid value DataSource.source_type: {self.source_type}') raise err if not primary_key_in_index: df = set_primary_key_frame(df, primary_key=primary_key, pk_dtypes=pk_dtypes) return df
[文档] def export_table_data(self, table, file_name=None, file_path=None, shares=None, start=None, end=None): """ 将数据表中的数据读取出来之后导出到一个文件中,便于用户使用过程中小规模转移数据或察看数据 使用这个函数时,用户可以不用理会数据源的类型,只需要指定数据表名称,以及筛选条件即可 导出的数据会被保存为csv文件,用户可以自行指定文件名以及文件存储路径,如果不指定文件名, 则默认使用数据表名称作为文件名,如果不指定文件存储路径,则默认使用当前工作目录作为 文件存储路径 Parameters ---------- table: str 数据表名称 file_name: str, optional 导出的文件名,如果不指定,则默认使用数据表名称作为文件名 file_path: str, optional 导出的文件存储路径,如果不指定,则默认使用当前工作目录作为文件存储路径 shares: list of str, optional ts_code筛选条件,为空时给出所有记录 start: DateTime like, optional YYYYMMDD格式日期,为空时不筛选 end: Datetime like,optional YYYYMMDD格式日期,当start不为空时有效,筛选日期范围 Returns ------- file_path_name: str 导出的文件的完整路径 """ # TODO: Implement this function: export_table_data # 如果table不合法,则抛出异常 table_master = get_table_master() non_sys_tables = table_master[table_master['table_usage'] != 'sys'].index.to_list() if table not in non_sys_tables: err = ValueError(f'Invalid table name: {table}!') raise err # 检查file_name是否合法 if file_name is None: file_name = table if file_path is None: file_path = os.getcwd() # 检查file_path_name是否存在,如果已经存在,则抛出异常(文件名经跨平台安全规范化) file_path_name = path.join(file_path, sanitize_filename(file_name)) if os.path.exists(file_path_name): err = FileExistsError(f'File {file_path_name} already exists!') raise err # 读取table数据 df = self.read_table_data(table=table, shares=shares, start=start, end=end) # 将数据写入文件 try: df.to_csv(file_path_name, encoding='utf-8') except Exception as e: err = RuntimeError(f'{e}, Failed to export table {table} to file {file_path_name}!') raise err return file_path_name
[文档] def write_table_data(self, df, table, on_duplicate='ignore'): """ 将df中的数据写入本地数据表(本地文件或数据库) 如果本地数据表不存在则新建数据表,如果本地数据表已经存在,则将df数据添加在本地表中 如果添加的数据主键与已有的数据相同,处理方式由on_duplicate参数确定 Parameters ---------- df: pd.DataFrame 一个数据表,数据表的列名应该与本地数据表定义一致 table: str 本地数据表名, on_duplicate: str 重复数据处理方式(仅当mode==db的时候有效) -ignore: 默认方式,将全部数据写入数据库表的末尾 -update: 将数据写入数据库表中,如果遇到重复的pk则修改表中的内容 Returns ------- int: 写入的数据条数 Notes ----- 注意!!不应直接使用该函数将数据写入本地数据库,因为写入的数据不会被检查 请使用update_table_data()来更新或写入数据到本地 """ assert isinstance(df, pd.DataFrame) if not isinstance(table, str): err = TypeError(f'table name should be a string, got {type(table)} instead.') raise err if table not in TABLE_MASTERS.keys(): raise KeyError(f'Invalid table name.') columns, dtypes, primary_key, pk_dtype = get_built_in_table_schema(table) rows_affected = 0 df = set_primary_key_frame(df, primary_key=primary_key, pk_dtypes=pk_dtype) if self.source_type == 'file': set_primary_key_index(df, primary_key=primary_key, pk_dtypes=pk_dtype) rows_affected = self._write_file(df, file_name=table) elif self.source_type == 'db': if not self._db_table_exists(table): self._new_db_table(db_table=table, columns=columns, dtypes=dtypes, primary_key=primary_key) if on_duplicate == 'ignore': rows_affected = self._write_database(df, db_table=table, primary_key=primary_key) elif on_duplicate == 'update': rows_affected = self._update_database(df, db_table=table, primary_key=primary_key) else: # for unexpected cases err = KeyError(f'Invalid process mode on duplication: {on_duplicate}') raise err self._table_list.add(table) return rows_affected
[文档] def update_table_data(self, table, df, merge_type='update') -> int: """ 检查输入的df,去掉不符合要求的列或行后,将数据合并到table中,包括以下步骤: 1,检查下载后的数据表的列名是否与数据表的定义相同,删除多余的列 2,如果datasource type是"db",删除下载数据中与本地数据重复的部分,仅保留新增数据 3,如果datasource type是"file",将下载的数据与本地数据合并并去重 返回处理完毕的dataFrame Parameters ---------- table: str, 数据表名,必须是database中定义的数据表 merge_type: str 指定如何合并下载数据和本地数据: - 'update': 默认值,如果下载数据与本地数据重复,用下载数据替代本地数据 - 'ignore' : 如果下载数据与本地数据重复,忽略重复部分 df: pd.DataFrame 通过传递一个DataFrame获取数据 如果数据获取渠道为"df",则必须给出此参数 Returns ------- int, 写入数据表中的数据的行数 """ if not isinstance(df, pd.DataFrame): err = TypeError(f'df should be a dataframe, got {type(df)} instead') raise err if not isinstance(merge_type, str): err = TypeError(f'merge type should be a string, got {type(merge_type)} instead.') raise err if merge_type not in ['ignore', 'update']: raise KeyError(f'Invalid merge type, should be either "ignore" or "update"') dnld_data = df if dnld_data.empty: return 0 table_columns, dtypes, primary_keys, pk_dtypes = get_built_in_table_schema(table) dnld_data = set_primary_key_frame(dnld_data, primary_key=primary_keys, pk_dtypes=pk_dtypes) dnld_columns = dnld_data.columns.to_list() # 如果table中的相当部分(25%)不能从df中找到,判断df与table完全不匹配,报错 # 否则判断df基本与table匹配,根据Constraints,添加缺少的列(通常为NULL列) missing_columns = [col for col in table_columns if col not in dnld_columns] if len(missing_columns) >= (len(table_columns) * 0.75): err = ValueError(f'there are too many missing columns in downloaded df, can not merge to local table:' f'table_columns:\n{[table_columns]}\n' f'downloaded:\n{[dnld_columns]}') raise err else: pass # 在后面调整列顺序时会同时添加缺的列并调整顺序 # 删除数据中过多的列,不允许出现缺少列 columns_to_drop = [col for col in dnld_columns if col not in table_columns] if len(columns_to_drop) > 0: dnld_data.drop(columns=columns_to_drop, inplace=True) # 确保df与table的column顺序一致 if len(missing_columns) > 0 or any(item_d != item_t for item_d, item_t in zip(dnld_columns, table_columns)): dnld_data = dnld_data.reindex(columns=table_columns, copy=False) if self.source_type == 'file': # 如果source_type == 'file',需要将下载的数据与本地数据合并,本地数据必须全部下载, # 数据量大后非常费时 # 因此本地文件系统承载的数据量非常有限 local_data = self.read_table_data(table) set_primary_key_index(dnld_data, primary_key=primary_keys, pk_dtypes=pk_dtypes) # 根据merge_type处理重叠部分: if merge_type == 'ignore': # 丢弃下载数据中的重叠部分 dnld_data = dnld_data[~dnld_data.index.isin(local_data.index)] elif merge_type == 'update': # 用下载数据中的重叠部分覆盖本地数据,下载数据不变,丢弃本地数据中的重叠部分(仅用于本地文件保存的情况) local_data = local_data[~local_data.index.isin(dnld_data.index)] else: # for unexpected cases raise KeyError(f'Invalid merge type, got "{merge_type}"') rows_affected = self.write_table_data(pd.concat([local_data, dnld_data]), table=table) elif self.source_type == 'db': rows_affected = self.write_table_data(df=dnld_data, table=table, on_duplicate=merge_type) else: # unexpected case raise KeyError(f'invalid data source type: {self.source_type}') return rows_affected
[文档] def drop_table_data(self, table): """ 删除本地存储的数据表(操作不可撤销,谨慎使用) 如果数据源设置了allow_drop_table为False,则无法删除数据表并报错 Parameters ---------- table: str, 本地数据表的名称 Returns ------- None Raises ------ RuntimeError: 当数据源设置了allow_drop_table为False时,无法删除数据表并报错 """ if not self.allow_drop_table: err = RuntimeError('Can\'t drop table from current datasource according to setting, please check: ' 'datasource.allow_drop_table') raise err if self.source_type == 'db': self._drop_db_table(db_table=table) elif self.source_type == 'file': self._drop_file(file_name=table) self._table_list.difference_update([table]) return None
[文档] def get_table_data_coverage(self, table, column, min_max_only=False): """ 获取本地数据表内容的覆盖范围,取出数据表的"column"列中的去重值并返回 Parameters ---------- table: str, 数据表的名称 column: str or list of str 需要去重并返回的数据列 min_max_only: bool, default False 为True时不需要返回整个数据列,仅返回最大值和最小值 如果仅返回最大值和和最小值,返回值为一个包含两个元素的列表, 第一个元素是最小值,第二个是最大值,第三个是总数量 Returns ------- List, 代表数据覆盖范围的列表 Examples -------- >>> import qteasy >>> qteasy.QT_DATA_SOURCE.get_table_data_coverage('stock_daily', 'ts_code') Out: ['000001.SZ', '000002.SZ', '000003.SZ', '000004.SZ', '000005.SZ', '000006.SZ', ..., '002407.SZ', '002408.SZ', '002409.SZ', '002410.SZ', '002411.SZ', ...] >>> import qteasy as qt >>> qt.QT_DATA_SOURCE.get_table_data_coverage('stock_daily', 'ts_code', min_max_only=True) Out: ['000001.SZ', '873593.BJ'] """ if self.source_type == 'db': if min_max_only: return self._get_db_table_minmax(table, column) else: return self._get_db_table_coverage(table, column) elif self.source_type == 'file': columns, dtypes, primary_keys, pk_dtypes = get_built_in_table_schema(table) return self._get_file_table_coverage(table, column, primary_keys, pk_dtypes, min_max_only) else: err = TypeError(f'Invalid source type: {self.source_type}') raise err
[文档] def get_data_table_size(self, table, human=True, string_form=True): """ 获取数据表占用磁盘空间的大小 Parameters ---------- table: str 数据表名称 human: bool, default True True时显示容易阅读的形式,如1.5MB而不是1590868, False时返回字节数 string_form: bool, default True True时以字符串形式返回结果,便于打印 Returns ------- tuple (size, rows): tuple of int or str: """ if self.source_type == 'file': size = self._get_file_size(table) rows = self._get_file_rows(table) # rows = 'unknown' elif self.source_type == 'db': rows, size = self._get_db_table_size(table) else: err = RuntimeError(f'unknown source type: {self.source_type}') raise err if size == -1: return 0, 0 if not string_form: return size, rows if human: return f'{human_file_size(size)}', f'{human_units(rows)}' else: return f'{size}', f'{rows}'
[文档] def get_table_info(self, table, verbose=True, print_info=True, human=True) -> dict: """ 获取并打印数据表的相关信息,包括数据表是否已有数据,数据量大小,占用磁盘空间、数据覆盖范围, 以及数据下载方法 Parameters: ----------- table: str 数据表名称 verbose: bool, Default: True 是否显示更多信息,如是,显示表结构等信息 print_info: bool, Default: True 是否打印输出所有结果 human: bool, Default: True 是否给出容易阅读的字符串形式 Returns ------- 一个dict,包含数据表的结构化信息: { table name: 1, str, 数据表名称 table_exists: 2, bool,数据表是否存在 table_size: 3, int/str,数据表占用磁盘空间,human 为True时返回容易阅读的字符串 table_rows: 4, int/str,数据表的行数,human 为True时返回容易阅读的字符串 primary_key1: 5, str,数据表第一个主键名称 pk_count1: 6, int,数据表第一个主键记录数量 pk_min1: 7, obj,数据表主键1起始记录 pk_max1: 8, obj,数据表主键2最终记录 primary_key2: 9, str,数据表第二个主键名称 pk_count2: 10, int,数据表第二个主键记录 pk_min2: 11, obj,数据表主键2起始记录 pk_max2: 12, obj,数据表主键2最终记录 } """ pk1 = None pk_records1 = None pk_min1 = None pk_max1 = None pk2 = None pk_records2 = None pk_min2 = None pk_max2 = None if not isinstance(table, str): err = TypeError(f'table should be name of a table, got {type(table)} instead') raise err if table not in TABLE_MASTERS: err = ValueError(f'in valid table name: {table}') raise err columns, dtypes, remarks, primary_keys, pk_dtypes = get_built_in_table_schema(table, with_remark=True, with_primary_keys=True) table_desc = TABLE_MASTERS[table][1] critical_key = TABLE_MASTERS[table][6] table_schema = pd.DataFrame({'columns': columns, 'dtypes': dtypes, 'remarks': remarks}) table_exists = self.table_data_exists(table) if print_info: if table_exists: table_size, table_rows = self.get_data_table_size(table, human=human) else: table_size, table_rows = '0 MB', '0' print(f'<{table}>--<{table_desc}>\n{table_size}/{table_rows} records on disc\n' f'primary keys: \n' f'----------------------------------------') else: if table_exists: table_size, table_rows = self.get_data_table_size(table, string_form=human, human=human) else: table_size, table_rows = 0, 0 pk_count = 0 for pk in primary_keys: pk_min_max_count = self.get_table_data_coverage(table, pk, min_max_only=True) pk_count += 1 critical = '' record_count = 'unknown' if len(pk_min_max_count) == 3: record_count = pk_min_max_count[2] if len(pk_min_max_count) == 0: pk_min_max_count = ['N/A', 'N/A'] if print_info: critical = "*" if pk == critical_key else " " print(f'{pk_count}: {critical} {pk}: <{record_count}> entries\n' f' starts:' f' {pk_min_max_count[0]}, end: {pk_min_max_count[1]}') if pk_count == 1: pk1 = pk pk_records1 = record_count pk_min1 = pk_min_max_count[0] pk_max1 = pk_min_max_count[1] elif pk_count == 2: pk2 = pk pk_records2 = record_count pk_min2 = pk_min_max_count[0] pk_max2 = pk_min_max_count[1] else: pass if verbose and print_info: print(f'\ncolumns of table:\n' f'------------------------------------\n' f'{table_schema}\n') _res = {'table': table, 'table_exists': table_exists, 'table_size': table_size, 'table_rows': table_rows, 'primary_key1': pk1, 'pk_records1': pk_records1, 'pk_min1': pk_min1, 'pk_max1': pk_max1, 'primary_key2': pk2, 'pk_records2': pk_records2, 'pk_min2': pk_min2, 'pk_max2': pk_max2 } return _res
# ============== # 系统操作表操作函数,专门用于操作sys_operations表,记录系统操作信息,数据格式简化 # ==============
[文档] def get_sys_table_last_id(self, table): """ 从已有的table中获取最后一个id Parameters ---------- table: str 数据表名称 Returns ------- last_id: int 当前使用的最后一个ID(自增ID) """ from .datatables import ensure_sys_table ensure_sys_table(table) # 如果是文件系统,在可行的情况下,直接从文件系统中获取最后一个id,否则读取文件数据后获取id if self.source_type in ['file']: df = self.read_sys_table_data(table) if df.empty: return 0 return int(df.index.max()) # 如果是数据库系统,直接获取最后一个id, 这种做法某些情况下有问,使用下面的方法无法获取最后一个id elif self.source_type == 'db': if not self._db_table_exists(table): columns, dtypes, prime_keys, pk_dtypes = get_built_in_table_schema(table) self._new_db_table(table, columns=columns, dtypes=dtypes, primary_key=prime_keys, auto_increment_id=True) return 0 columns, dtypes, primary_keys, pk_dtypes = get_built_in_table_schema(table, with_primary_keys=True) primary_key = primary_keys[0] sql = f"SELECT * FROM `{table}` ORDER BY `{primary_key}` DESC LIMIT 1;" res = self._db_execute_one(sql) if res is not None: return res[0][0] if len(res) > 0 else 0 else: # for other unexpected cases pass pass
[文档] def read_sys_table_data(self, table, **kwargs) -> pd.DataFrame: """读取系统操作表的数据,包括读取所有记录,以及根据给定的条件读取记录 返回的数据类型为pd.DataFrame,如果给出kwargs,返回根据条件筛选后的数据 Parameters ---------- table: str 需要读取的数据表名称 kwargs: dict 筛选数据的条件,包括用作筛选条件的字典如: {account_id = 123} Returns ------- pd.DataFrame: 返回的数据为DataFrame,如果给出kwargs,返回的数据仅包括筛选后的数据 """ from .datatables import ensure_sys_table ensure_sys_table(table) # 检查kwargs中是否有不可用的字段 columns, dtypes, p_keys, pk_dtypes = get_built_in_table_schema(table) if any(k not in columns for k in kwargs): err = KeyError(f'kwargs not valid: {[k for k in kwargs if k not in columns]}') raise err # 读取数据,如果给出id,则只读取一条数据,否则读取所有数据 if self.source_type == 'db': res_df = self._read_database(table) if res_df.empty: return res_df set_primary_key_index(res_df, primary_key=p_keys, pk_dtypes=pk_dtypes) elif self.source_type == 'file': res_df = self._read_file(table, p_keys, pk_dtypes) else: # for other unexpected cases return pd.DataFrame() if res_df.empty: return res_df # 筛选数据(物理表若尚未包含某列,视为无匹配行,避免 KeyError) for k, v in kwargs.items(): if k not in res_df.columns: return pd.DataFrame() res_df = res_df.loc[res_df[k] == v] return res_df.sort_index()
[文档] def read_sys_table_record(self, table, *, record_id: int, **kwargs) -> dict: """ 读取系统操作表的数据,根据指定的id读取数据,返回一个dict 本函数调用read_sys_table_data()读取整个数据表,并返回record_id行的数据 返回的dict包含所有字段的值,key为字段名,value为字段值 Parameters ---------- table: str 需要读取的数据表名称 record_id: int 需要读取的数据的id kwargs: dict 筛选数据的条件,包括用作筛选条件的字典如: account_id = 123 Returns ------- data: dict 读取的数据,包括数据表的结构化信息以及数据表中的记录 """ # 检查record_id是否合法 if record_id is not None and record_id <= 0: return {} data = self.read_sys_table_data(table, **kwargs) if data.empty: return {} if record_id not in data.index: return {} return data.loc[record_id].to_dict()
[文档] def update_sys_table_data(self, table: str, record_id: int, **data) -> int: """ 更新系统操作表的数据,根据指定的id更新数据,更新的内容由kwargs给出。 每次只能更新一条数据,数据以dict形式给出 可以更新一个或多个字段,如果给出的字段不存在,则抛出异,id不可更新。 id必须存在,否则抛出异常 Parameters ---------- table: str 需要更新的数据表名称 record_id: int 需要更新的数据的id data: dict 需要更新的数据,包括需要更新的字段如: account_id = 123 Returns ------- id: int 更新的记录ID Raises ------ KeyError: 当给出的id不存在或为None时 KeyError: 当给出的字段不存在时 """ from .datatables import ensure_sys_table ensure_sys_table(table) # TODO: 为了提高开发速度,使用self.update_table_data(),后续需要重构代码 # 用下面的思路重构代码,提高运行效率 """ # 检察数据,如果**kwargs中有不可用的字段,则抛出异常,如果kwargs为空,则返回None # 判断id是否存在范围内,如果id超出范围,则抛出异常 # 写入数据,如果是文件系统,读取文件,更新数据,然后写入文件,如果是数据库,直接用SQL更新数据库 if self.source_type == 'file': pass elif self.source_type == 'db': pass else: # for other unexpected cases pass pass """ # 将data构造为一个df,然后调用self.update_table_data() table_data = self.read_sys_table_record(table, record_id=record_id) if table_data == {}: raise KeyError(f'record_id({record_id}) not found in table {table}') # 当data中有不可用的字段时,会抛出异常 columns, dtypes, p_keys, pk_dtypes = get_built_in_table_schema(table) data_columns = [col for col in columns if col not in p_keys] if any(k not in data_columns for k in data.keys()): raise KeyError(f'kwargs not valid: {[k for k in data.keys() if k not in data_columns]}') # 更新original_data table_data.update(data) df_data = pd.DataFrame(table_data, index=[record_id]) df_data.index.name = p_keys[0] self.update_table_data(table, df_data, merge_type='update') return record_id
[文档] def insert_sys_table_data(self, table: str, **data) -> int: """ 插入系统操作表的数据 一次插入一条记录,数据以dict形式给出 不需要给出数据的ID,因为ID会自动生成 如果给出的数据字段不完整,则抛出异常 如果给出的数据中有不可用的字段,则抛出异常 Parameters ---------- table: str 需要更新的数据表名称 data: dict 需要更新或插入的数据,数据的key必须与数据库表的字段相同,否则会抛出异常 Returns ------- record_id: int 更新的记录ID Raises ------ KeyError: 当给出的字段不完整或者有不可用的字段时 """ from .datatables import ensure_sys_table ensure_sys_table(table) # TODO: 为了缩短开发时间,先暂时调用self.update_table_data(),后续需要重构 # 按照下面的思路重构简化代码: """ # 检察数据,如果data中有不可用的字段,则抛出异常,如果data为空,则返回None if not isinstance(data, dict): err = TypeError(f'Input data must be a dict, but got {type(data)}') raise err if not data: return None columns, dtypes, p_keys, pk_dtypes = get_built_in_table_schema(table) values = list(data.values()) # 检查data的key是否与column完全一致,如果不一致,则抛出异常 if list(data.keys() != columns): raise KeyError(f'Input data keys must be the same as the table columns, ' f'got {list(data.keys())} vs {columns}') # 写入数据,如果是文件系统,对可行的文件类型直接写入文件,否则读取文件,插入数据后再写入文件,如果是数据库,直接用SQL更新数据库 if self.source_type == 'file': # 获取最后一个ID,然后+1,作为新的ID(仅当source_type==file时,数据库可以自动生成ID) last_id = self.get_last_id(table) new_id = last_id + 1 if last_id is not None else 1 pass elif self.source_type == 'db': # 使用SQL插入一条数据到数据库 db_table = table if not self._db_table_exists(db_table=table): # 如果数据库中不存在该表,则创建表 self._new_db_table(db_table=table, columns=columns, dtypes=dtypes, primary_key=primary_key) # 生成sql语句 sql = f"INSERT INTO `{db_table}` (" for col in columns[:-1]: sql += f"`{col}`, " sql += f"`{columns[-1]}`)\nVALUES\n(" for val in values[:-1]: sql += f"{val}, " sql += f"{values[-1]})\n" try: self.conn.execute(sql) self.conn.commit() except Exception as e: err = RuntimeError(f'{e}, An error occurred when insert data into table {table} with sql:\n{sql}') raise err else: # for other unexpected cases pass last_id = self.get_last_id(table) return last_id """ # 将data构造为一个df,然后调用self.update_table_data() last_id = self.get_sys_table_last_id(table) record_id = last_id + 1 if last_id is not None else 1 columns, dtypes, primary_keys, pk_dtypes = get_built_in_table_schema(table) data_columns = [col for col in columns if col not in primary_keys] # 检查data的key是否与data_column完全一致,如果不一致,则抛出异常 if any(k not in data_columns for k in data.keys()) or any(k not in data.keys() for k in data_columns): err = KeyError(f'Input data keys must be the same as the table data columns, ' f'got {list(data.keys())} vs {data_columns}') raise err df = pd.DataFrame(data, index=[record_id], columns=data.keys()) df = df.reindex(columns=columns) df.index.name = primary_keys[0] # 插入数据 self.update_table_data(table, df, merge_type='ignore') # TODO: 这里为什么要用'ignore'而不是'update'? 现在改为'update', # test_database和test_trading测试都能通过,后续完整测试 return record_id
[文档] def delete_sys_table_data(self, table: str, record_ids: (list, tuple)) -> int: """ 删除系统数据表中的某些记录,被删除的记录的ID使用列表或tuple传入 parameters ---------- table: str 需要删除数据的表名 record_ids: list of int or tuple of int 需要删除的记录的ID列表 Returns ------- int 删除的记录数量 """ from .datatables import ensure_sys_table # 如果不是system table,直接返回0 try: ensure_sys_table(table) except KeyError: f'{table} is not valid table, can not delete records from it.' return 0 except TypeError: f'{table} is not a system table, can not delete records from it.' return 0 except Exception as e: f'An error occurred when checking table {table}: {e}' return 0 # 检查record_ids是否合法 if not isinstance(record_ids, (list, tuple)): err = TypeError(f'record_ids should be a list or tuple, got {type(record_ids)} instead') raise err if not all(isinstance(rid, int) for rid in record_ids): err = TypeError(f'all record_ids should be int, got {[type(rid) for rid in record_ids]} instead') raise err columns, dtypes, primary_keys, pk_dtypes = get_built_in_table_schema(table, with_primary_keys=True) primary_key = primary_keys[0] if self.source_type == 'db': res = self._delete_database_records(table, primary_key=primary_key, record_ids=record_ids) elif self.source_type == 'file': res = self._delete_file_records(table, primary_key=primary_key, record_ids=record_ids) else: err = RuntimeError(f'invalid source type: {self.source_type}') raise err return res
# ============== # 特殊函数,一些有用的API # ==============
[文档] def get_all_basic_table_data(self, refresh_cache=False, raise_error=True): """ 一个快速获取所有basic数据表的函数,通常情况缓存处理以加快速度 如果设置refresh_cache为True,则清空缓存并重新下载数据 Parameters ---------- refresh_cache: Bool, Default False 如果为True,则清空缓存并重新下载数据 raise_error: Bool, Default True 如果为True,则在数据表为空时抛出ValueError Returns ------- DataFrame """ if refresh_cache: self._get_all_basic_table_data.cache_clear() return self._get_all_basic_table_data(raise_error=raise_error)
@lru_cache(maxsize=1) def _get_all_basic_table_data(self, raise_error=True): """ 获取所有basic数据表 Parameters ---------- raise_error: Bool, Default True 如果为True,则在数据表为空时抛出ValueError Returns ------- tuple of DataFrames: df_s: stock_basic df_i: index_basic df_f: fund_basic df_ft: future_basic df_o: opt_basic Raises ------ ValueError 如果任意一个数据表为空,则抛出ValueError """ df_s = self.read_table_data('stock_basic') if df_s.empty and raise_error: err = ValueError('stock_basic table is empty, please refill data source with ' '"qt.refill_data_source(tables="stock_basic")"') raise err df_i = self.read_table_data('index_basic') if df_i.empty and raise_error: err = ValueError('index_basic table is empty, please refill data source with ' '"qt.refill_data_source(tables="index_basic")"') raise err df_f = self.read_table_data('fund_basic') if df_f.empty and raise_error: err = ValueError('fund_basic table is empty, please refill data source with ' '"qt.refill_data_source(tables="fund_basic")"') raise err df_ft = self.read_table_data('future_basic') if df_ft.empty and raise_error: err = ValueError('future_basic table is empty, please refill data source with ' '"qt.refill_data_source(tables="future_basic")"') raise err df_o = self.read_table_data('opt_basic') if df_o.empty and raise_error: err = ValueError('opt_basic table is empty, please refill data source with ' '"qt.refill_data_source(tables="opt_basic")"') raise err df_ths = self.read_table_data('ths_index_basic') if df_ths.empty and raise_error: err = ValueError('ths_index_basic table is empty, please refill data source with ' '"qt.refill_data_source(tables="ths_index_basic")"') raise err return df_s, df_i, df_f, df_ft, df_o, df_ths
[文档] def drop_empty_tables(self) -> int: """ 从datasource中删除所有空表,即行数为0的表 Returns ------- int 删除的表数量 """ dropped_count = 0 for table in self.all_tables: table_info = self.get_table_info(table, print_info=False) table_rows = table_info['table_rows'] if table_rows == 0: self.drop_table_data(table) dropped_count += 1 return dropped_count
# TODO: 在新的架构下,似乎不再需要这个函数,可以考虑删除
[文档] def reconnect(self): """ 当数据库超时或其他原因丢失连接时,Ping数据库检查状态, 如果可行的话,重新连接数据库 Returns ------- True: 连接成功 False: 连接失败 """ if self.source_type != 'db': return True pass