Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit e0aa489

Browse files
author
Ehsan Totoni
committed
lr rand fix format, Y type
1 parent f8d9d3d commit e0aa489

1 file changed

Lines changed: 10 additions & 8 deletions

File tree

examples/logistic_regression_rand.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44
@hpat.jit
55
def logistic_regression(iterations):
66
print("generating random data...")
7-
N = 10**8
7+
N = 10**3
88
D = 10
9-
g = 2*np.random.ranf(D)-1
10-
X = 2*np.random.ranf((N,D))-1
11-
Y = (np.dot(X,g)>0.0)==(np.random.ranf(N)>.90)
9+
g = 2 * np.random.ranf(D) - 1
10+
X = 2 * np.random.ranf((N, D)) - 1
11+
Y = ((np.dot(X, g) > 0.0) == (np.random.ranf(N) > .90)) + .0
1212

13-
w = 2*np.random.ranf(D)-1
13+
w = 2 * np.random.ranf(D) - 1
1414
for i in range(iterations):
15-
w -= np.dot(((1.0 / (1.0 + np.exp(-Y * np.dot(X,w))) - 1.0) * Y), X)
16-
R = np.dot(X,w)>0.0
17-
accuracy = np.sum(R==Y)/N
15+
w -= np.dot(((1.0 / (1.0 + np.exp(-Y * np.dot(X, w))) - 1.0) * Y), X)
16+
R = np.dot(X, w) > 0.0
17+
accuracy = np.sum(R == Y) / N
1818
print(accuracy, w)
1919

20+
return w
21+
2022
w = logistic_regression(2000)

0 commit comments

Comments
 (0)