diff --git a/cflearn/models/transformer.py b/cflearn/models/transformer.py index 51ce7ad16..d2e483413 100644 --- a/cflearn/models/transformer.py +++ b/cflearn/models/transformer.py @@ -2,7 +2,7 @@ @ModelBase.register("transformer") -@ModelBase.register_pipe("transformer", head="linear") +@ModelBase.register_pipe("transformer", head="fcnn", head_config="highway") class Transformer(ModelBase): pass