Commit 88651c5e authored by Dr.李's avatar Dr.李

FEATURE: added fetch industry

parent 04962c48
......@@ -25,7 +25,8 @@ from sqlalchemy import (
from PyFin.api import advanceDateByCalendar
from alphamind.data.dbmodel.models_rl import (
Market
Market,
Industry
)
if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl":
......@@ -35,6 +36,7 @@ else:
from alphamind.data.engines.universe import Universe
from alphamind.data.processing import factor_processing
from alphamind.data.engines.utilities import _map_industry_category
DAILY_RETURN_OFFSET = 0
......@@ -161,6 +163,31 @@ class SqlEngine:
dates: Iterable[str] = None) -> pd.DataFrame:
return universe.query(self, start_date, end_date, dates)
def fetch_industry(self,
ref_date: str,
codes: Iterable[int] = None,
category: str = 'sw',
level: int = 1):
code_name = 'industry_code' + str(level)
category_name = 'industry_name' + str(level)
cond = and_(
Industry.trade_date == ref_date,
Industry.code.in_(codes),
Industry.flag == 1
) if codes else and_(
Industry.trade_date == ref_date,
Industry.flag == 1
)
query = select([Industry.code.label("code"),
getattr(Industry, code_name).label('industry_code'),
getattr(Industry, category_name).label('industry')]).where(
cond
).distinct()
return pd.read_sql(query, self.engine).dropna().drop_duplicates(['code'])
if __name__ == "__main__":
db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
......@@ -175,3 +202,5 @@ if __name__ == "__main__":
print(df)
df = sql_engine.fetch_dx_return_range(universe, start_date=start_date, end_date=end_date)
print(df)
df = sql_engine.fetch_industry(ref_date="2020-10-09", codes=["2010031963"])
print(df)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment