Commit cc4f39b3 authored by Dr.李's avatar Dr.李

added test for grouped neutralize

parent 140b4f7c
......@@ -27,6 +27,21 @@ class TestNeutralize(unittest.TestCase):
np.testing.assert_array_almost_equal(calc_res, exp_res)
def test_neutralize_with_group(self):
y = np.random.randn(3000, 4)
x = np.random.randn(3000, 10)
groups = np.random.randint(30, size=3000)
calc_res = neutralize(x, y, groups)
model = LinearRegression(fit_intercept=False)
for i in range(30):
curr_x = x[groups == i]
curr_y = y[groups == i]
model.fit(curr_x, curr_y)
exp_res = curr_y - curr_x @ model.coef_.T
np.testing.assert_array_almost_equal(calc_res[groups ==i ], exp_res)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment