Source code for betalens.datafeed.core

#%%By Janis 250226
'''
## 研究数据库结构(摘要)

### 表1 个券行情(日频)
- 列:入库实际时间(=最早可交易时间)、windcode、中文名、数据性质(收盘价/成交量等)、数值、备注(json)
- 规则:开盘价最早 09:30;其余价量最早 15:00 可确定
- 处理:计算日频收益率(如 close-to-close)用于回测与挖掘

### 表2 个券基本面(日频入库,事件驱动)
- 列:入库实际时间(=最早可交易时间)、数据理论发生时间(报告期,如 0331/0630/0930/1231)、windcode、中文名、数据性质(如归母净利润)、数值、备注(json)
- 规则:按公告时点入库,关注盘前/盘中/盘后;理论发生时间仅用于展示,不做因果外推
- 年报次序不一:优先一致预期或线性外推作占位,尽量避免非原生数据入库

### 表3 宏观经济(事件驱动)
- 列:入库实际时间(=最早可交易时间)、数据理论发生时间(如“1月GDP”)、windcode、中文名、数据性质(可含均线算子)、数值、备注(json)
- 规则:与表2一致(事件驱动、区分公告时点与发生时点)

### 表4 因子库
- 列:入库实际时间(=最早可交易时间)、数据编制方式、数值、备注(json)

## 投资数据库结构(摘要)
- 不保留完整历史
- 从研究数据库按需拉取最近滚动窗口数据;其余数据在线拉取
- 在线数据仅记录入库时间(可交易可用时点)

## 投资数据库结构:
- 出于投资目的,实际不需要历史数据
- 从研究数据库拉取所需最近一个滚动窗口内的数据,其余全部在线拉取,并直接记录入库时间

更新日志:
2025-10-31: 重构datafeed.py,新增工具模块
- excel.py: Excel文件处理模块
- validation.py: 数据验证和异常检查工具
- query.py: 数据库查询功能重构
- integration.py: 数据库与Excel交互功能
- wind_ingest.py: Wind数据抓取模块

2025-11-03: 简化core.py为薄封装层
- 移除所有业务逻辑到子模块
- core.py仅保留连接管理和函数组合
- 所有数据转换、插入逻辑位于子模块
'''
import numpy as np
import pandas as pd
import psycopg2
import psycopg2.extras
import os
import itertools
from psycopg2.extras import execute_values
from functools import wraps
import time
import logging
from pathlib import Path
from datetime import datetime
import warnings

# 导入新的工具模块(函数式)
from .excel import read_file, cross_section_to_db_format, check_excel_errors, apply_time_alignment
from .validation import validate_and_fix
from .query import (
    query_nearest_after as _query_nearest_after,
    query_nearest_before as _query_nearest_before,
    query_nearest_in_range_after as _query_nearest_in_range_after,
    query_nearest_in_range_before as _query_nearest_in_range_before,
    query_time_range as _query_time_range,
    get_available_dates as _get_available_dates,
    get_latest_date as _get_latest_date,
    build_query as _build_query
)
from .integration import (
    process_directory_tree as _process_directory_tree,
    incremental_insert as _incremental_insert,
    insert_dataframe as _insert_dataframe,
    process_excel_to_db_format as _process_excel_to_db_format
)
from .wind_ingest import (
    fetch_daily_market as _fetch_daily_market,
    fetch_daily_index as _fetch_daily_index,
    fetch_daily_fund as _fetch_daily_fund,
    fetch_daily_bond as _fetch_daily_bond
)
from .ede_processor import (
    process_ede_file as _process_ede_file
)
from .config import get_database_config, get_logging_config

[docs] def func_timer(function): ''' 用装饰器实现函数计时 :param function: 需要计时的函数 :return: None ''' @wraps(function) def function_timer(*args, **kwargs): print('[Function: {name} start...]'.format(name = function.__name__)) t0 = time.time() result = function(*args, **kwargs) t1 = time.time() print('[Function: {name} finished, spent time: {time:.2f}s]'.format(name = function.__name__,time = t1 - t0)) return result return function_timer
[docs] class Datafeed(): """ 数据管理主类(薄封装层) 职责: - 管理数据库连接 - 提供统一的日志记录 - 组合子模块功能为高层API - 不包含业务逻辑(所有逻辑在子模块中) 使用示例: # 创建实例(使用默认配置) df = Datafeed("daily_market_data") # 创建实例(自定义数据库配置) df = Datafeed( table_name="daily_market_data", db_config={ 'dbname': 'my_database', 'user': 'my_user', 'password': 'my_password', 'host': 'localhost', 'port': '5432' } ) # 单文件插入 df.insert_csv_file("data.csv", config={...}) # Wind数据抓取 df.ingest_wind_daily_market( codes=['000001.SZ'], start_date='2024-01-01', end_date='2024-01-31' ) # 查询 data = df.query_time_range( codes=['000001.SZ'], start_date='2024-01-01', metric='收盘价(元)' ) # 关闭 df.close() """ _initialized = False
[docs] def __init__( self, table_name, db_config=None, log_dir=None ): """ 初始化Datafeed实例 Args: table_name: 数据库表名 db_config: 数据库配置字典,包含以下键: - dbname: 数据库名 - user: 用户名 - password: 密码 - host: 主机地址 - port: 端口 如果为None,使用config.json中的配置 log_dir: 日志目录,如果为None,使用config.json中的配置 """ # 从配置文件获取默认配置 default_db_config = get_database_config() default_log_config = get_logging_config() # 合并用户配置和默认配置 if db_config is None: db_config = default_db_config else: # 使用用户提供的配置,未提供的使用默认值 db_config = {**default_db_config, **db_config} # 设置日志目录 if log_dir is None: log_dir = default_log_config.get('log_dir', './logs') # 建立数据库连接 self.conn = psycopg2.connect( dbname=db_config['dbname'], user=db_config['user'], password=db_config['password'], host=db_config['host'], port=db_config['port'] ) self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) self.sheet = table_name # 设置logger(用于函数式调用) log_dir_path = Path(log_dir) log_dir_path.mkdir(parents=True, exist_ok=True) log_file = log_dir_path / f"datafeed_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" self.logger = logging.getLogger('Datafeed') self.logger.setLevel(logging.INFO) if not self.logger.handlers: fh = logging.FileHandler(log_file, encoding='utf-8') fh.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.WARNING) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) fh.setFormatter(formatter) ch.setFormatter(formatter) self.logger.addHandler(fh) self.logger.addHandler(ch) self.__class__._initialized = True
# ========== 插入功能 ==========
[docs] def insert_csv_file( self, filepath: str, config: dict, mode: str = 'incremental' ) -> dict: """ 单文件CSV插入(薄封装) 组合流程: 1. excel.read_file - 读取文件 2. excel.cross_section_to_db_format - 转换格式(如果需要) 3. validation.validate_and_fix - 验证和修复(如果配置) 4. integration.insert_dataframe/incremental_insert - 插入 Args: filepath: 文件路径 config: 配置字典,包含: - key_columns: 键列列表(用于cross_section转换) - value_columns: 值列列表(可选,自动推断) - key_value_mapping: 列名映射 - additional_fields: 额外字段 - validation: 验证配置 - apply_time_alignment: 是否应用时间对齐(开盘09:30,其他15:00) mode: 'insert'(直接插入)或'incremental'(增量插入,默认) Returns: 统计信息字典 """ self.logger.info(f"开始处理文件: {filepath}, 模式: {mode}") # 1. 处理Excel为DB格式 df, errors = _process_excel_to_db_format( filepath=filepath, config=config, logger=self.logger ) if df is None or errors: self.logger.error(f"文件处理失败: {errors}") return {'success': False, 'errors': errors} # 2. 可选:应用时间对齐 if config.get('apply_time_alignment', False): df = apply_time_alignment( df, date_column=config.get('date_column', 'datetime'), metric_column=config.get('metric_column', 'metric'), logger=self.logger ) # 3. 插入数据库 if mode == 'insert': success, message, stats = _insert_dataframe( cursor=self.cursor, conn=self.conn, df=df, table=self.sheet, logger=self.logger ) return {'success': success, 'message': message, 'stats': stats} elif mode == 'incremental': new_rows, skipped_rows = _incremental_insert( cursor=self.cursor, conn=self.conn, df=df, table=self.sheet, logger=self.logger ) return { 'success': True, 'new_rows': new_rows, 'skipped_rows': skipped_rows } else: raise ValueError(f"未知的插入模式: {mode}")
[docs] def insert_ede_file( self, filepath: str, date_from: str = 'filename', default_datetime: str = None, mode: str = 'incremental' ) -> dict: """ 处理并插入EDE格式的Excel文件(薄封装) EDE格式特征: - 第一列:证券代码 -第二列:证券简称 - 第三列及之后:指标列,格式为"指标名 [元数据] 值类型 [元数据] 单位" 示例EDE格式: 证券代码 证券简称 流通A股 [交易日期] 最新 [单位] 股 002460.SZ 赣锋锂业 1,211,379,763.0000 1772.HK 赣锋锂业 1,211,379,763.0000 处理流程: 1. 读取Excel文件 2. 清理数据(去除空值、"数据来源:Wind"等) 3. 识别code、name列 4. 解析metric列,提取指标名称和元数据 5. 构建日期列(从文件名或列名中提取) 6. 转换为数据库格式 7. 插入数据库 Args: filepath: EDE格式Excel文件路径 date_from: 日期来源,可选值: - 'filename': 从文件名提取日期(如EDE20251103.xlsx -> 2025-11-03 15:30:00) - 'metric': 从列名中的[日期]部分提取 default_datetime: 默认日期时间(当无法从文件名或列名提取时使用) 格式:'YYYY-MM-DD HH:MM:SS',如'2025-11-03 15:30:00' mode: 插入模式 - 'incremental': 增量插入(默认),只插入新数据 - 'insert': 直接插入,会检查重复并跳过 Returns: 统计信息字典,包含: - success: 是否成功 - new_rows: 新增行数 - skipped_rows: 跳过行数 - errors: 错误列表(如果有) Example: >>> df = Datafeed('daily_market_data') >>> result = df.insert_ede_file( ... 'EDE20251103.xlsx', ... date_from='filename', ... mode='incremental' ... ) >>> print(f"新增{result['new_rows']}行,跳过{result['skipped_rows']}行") """ self.logger.info(f"开始处理EDE文件: {filepath}, 日期来源: {date_from}, 模式: {mode}") # 1. 处理EDE文件为DB格式 df, errors = _process_ede_file( filepath=filepath, date_from=date_from, default_datetime=default_datetime, logger=self.logger ) if df is None or errors: self.logger.error(f"EDE文件处理失败: {errors}") return {'success': False, 'errors': errors} self.logger.info(f"EDE文件转换完成: {len(df)} 行数据") # 2. 插入数据库 if mode == 'insert': success, message, stats = _insert_dataframe( cursor=self.cursor, conn=self.conn, df=df, table=self.sheet, logger=self.logger ) return { 'success': success, 'message': message, 'new_rows': stats.get('inserted_rows', 0), 'skipped_rows': stats.get('duplicate_rows', 0), 'stats': stats } elif mode == 'incremental': new_rows, skipped_rows = _incremental_insert( cursor=self.cursor, conn=self.conn, df=df, table=self.sheet, logger=self.logger ) return { 'success': True, 'new_rows': new_rows, 'skipped_rows': skipped_rows } else: raise ValueError(f"未知的插入模式: {mode}")
[docs] def batch_process_excel_files( self, folder_path: str, config: dict, file_pattern: str = "*.csv", recursive: bool = True, mode: str = 'insert' ): """ 批量处理Excel文件并插入数据库(薄封装) 直接调用 integration.process_directory_tree Args: folder_path: 文件夹路径 config: 处理配置 file_pattern: 文件匹配模式 recursive: 是否递归搜索 mode: 插入模式,'insert'或'incremental' Returns: 处理统计字典 """ return _process_directory_tree( cursor=self.cursor, conn=self.conn, root_dir=folder_path, table=self.sheet, config=config, file_pattern=file_pattern, recursive=recursive, mode=mode, logger=self.logger )
[docs] def incremental_update( self, df: pd.DataFrame, date_column: str = 'datetime', code_column: str = 'code', metric_column: str = 'metric' ): """ 增量更新数据到数据库(薄封装) 直接调用 integration.incremental_insert Args: df: 待更新的DataFrame date_column: 日期列名 code_column: 代码列名 metric_column: 指标列名 Returns: (新增行数, 重复行数) """ new_rows, skipped_rows = _incremental_insert( cursor=self.cursor, conn=self.conn, df=df, table=self.sheet, date_column=date_column, code_column=code_column, metric_column=metric_column, logger=self.logger ) return new_rows, skipped_rows
[docs] def insert_with_conflict_check( self, df: pd.DataFrame, date_column: str = 'datetime', code_column: str = 'code', metric_column: str = 'metric' ): """ 插入数据时检测重复和冲突(批量查询优化版本) - 如果key(datetime, code, metric)相同但value不同:记录冲突,不插入 - 如果key和value完全相同:跳过,不插入 - 如果key不存在:插入新记录 Args: df: 待插入的DataFrame,需包含columns: datetime, code, name, metric, value date_column: 日期列名 code_column: 代码列名 metric_column: 指标列名 Returns: (新增行数, 跳过行数, 冲突列表) 冲突列表格式:[{datetime, code, name, metric, db_value, new_value}, ...] """ # 必需列 required_cols = [date_column, code_column, 'name', metric_column, 'value'] for col in required_cols: if col not in df.columns: raise ValueError(f"DataFrame缺少必需列: {col}") if df.empty: return 0, 0, [] # 1. 批量查询所有key的现有记录 keys = [ (row[date_column], row[code_column], row[metric_column]) for _, row in df.iterrows() ] query = f""" SELECT t.dt, t.cd, t.mt, d.value, d.name FROM (VALUES %s) AS t(dt, cd, mt) LEFT JOIN {self.sheet} d ON d.{date_column} = t.dt AND d.{code_column} = t.cd AND d.{metric_column} = t.mt """ execute_values(self.cursor, query, keys, template="(%s, %s, %s)", page_size=1000) results = self.cursor.fetchall() # 构建字典:key -> (value, name) existing_records = {} for r in results: key = (r['dt'], r['cd'], r['mt']) existing_records[key] = {'value': r['value'], 'name': r['name']} # 2. 在Python中比对差异 rows_to_insert = [] conflicts = [] skipped_rows = 0 for _, row in df.iterrows(): datetime_val = row[date_column] code_val = row[code_column] name_val = row['name'] metric_val = row[metric_column] value_val = row['value'] key = (datetime_val, code_val, metric_val) existing = existing_records.get(key) if existing is None or existing['value'] is None: # key不存在,准备插入 rows_to_insert.append(row) else: # key存在,比较value db_value = existing['value'] # 比较value(考虑浮点数精度) try: if isinstance(db_value, (int, float)) and isinstance(value_val, (int, float)): value_equal = abs(float(db_value) - float(value_val)) < 1e-9 else: value_equal = (db_value == value_val) except (ValueError, TypeError): value_equal = (db_value == value_val) if not value_equal: # 冲突:key相同但value不同 conflicts.append({ 'datetime': datetime_val, 'code': code_val, 'name': name_val, 'metric': metric_val, 'db_value': db_value, 'new_value': value_val }) skipped_rows += 1 else: # 完全重复,跳过 skipped_rows += 1 # 3. 批量插入新记录 new_rows = len(rows_to_insert) if rows_to_insert: insert_query = f""" INSERT INTO {self.sheet} ({date_column}, {code_column}, name, {metric_column}, value) VALUES %s """ values = [ (row[date_column], row[code_column], row['name'], row[metric_column], row['value']) for row in rows_to_insert ] execute_values(self.cursor, insert_query, values, page_size=1000) self.conn.commit() self.logger.info(f"插入 {new_rows} 行新数据") if conflicts: self.logger.warning(f"发现 {len(conflicts)} 条冲突记录") if skipped_rows: self.logger.info(f"跳过 {skipped_rows} 行重复/冲突数据") return new_rows, skipped_rows, conflicts
[docs] def update_data( self, df: pd.DataFrame, date_column: str = 'datetime', code_column: str = 'code', metric_column: str = 'metric' ): """ 使用SQL UPDATE批量更新数据 对于存在的记录(相同datetime, code, metric),更新value和name 对于不存在的记录,可选择插入(upsert模式) Args: df: 待更新的DataFrame,需包含columns: datetime, code, name, metric, value date_column: 日期列名 code_column: 代码列名 metric_column: 指标列名 Returns: 更新行数 """ # 必需列 required_cols = [date_column, code_column, 'name', metric_column, 'value'] for col in required_cols: if col not in df.columns: raise ValueError(f"DataFrame缺少必需列: {col}") updated_rows = 0 # 使用 UPSERT (INSERT ... ON CONFLICT ... DO UPDATE) upsert_query = f""" INSERT INTO {self.sheet} ({date_column}, {code_column}, name, {metric_column}, value) VALUES %s ON CONFLICT ({date_column}, {code_column}, {metric_column}) DO UPDATE SET value = EXCLUDED.value, name = EXCLUDED.name """ values = [ (row[date_column], row[code_column], row['name'], row[metric_column], row['value']) for _, row in df.iterrows() ] execute_values(self.cursor, upsert_query, values) updated_rows = self.cursor.rowcount self.conn.commit() self.logger.info(f"更新/插入 {updated_rows} 行数据") return updated_rows
# ========== Wind数据抓取 ==========
[docs] def ingest_wind_daily_market( self, codes: list, start_date: str, end_date: str, fields: list = None, asset_type: str = 'stock', mode: str = 'incremental' ) -> dict: """ 从Wind获取日行情并插入数据库(薄封装) 组合流程: 1. wind_ingest.fetch_daily_market - 获取Wind数据 2. integration.incremental_insert/insert_dataframe - 插入 Args: codes: 代码列表 start_date: 开始日期,格式'YYYY-MM-DD' end_date: 结束日期 fields: 字段列表,None使用默认字段 asset_type: 资产类型,'stock', 'index', 'fund', 'bond' mode: 插入模式,'incremental'(默认)或'insert' Returns: 统计信息字典 """ self.logger.info( f"开始从Wind获取数据: codes={len(codes)}, " f"date_range={start_date}~{end_date}, asset_type={asset_type}" ) # 1. 从Wind获取数据 df = _fetch_daily_market( codes=codes, start_date=start_date, end_date=end_date, fields=fields, asset_type=asset_type, apply_time_stamps=True, logger=self.logger ) if df.empty: self.logger.warning("未获取到任何数据") return {'success': False, 'message': 'No data fetched'} # 2. 插入数据库 if mode == 'incremental': new_rows, skipped_rows = _incremental_insert( cursor=self.cursor, conn=self.conn, df=df, table=self.sheet, logger=self.logger ) return { 'success': True, 'new_rows': new_rows, 'skipped_rows': skipped_rows, 'total_fetched': len(df) } else: success, message, stats = _insert_dataframe( cursor=self.cursor, conn=self.conn, df=df, table=self.sheet, logger=self.logger ) return {'success': success, 'message': message, 'stats': stats}
[docs] def ingest_wind_daily_index(self, codes: list, start_date: str, end_date: str, fields: list = None, mode: str = 'incremental') -> dict: """Wind指数数据抓取(便捷封装)""" return self.ingest_wind_daily_market(codes, start_date, end_date, fields, 'index', mode)
[docs] def ingest_wind_daily_fund(self, codes: list, start_date: str, end_date: str, fields: list = None, mode: str = 'incremental') -> dict: """Wind基金数据抓取(便捷封装)""" return self.ingest_wind_daily_market(codes, start_date, end_date, fields, 'fund', mode)
[docs] def ingest_wind_daily_bond(self, codes: list, start_date: str, end_date: str, fields: list = None, mode: str = 'incremental') -> dict: """Wind债券数据抓取(便捷封装)""" return self.ingest_wind_daily_market(codes, start_date, end_date, fields, 'bond', mode)
# ========== 查询功能(薄封装)==========
[docs] @func_timer def run_query(self, conditions: list = None, params: list = None, select_columns: str = '*'): """ 执行自定义SQL查询(薄封装) 替代原 query_data 方法,使用 query.build_query Args: conditions: SQL条件列表,如['datetime >= %s', 'code = %s'] params: 参数列表 select_columns: 要选择的列 Returns: DataFrame """ query, params = _build_query( table_name=self.sheet, conditions=conditions, params=params, select_columns=select_columns ) self.logger.info(f"执行查询: {query}") self.cursor.execute(query, params) result = pd.DataFrame(self.cursor.fetchall()) return result
[docs] @func_timer def query_time_range( self, codes: list | None = None, start_date: str | None = None, end_date: str | None = None, metric: str | None = None, limit: int | None = None, ): """ 查询指定时间范围的数据(薄封装) 直接调用 query.query_time_range Args: codes: 代码列表 start_date: 开始日期 end_date: 结束日期 metric: 指标 limit: 最大返回行数,None表示不限制 Returns: DataFrame """ return _query_time_range( cursor=self.cursor, table_name=self.sheet, codes=codes, start_date=start_date, end_date=end_date, metric=metric, limit=limit, logger=self.logger )
[docs] @func_timer def query_nearest_after(self, params=None): """ 根据输入时间戳序列查找每个时点之后最近的有效值(薄封装) 主要用于回测时提取价格 直接调用 query.query_nearest_after Args: params (dict): 必须包含以下键: - codes: 代码列表 - datetimes: 目标时间戳列表(格式:'YYYY-MM-DD HH:MM') - metric: 查询的指标名称 - time_tolerance: 允许的最大时间间隔(单位:小时,默认不限制) Returns: DataFrame: 包含以下列: code | input_ts | datetime | diff_hours | value | name """ # 参数校验 required_keys = ['codes', 'datetimes', 'metric'] if not all(k in params for k in required_keys): raise ValueError(f"必须提供参数: {required_keys}") # 使用函数式查询 df = _query_nearest_after( cursor=self.cursor, table_name=self.sheet, codes=params['codes'], datetimes=params['datetimes'], metric=params['metric'], time_tolerance=params.get('time_tolerance'), logger=self.logger ) return df
[docs] @func_timer def query_nearest_in_range_after(self, params=None): """ 在 (start, end) 区间内查找距 start 最近的有效值(薄封装) Args: params (dict): 必须包含以下键: - codes: 代码列表 - ranges: [(start, end), ...] 区间列表 - metric: 指标名 - time_tolerance: 锚点容差(小时,可选) Returns: DataFrame: code | input_ts(=start) | datetime | diff_hours | value | name """ required_keys = ['codes', 'ranges', 'metric'] if not all(k in params for k in required_keys): raise ValueError(f"必须提供参数: {required_keys}") df = _query_nearest_in_range_after( cursor=self.cursor, table_name=self.sheet, codes=params['codes'], ranges=params['ranges'], metric=params['metric'], time_tolerance=params.get('time_tolerance'), logger=self.logger ) return df
[docs] @func_timer def query_nearest_in_range_before(self, params=None): """ 在 (start, end) 区间内查找距 end 最近的有效值(薄封装) Args: params (dict): 必须包含以下键: - codes: 代码列表 - ranges: [(start, end), ...] 区间列表 - metric: 指标名 - time_tolerance: 锚点容差(小时,可选) Returns: DataFrame: code | input_ts(=end) | datetime | diff_hours | value | name """ required_keys = ['codes', 'ranges', 'metric'] if not all(k in params for k in required_keys): raise ValueError(f"必须提供参数: {required_keys}") df = _query_nearest_in_range_before( cursor=self.cursor, table_name=self.sheet, codes=params['codes'], ranges=params['ranges'], metric=params['metric'], time_tolerance=params.get('time_tolerance'), logger=self.logger ) return df
[docs] @func_timer def query_nearest_before(self, params=None): """ 根据输入时间戳序列查找每个时点之前最近的有效值(薄封装) 主要用于回测时提取历史价格特征 直接调用 query.query_nearest_before Args: params (dict): 必须包含以下键: - codes: 代码列表 - datetimes: 目标时间戳列表(格式:'YYYY-MM-DD HH:MM') - metric: 查询的指标名称 - time_tolerance: 允许的最大时间间隔(单位:小时,默认不限制) Returns: DataFrame: 包含以下列: code | input_ts | datetime | diff_hours | value | name """ # 参数校验 required_keys = ['codes', 'datetimes', 'metric'] if not all(k in params for k in required_keys): raise ValueError(f"必须提供参数: {required_keys}") # 使用函数式查询 df = _query_nearest_before( cursor=self.cursor, table_name=self.sheet, codes=params['codes'], datetimes=params['datetimes'], metric=params['metric'], time_tolerance=params.get('time_tolerance'), logger=self.logger ) return df
[docs] def get_latest_date(self, code: str = None, metric: str = None): """ 获取数据库中的最新日期(薄封装) 直接调用 query.get_latest_date Args: code: 代码,None表示所有代码 metric: 指标,None表示所有指标 Returns: 最新日期 """ return _get_latest_date( cursor=self.cursor, table_name=self.sheet, code=code, metric=metric, logger=self.logger )
[docs] def get_available_dates( self, code: str, metric: str, start_date: str = None, end_date: str = None ): """ 获取指定代码和指标的可用日期列表(薄封装) 直接调用 query.get_available_dates Args: code: 代码 metric: 指标 start_date: 开始日期 end_date: 结束日期 Returns: 日期列表 """ return _get_available_dates( cursor=self.cursor, table_name=self.sheet, code=code, metric=metric, start_date=start_date, end_date=end_date, logger=self.logger )
# ========== 验证功能(薄封装)==========
[docs] def validate_dataframe( self, df: pd.DataFrame, validations: dict ): """ 验证和修复DataFrame(薄封装) 直接调用 validation.validate_and_fix Args: df: 待验证的DataFrame validations: 验证配置 Returns: (修复后的DataFrame, 验证报告) """ return validate_and_fix( df, validations=validations, logger=self.logger, inplace=False )
[docs] def check_excel_file( self, filepath: str, checks: dict = None ): """ 检查Excel文件中的错误(薄封装) 直接调用 excel.check_excel_errors Args: filepath: 文件路径 checks: 检查配置 Returns: (是否通过, 错误列表) """ df = read_file(filepath, logger=self.logger) return check_excel_errors(df, checks, logger=self.logger)
# ========== 表管理 ==========
[docs] def truncate_table(self): """ 清空表中所有数据 WARNING: 此操作不可逆,会删除表中所有记录 Returns: 删除的行数 """ # 先查询表中的行数 self.cursor.execute(f"SELECT COUNT(*) FROM {self.sheet}") count = self.cursor.fetchone()['count'] # 清空表 self.cursor.execute(f"TRUNCATE TABLE {self.sheet}") self.conn.commit() self.logger.warning(f"已清空表 {self.sheet},删除 {count} 行数据") return count
# ========== 连接管理 ==========
[docs] def close(self): """关闭数据库连接""" if self.cursor: self.cursor.close() if self.conn: self.conn.close() self.logger.info("数据库连接已关闭")
# ========== 辅助函数(保留)==========
[docs] def get_absolute_trade_days(begin_date, end_date, period, use_pmc=True): """ 获取交易日序列 Args: begin_date: 开始日期,字符串格式 end_date: 结束日期,字符串格式 period: 周期,如'D'(日), 'W'(周), 'M'(月), 'Q'(季), 'S'(半年), 'Y'(年) use_pmc: 默认True,使用pandas_market_calendars(中国A股/北京时区);False时使用akshare Returns: 交易日列表(datetime.datetime对象) """ import pandas as pd period_map = {"D": None, "W": "W", "M": "M", "Q": "Q", "S": "2Q", "Y": "Y"} freq = period_map.get(period.upper()) if use_pmc: import pandas_market_calendars as mcal cal = mcal.get_calendar("XSHG") schedule = cal.schedule(start_date=begin_date, end_date=end_date, tz="Asia/Shanghai") dates = pd.to_datetime(schedule.index).tz_localize(None).to_series().reset_index(drop=True) else: import akshare as ak df = ak.tool_trade_date_hist_sina() dates = pd.to_datetime(df["trade_date"]) mask = (dates >= pd.Timestamp(begin_date)) & (dates <= pd.Timestamp(end_date)) dates = dates[mask].sort_values().reset_index(drop=True) if freq: dates = dates.groupby(dates.dt.to_period(freq)).last().reset_index(drop=True) return [d.date() for d in dates]
[docs] def trade_days_offset(begin_datetime, offset, period='D'): """ 交易日偏移计算 Args: begin_datetime: 起始datetime对象 offset: 偏移量(整数) period: 周期,默认'D' Returns: 偏移后的datetime对象 """ from datetime import datetime import akshare as ak import pandas as pd df = ak.tool_trade_date_hist_sina() all_days = pd.to_datetime(df["trade_date"]).sort_values().tolist() begin_date = pd.Timestamp(begin_datetime.date()) idx = next((i for i, d in enumerate(all_days) if d >= begin_date), None) target = all_days[idx + offset] return datetime.combine(target.date(), begin_datetime.time())