-
Notifications
You must be signed in to change notification settings - Fork 0
/
run-spark-ex.py
64 lines (52 loc) · 1.91 KB
/
run-spark-ex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Add Spark Python Files to Python Path
import sys
import os
SPARK_HOME = "/opt/bitnami/spark" # Set this to wherever you have compiled Spark
os.environ["SPARK_HOME"] = SPARK_HOME # Add Spark path
os.environ["SPARK_LOCAL_IP"] = "127.0.0.1" # Set Local IP
sys.path.append( SPARK_HOME + "/python") # Add python files to Python Path
import pyspark
from pyspark.mllib.classification import LogisticRegressionWithSGD
import numpy as np
from pyspark import SparkConf, SparkContext
from pyspark.mllib.regression import LabeledPoint # added
def getSparkContext():
"""
Gets the Spark Context
"""
conf = (SparkConf()
.setMaster("local") # run on local
.setAppName("Logistic Regression") # Name of App
.set("spark.executor.memory", "1g")) # Set 1 gig of memory
sc = SparkContext(conf = conf)
return sc
def mapper(line):
"""
Mapper that converts an input line to a feature vector
"""
feats = line.strip().split(",")
# labels must be at the beginning for LRSGD
label = feats[len(feats) - 1]
feats = feats[: len(feats) - 1]
#feats.insert(0,label) # FIXED
features = [ float(feature) for feature in feats ] # need floats
#return np.array(features)
return LabeledPoint(label, features) # FIXED
#sc = getSparkContext()
sc = pyspark.SparkContext()
# Load and parse the data
data = sc.textFile("./data")
parsedData = data.map(mapper)
# Train model
model = LogisticRegressionWithSGD.train(parsedData)
#print(parsedData.collect())
#input()
# Predict the first elem will be actual data and the second
# item will be the prediction of the model
labelsAndPreds = parsedData.map(lambda point: (int(point.label),
model.predict(point.features)))
# Evaluating the model on training data
trainErr = labelsAndPreds.filter(lambda v: v[0] != v[1]).count()/ float(parsedData.count())
#trainErr = ""
# Print some stuff
print("Training Error = " + str(trainErr))