Commit 58e26c91 authored by Dr.李's avatar Dr.李

return x as DataFrame instead of array

parent ce9b5c02
......@@ -101,7 +101,8 @@ def prepare_data(engine: SqlEngine,
['trade_date', 'code', 'weight', 'isOpen', 'industry_code', 'industry'] + transformer.names]
def batch_processing(x_values,
def batch_processing(names,
x_values,
y_values,
groups,
group_label,
......@@ -132,10 +133,11 @@ def batch_processing(x_values,
else:
this_risk_exp = None
train_x_buckets[end] = factor_processing(this_raw_x,
pre_process=pre_process,
risk_factors=this_risk_exp,
post_process=post_process)
train_x_buckets[end] = pd.DataFrame(factor_processing(this_raw_x,
pre_process=pre_process,
risk_factors=this_risk_exp,
post_process=post_process),
columns=names)
train_y_buckets[end] = factor_processing(this_raw_y,
pre_process=pre_process,
......@@ -163,7 +165,7 @@ def batch_processing(x_values,
inner_left_index = bisect.bisect_left(sub_dates, end)
inner_right_index = bisect.bisect_right(sub_dates, end)
predict_x_buckets[end] = ne_x[inner_left_index:inner_right_index]
predict_x_buckets[end] = pd.DataFrame(ne_x[inner_left_index:inner_right_index], columns=names)
predict_risk_buckets[end] = this_risk_exp[inner_left_index:inner_right_index]
predict_codes_bucket[end] = this_codes[inner_left_index:inner_right_index]
......@@ -198,8 +200,8 @@ def fetch_data_package(engine: SqlEngine,
pre_process: Iterable[object] = None,
post_process: Iterable[object] = None) -> dict:
alpha_logger.info("Starting data package fetching ...")
transformer = Transformer(alpha_factors)
names = transformer.names
dates, return_df, factor_df = prepare_data(engine,
transformer,
start_date,
......@@ -210,7 +212,7 @@ def fetch_data_package(engine: SqlEngine,
warm_start)
return_df, dates, date_label, risk_exp, x_values, y_values, train_x, train_y, codes = \
_merge_df(engine, transformer.names, factor_df, return_df, universe, dates, risk_model, neutralized_risk)
_merge_df(engine, names, factor_df, return_df, universe, dates, risk_model, neutralized_risk)
alpha_logger.info("data merging finished")
......@@ -226,7 +228,8 @@ def fetch_data_package(engine: SqlEngine,
alpha_logger.info("Loading data is finished")
train_x_buckets, train_y_buckets, train_risk_buckets, predict_x_buckets, predict_y_buckets, predict_risk_buckets, predict_codes_bucket \
= batch_processing(x_values,
= batch_processing(names,
x_values,
y_values,
dates,
date_label,
......@@ -239,10 +242,11 @@ def fetch_data_package(engine: SqlEngine,
alpha_logger.info("Data processing is finished")
ret = dict()
ret['x_names'] = transformer.names
ret['x_names'] = names
ret['settlement'] = return_df
ret['train'] = {'x': train_x_buckets, 'y': train_y_buckets, 'risk': train_risk_buckets}
ret['predict'] = {'x': predict_x_buckets, 'y': predict_y_buckets, 'risk': predict_risk_buckets, 'code': predict_codes_bucket}
ret['predict'] = {'x': predict_x_buckets, 'y': predict_y_buckets, 'risk': predict_risk_buckets,
'code': predict_codes_bucket}
return ret
......@@ -314,7 +318,7 @@ def fetch_train_phase(engine,
ret = dict()
ret['x_names'] = transformer.names
ret['train'] = {'x': ne_x, 'y': ne_y, 'code': this_code}
ret['train'] = {'x': pd.DataFrame(ne_x, columns=transformer.names), 'y': ne_y, 'code': this_code}
return ret
......@@ -398,7 +402,7 @@ def fetch_predict_phase(engine,
ret = dict()
ret['x_names'] = transformer.names
ret['predict'] = {'x': ne_x, 'code': codes}
ret['predict'] = {'x': pd.DataFrame(ne_x, columns=transformer.names), 'code': codes}
return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment