diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index dcfebe3d4f142..dae6a0499b991 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -116,14 +116,13 @@ def test_gpu_feature(): with tempfile.NamedTemporaryFile(mode='w') as f: f.write(json_records) f.flush() - inputs, results = auto_scheduler.LogReader(f.name).read_lines() + inputs, results = auto_scheduler.RecordReader(f.name).read_lines() inp = inputs[0] - dag = auto_scheduler.workload_key_to_dag(inp.task.workload_key) - task = auto_scheduler.SearchTask(dag, inp.task.workload_key, inp.task.target, None, auto_scheduler.HardwareParams(100000, 16, 64, 4, 64)) + dag = auto_scheduler.ComputeDAG(inp.task.workload_key) + task = auto_scheduler.SearchTask(dag, inp.task.workload_key, inp.task.target, None, auto_scheduler.HardwareParams(100000, 16, 64)) - state = auto_scheduler.serialization.get_states_from_measure_inputs(inputs, task)[0] - state = dag.infer_bound_from_state(state) + state = dag.infer_bound_from_state(inputs[0].state) fea = auto_scheduler.feature.get_per_store_features_from_states([state], task)[0] names = auto_scheduler.feature.get_per_store_feature_names()