Skip to content

Commit

Permalink
fix session recommender (intel-analytics#1558)
Browse files Browse the repository at this point in the history
  • Loading branch information
songhappy authored Aug 8, 2019
1 parent f1a7940 commit 54943c0
Showing 1 changed file with 16 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SessionRecommenderSpec extends ZooSpecHelper {

val itemCount = 100
val sessionLength = 10
val model = SessionRecommender[Float](itemCount, sessionLength, includeHistory = false)
val model = SessionRecommender[Float](itemCount, sessionLength = sessionLength)
val ran = new Random(42L)
val data = (1 to 100).map { x =>
val items: Seq[Float] = for (i <- 1 to sessionLength) yield
Expand All @@ -69,7 +69,7 @@ class SessionRecommenderSpec extends ZooSpecHelper {
val itemCount = 100
val sessionLength = 10
val historyLength = 5
val model = SessionRecommender[Float](itemCount, sessionLength,
val model = SessionRecommender[Float](itemCount, sessionLength = sessionLength,
includeHistory = true, historyLength = historyLength)
val ran = new Random(42L)
val data = (1 to 100).map { x =>
Expand All @@ -91,26 +91,10 @@ class SessionRecommenderSpec extends ZooSpecHelper {
val itemCount = 100
val sessionLength = 10
val historyLength = 5
val model = SessionRecommender[Float](itemCount, sessionLength,
val model = SessionRecommender[Float](itemCount, sessionLength = sessionLength,
includeHistory = true, historyLength = historyLength)
val ran = new Random(42L)
val data1: RDD[Sample[Float]] = sc.parallelize(1 to 100)
.map { x =>
val items1: Seq[Float] = for (i <- 1 to sessionLength) yield ran.nextInt(itemCount).toFloat
val items2: Seq[Float] = for (i <- 1 to historyLength) yield ran.nextInt(itemCount).toFloat
val input1 = Tensor(items1.toArray, Array(sessionLength))
val input2 = Tensor(items2.toArray, Array(historyLength))
Sample[Float](Array(input1, input2))
}

val recommedations1 = model.recommendForSession(data1, 3, zeroBasedLabel = false)
recommedations1.take(10)
.map { x =>
assert(x.size == 3)
assert(x(0)._2 >= x(1)._2)
}

val data2: Array[Sample[Float]] = (1 to 10)
val data1: Array[Sample[Float]] = (1 to 10)
.map { x =>
val items1: Seq[Float] = for (i <- 1 to sessionLength) yield ran.nextInt(itemCount).toFloat
val items2: Seq[Float] = for (i <- 1 to historyLength) yield ran.nextInt(itemCount).toFloat
Expand All @@ -119,19 +103,26 @@ class SessionRecommenderSpec extends ZooSpecHelper {
Sample[Float](Array(input1, input2))
}.toArray

val recommedations2 = model.recommendForSession(data2, 4, zeroBasedLabel = false)
recommedations2.map { x =>
val recommedations1 = model.recommendForSession(data1, 4, zeroBasedLabel = false)
recommedations1.map { x =>
assert(x.size == 4)
assert(x(0)._2 >= x(1)._2)
}

val data2: RDD[Sample[Float]] = sc.parallelize(data1)
val recommedations2 = model.recommendForSession(data2, 3, zeroBasedLabel = false)
recommedations2.take(10).map { x =>
assert(x.size == 3)
assert(x(0)._2 >= x(1)._2)
}
}

"SessionRecommender compile and fit" should "work properly" in {

val itemCount = 100
val sessionLength = 10
val historyLength = 5
val model = SessionRecommender[Float](itemCount, sessionLength,
val model = SessionRecommender[Float](itemCount, 10, sessionLength = sessionLength,
includeHistory = true, historyLength = historyLength)
val ran = new Random(42L)
val data1 = sc.parallelize(1 to 100)
Expand All @@ -144,6 +135,7 @@ class SessionRecommenderSpec extends ZooSpecHelper {
Sample(Array(input1, input2), Array(label))
}
model.compile(optimizer = "rmsprop", loss = "sparse_categorical_crossentropy")
model.summary()
model.fit(data1, nbEpoch = 1)
}
}
Expand All @@ -153,7 +145,7 @@ class SessionRecommenderSerialTest extends ModuleSerializationTest {
val ran = new Random(42L)
val itemCount = 100
val sessionLength = 10
val model = SessionRecommender[Float](100, 10, includeHistory = false)
val model = SessionRecommender[Float](100, sessionLength = 10)
val items: Seq[Float] = for (i <- 1 to sessionLength) yield
ran.nextInt(itemCount - 1).toFloat + 1
val data = Tensor(items.toArray, Array(sessionLength)).resize(1, sessionLength)
Expand Down

0 comments on commit 54943c0

Please sign in to comment.