Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Sign in
Toggle navigation
A
alpha-mind
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Dr.李
alpha-mind
Commits
1e8f9e1b
Commit
1e8f9e1b
authored
Jan 05, 2018
by
Dr.李
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update tree models
parent
88f50881
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
13 deletions
+32
-13
treemodel.py
alphamind/model/treemodel.py
+32
-13
No files found.
alphamind/model/treemodel.py
View file @
1e8f9e1b
...
@@ -9,8 +9,9 @@ from typing import List
...
@@ -9,8 +9,9 @@ from typing import List
import
numpy
as
np
import
numpy
as
np
from
distutils.version
import
LooseVersion
from
distutils.version
import
LooseVersion
from
sklearn
import
__version__
as
sklearn_version
from
sklearn
import
__version__
as
sklearn_version
from
xgboost
import
__version__
as
xgbboot_version
from
sklearn.ensemble
import
RandomForestRegressor
as
RandomForestRegressorImpl
from
sklearn.ensemble
import
RandomForestRegressor
as
RandomForestRegressorImpl
#
from xgboost import XGBRegressor as XGBRegressorImpl
from
xgboost
import
XGBRegressor
as
XGBRegressorImpl
from
alphamind.model.modelbase
import
ModelBase
from
alphamind.model.modelbase
import
ModelBase
from
alphamind.utilities
import
alpha_logger
from
alphamind.utilities
import
alpha_logger
...
@@ -28,6 +29,7 @@ class RandomForestRegressor(ModelBase):
...
@@ -28,6 +29,7 @@ class RandomForestRegressor(ModelBase):
def
save
(
self
)
->
dict
:
def
save
(
self
)
->
dict
:
model_desc
=
super
()
.
save
()
model_desc
=
super
()
.
save
()
model_desc
[
'sklearn_version'
]
=
sklearn_version
model_desc
[
'sklearn_version'
]
=
sklearn_version
return
model_desc
@
classmethod
@
classmethod
def
load
(
cls
,
model_desc
:
dict
):
def
load
(
cls
,
model_desc
:
dict
):
...
@@ -40,17 +42,34 @@ class RandomForestRegressor(ModelBase):
...
@@ -40,17 +42,34 @@ class RandomForestRegressor(ModelBase):
return
obj_layout
return
obj_layout
# class XGBRegressor(ModelBase):
class
XGBRegressor
(
ModelBase
):
#
# def __init__(self,
def
__init__
(
self
,
# n_estimators: int=100,
n_estimators
:
int
=
100
,
# learning_rate: float=0.1,
learning_rate
:
float
=
0.1
,
# max_depth: int=3,
max_depth
:
int
=
3
,
# features: List=None, **kwargs):
features
:
List
=
None
,
**
kwargs
):
# super().__init__(features)
super
()
.
__init__
(
features
)
# self.impl = XGBRegressorImpl(n_estimators=n_estimators,
self
.
impl
=
XGBRegressorImpl
(
n_estimators
=
n_estimators
,
# learning_rate=learning_rate,
learning_rate
=
learning_rate
,
# max_depth=max_depth,
max_depth
=
max_depth
,
# **kwargs)
**
kwargs
)
def
save
(
self
)
->
dict
:
model_desc
=
super
()
.
save
()
model_desc
[
'xgbboot_version'
]
=
xgbboot_version
return
model_desc
@
classmethod
def
load
(
cls
,
model_desc
:
dict
):
obj_layout
=
super
()
.
load
(
model_desc
)
if
LooseVersion
(
sklearn_version
)
<
LooseVersion
(
model_desc
[
'xgbboot_version'
]):
alpha_logger
.
warning
(
'Current xgboost version {0} is lower than the model version {1}. '
'Loaded model may work incorrectly.'
.
format
(
xgbboot_version
,
model_desc
[
'xgbboot_version'
]))
return
obj_layout
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment