Commit 88651c5e authored by Dr.李's avatar Dr.李

FEATURE: added fetch industry

parent 04962c48
...@@ -25,7 +25,8 @@ from sqlalchemy import ( ...@@ -25,7 +25,8 @@ from sqlalchemy import (
from PyFin.api import advanceDateByCalendar from PyFin.api import advanceDateByCalendar
from alphamind.data.dbmodel.models_rl import ( from alphamind.data.dbmodel.models_rl import (
Market Market,
Industry
) )
if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl": if "DB_VENDOR" in os.environ and os.environ["DB_VENDOR"].lower() == "rl":
...@@ -35,6 +36,7 @@ else: ...@@ -35,6 +36,7 @@ else:
from alphamind.data.engines.universe import Universe from alphamind.data.engines.universe import Universe
from alphamind.data.processing import factor_processing from alphamind.data.processing import factor_processing
from alphamind.data.engines.utilities import _map_industry_category
DAILY_RETURN_OFFSET = 0 DAILY_RETURN_OFFSET = 0
...@@ -161,6 +163,31 @@ class SqlEngine: ...@@ -161,6 +163,31 @@ class SqlEngine:
dates: Iterable[str] = None) -> pd.DataFrame: dates: Iterable[str] = None) -> pd.DataFrame:
return universe.query(self, start_date, end_date, dates) return universe.query(self, start_date, end_date, dates)
def fetch_industry(self,
ref_date: str,
codes: Iterable[int] = None,
category: str = 'sw',
level: int = 1):
code_name = 'industry_code' + str(level)
category_name = 'industry_name' + str(level)
cond = and_(
Industry.trade_date == ref_date,
Industry.code.in_(codes),
Industry.flag == 1
) if codes else and_(
Industry.trade_date == ref_date,
Industry.flag == 1
)
query = select([Industry.code.label("code"),
getattr(Industry, code_name).label('industry_code'),
getattr(Industry, category_name).label('industry')]).where(
cond
).distinct()
return pd.read_sql(query, self.engine).dropna().drop_duplicates(['code'])
if __name__ == "__main__": if __name__ == "__main__":
db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8" db_url = "mysql+mysqldb://reader:Reader#2020@121.37.138.1:13317/vision?charset=utf8"
...@@ -175,3 +202,5 @@ if __name__ == "__main__": ...@@ -175,3 +202,5 @@ if __name__ == "__main__":
print(df) print(df)
df = sql_engine.fetch_dx_return_range(universe, start_date=start_date, end_date=end_date) df = sql_engine.fetch_dx_return_range(universe, start_date=start_date, end_date=end_date)
print(df) print(df)
df = sql_engine.fetch_industry(ref_date="2020-10-09", codes=["2010031963"])
print(df)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment