Source code for betalens.datafeed.industry

#%%By Janis 260602
"""
行业归属查询模块(函数式)

设计要点
--------
行业归属是 point-in-time 的"状态"数据:某公司从某日起属于某行业,直到下次变更。
因此不另造存储模型,而是复用现有长格式时序表,约定:
    - metric : 分类体系名,如 '申万一级行业' / '申万二级行业' / '中信一级行业'
    - value  : 行业代码的数值部分(如 801780),便于数值索引与分组
    - remark : JSONB,存行业名等文本,约定 {"ind_name", "ind_code", "scheme", "level"}
    - datetime: 该归属关系的生效时点(最早可知日)

查询语义 = 取 datetime <= 查询日 的最近一条,与 query.query_nearest_before 同构,
天然避免前视偏差。本模块额外把 remark(JSONB) 解析出来返回行业名。

主要接口
--------
    query_industry        : 正查——某公司在某日所属行业
    get_industry_members  : 反查——某日某行业的成分股
    build_industry_records: 入库辅助——把 (code,name,生效日,行业) 整理成长格式
"""

import itertools
import logging
from typing import Optional, List, Tuple, Union

import pandas as pd


DEFAULT_TABLE = 'industry'


def _get_default_logger():
    logger = logging.getLogger('IndustryQuery')
    if not logger.handlers:
        logger.setLevel(logging.INFO)
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)
        ch.setFormatter(logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
        logger.addHandler(ch)
    return logger


def _explode_remark(df: pd.DataFrame) -> pd.DataFrame:
    """把 remark(JSONB->dict) 展开成 ind_name / ind_code / scheme 列"""
    if df.empty or 'remark' not in df.columns:
        for c in ('ind_name', 'ind_code', 'scheme'):
            df[c] = None
        return df

    def _get(r, k):
        return r.get(k) if isinstance(r, dict) else None

    df['ind_name'] = df['remark'].apply(lambda r: _get(r, 'ind_name'))
    df['ind_code'] = df['remark'].apply(lambda r: _get(r, 'ind_code'))
    df['scheme'] = df['remark'].apply(lambda r: _get(r, 'scheme'))
    return df


def _scheme_clause(scheme: str, exact: bool, col: str = 't.metric') -> Tuple[str, str]:
    """生成 metric 匹配子句与参数。

    版本无关查询:scheme 不带版本后缀(如 '申万一级行业')时用前缀匹配,
    覆盖 '申万一级行业(旧版/2014/2021)' 等全部版本;配合 ORDER BY datetime DESC,
    取 datetime<=查询日 的最近一条 → 自动落到查询日生效的那个版本,无需硬编码版本边界。

    带版本后缀(如 '申万一级行业(2021)')时前缀匹配退化为精确,只命中该版本。
    exact=True 则强制精确匹配(旧行为)。

    Args:
        scheme: 分类体系名
        exact: True 强制精确匹配
        col: metric 列引用(带表别名前缀,如 't.metric' 或 'metric')

    Returns:
        (SQL 片段, 参数值);SQL 片段形如 '{col} = %s' 或 '{col} LIKE %s ESCAPE ...'
    """
    if exact:
        return f'{col} = %s', scheme
    # 转义 LIKE 通配符(中文 metric 名一般不含,但稳妥起见)
    esc = scheme.replace('\\', '\\\\').replace('%', '\\%').replace('_', '\\_')
    return f"{col} LIKE %s ESCAPE '\\'", esc + '%'


[docs] def query_industry( cursor, codes: List[str], dates: Union[str, List[str]], scheme: str = '申万一级行业', table_name: str = DEFAULT_TABLE, exact: bool = False, logger: Optional[logging.Logger] = None, ) -> pd.DataFrame: """ 正查:每个 (code, date) 在该日所属的行业(point-in-time,取 datetime<=date 的最近一条) Args: cursor: 数据库游标(建议 RealDictCursor) codes: 证券代码列表 dates: 查询日期,单个或列表,格式 'YYYY-MM-DD' 或 'YYYY-MM-DD HH:MM:SS' scheme: 分类体系(即 metric)。不带版本后缀(如 '申万一级行业')时自动匹配全部 版本,最近一条天然落到查询日生效的版本;带后缀(如 '申万一级行业(2021)') 则只查该版本。 table_name: 表名,默认 'industry' exact: 强制精确匹配 metric(关闭版本自动选择),默认 False logger: 日志器 Returns: DataFrame: code | query_date | effective_dt | sec_name | industry_value | ind_name | ind_code | scheme 无归属记录的 (code,date) 行业字段为 NaN/None """ if logger is None: logger = _get_default_logger() if not codes: raise ValueError("codes不能为空") if isinstance(dates, str): dates = [dates] if not dates: raise ValueError("dates不能为空") pairs = list(itertools.product(codes, dates)) value_ph = ', '.join(['(%s, %s::TIMESTAMP)'] * len(pairs)) metric_clause, metric_param = _scheme_clause(scheme, exact, col='t.metric') sql = f""" WITH input_data (code, q_date) AS ( VALUES {value_ph} ), cand AS ( SELECT i.code, i.q_date, t.datetime AS effective_dt, t.name AS sec_name, t.value AS industry_value, t.remark AS remark, ROW_NUMBER() OVER ( PARTITION BY i.code, i.q_date ORDER BY t.datetime DESC ) AS rn FROM input_data i LEFT JOIN {table_name} t ON i.code = t.code AND {metric_clause} AND t.datetime <= i.q_date ) SELECT code, q_date AS query_date, effective_dt, sec_name, industry_value, remark FROM cand WHERE rn = 1 """ params: List = [] for code, dt in pairs: params.extend([code, dt]) params.append(metric_param) logger.info(f"query_industry: {len(codes)}代码 × {len(dates)}日期, scheme={scheme}") cursor.execute(sql, params) df = pd.DataFrame(cursor.fetchall()) df = _explode_remark(df) logger.info(f"返回 {len(df)} 条") return df
[docs] def get_industry_members( cursor, industry: Union[str, int, float], date: str, scheme: str = '申万一级行业', table_name: str = DEFAULT_TABLE, by: str = 'name', exact: bool = False, logger: Optional[logging.Logger] = None, ) -> pd.DataFrame: """ 反查:某日某行业的成分股(每只股票取 datetime<=date 的最近归属,再筛目标行业) Args: cursor: 数据库游标 industry: 目标行业,可为行业名(str,匹配 remark->>'ind_name') 或行业代码数值(int/float,匹配 value) date: 查询日期 scheme: 分类体系(metric)。不带版本后缀时自动匹配全部版本(最近一条天然落到 查询日生效的版本);带后缀只查该版本。 table_name: 表名 by: 'name' 用行业名匹配,'value' 用行业代码数值匹配; industry 类型也会自动推断 exact: 强制精确匹配 metric(关闭版本自动选择),默认 False logger: 日志器 Returns: DataFrame: code | sec_name | industry_value | ind_name | ind_code | scheme """ if logger is None: logger = _get_default_logger() use_value = (by == 'value') or isinstance(industry, (int, float)) if use_value: match_cond = "WHERE rn = 1 AND industry_value = %s" match_param: Union[str, int, float] = industry else: match_cond = "WHERE rn = 1 AND (remark->>'ind_name') = %s" match_param = str(industry) metric_clause, metric_param = _scheme_clause(scheme, exact, col='metric') sql = f""" WITH latest AS ( SELECT code, name AS sec_name, value AS industry_value, remark, datetime, ROW_NUMBER() OVER ( PARTITION BY code ORDER BY datetime DESC ) AS rn FROM {table_name} WHERE {metric_clause} AND datetime <= %s::TIMESTAMP ) SELECT code, sec_name, industry_value, remark FROM latest {match_cond} ORDER BY code """ params = [metric_param, date, match_param] logger.info(f"get_industry_members: {scheme}={industry} @ {date}") cursor.execute(sql, params) df = pd.DataFrame(cursor.fetchall()) df = _explode_remark(df) logger.info(f"成分股 {len(df)} 只") return df
[docs] def build_industry_records( df: pd.DataFrame, scheme: str = '申万一级行业', code_col: str = 'code', name_col: str = 'name', date_col: str = 'effective_dt', ind_name_col: str = 'ind_name', ind_code_col: Optional[str] = 'ind_code', ) -> pd.DataFrame: """ 入库辅助:把行业归属明细整理成可直接 incremental_insert 的长格式 输入每行 = 一条归属事件 (证券, 生效日, 行业)。输出列: datetime, code, name, metric(=scheme), value(=行业代码数值), remark(dict) Args: df: 明细 DataFrame scheme: 分类体系,写入 metric code_col/name_col/date_col: 证券代码/名称/生效日 列名 ind_name_col: 行业名列名 ind_code_col: 行业代码列名(如 '801780.SI');为 None 则不填 value Returns: 长格式 DataFrame(datetime, code, name, metric, value, remark) """ import re out = pd.DataFrame() out['datetime'] = pd.to_datetime(df[date_col]) out['code'] = df[code_col].astype(str) out['name'] = df[name_col].astype(str) out['metric'] = scheme def _to_num(c): if c is None or (isinstance(c, float) and pd.isna(c)): return None m = re.search(r'\d+', str(c)) return int(m.group()) if m else None if ind_code_col and ind_code_col in df.columns: out['value'] = df[ind_code_col].apply(_to_num) ind_codes = df[ind_code_col] else: out['value'] = None ind_codes = pd.Series([None] * len(df), index=df.index) out['remark'] = [ { 'ind_name': (None if pd.isna(n) else str(n)), 'ind_code': (None if pd.isna(c) else str(c)), 'scheme': scheme, } for n, c in zip(df[ind_name_col], ind_codes) ] return out