Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] Separate shapes from DAG hash and enable schedule sharing #7317

Merged
merged 7 commits into from
Jan 25, 2021

Conversation

comaniac
Copy link
Contributor

@comaniac comaniac commented Jan 20, 2021

In this PR, we attempt to enable schedule sharing as a workaround before the dynamic shape support is fully landed. The idea is that if we have a schedule for batch size 1, then it is actually applicable to all other batch sizes (regardless the performance). This is useful when we only tune the workload with batch size 1 but wish to use it for all batch sizes to at least make the flow working.

To do so, we introduce "workload distance factor", which indicates the similarity of two workloads. Specifically, it is calculated by the following rules:

  • If two workloads are not for the same compute DAG or function, then inf.
  • If two workloads are for the same compute DAG/function, and
    • their non-zero integer arguments are dividable and their zero and non-integer arguments are the same, then factor=prod(a / b) for a, b in zip(wkl1.args, wkl2.args),
    • otherwise inf.

As a result, the distance factor ranges from 1 to inf. When the distance factor is not inf, meaning that it is safe to apply the schedule of workload 2 to workload 1.

The above mechanism works well for registered TE computes but not the ComputeDAG extracted from Relay programs. This is because currently when extracting tasks from Relay, we use MD5 to hash the ComputeDAG serialized string to be its key, which includes not only the DAG structure but the shapes, so it's impossible to calculate the distance factor. To make it work, this PR also improves the hashing mechanism of ComputeDAG by separating the input/output tensor shapes so that they can be accessed. For example, the workload key of a ComputeDAG was:

["8d5a93959138dc7b2ee1f1b3219dfa14"]

and it now becomes:

["ad6cecbf5d85cb1cda3c2bb7af170211", 1, 7, 7, 512, 4, 4, 512, 512, 1, 7, 7, 512, 1, 1, 1, 512, 1, 1, 1, 512, 1, 7, 7, 512]

Please note that since we change the workload key format of ComputeDAG, the tuning logs won't match anymore. To make it work again, we can use the following script to update the keys in existing log files. This is also the way I used to update the CI logs:

import json
import hashlib
import os
import sys

from tvm.te import ComputeOp, PlaceholderOp

from tvm.auto_scheduler import save_records
from tvm.auto_scheduler.measure import MeasureInput
from tvm.auto_scheduler.measure_record import load_records
from tvm.auto_scheduler.utils import get_const_tuple

tasks = [] # Extract tasks from a Relay program
log_file = "old-log-file"
new_log_file = "new-log-file"

def get_old_hash_key(dag):
    """Return the hash key of a compute DAG."""
    str_key = ""
    for op in dag.ops:
        t = op.output(0)
        if isinstance(op, PlaceholderOp):
            str_key += "placeholder,"
            str_key += str(get_const_tuple(t.shape)) + ","
            str_key += t.dtype + ";"
        elif isinstance(op, ComputeOp):
            str_key += str(t.op.body) + ","
            str_key += str(get_const_tuple(t.shape)) + ","
            str_key += t.dtype + ";"
        else:
            raise ValueError("Invalid op: " + op)

    str_key = str_key.encode(encoding="utf-8")
    return hashlib.md5(str_key).hexdigest()


# Establish the key mapping
old_key_to_task = {}
hit_count = {}
for idx, task in enumerate(tasks):
    old_key = json.dumps((get_old_hash_key(task.compute_dag),))
    old_key_to_task[old_key] = task
    hit_count[old_key] = 0
    print("Task %d %s -> %s" % (idx, old_key, task.workload_key))


# Update the workload key in an existing log file
new_inputs = []
new_results = []
for inp, res in load_records(log_file):
    if inp.task.workload_key not in old_key_to_task:
        print(
            "Ignore key %s in log file due to no corresponding task found" % inp.task.workload_key
        )
        continue
    hit_count[inp.task.workload_key] += 1
    new_inputs.append(MeasureInput(old_key_to_task[inp.task.workload_key], inp.state))
    new_results.append(res)

for key, cnt in hit_count.items():
    print("Old key %s hits %d times" % (key, cnt))

if os.path.exists(new_log_file):
    os.remove(new_log_file)
save_records(new_log_file, new_inputs, new_results)

cc @merrymercy @jcf94

@comaniac comaniac requested review from merrymercy and jcf94 January 20, 2021 23:37
@merrymercy
Copy link
Member

Because this PR breaks compatibility, should we add some warning messages and point users to this PR?

@comaniac
Copy link
Contributor Author

Because this PR breaks compatibility, should we add some warning messages and point users to this PR?

How about adding the message to the exsiting schedule not found warning and point to this PR?

@comaniac comaniac force-pushed the ansor_sep_dag_arg branch 2 times, most recently from 0d6c891 to 9352836 Compare January 23, 2021 02:05
@merrymercy
Copy link
Member

The added warning message is too general. There are a lot of other reasons to hit this warning message. So I prefer not having it.
I don't have a good place to put this warning message either. Let us delete it and merge this.

@comaniac
Copy link
Contributor Author

The added warning message is too general. There are a lot of other reasons to hit this warning message. So I prefer not having it.
I don't have a good place to put this warning message either. Let us delete it and merge this.

I agreed that the added warning message was too general so I reverted it. Meanwhile, I came up with another idea that adds the warning message when loading the records. Now users will see the following message when loading the logs with old format:

MeasureInput with old format workload key ["b32ed43fb351136894c322ee49097a1a"] should be updated using the script from https://github.com/apache/tvm/pull/7317.

@FrozenGene FrozenGene merged commit e6d5318 into apache:main Jan 25, 2021
@FrozenGene
Copy link
Member

Thanks @comaniac @merrymercy It is merged.

@comaniac comaniac deleted the ansor_sep_dag_arg branch February 5, 2021 22:41
alexwong pushed a commit to alexwong/tvm that referenced this pull request Feb 11, 2021
…ring (apache#7317)

* [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing

* Update CI logs

* lint

* fix registry

* add message; fix layout rewrite mismatch

* update message

* support other formats
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
…ring (apache#7317)

* [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing

* Update CI logs

* lint

* fix registry

* add message; fix layout rewrite mismatch

* update message

* support other formats
Lokiiiiii pushed a commit to Lokiiiiii/tvm that referenced this pull request Mar 2, 2021
…ring (apache#7317)

* [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing

* Update CI logs

* lint

* fix registry

* add message; fix layout rewrite mismatch

* update message

* support other formats
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Mar 2, 2021
…ring (apache#7317)

* [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing

* Update CI logs

* lint

* fix registry

* add message; fix layout rewrite mismatch

* update message

* support other formats
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants