forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree_views.h
977 lines (921 loc) · 27.9 KB
/
tree_views.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
#pragma once
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/strtod.h>
#include <torch/csrc/jit/script/tree.h>
#include <functional>
#include <iostream>
#include <string>
namespace torch {
namespace jit {
namespace script {
// clang-format off
// TreeView provides a statically-typed way to traverse the tree, which should
// be formed according to the grammar below.
//
// A few notes on types and their aliases:
// - List<T> is really a Tree with kind TK_LIST and elements as subtrees
// - Maybe<T> is really a Tree with kind TK_OPTION that has 0 or 1 subtree of type T
// - Builtin types are: Ident (TK_IDENT), String (TK_STRING)
//
// Param = Param(Expr type, Ident name) TK_PARAM
//
// Decl = Decl(List<Param> params, Maybe<Expr> return_type) TK_DECL
// Def = Def(Ident name, Decl decl, List<Stmt> body) TK_DEF
// ClassDef = ClassDef(Ident name, List<Def> body) TK_CLASS_DEF
//
// Stmt = If(Expr cond, List<Stmt> true_body, List<Stmt> false_body) TK_IF
// | For(List<Expr> targets, List<Expr> iters, List<Stmt> body) TK_FOR
// | While(Expr cond, List<Stmt> body) TK_WHILE
// | Global(List<Ident> idents) TK_GLOBAL
// -- NB: the only type of Expr's allowed on lhs are Var
// Or a tuple containing Var with an optional terminating Starred
// | Assign(Expr lhs, Expr rhs) TK_ASSIGN
// | AugAssign(Expr lhs, AugAssignKind aug_op, Expr rhs) TK_AUG_ASSIGN
// | Return(List<Expr> values) TK_RETURN
// | ExprStmt(List<Expr> expr) TK_EXPR_STMT
// | Raise(Expr expr) TK_RAISE
// | Def TK_DEF
//
// Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR
// | BinOp(Expr lhs, Expr rhs)
// | And TK_AND
// | Or TK_OR
// | Lt '<'
// | Gt '>'
// | Eq TK_EQ
// | Le TK_LE
// | Ge TK_GE
// | Ne TK_NE
// | Is TK_IS
// | IsNot TK_ISNOT
// | Add '+'
// | Sub '-'
// | Mul '*'
// | Div '/'
// | Mod '%'
// | MatMult '@'
// | Pow TK_POW
// | UnaryOp(Expr expr)
// | Not TK_NOT
// | USub '-'
// | Const(String value) TK_CONST
// -- NB: x.name(y) is desugared into name(x, y)
// | Apply(Ident name, List<Expr> args, List<Attribute> kwargs) TK_APPLY
// | Select(Expr value, Ident selector) '.'
// | Subscript(Expr value, List<Expr> subscript_exprs) TK_SUBSCRIPT
// | SliceExpr(Maybe<Expr> start, Maybe<Expr> end) TK_SLICE_EXPR
// | Var(Ident name) TK_VAR
// | ListLiteral(List<Expr> inputs) TK_LIST_LITERAL
// | TupleLiteral(List<Expr> inputs) TK_TUPLE_LITERAL
// | Starred(Expr expr) TK_STARRED
//
// -- NB: only allowed expressions are Const or List(Const)
// (List as a value, not type constructor)
// Attribute = Attribute(Ident name, Expr value) TK_ATTRIBUTE
//
// AugAssignKind =
// | Add() TK_PLUS_EQ
// | Sub() TK_MINUS_EQ
// | Mul() TK_TIMES_EQ
// | Div() TK_DIV_EQ
//
// Each subclass of TreeView should provide:
// 1. Constructor that takes a TreeRef, and checks that it's of the right type.
// 2. Accessors that get underlying information out of the object. If they
// return subtrees, they should wrap them in appropriate views too.
// 3. Static method 'create' that creates the underlying TreeRef object
// for every TreeRef kind that has a TreeView, the parser always uses
// (e.g.) Ident::create rather than Compound::Create, this means that
// changes to the structure of Ident are always made right here rather
// than both in the parser and in this code.
// XXX: these structs should have no fields to prevent slicing when passing by value
// clang-format on
struct TreeView {
explicit TreeView(TreeRef tree) : tree_(std::move(tree)) {}
TreeRef tree() const {
return tree_;
}
const SourceRange& range() const {
return tree_->range();
}
operator TreeRef() const {
return tree_;
}
const TreeRef& get() const {
return tree_;
}
int kind() const {
return tree_->kind();
}
void dump() const {
std::cout << tree_;
}
protected:
const TreeRef& subtree(size_t i) const {
return tree_->trees().at(i);
}
TreeRef tree_;
};
template <typename T>
struct ListIterator {
ListIterator(TreeList::const_iterator it) : it(it) {}
bool operator!=(const ListIterator& rhs) const {
return it != rhs.it;
}
bool operator==(const ListIterator& rhs) const {
return it == rhs.it;
}
T operator*() const {
return T(*it);
}
ListIterator& operator+=(std::ptrdiff_t n) {
it += n;
return *this;
}
ListIterator& operator++() {
++it;
return *this;
}
ListIterator& operator--() {
--it;
return *this;
}
private:
TreeList::const_iterator it;
};
template <typename T>
struct List : public TreeView {
using iterator = ListIterator<T>;
using const_iterator = ListIterator<T>;
List(const TreeRef& tree) : TreeView(tree) {
tree->match(TK_LIST);
// Iterate over list to temporarily instantiate Ts that will check the type
for (const T& elem : *this) {
(void)elem; // silence unused warning
}
}
iterator begin() const {
return iterator(tree_->trees().begin());
}
iterator end() const {
return iterator(tree_->trees().end());
}
bool empty() const {
return tree_->trees().begin() == tree_->trees().end();
}
T operator[](size_t i) const {
return T(subtree(i));
}
TreeRef map(const std::function<TreeRef(const T&)>& fn) {
return tree_->map([&](TreeRef v) { return fn(T(v)); });
}
static List create(const SourceRange& range, const std::vector<T>& subtrees) {
TreeList type_erased_sub{subtrees.begin(), subtrees.end()};
return List(Compound::create(TK_LIST, range, std::move(type_erased_sub)));
}
static List unsafeCreate(const SourceRange& range, TreeList&& subtrees) {
return List(Compound::create(TK_LIST, range, std::move(subtrees)));
}
size_t size() const {
return tree_->trees().size();
}
};
template <typename T>
struct Maybe : public TreeView {
explicit Maybe(const TreeRef& tree) : TreeView(tree) {
tree_->match(TK_OPTION);
if (tree_->trees().size() > 1)
throw ErrorReport(tree) << "Maybe trees can have at most one subtree";
}
/* implicit */ Maybe(const T& tree) : TreeView(tree) {}
bool present() const {
return tree_->trees().size() > 0;
}
T get() const {
return T(tree_->trees().at(0));
}
TreeRef map(const std::function<TreeRef(const T&)>& fn) {
return tree_->map([&](TreeRef v) { return fn(T(v)); });
}
static Maybe<T> create(const SourceRange& range) {
return Maybe<T>(Compound::create(TK_OPTION, range, {}));
}
static Maybe<T> create(const SourceRange& range, const T& value) {
return Maybe<T>(Compound::create(TK_OPTION, range, {value}));
}
};
struct Ident : public TreeView {
explicit Ident(const TreeRef& tree) : TreeView(tree) {
tree_->match(TK_IDENT);
}
const std::string& name() const {
return subtree(0)->stringValue();
}
static Ident create(const SourceRange& range, const std::string& name) {
return Ident(Compound::create(TK_IDENT, range, {String::create(name)}));
}
};
////////////////////////////////////////////////////////////////////////////////
// Base types (production LHS)
////////////////////////////////////////////////////////////////////////////////
struct Stmt : public TreeView {
explicit Stmt(const TreeRef& tree) : TreeView(tree) {
switch (tree->kind()) {
case TK_IF:
case TK_FOR:
case TK_WHILE:
case TK_GLOBAL:
case TK_ASSIGN:
case TK_AUG_ASSIGN:
case TK_RETURN:
case TK_EXPR_STMT:
case TK_RAISE:
case TK_ASSERT:
case TK_PASS:
case TK_DEF:
return;
default:
throw ErrorReport(tree)
<< kindToString(tree->kind()) << " is not a valid Stmt";
}
}
};
struct Expr : public TreeView {
explicit Expr(const TreeRef& tree) : TreeView(tree) {
switch (tree->kind()) {
case TK_IF_EXPR:
case TK_AND:
case TK_OR:
case '<':
case '>':
case TK_IS:
case TK_ISNOT:
case TK_EQ:
case TK_LE:
case TK_GE:
case TK_NE:
case '+':
case '-':
case TK_UNARY_MINUS:
case '*':
case TK_STARRED:
case '/':
case '%':
case TK_NOT:
case TK_CONST:
case TK_STRINGLITERAL:
case TK_TRUE:
case TK_FALSE:
case TK_NONE:
case TK_CAST:
case TK_APPLY:
case '.':
case TK_SUBSCRIPT:
case TK_SLICE_EXPR:
case TK_VAR:
case TK_LIST_LITERAL:
case TK_TUPLE_LITERAL:
case TK_DICT_LITERAL:
case '@':
case TK_POW:
case TK_FLOOR_DIV:
case '&':
case '^':
case '|':
case TK_LIST_COMP:
case TK_DOTS:
return;
default:
throw ErrorReport(tree)
<< kindToString(tree->kind()) << " is not a valid Expr";
}
}
};
////////////////////////////////////////////////////////////////////////////////
// Helper nodes (mostly for function arguments)
////////////////////////////////////////////////////////////////////////////////
struct Attribute : public TreeView {
explicit Attribute(const TreeRef& tree) : TreeView(tree) {
tree_->match(TK_ATTRIBUTE);
}
Ident name() const {
return Ident(subtree(0));
}
Expr value() const {
return Expr(subtree(1));
}
static Attribute create(
const SourceRange& range,
const Ident& name,
const TreeRef& value) {
return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value}));
}
};
struct Param : public TreeView {
explicit Param(const TreeRef& tree) : TreeView(tree) {
tree_->match(TK_PARAM);
}
static Param create(
const SourceRange& range,
const Ident& ident,
const Expr& type,
const Maybe<Expr>& def,
bool kwarg_only) {
TreeRef kwarg_only_tree =
Compound::create(kwarg_only ? TK_TRUE : TK_FALSE, range, {});
return Param(
Compound::create(TK_PARAM, range, {ident, type, def, kwarg_only_tree}));
}
Ident ident() const {
return Ident(subtree(0));
}
Expr type() const {
return Expr(subtree(1));
}
Maybe<Expr> defaultValue() const {
return Maybe<Expr>(subtree(2));
}
bool kwarg_only() const {
return TK_TRUE == subtree(3)->kind();
}
Param withType(const Expr& typ) const {
return Param::create(range(), ident(), typ, defaultValue(), kwarg_only());
}
};
////////////////////////////////////////////////////////////////////////////////
// Top level definitions
////////////////////////////////////////////////////////////////////////////////
struct Decl : public TreeView {
explicit Decl(const TreeRef& tree) : TreeView(tree) {
tree->match(TK_DECL);
}
List<Param> params() const {
return List<Param>(subtree(0));
}
Maybe<Expr> return_type() const {
return Maybe<Expr>(subtree(1));
}
static Decl create(
const SourceRange& range,
const List<Param>& params,
const Maybe<Expr>& return_type) {
return Decl(Compound::create(TK_DECL, range, {params, return_type}));
}
};
struct Def : public TreeView {
explicit Def(const TreeRef& tree) : TreeView(tree) {
tree->match(TK_DEF);
}
Def withName(std::string new_name) const {
auto new_ident = Ident::create(name().range(), std::move(new_name));
return create(range(), new_ident, decl(), statements());
}
Ident name() const {
return Ident(subtree(0));
}
Decl decl() const {
return Decl(subtree(1));
}
List<Stmt> statements() const {
return List<Stmt>(subtree(2));
}
static Def create(
const SourceRange& range,
const Ident& name,
const Decl& decl,
const List<Stmt>& stmts) {
return Def(Compound::create(TK_DEF, range, {name, decl, stmts}));
}
};
struct ClassDef : public TreeView {
explicit ClassDef(const TreeRef& tree) : TreeView(tree) {
tree->match(TK_CLASS_DEF);
}
ClassDef withName(std::string new_name) const {
auto new_ident = Ident::create(name().range(), std::move(new_name));
return create(range(), new_ident, defs());
}
Ident name() const {
return Ident(subtree(0));
}
List<Def> defs() const {
return List<Def>(subtree(1));
}
static ClassDef create(
const SourceRange& range,
const Ident& name,
const List<Def>& defs) {
return ClassDef(Compound::create(TK_CLASS_DEF, range, {name, defs}));
}
};
////////////////////////////////////////////////////////////////////////////////
// Statements
////////////////////////////////////////////////////////////////////////////////
struct If : public Stmt {
explicit If(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_IF);
}
Expr cond() const {
return Expr(subtree(0));
}
List<Stmt> trueBranch() const {
return List<Stmt>(subtree(1));
}
List<Stmt> falseBranch() const {
return List<Stmt>(subtree(2));
}
If withNewBranches(
const List<Stmt>& true_branch,
const List<Stmt>& false_branch) const {
return create(range(), cond(), true_branch, false_branch);
}
static If create(
const SourceRange& range,
const Expr& cond,
const List<Stmt>& true_branch,
const List<Stmt>& false_branch) {
return If(
Compound::create(TK_IF, range, {cond, true_branch, false_branch}));
}
};
struct While : public Stmt {
explicit While(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_WHILE);
}
Expr cond() const {
return Expr(subtree(0));
}
List<Stmt> body() const {
return List<Stmt>(subtree(1));
}
static While create(
const SourceRange& range,
const Expr& cond,
const List<Stmt>& body) {
return While(Compound::create(TK_WHILE, range, {cond, body}));
}
};
struct For : public Stmt {
explicit For(const TreeRef& tree) : Stmt(tree) {
tree->match(TK_FOR);
}
List<Expr> targets() const {
return List<Expr>(subtree(0));
}
List<Expr> itrs() const {
return List<Expr>(subtree(1));
}
List<Stmt> body() const {
return List<Stmt>(subtree(2));
}
static For create(
const SourceRange& range,
const List<Expr>& targets,
const List<Expr>& itrs,
const List<Stmt>& body) {
return For(Compound::create(TK_FOR, range, {targets, itrs, body}));
}
};
// TODO: supports only single comprehension for now
struct ListComp : public Expr {
explicit ListComp(const TreeRef& tree) : Expr(tree) {
tree->match(TK_LIST_COMP);
}
Expr elt() const {
return Expr(subtree(0));
}
Expr target() const {
return Expr(subtree(1));
}
Expr iter() const {
return Expr(subtree(2));
}
// TODO: no ifs for now
static ListComp create(
const SourceRange& range,
const Expr& elt,
const Expr& target,
const Expr& iter) {
return ListComp(Compound::create(TK_LIST_COMP, range, {elt, target, iter}));
}
};
struct Global : public Stmt {
explicit Global(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_GLOBAL);
}
List<Ident> names() {
return List<Ident>(subtree(0));
}
static Global create(const SourceRange& range, const List<Ident>& names) {
return Global(Compound::create(TK_GLOBAL, range, {names}));
}
};
struct AugAssignKind : public TreeView {
explicit AugAssignKind(const TreeRef& tree) : TreeView(tree) {
switch (tree->kind()) {
case '+':
case '-':
case '*':
case '/':
return;
default:
throw ErrorReport(tree) << "is not a valid AugAssignKind";
}
}
};
// Augmented assignment, like "foo += bar"
struct AugAssign : public Stmt {
explicit AugAssign(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_AUG_ASSIGN);
}
static AugAssign create(
const SourceRange& range,
const Expr& lhs,
const AugAssignKind& aug_op,
const Expr& rhs) {
return AugAssign(
Compound::create(TK_AUG_ASSIGN, range, {lhs, aug_op, rhs}));
}
Expr lhs() const {
return Expr(subtree(0));
}
int aug_op() const {
return subtree(1)->kind();
}
Expr rhs() const {
return Expr(subtree(2));
}
};
struct Assign : public Stmt {
explicit Assign(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_ASSIGN);
}
static Assign create(
const SourceRange& range,
const Expr& lhs,
const Expr& rhs) {
return Assign(Compound::create(TK_ASSIGN, range, {lhs, rhs}));
}
Expr lhs() const {
return Expr(subtree(0));
}
Expr rhs() const {
return Expr(subtree(1));
}
};
struct Return : public Stmt {
explicit Return(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_RETURN);
}
Expr expr() const {
return Expr(subtree(0));
}
static Return create(const SourceRange& range, const Expr& value) {
return Return(Compound::create(TK_RETURN, range, {value}));
}
};
struct Raise : public Stmt {
explicit Raise(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_RAISE);
}
Maybe<Expr> expr() const {
return Maybe<Expr>(subtree(0));
}
static Raise create(const SourceRange& range, const Maybe<Expr>& expr) {
return Raise(Compound::create(TK_RAISE, range, {expr}));
}
};
struct Assert : public Stmt {
explicit Assert(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_ASSERT);
}
Expr test() const {
return Expr(subtree(0));
}
Maybe<Expr> msg() const {
return Maybe<Expr>(subtree(1));
}
static Assert create(
const SourceRange& range,
const Expr& test,
const Maybe<Expr>& msg) {
return Assert(Compound::create(TK_ASSERT, range, {test, msg}));
}
};
struct Pass : public Stmt {
explicit Pass(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_PASS);
}
static Pass create(const SourceRange& range) {
return Pass(Compound::create(TK_PASS, range, {}));
}
};
struct Dots : public Expr {
explicit Dots(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_DOTS);
}
static Dots create(const SourceRange& range) {
return Dots(Compound::create(TK_DOTS, range, {}));
}
};
struct ExprStmt : public Stmt {
explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_EXPR_STMT);
}
Expr expr() {
return Expr(subtree(0));
}
static ExprStmt create(const SourceRange& range, const Expr& list) {
return ExprStmt(Compound::create(TK_EXPR_STMT, range, {list}));
}
};
////////////////////////////////////////////////////////////////////////////////
// Expressions
////////////////////////////////////////////////////////////////////////////////
struct BinOp : public Expr {
explicit BinOp(const TreeRef& tree) : Expr(tree) {
switch (tree->kind()) {
case TK_AND:
case TK_OR:
case '<':
case '>':
case TK_IS:
case TK_ISNOT:
case TK_EQ:
case TK_LE:
case TK_GE:
case TK_NE:
case '+':
case '*':
case '/':
case '-':
case '@':
case TK_POW:
case '%':
case '&':
case '^':
case '|':
case TK_FLOOR_DIV:
if (tree->trees().size() != 2)
throw ErrorReport(tree)
<< "BinOp expected 2 subtrees, found " << tree->trees().size();
return;
default:
throw ErrorReport(tree)
<< kindToString(tree->kind()) << " is not a valid BinOp";
}
}
Expr lhs() const {
return Expr(subtree(0));
}
Expr rhs() const {
return Expr(subtree(1));
}
static BinOp create(
const SourceRange& range,
int kind,
const Expr& lhs,
const Expr& rhs) {
return BinOp(Compound::create(kind, range, {lhs, rhs}));
}
};
struct UnaryOp : public Expr {
explicit UnaryOp(const TreeRef& tree) : Expr(tree) {
switch (tree->kind()) {
case TK_UNARY_MINUS:
case TK_NOT:
if (tree->trees().size() != 1)
throw ErrorReport(tree)
<< "UnaryOp expected 1 subtree, found " << tree->trees().size();
return;
default:
throw ErrorReport(tree)
<< kindToString(tree->kind()) << " is not a valid UnaryOp";
}
}
static UnaryOp create(const SourceRange& range, int kind, const Expr& expr) {
return UnaryOp(Compound::create(kind, range, {expr}));
}
};
struct Const : public Expr {
explicit Const(const TreeRef& tree) : Expr(tree) {
tree_->matchNumSubtrees(TK_CONST, 1);
}
bool isFloatingPoint() const {
return subtree(0)->stringValue().find_first_of(".eE") != std::string::npos;
}
bool isIntegral() const {
return !isFloatingPoint();
}
int64_t asIntegral() const {
return c10::stoll(subtree(0)->stringValue());
}
double asFloatingPoint() const {
char* dummy;
return torch::jit::script::strtod_c(
subtree(0)->stringValue().c_str(), &dummy);
}
const std::string& text() const {
return subtree(0)->stringValue();
}
static Const create(const SourceRange& range, const std::string& value) {
return Const(Compound::create(TK_CONST, range, {String::create(value)}));
}
};
struct StringLiteral : public Expr {
explicit StringLiteral(const TreeRef& tree) : Expr(tree) {
tree_->matchNumSubtrees(TK_STRINGLITERAL, 1);
}
const std::string& text() const {
return subtree(0)->stringValue();
}
static StringLiteral create(
const SourceRange& range,
const std::string& value) {
return StringLiteral(
Compound::create(TK_STRINGLITERAL, range, {String::create(value)}));
}
};
struct Apply : public Expr {
explicit Apply(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_APPLY);
}
Expr callee() const {
return Expr(subtree(0));
}
List<Expr> inputs() const {
return List<Expr>(subtree(1));
}
List<Attribute> attributes() const {
return List<Attribute>(subtree(2));
}
static Apply create(
const SourceRange& range,
const Expr& callee,
const List<Expr>& inputs,
const List<Attribute>& attributes) {
return Apply(
Compound::create(TK_APPLY, range, {callee, inputs, attributes}));
}
};
struct Select : public Expr {
explicit Select(const TreeRef& tree) : Expr(tree) {
tree_->match('.');
}
Expr value() const {
return Expr(subtree(0));
}
Ident selector() const {
return Ident(subtree(1));
}
static Select create(
const SourceRange& range,
const Expr& value,
const Ident& selector) {
return Select(Compound::create('.', range, {value, selector}));
}
};
struct SliceExpr : public Expr {
explicit SliceExpr(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_SLICE_EXPR);
}
Maybe<Expr> start() const {
return Maybe<Expr>(subtree(0));
}
Maybe<Expr> end() const {
return Maybe<Expr>(subtree(1));
}
Expr startOr(int alternative) const {
const auto startOption = start();
return startOption.present() ? startOption.get() : createInt(alternative);
}
Expr endOr(int alternative) const {
const auto endOption = end();
return endOption.present() ? endOption.get() : createInt(alternative);
}
static SliceExpr create(
const SourceRange& range,
const Maybe<Expr>& start,
const Maybe<Expr>& end) {
return SliceExpr(Compound::create(TK_SLICE_EXPR, range, {start, end}));
}
private:
Expr createInt(int value) const {
return Expr(Const::create(range(), std::to_string(value)));
}
};
struct Subscript : public Expr {
explicit Subscript(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_SUBSCRIPT);
}
Expr value() const {
return Expr(subtree(0));
}
List<Expr> subscript_exprs() const {
return List<Expr>(subtree(1));
}
static Subscript create(
const SourceRange& range,
const Expr& value,
const List<Expr>& subscript_exprs) {
return Subscript(
Compound::create(TK_SUBSCRIPT, range, {value, subscript_exprs}));
}
};
struct Var : public Expr {
explicit Var(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_VAR);
};
Ident name() const {
return Ident(subtree(0));
}
static Var create(const SourceRange& range, const Ident& name) {
return Var(Compound::create(TK_VAR, range, {name}));
}
};
struct TernaryIf : public Expr {
explicit TernaryIf(const TreeRef& tree) : Expr(tree) {
tree_->matchNumSubtrees(TK_IF_EXPR, 3);
};
Expr cond() const {
return Expr(subtree(0));
}
Expr true_expr() const {
return Expr(subtree(1));
}
Expr false_expr() const {
return Expr(subtree(2));
}
static TernaryIf create(
const SourceRange& range,
const Expr& cond,
const Expr& true_expr,
const Expr& false_expr) {
return TernaryIf(
Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr}));
};
};
struct ListLiteral : public Expr {
explicit ListLiteral(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_LIST_LITERAL);
}
List<Expr> inputs() const {
return subtree(0);
}
static ListLiteral create(
const SourceRange& range,
const List<Expr>& inputs) {
return ListLiteral(Compound::create(TK_LIST_LITERAL, range, {inputs}));
}
};
struct TupleLiteral : public Expr {
explicit TupleLiteral(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_TUPLE_LITERAL);
}
List<Expr> inputs() const {
return subtree(0);
}
static TupleLiteral create(
const SourceRange& range,
const List<Expr>& inputs) {
return TupleLiteral(Compound::create(TK_TUPLE_LITERAL, range, {inputs}));
}
};
struct DictLiteral : public Expr {
explicit DictLiteral(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_DICT_LITERAL);
}
List<Expr> key_inputs() const {
return subtree(0);
}
List<Expr> value_inputs() const {
return subtree(1);
}
static DictLiteral create(
const SourceRange& range,
const List<Expr>& keys,
const List<Expr>& values) {
return DictLiteral(
Compound::create(TK_DICT_LITERAL, range, {keys, values}));
}
};
struct Starred : public Expr {
explicit Starred(const TreeRef& tree) : Expr(tree) {
tree_->match(TK_STARRED);
}
Expr expr() const {
return Expr(subtree(0));
}
static Starred create(const SourceRange& range, const Expr& expr) {
return Starred(Compound::create(TK_STARRED, range, {expr}));
}
};
} // namespace script
} // namespace jit
} // namespace torch
namespace std {
template <typename T>
struct iterator_traits<torch::jit::script::ListIterator<T>>
: std::iterator_traits<torch::jit::script::TreeList::const_iterator> {};
} // namespace std