From f703b539b6e5e606737ebe8d635d468d88674d16 Mon Sep 17 00:00:00 2001 From: Zhenghao Zhang Date: Wed, 9 Nov 2022 22:00:55 +0800 Subject: [PATCH] fix panic if zero IDF --- master/tasks.go | 24 ++++++++++-- master/tasks_test.go | 88 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 4 deletions(-) diff --git a/master/tasks.go b/master/tasks.go index 488cf61ae..abfbaaeaa 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -286,7 +286,11 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemFeedback)) // inverse document frequency of users for i := range dataset.UserFeedback { - userIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(dataset.UserFeedback[i]))) + if dataset.ItemCount() == len(dataset.UserFeedback[i]) { + userIDF[i] = 1 + } else { + userIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(dataset.UserFeedback[i]))) + } } t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.UserFeedback)) } @@ -303,7 +307,11 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemLabels)) // inverse document frequency of labels for i := range labeledItems { - labelIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(labeledItems[i]))) + if dataset.ItemCount() == len(labeledItems[i]) { + labelIDF[i] = 1 + } else { + labelIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(labeledItems[i]))) + } } t.taskMonitor.Add(TaskFindItemNeighbors, len(labeledItems)) } @@ -597,7 +605,11 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserFeedback)) // inverse document frequency of items for i := range dataset.ItemFeedback { - itemIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(dataset.ItemFeedback[i]))) + if dataset.UserCount() == len(dataset.ItemFeedback[i]) { + itemIDF[i] = 1 + } else { + itemIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(dataset.ItemFeedback[i]))) + } } t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.ItemFeedback)) } @@ -614,7 +626,11 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserLabels)) // inverse document frequency of labels for i := range labeledUsers { - labelIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(labeledUsers[i]))) + if dataset.UserCount() == len(labeledUsers[i]) { + labelIDF[i] = 1 + } else { + labelIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(labeledUsers[i]))) + } } t.taskMonitor.Add(TaskFindUserNeighbors, len(labeledUsers)) } diff --git a/master/tasks_test.go b/master/tasks_test.go index 5951500a6..44e6d4c10 100644 --- a/master/tasks_test.go +++ b/master/tasks_test.go @@ -248,6 +248,50 @@ func TestMaster_FindItemNeighborsIVF(t *testing.T) { assert.Equal(t, task.StatusComplete, m.taskMonitor.Tasks[TaskFindItemNeighbors].Status) } +func TestMaster_FindItemNeighborsIVF_ZeroIDF(t *testing.T) { + // create mock master + m := newMockMaster(t) + defer m.Close() + // create config + m.Config = &config.Config{} + m.Config.Recommend.CacheSize = 3 + m.Config.Master.NumJobs = 4 + m.Config.Recommend.ItemNeighbors.EnableIndex = true + m.Config.Recommend.ItemNeighbors.IndexRecall = 1 + m.Config.Recommend.ItemNeighbors.IndexFitEpoch = 10 + + // create dataset + err := m.DataClient.BatchInsertItems([]data.Item{ + {"0", false, []string{"*"}, time.Now(), []string{"a"}, ""}, + {"1", false, []string{"*"}, time.Now(), []string{"a"}, ""}, + }) + assert.NoError(t, err) + err = m.DataClient.BatchInsertFeedback([]data.Feedback{ + {FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "0", ItemId: "0"}}, + {FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "0", ItemId: "1"}}, + }, true, true, true) + assert.NoError(t, err) + dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) + assert.NoError(t, err) + m.rankingTrainSet = dataset + + // similar items (common users) + m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeRelated + neighborTask := NewFindItemNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) + similar, err := m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "0"), 0, 100) + assert.NoError(t, err) + assert.Equal(t, []string{"1"}, cache.RemoveScores(similar)) + + // similar items (common labels) + m.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeSimilar + neighborTask = NewFindItemNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) + similar, err = m.CacheClient.GetSorted(cache.Key(cache.ItemNeighbors, "0"), 0, 100) + assert.NoError(t, err) + assert.Equal(t, []string{"1"}, cache.RemoveScores(similar)) +} + func TestMaster_FindUserNeighborsBruteForce(t *testing.T) { // create mock master m := newMockMaster(t) @@ -421,6 +465,50 @@ func TestMaster_FindUserNeighborsIVF(t *testing.T) { assert.Equal(t, task.StatusComplete, m.taskMonitor.Tasks[TaskFindUserNeighbors].Status) } +func TestMaster_FindUserNeighborsIVF_ZeroIDF(t *testing.T) { + // create mock master + m := newMockMaster(t) + defer m.Close() + // create config + m.Config = &config.Config{} + m.Config.Recommend.CacheSize = 3 + m.Config.Master.NumJobs = 4 + m.Config.Recommend.UserNeighbors.EnableIndex = true + m.Config.Recommend.UserNeighbors.IndexRecall = 1 + m.Config.Recommend.UserNeighbors.IndexFitEpoch = 10 + + // create dataset + err := m.DataClient.BatchInsertUsers([]data.User{ + {"0", []string{"a"}, nil, ""}, + {"1", []string{"a"}, nil, ""}, + }) + assert.NoError(t, err) + err = m.DataClient.BatchInsertFeedback([]data.Feedback{ + {FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "0", ItemId: "0"}}, + {FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "1", ItemId: "0"}}, + }, true, true, true) + assert.NoError(t, err) + dataset, _, _, _, err := m.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) + assert.NoError(t, err) + m.rankingTrainSet = dataset + + // similar users (common items) + m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeRelated + neighborTask := NewFindUserNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) + similar, err := m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "0"), 0, 100) + assert.NoError(t, err) + assert.Equal(t, []string{"1"}, cache.RemoveScores(similar)) + + // similar users (common labels) + m.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeSimilar + neighborTask = NewFindUserNeighborsTask(&m.Master) + assert.NoError(t, neighborTask.run(nil)) + similar, err = m.CacheClient.GetSorted(cache.Key(cache.UserNeighbors, "0"), 0, 100) + assert.NoError(t, err) + assert.Equal(t, []string{"1"}, cache.RemoveScores(similar)) +} + func TestMaster_LoadDataFromDatabase(t *testing.T) { // create mock master m := newMockMaster(t)