-
Notifications
You must be signed in to change notification settings - Fork 2
/
isel.cpp
1832 lines (1507 loc) · 101 KB
/
isel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <span>
#include <thread>
#include <llvm/Support/Debug.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/ControlFlow/IR/ControlFlow.h>
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Rewrite/PatternApplicator.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Pass/Pass.h>
#include <mlir/IR/PatternMatch.h>
// to let llvm handle translating globals
#include <llvm/IR/IRBuilder.h>
#include <mlir/Conversion/LLVMCommon/Pattern.h>
#include <mlir/Target/LLVMIR/ModuleTranslation.h>
#include <mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include "util.h"
#include "isel.h"
#include "AMD64/AMD64Dialect.h"
#include "AMD64/AMD64Ops.h"
using GlobalsInfo = amd64::GlobalsInfo;
using GlobalSymbolInfo = amd64::GlobalSymbolInfo;
// anonymous namespace to contain patterns
namespace {
#include "AMD64/Lowerings.cpp.inc"
using namespace mlir;
// use this to mark that a specific type of instruction is not available to use in the lambda of a pattern
using NA = void;
// std::derived_from does not work, because interfaces that inherit from each other don't *actually* inherit anything in Cpp, they just have conversion operators
template<typename Derived, typename Base>
concept MLIRInterfaceDerivedFrom = requires(Derived d){
{ d } -> std::convertible_to<Base>;
};
template <MLIRInterfaceDerivedFrom<amd64::RegisterTypeInterface> RegisterTy, bool matchZeroResult = false>
auto defaultBitwidthMatchLambda = []<unsigned bitwidth>(auto thiis, auto op, typename mlir::OpConversionPattern<decltype(op)>::OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter){
using OpTy = decltype(op);
auto matchZeroResultReturn = [](){
if constexpr(matchZeroResult)
return mlir::success();
else
return mlir::failure();
};
if constexpr (OpTy::template hasTrait<mlir::OpTrait::ZeroResults>()){
return matchZeroResultReturn();
}
if constexpr(OpTy::template hasTrait<mlir::OpTrait::VariadicResults>()){
if(op.getNumResults() == 0)
return matchZeroResultReturn();
}
mlir::Type opType;
if constexpr (OpTy::template hasTrait<mlir::OpTrait::OneResult>())
opType = op.getType();
else
opType = op->getResult(0).getType();
auto type = thiis->getTypeConverter()->convertType(opType);
if(!type)
return rewriter.notifyMatchFailure(op, "type conversion failed");
auto typeToMatch= type.template dyn_cast<RegisterTy>();
if(!typeToMatch)
return rewriter.notifyMatchFailure(op, "not the correct register type");
if(typeToMatch.getBitwidth() != bitwidth)
return rewriter.notifyMatchFailure(op, "bitwidth mismatch");
return mlir::success();
};
auto intOrFloatBitwidthMatchLambda = defaultBitwidthMatchLambda<amd64::RegisterTypeInterface>;
auto intBitwidthMatchLambda = defaultBitwidthMatchLambda<amd64::GPRegisterTypeInterface>;
auto floatBitwidthMatchLambda = defaultBitwidthMatchLambda<amd64::FPRegisterTypeInterface>;
auto matchAllLambda = []<unsigned>(auto, auto, auto, mlir::ConversionPatternRewriter&){
return mlir::success();
};
template<typename T, typename OpTy, typename AdaptorTy>
concept MatchReplaceLambda = requires(T lambda, OpTy op, AdaptorTy adaptor, mlir::ConversionPatternRewriter& rewriter){
{ lambda.template operator()<8, NA, NA, NA, NA, NA>(op, adaptor, rewriter) } -> std::convertible_to<mlir::LogicalResult>;
};
/// somewhat generic pattern matching struct
template<
typename OpTy,
unsigned bitwidth,
auto lambda,
// default template parameters start
// if i specify the default arg inline, that instantly crashes clangd. But using a separate variable reduces code duplication anyway, so thanks I guess?
auto bitwidthMatchLambda = intBitwidthMatchLambda,
int benefit = 1,
typename INST1 = NA, typename INST2 = NA, typename INST3 = NA, typename INST4 = NA, typename INST5 = NA
>
//requires MatchReplaceLambda<decltype(lambda), OpTy, typename mlir::OpConversionPattern<OpTy>::OpAdaptor>
struct Match : public mlir::OpConversionPattern<OpTy>{
Match(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<OpTy>(tc, ctx, benefit){}
using OpAdaptor = typename mlir::OpConversionPattern<OpTy>::OpAdaptor;
mlir::LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override {
auto bitwidthMatchResult = bitwidthMatchLambda.template operator()<bitwidth>(this, op, adaptor, rewriter);
if(bitwidthMatchResult.failed())
return bitwidthMatchResult;
return lambda.template operator()<bitwidth, INST1, INST2, INST3, INST4, INST5>(op, adaptor, rewriter);
}
};
#define PATTERN_FOR_BITWIDTH_INT_MRI(bitwidth, patternName, OpTy, opPrefixToReplaceWith, lambda, ...) \
using patternName ## bitwidth = Match<OpTy, bitwidth, lambda, ## __VA_ARGS__, \
opPrefixToReplaceWith ## bitwidth ## rr, opPrefixToReplaceWith ## bitwidth ## ri, opPrefixToReplaceWith ## bitwidth ## rm, opPrefixToReplaceWith ## bitwidth ## mi, opPrefixToReplaceWith ## bitwidth ## mr>;
#define PATTERN_INT_1(patternName, opTy, opPrefixToReplaceWith, lambda, matchLambda, benefit) \
PATTERN_FOR_BITWIDTH_INT_MRI(8, patternName, opTy, opPrefixToReplaceWith, lambda, matchLambda, benefit) \
PATTERN_FOR_BITWIDTH_INT_MRI(16, patternName, opTy, opPrefixToReplaceWith, lambda, matchLambda, benefit) \
PATTERN_FOR_BITWIDTH_INT_MRI(32, patternName, opTy, opPrefixToReplaceWith, lambda, matchLambda, benefit) \
PATTERN_FOR_BITWIDTH_INT_MRI(64, patternName, opTy, opPrefixToReplaceWith, lambda, matchLambda, benefit)
#define PATTERN_INT_2(patternName, opTy, opPrefixToReplaceWith, lambda, matchLambda) \
PATTERN_INT_1(patternName, opTy, opPrefixToReplaceWith, lambda, matchLambda, 1)
#define PATTERN_INT_3(patternName, opTy, opPrefixToReplaceWith, lambda) \
PATTERN_INT_2(patternName, opTy, opPrefixToReplaceWith, lambda, intBitwidthMatchLambda)
#define GET_MACRO(_1, _2, _3, _4, _5, _6, name, ...) name
// default args for the macros
#define PATTERN_INT(...) GET_MACRO(__VA_ARGS__, PATTERN_INT_1, PATTERN_INT_2, PATTERN_INT_3)(__VA_ARGS__)
template<typename OpTy, auto matchAndRewriteInject, unsigned benefit = 1>
struct SimplePat : public mlir::OpConversionPattern<OpTy>{
using OpAdaptor = typename mlir::OpConversionPattern<OpTy>::OpAdaptor;
SimplePat(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<OpTy>(tc, ctx, benefit){}
mlir::LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override {
return matchAndRewriteInject(op, adaptor, rewriter);
}
};
template<typename OpTy>
struct ErasePat : public mlir::OpConversionPattern<OpTy>{
using OpAdaptor = typename mlir::OpConversionPattern<OpTy>::OpAdaptor;
ErasePat(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<OpTy>(tc, ctx, 1){}
mlir::LogicalResult matchAndRewrite(OpTy op, OpAdaptor, mlir::ConversionPatternRewriter& rewriter) const override {
rewriter.eraseOp(op);
return mlir::success();
}
};
// TODO technically I could replace the OpAdaptor template arg everywhere with a simple `auto adaptor` parameter
// TODO even better: use a typename T for the op, and then use T::Adaptor for the adaptor
// TODO an alternative would be to generate custom builders for the RR versions, which check if their argument is a movxxri and then fold it into the RR, resulting in an RI version. That probably wouldn't work because the returned thing would of course expect an RR version, not an RI version
auto binOpMatchReplace = []<unsigned actualBitwidth,
typename INSTrr, typename INSTri, typename INSTrm, typename INSTmi, typename INSTmr
>(auto op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
rewriter.replaceOpWithNewOp<INSTrr>(op, adaptor.getLhs(), adaptor.getRhs());
return mlir::success();
};
// it would be nice to use folds for matching mov's and folding them into the add, but that's not possible right now, so we either have to match it here, or ignore it for now (binOpMatchReplace's approach)
// I finally found out when to use the OpAdaptor and when not to: The OpAdaptor seems to give access to the operands in their already converted form, whereas the op itself still has all operands in their original form.
// In this case we need to access the operand in the original form, to check if it was a constant, we're not interested in what it got converted to
auto binOpAndImmMatchReplace = []<unsigned actualBitwidth,
typename INSTrr, typename INSTri, typename INSTrm, typename INSTmi, typename INSTmr
>(auto op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
// TODO change this so it can also match constants from the llvm dialect
auto constantOp = mlir::dyn_cast_or_null<mlir::arith::ConstantIntOp>(op.getLhs().getDefiningOp());
auto other = adaptor.getRhs();
if(!constantOp){
constantOp = mlir::dyn_cast_or_null<mlir::arith::ConstantIntOp>(op.getRhs().getDefiningOp());
other = adaptor.getLhs();
}
if(!constantOp ||
// immediate is max 32 bit, otherwise we have to generate a mov for it
fitsInto32BitImm(constantOp.value())
){
// -> we need to use the RR version, there is no pure immediate operand
rewriter.replaceOpWithNewOp<INSTrr>(op, adaptor.getLhs(), adaptor.getRhs());
}else{
// -> there is a pure immediate operand, which fits into the instruction -> we can use the RI version to save the MOVxxri
auto newOp = rewriter.replaceOpWithNewOp<INSTri>(op, other);
newOp.instructionInfo().imm = constantOp.value();
}
return mlir::success();
};
// TODO think about whether to define different patterns with this or not etc.
// TODO this binOpAndImmMatchReplace could be split up into multiple patterns, but that might be slower
PATTERN_INT(AddIPat, arith::AddIOp, amd64::ADD, binOpAndImmMatchReplace, intBitwidthMatchLambda, 2);
PATTERN_INT(SubIPat, arith::SubIOp, amd64::SUB, binOpMatchReplace);
PATTERN_INT(AndIPat, arith::AndIOp, amd64::AND, binOpMatchReplace);
PATTERN_INT(OrIPat, arith::OrIOp, amd64::OR, binOpMatchReplace);
PATTERN_INT(XOrIPat, arith::XOrIOp, amd64::XOR, binOpMatchReplace);
auto cmpIBitwidthMatcher =
[]<unsigned innerBitwidth>(auto thiis, auto op, auto, mlir::ConversionPatternRewriter&){
// cmp always has i1 as a result type, so we need to match the arguments' bitwidths
auto tc = thiis->getTypeConverter();
auto lhsNewType = tc->convertType(op.getLhs().getType()).template dyn_cast<amd64::GPRegisterTypeInterface>(); // TODO maybe we can use the adaptor and just use .getType on that instead of bothering with the type converter. Not just here, but everywhere
auto rhsNewType = tc->convertType(op.getRhs().getType()).template dyn_cast<amd64::GPRegisterTypeInterface>();
assert(lhsNewType && rhsNewType && "cmp's operands should be convertible to register types");
bool success = lhsNewType.getBitwidth() == innerBitwidth && lhsNewType == rhsNewType;
return mlir::failure(!success);
};
// TODO specify using benefits, that this has to have lower priority than the version which matches a jump with a cmp as an argument (which doesn't exist yet).
template<typename CmpPredicate>
auto cmpIMatchReplace = []<unsigned actualBitwidth,
typename CMPrr, typename CMPri, typename, typename, typename
>(auto op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
auto loc = op->getLoc();
rewriter.create<CMPrr>(loc, adaptor.getLhs(), adaptor.getRhs());
if(!op->use_empty()){
switch(op.getPredicate()){
case CmpPredicate::eq: rewriter.replaceOpWithNewOp<amd64::SETE8r>(op); break;
case CmpPredicate::ne: rewriter.replaceOpWithNewOp<amd64::SETNE8r>(op); break;
case CmpPredicate::slt: rewriter.replaceOpWithNewOp<amd64::SETL8r>(op); break;
case CmpPredicate::sle: rewriter.replaceOpWithNewOp<amd64::SETLE8r>(op); break;
case CmpPredicate::sgt: rewriter.replaceOpWithNewOp<amd64::SETG8r>(op); break;
case CmpPredicate::sge: rewriter.replaceOpWithNewOp<amd64::SETGE8r>(op); break;
case CmpPredicate::ult: rewriter.replaceOpWithNewOp<amd64::SETB8r>(op); break;
case CmpPredicate::ule: rewriter.replaceOpWithNewOp<amd64::SETBE8r>(op); break;
case CmpPredicate::ugt: rewriter.replaceOpWithNewOp<amd64::SETA8r>(op); break;
case CmpPredicate::uge: rewriter.replaceOpWithNewOp<amd64::SETAE8r>(op); break;
}
}else{
rewriter.eraseOp(op); // we need to replace the root op, the CMP doesn't have a result, so replace it with nothing
}
return mlir::success();
};
PATTERN_INT(CmpIPat, mlir::arith::CmpIOp, amd64::CMP, cmpIMatchReplace<mlir::arith::CmpIPredicate>, cmpIBitwidthMatcher);
template <bool isRem>
auto matchDivRem = []<unsigned actualBitwidth,
typename DIVr, typename, typename, typename, typename
>(auto op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
auto div = rewriter.create<DIVr>(op->getLoc(), adaptor.getLhs(), adaptor.getRhs());
if constexpr(isRem)
rewriter.replaceOp(op, div.getRemainder());
else
rewriter.replaceOp(op, div.getQuotient());
return mlir::success();
};
#define MULI_DIVI_PAT(bitwidth) \
using MulIPat ## bitwidth = Match<mlir::arith::MulIOp, bitwidth, binOpMatchReplace, intBitwidthMatchLambda, 1, amd64::MUL ## bitwidth ## r>; \
using DivUIPat ## bitwidth = Match<mlir::arith::DivUIOp, bitwidth, \
matchDivRem<false>, intBitwidthMatchLambda, 1, \
amd64::DIV ## bitwidth ## r>; \
using DivSIPat ## bitwidth = Match<mlir::arith::DivSIOp, bitwidth, \
matchDivRem<false>,intBitwidthMatchLambda, 1, \
amd64::IDIV ## bitwidth ## r>; \
using RemUIPat ## bitwidth = Match<mlir::arith::RemUIOp, bitwidth, \
matchDivRem<true>, intBitwidthMatchLambda, 1, \
amd64::DIV ## bitwidth ## r>; \
using RemSIPat ## bitwidth = Match<mlir::arith::RemSIOp, bitwidth, \
matchDivRem<true>, intBitwidthMatchLambda, 1, \
amd64::IDIV ## bitwidth ## r>;
MULI_DIVI_PAT(8); MULI_DIVI_PAT(16); MULI_DIVI_PAT(32); MULI_DIVI_PAT(64);
#undef MULI_DIVI_PAT
#define SHIFT_PAT(bitwidth) \
using ShlIPat ## bitwidth = Match<mlir::arith::ShLIOp, bitwidth, binOpMatchReplace, intBitwidthMatchLambda, 1, amd64::SHL ## bitwidth ## rr, amd64::SHL ## bitwidth ## ri>; \
using ShrUIPat ## bitwidth = Match<mlir::arith::ShRUIOp, bitwidth, binOpMatchReplace, intBitwidthMatchLambda, 1, amd64::SHR ## bitwidth ## rr, amd64::SHR ## bitwidth ## ri>; \
using ShrSIPat ## bitwidth = Match<mlir::arith::ShRSIOp, bitwidth, binOpMatchReplace, intBitwidthMatchLambda, 1, amd64::SAR ## bitwidth ## rr, amd64::SAR ## bitwidth ## ri>;
SHIFT_PAT(8); SHIFT_PAT(16); SHIFT_PAT(32); SHIFT_PAT(64);
#undef SHIFT_PAT
auto movMatchReplace = []<unsigned actualBitwidth,
typename INSTrr, typename INSTri, typename INSTrm, typename INSTmi, typename INSTmr
>(mlir::arith::ConstantIntOp op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
rewriter.replaceOpWithNewOp<INSTri>(op, op.value());
return mlir::success();
};
PATTERN_INT(ConstantIntPat, mlir::arith::ConstantIntOp, amd64::MOV, movMatchReplace);
// sign/zero extensions
/// ZExt from i1 pattern
//struct ExtUII1Pat : mlir::OpConversionPattern<mlir::arith::ExtUIOp> {
// ExtUII1Pat(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<mlir::arith::ExtUIOp>(tc, ctx, 2){}
//
// mlir::LogicalResult matchAndRewrite(mlir::arith::ExtUIOp zextOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override {
// // we're only matching i1s here
// if(!zextOp.getIn().getType().isInteger(1))
// return rewriter.notifyMatchFailure(zextOp, "this pattern only extends i1s");
//
// // to be precise, we're only matching cmps for the moment, although this might change later
// if(!zextOp.getIn().isa<mlir::OpResult>())
// return rewriter.notifyMatchFailure(zextOp, "i1 zext pattern only matches cmps, this seems to be a block arg");
//
// auto cmpi = mlir::dyn_cast<mlir::arith::CmpIOp>(zextOp.getIn().getDefiningOp());
// if(!cmpi)
// return rewriter.notifyMatchFailure(zextOp, "only cmps are supported for i1 extension for now");
//
// assert(zextOp.getOut().getType().isIntOrFloat() && "extui with non-int result type");
//
// // TODO this doesn't make any sense without a CMOVcc/SETcc for/after the cmp
// switch(zextOp.getOut().getType().getIntOrFloatBitWidth()){
// case 8: rewriter.replaceOp(zextOp, adaptor.getIn() [> the cmpi should be replaced by a SETcc in another pattern, as this sets an 8 bit register anyway, we don't need to do anything here <]); break;
// case 16: rewriter.replaceOpWithNewOp<amd64::MOVZXr16r8>(zextOp, adaptor.getIn()); break;
// case 32: rewriter.replaceOpWithNewOp<amd64::MOVZXr32r8>(zextOp, adaptor.getIn()); break;
// case 64: rewriter.replaceOpWithNewOp<amd64::MOVZXr64r8>(zextOp, adaptor.getIn()); break;
//
// default:
// return rewriter.notifyMatchFailure(zextOp, "unsupported bitwidth for i1 extension");
// }
// return mlir::success();
// }
//};
// TODO sign extension for i1
/// the inBitwidth matches the bitwidth of the input operand to the extui, which needs to be different per pattern, because the corresponding instruction differs.
template<unsigned inBitwidth, auto getIn, auto getOut>
/// the outBitwidth matches the bitwidth of the result of the extui, which also affects which instruction is used.
/// ---
/// works with both floats and ints
auto truncExtUiSiBitwidthMatcher = []<unsigned outBitwidth>(auto thiis, auto op, auto adaptor, mlir::ConversionPatternRewriter& rewriter){
// TODO shouldn't this use getOut(op) instead of op?
// out bitwidth
auto failure = intOrFloatBitwidthMatchLambda.operator()<outBitwidth>(thiis, op, adaptor, rewriter);
if(failure.failed())
return rewriter.notifyMatchFailure(op, "out bitwidth mismatch");
// in bitwidth
mlir::Type opType = getIn(adaptor).getType();
auto typeToMatch= thiis->getTypeConverter()->convertType(opType).template dyn_cast<amd64::RegisterTypeInterface>();
assert(typeToMatch && "expected register type");
if(typeToMatch.getBitwidth() != inBitwidth)
return rewriter.notifyMatchFailure(op, "in bitwidth mismatch");
return mlir::success();
};
template<unsigned inBitwidth, auto getIn, auto getOut, amd64::SizeChange::Kind kind>
auto truncExtUiSiMatchReplace = []<unsigned outBitwidth,
typename MOVSZX, typename SR8ri, typename, typename, typename
>(auto szextOp, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
// we need to take care to truncate an i1 to 0/1, before we 'actually' use it, i.e. do real computations with it on the other side of the MOVZX
// i1's are represented as 8 bits currently, so we only need to check this in patterns which extend 8 bit
if constexpr (inBitwidth == 8) /* && */ if(getIn(szextOp).getType().isInteger(1)){
// assert the 8 bits for the i1
assert(mlir::dyn_cast<amd64::GPRegisterTypeInterface>(getIn(adaptor).getType()).getBitwidth() == 8);
if constexpr(kind == amd64::SizeChange::Kind::SExt){
static_assert(outBitwidth > 8, "This pattern can't sign extend i1 to i8");
// shift it left, then shift it right, to make a mask
auto SHL = rewriter.create<amd64::SHL8ri>(szextOp.getLoc(), getIn(adaptor));
auto SAR = rewriter.create<amd64::SAR8ri>(szextOp.getLoc(), SHL);
SHL.instructionInfo().imm = SAR.instructionInfo().imm = 0x7;
// now its sign extended to the 8 bits, then lets let the MOVSZX do its thing
rewriter.replaceOpWithNewOp<MOVSZX>(szextOp, SAR);
return mlir::success();
}else{
static_assert(kind == amd64::SizeChange::Kind::ZExt);
// and it with 1
auto AND = rewriter.create<amd64::AND8ri>(szextOp.getLoc(), getIn(adaptor));
AND.instructionInfo().imm = 0x1;
rewriter.replaceOpWithNewOp<MOVSZX>(szextOp, AND);
return mlir::success();
}
}
if constexpr(std::is_same_v<MOVSZX, amd64::MOV32rr>) // use the mov to 64 version of MOV32rr
if constexpr(inBitwidth == 32 && outBitwidth == 64)
rewriter.replaceOpWithNewOp<MOVSZX>(szextOp, amd64::gpr64Type::get(rewriter.getContext()), getIn(adaptor));
else if constexpr(inBitwidth == 64 && outBitwidth == 32)
rewriter.replaceOpWithNewOp<MOVSZX>(szextOp, amd64::gpr32Type::get(rewriter.getContext()), getIn(adaptor));
else
static_assert(false, "this pattern can not be used to translate between the same bitwidth");
else
rewriter.replaceOpWithNewOp<MOVSZX>(szextOp, getIn(adaptor));
return mlir::success();
};
auto arithGetIn = [](auto adaptorOrOp){ return adaptorOrOp.getIn(); };
auto arithGetOut = [](auto adaptorOrOp){ return adaptorOrOp.getOut(); };
template<unsigned inBitwidth, amd64::SizeChange::Kind kind>
auto arithTruncExtUiSiMatchReplace = truncExtUiSiMatchReplace<inBitwidth, arithGetIn, arithGetIn, kind>;
template<unsigned inBitwidth>
auto arithTruncExtUiSiBitwidthMatcher = truncExtUiSiBitwidthMatcher<inBitwidth, arithGetIn, arithGetIn>;
/// only for 16-64 bits outBitwidth, for 8 we have a special pattern. There are more exceptions: Because not all versions of MOVZX exist, MOVZXr8r8 wouldn't make sense (also invalid in MLIR), MOVZXr64r32 is just a MOV, etc.
#define EXT_UI_SI_PAT(outBitwidth, inBitwidth) \
using ExtUIPat ## inBitwidth ## _to_ ## outBitwidth = Match<mlir::arith::ExtUIOp, outBitwidth, arithTruncExtUiSiMatchReplace<inBitwidth, amd64::SizeChange::ZExt>, arithTruncExtUiSiBitwidthMatcher<inBitwidth>, 1, amd64::MOVZX ## r ## outBitwidth ## r ## inBitwidth>; \
using ExtSIPat ## inBitwidth ## _to_ ## outBitwidth = Match<mlir::arith::ExtSIOp, outBitwidth, arithTruncExtUiSiMatchReplace<inBitwidth, amd64::SizeChange::SExt>, arithTruncExtUiSiBitwidthMatcher<inBitwidth>, 1, amd64::MOVSX ## r ## outBitwidth ## r ## inBitwidth, NA, NA, NA, NA>;
// generalizable cases:
EXT_UI_SI_PAT(16, 8);
EXT_UI_SI_PAT(32, 8); EXT_UI_SI_PAT(32, 16);
EXT_UI_SI_PAT(64, 8); EXT_UI_SI_PAT(64, 16);
#undef EXT_UI_SI_PAT
// cases that are still valid in mlir, but not covered here:
// - 32 -> 64 (just a MOV)
// - any weird integer types, but we ignore those anyway
using ExtUIPat32_to_64 = Match<mlir::arith::ExtUIOp, 64, arithTruncExtUiSiMatchReplace<32, amd64::SizeChange::ZExt>, arithTruncExtUiSiBitwidthMatcher<32>, 1, amd64::MOV32rr>;
// for sign extend, the pattern above would work, but for simplicity, just do it manually here:
using ExtSIPat32_to_64 = Match<mlir::arith::ExtSIOp, 64, arithTruncExtUiSiMatchReplace<32, amd64::SizeChange::SExt>, arithTruncExtUiSiBitwidthMatcher<32>, 1, amd64::MOVSXr64r32>;
// trunc
#define TRUNC_PAT(outBitwidth, inBitwidth) \
using TruncPat ## inBitwidth ## _to_ ## outBitwidth = Match<mlir::arith::TruncIOp, outBitwidth, arithTruncExtUiSiMatchReplace<inBitwidth, amd64::SizeChange::Trunc>, arithTruncExtUiSiBitwidthMatcher<inBitwidth>, 1, amd64::MOV ## outBitwidth ## rr>;
TRUNC_PAT(8, 16); TRUNC_PAT(8, 32); TRUNC_PAT(8, 64);
TRUNC_PAT(16, 32); TRUNC_PAT(16, 64);
TRUNC_PAT(32, 64);
#undef TRUNC_PAT
// branches
auto branchMatchReplace = []<unsigned actualBitwidth,
typename JMP, typename, typename, typename, typename
>(auto br, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
rewriter.replaceOpWithNewOp<JMP>(br, adaptor.getDestOperands(), br.getDest());
return mlir::success();
};
using BrPat = Match<mlir::cf::BranchOp, 64, branchMatchReplace, matchAllLambda, 1, amd64::JMP>;
template<typename IntegerCmpOp>
auto condBrMatchReplace = [](auto thiis, auto /* some kind of cond branch, either cf or llvm */ op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
auto ops1 = adaptor.getTrueDestOperands();
auto ops2 = adaptor.getFalseDestOperands();
auto* block1 = op.getTrueDest();
auto* block2 = op.getFalseDest();
assert(thiis->getTypeConverter()->convertType(mlir::IntegerType::get(thiis->getContext(), 1)) == amd64::gpr8Type::get(thiis->getContext()));
auto generalI1Case = [&](){
// and the condition i1 with 1, then do JNZ
auto andri = rewriter.create<amd64::AND8ri>(op.getLoc(), adaptor.getCondition());
andri.instructionInfo().imm = 1;
rewriter.replaceOpWithNewOp<amd64::JNZ>(op, ops1, ops2, block1, block2);
return mlir::success();
};
// if its a block argument, emit a CMP and a conditional JMP
if(adaptor.getCondition().template isa<mlir::BlockArgument>()) [[unlikely]]{
// need to do this in case of a block argument at the start, because otherwise calling getDefiningOp() will fail. This is also called if no other case matches
return generalI1Case();
} else if(auto constI1AsMov8 = mlir::dyn_cast<amd64::MOV8ri>(adaptor.getCondition().getDefiningOp())) [[unlikely]]{
// constant conditions, that can either occur naturally or through folding, will result in i1 constants that get matched by the constant int pattern, and thus converted to MOV8ri (because i1 is modeled as 8 bit)
// do an unconditional JMP, if it's a constant condition
if(constI1AsMov8.instructionInfo().imm)
rewriter.replaceOpWithNewOp<amd64::JMP>(op, ops1, block1);
else
rewriter.replaceOpWithNewOp<amd64::JMP>(op, ops2, block2);
// TODO how do I remove this MOV8ri, if it's only used here? erasing it results in an error about failing to legalize the erased op
return mlir::success();
} else if(auto setccPredicate = mlir::dyn_cast<amd64::PredicateOpInterface>(adaptor.getCondition().getDefiningOp())){
// conditional branch
auto cmpi = mlir::dyn_cast<IntegerCmpOp>(op.getCondition().getDefiningOp());
auto CMP = setccPredicate->getPrevNode();
assert(cmpi && CMP && "Conditional branch with SETcc, but without cmpi and CMP");
// we're using the SETcc here, because the cmpi might have been folded, so we need to get to the original CMP, which is the CMP before the SETcc, then the SETcc has the right predicate
// cmp should already have been replaced by the cmp pattern, so we don't need to do that here
// but the cmp can be arbitrarily far away, so we need to reinsert it here, except if it's immediately before our current op (cond.br), and it's the replacement for the SETcc that we're using
if(!(cmpi->getNextNode() == op && setccPredicate->getNextNode() == cmpi)){
// TODO IR/Builders.cpp suggests that this gets inserted at the right point, but check again
// clone the cmpi, it will then have no uses and get pattern matched with the normal cmp pattern as we want. The SET from that pattern shouldn't be generated, because the cmpi is dead.
rewriter.clone(*CMP);
}
using namespace amd64::conditional;
switch(setccPredicate.getPredicate()){
case Z: rewriter.replaceOpWithNewOp<amd64::JE>(op, ops1, ops2, block1, block2); break;
case NZ: rewriter.replaceOpWithNewOp<amd64::JNE>(op, ops1, ops2, block1, block2); break;
case L: rewriter.replaceOpWithNewOp<amd64::JL>(op, ops1, ops2, block1, block2); break;
case LE: rewriter.replaceOpWithNewOp<amd64::JLE>(op, ops1, ops2, block1, block2); break;
case G: rewriter.replaceOpWithNewOp<amd64::JG>(op, ops1, ops2, block1, block2); break;
case GE: rewriter.replaceOpWithNewOp<amd64::JGE>(op, ops1, ops2, block1, block2); break;
case C: rewriter.replaceOpWithNewOp<amd64::JB>(op, ops1, ops2, block1, block2); break;
case BE: rewriter.replaceOpWithNewOp<amd64::JBE>(op, ops1, ops2, block1, block2); break;
case A: rewriter.replaceOpWithNewOp<amd64::JA>(op, ops1, ops2, block1, block2); break;
case NC: rewriter.replaceOpWithNewOp<amd64::JAE>(op, ops1, ops2, block1, block2); break;
default: llvm_unreachable("unknown predicate");
}
//if(setccPredicate->use_empty())
//rewriter.eraseOp(setccPredicate);
return mlir::success();
}else{
// general i1 arithmetic
return generalI1Case();
}
};
struct CondBrPat : public mlir::OpConversionPattern<mlir::cf::CondBranchOp> {
CondBrPat(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<mlir::cf::CondBranchOp>(tc, ctx, 3){}
mlir::LogicalResult matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override {
return condBrMatchReplace<mlir::arith::CmpIOp>(this, op, adaptor, rewriter);
}
};
auto callMatchReplace = []<unsigned actualBitwidth,
typename RegisterTy, typename, typename, typename, typename
>(auto callOp, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
if(callOp.getNumResults() > 1)
return rewriter.notifyMatchFailure(callOp, "multiple return values not supported");
// direct call and indirect call are just handled the same as in the llvm dialect, the first op can be a ptr. So we don't do anything here, and check if the attribute exists in encoding, if it does, move it to AX and start one after the normal operand with moving to the operand registers
if(callOp.getNumResults() == 0)
rewriter.replaceOpWithNewOp<amd64::CALL>(callOp, TypeRange(), callOp.getCalleeAttr(), /* is guaranteed external */ false, adaptor.getOperands());
else
rewriter.replaceOpWithNewOp<amd64::CALL>(callOp, RegisterTy::get(callOp->getContext()), callOp.getCalleeAttr(), /* is guaranteed external */ false, adaptor.getOperands());
return mlir::success();
};
#define CALL_PAT(bitwidth) \
using IntCallPat ## bitwidth = Match<mlir::func::CallOp, bitwidth, callMatchReplace, defaultBitwidthMatchLambda<amd64::GPRegisterTypeInterface, true>, 1, amd64::gpr ## bitwidth ## Type>
CALL_PAT(8); CALL_PAT(16); CALL_PAT(32); CALL_PAT(64);
using FloatCallPat32 = Match<mlir::func::CallOp, 32, callMatchReplace, defaultBitwidthMatchLambda<amd64::FPRegisterTypeInterface, true>, 1, amd64::fpr32Type>;
using FloatCallPat64 = Match<mlir::func::CallOp, 64, callMatchReplace, defaultBitwidthMatchLambda<amd64::FPRegisterTypeInterface, true>, 1, amd64::fpr64Type>;
#undef CALL_PAT
// TODO maybe AND i1's before returning them
// returns
auto returnMatchReplace = []<unsigned actualBitwidth,
typename, typename, typename, typename, typename
>(auto returnOp, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
mlir::Value retOperand;
if(returnOp.getNumOperands() > 1)
return rewriter.notifyMatchFailure(returnOp, "multiple return values not supported");
else if(returnOp.getNumOperands() == 0)
// if there is a zero op return, the function does not return anythign ,so we can just mov 0 to rax and return that.
// TODO consider omitting the return
retOperand = rewriter.create<amd64::MOV64ri>(returnOp.getLoc(), 0);
else
retOperand = adaptor.getOperands().front();
rewriter.replaceOpWithNewOp<amd64::RET>(returnOp, retOperand);
return mlir::success();
};
using ReturnPat = Match<mlir::func::ReturnOp, 64, returnMatchReplace, matchAllLambda, 1>;
// see https://mlir.llvm.org/docs/DialectConversion/#type-converter and https://mlir.llvm.org/docs/DialectConversion/#region-signature-conversion
// -> this is necessary to convert the types of the region ops
/// similar to `RegionOpConversion` from OpenMPToLLVM.cpp
// TODO i found this other way with `populateAnyFunctionOpInterfaceTypeConversionPattern`, performance test the two against one another
//struct FuncPat : public mlir::OpConversionPattern<mlir::func::FuncOp>{
// FuncPat(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<mlir::func::FuncOp>(tc, ctx, 1){}
//
// mlir::LogicalResult matchAndRewrite(mlir::func::FuncOp func, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override {
// // TODO the whole signature conversion doesn't work, instead we use the typeconverter to convert the function type, which isn't pretty, but it works
// //mlir::TypeConverter::SignatureConversion signatureConversion(func.getNumArguments());
// //signatureConversion.getConvertedTypes()
// //
// //if(failed(getTypeConverter()->convertSignatureArgs(func.getFunctionType(), signatureConversion)))
// // return mlir::failure();
//
// auto newFunc = rewriter.create<mlir::func::FuncOp>(
// func.getLoc(), func.getName(),
// mlir::dyn_cast<mlir::FunctionType>(getTypeConverter()->convertType(adaptor.getFunctionType())) [> this is also probably wrong, type converter still says its illegal afterwards <],
// adaptor.getSymVisibilityAttr(),
// adaptor.getArgAttrsAttr(),
// adaptor.getResAttrsAttr());
// rewriter.inlineRegionBefore(func.getRegion(), newFunc.getRegion(), newFunc.getRegion().end());
// if (failed(rewriter.convertRegionTypes(&newFunc.getRegion(), *getTypeConverter())))
// return rewriter.notifyMatchFailure(func, "failed to convert region types");
//
// //convertFuncOpTypes(func, *getTypeConverter(), rewriter);
//
// rewriter.replaceOp(func, newFunc->getResults());
// return mlir::success();
// }
//};
// TODO remove this, this is a skeleton to show that the bare 'ConversionPattern' class also works
struct TestConvPatternWOOp : public mlir::ConversionPattern{
TestConvPatternWOOp(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::ConversionPattern(tc, "test.conv.pat", 1, ctx){}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands, mlir::ConversionPatternRewriter &rewriter) const override {
(void)op; (void)operands; (void)rewriter;
return mlir::failure();
}
};
// llvm test patterns
struct LLVMGEPPattern : public mlir::OpConversionPattern<LLVM::GEPOp>{
LLVMGEPPattern(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<LLVM::GEPOp>(tc, ctx, 1){}
mlir::LogicalResult matchAndRewrite(LLVM::GEPOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override {
auto dl = mlir::DataLayout::closest(op);
// TODO this is wrong, we don't want to add the size of the whole type
auto getBytesOffsetAndType = [&](auto type, int64_t elemNum) -> std::pair<int64_t, mlir::Type> {
using namespace mlir::LLVM;
if(LLVMStructType structType = type.template dyn_cast<LLVMStructType>()){
assert(elemNum >= 0 && "struct element number cannot be negative");
auto elemNumUnsigned = static_cast<uint64_t>(elemNum);
auto structParts = structType.getBody();
// modified version of LLVMStructType::getTypeSizeInBits
unsigned structSizeUpToNow = 0;
unsigned structAlignment = 1;
for (auto [i, element] : llvm::enumerate(structParts)){
// stop if we've reached the element we want to address
if(i == elemNumUnsigned)
return {structSizeUpToNow, element}; // TODO hope that there's no dangling reference to element here, but i think its a reference, and the structParts is a reference itself, so it should be fine
unsigned elementAlignment = structType.isPacked() ? 1 : dl.getTypeABIAlignment(element);
// Add padding to the struct size to align it to the abi alignment of the
// element type before then adding the size of the element
structSizeUpToNow = llvm::alignTo(structSizeUpToNow, elementAlignment);
structSizeUpToNow += dl.getTypeSize(element);
// The alignment requirement of a struct is equal to the strictest alignment
// requirement of its elements.
structAlignment = std::max(elementAlignment, structAlignment);
}
llvm_unreachable("gep struct element index out of bounds");
}else if(LLVMArrayType arrayType = type.template dyn_cast<LLVMArrayType>()){
return {dl.getTypeSize(arrayType.getElementType()) * elemNum, arrayType.getElementType()};
}else if(LLVMPointerType ptrType = type.template dyn_cast<LLVMPointerType>()){
// in this case, we just always use the source element type, and then index into that
// TODO this assertion is wrong, what we would actually want to assert, is that *this* index is the first one
//assert(op.getIndices().size() <= 2 && "only up to two indices supported for pointer element types");
return {dl.getTypeSize(op.getSourceElementType()) * elemNum, op.getSourceElementType()};
}else if(IntegerType intType = type.template dyn_cast<IntegerType>()){
// TODO this assertion probably wrong for the same reasons as above, just that this should always be the last index.
//assert(op.getIndices().size() == 1 && "only one index is supported for int element types");
// TODO rethink this, im not sure this makes sense. It seems to work atm, but just the fact that we're using the type size of the int type here, and the type size of the element source type above doesn't make sense. The two are '1 layer apart', so they shouldn't be used on the same layer, right?
return {dl.getTypeSize(intType) * elemNum, mlir::Type()};
}else{
op.dump();
type.dump();
llvm_unreachable("unhandled gep base type");
}
};
// there is no adaptor.getIndices(), the adaptor only gives access to the dynamic indices. so we iterate over all of the indices, and if we find a dynamic one, use the rewriter to remap it
auto indices = op.getIndices();
// TODO check for allocas to optimize if possible
auto currentIndexComputationValue = adaptor.getBase();
// TODO the other case is some weird vector thing, i'd rather have it fail for now, if that is encountered
assert(op.getElemType().has_value());
// we start by indexing into the base type
mlir::Type currentlyIndexedType = op.getBase().getType();
for(auto indexPtr_u : indices){
assert(getTypeConverter()->convertType(currentIndexComputationValue.getType()) == amd64::gpr64Type::get(getContext()) && "only 64 bit pointers are supported");
if(mlir::Value val = indexPtr_u.dyn_cast<mlir::Value>()){
// no dynamic struct indices please
assert(!mlir::isa<LLVM::LLVMStructType>(currentlyIndexedType) && "dynamic struct indices are not allowed, this should be fixed in the verification of GEP in the llvm dialect!");
//llvm::errs() << "value to be scaled in gep: "; op.dump(); llvm::errs() << " original value: "; val.dump(); llvm::errs() << " remapped value:"; rewriter.getRemappedValue(val).dump();
auto scaled = rewriter.create<amd64::IMUL64rri>(op.getLoc(), rewriter.getRemappedValue(val));
// we perform the computation analogously, but just for ptr/array types, so use 1 as the index
std::tie(scaled.instructionInfo().imm, currentlyIndexedType) = getBytesOffsetAndType(currentlyIndexedType, 1);
currentIndexComputationValue = rewriter.create<amd64::ADD64rr>(op.getLoc(), currentIndexComputationValue, scaled);
}else{
// has to be integer attr otherwise
auto indexInt = indexPtr_u.get<mlir::IntegerAttr>().getValue().getSExtValue();
int64_t byteOffset;
std::tie(byteOffset, currentlyIndexedType) = getBytesOffsetAndType(currentlyIndexedType, indexInt);
// if the offset is zero, we don't have to create an instruction, but we do need to change the indexed type
if(byteOffset == 0)
continue;
auto addri = rewriter.create<amd64::ADD64ri>(op.getLoc(), currentIndexComputationValue);
addri.instructionInfo().imm = byteOffset;
currentIndexComputationValue = addri;
}
}
rewriter.replaceOp(op, currentIndexComputationValue);
return mlir::success();
}
};
template<typename INSTrm>
auto llvmLoadMatchReplace = [](LLVM::LoadOp op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
// TODO this is an ugly hack, because this op gets unrealized conversion casts as args (ptr in this case), because the ptr type gets converted to an i64, instead of a memloc, so the alloca returning a memloc doesn't work
// even if we perform the conversion casts ourselves and insert a 1-2 type conversion from ptr to memloc/i64, this still doesn't work
auto ptr = adaptor.getAddr();
if(auto cast = mlir::dyn_cast_if_present<mlir::UnrealizedConversionCastOp>(ptr.getDefiningOp())){
assert(cast->getNumOperands() == 1 && mlir::isa<amd64::memLocType>(cast->getOperand(0).getType()));
rewriter.replaceOpWithNewOp<INSTrm>(op, cast->getOperand(0));
return mlir::success();
}
auto mem = rewriter.create<amd64::MemB>(op.getLoc(), ptr);
rewriter.replaceOpWithNewOp<INSTrm>(op, mem);
return mlir::success();
};
auto llvmIntLoadMatchReplace = []<unsigned actualBitwidth,
typename, typename, typename INSTrm, typename, typename
>(LLVM::LoadOp op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
return llvmLoadMatchReplace<INSTrm>(op, adaptor, rewriter);
};
template<typename INSTrm>
auto llvmFloatLoadMatchReplace = []<unsigned actualBitwidth,
typename, typename, typename, typename, typename
>(LLVM::LoadOp op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
return llvmLoadMatchReplace<INSTrm>(op, adaptor, rewriter);
};
PATTERN_INT(LLVMIntLoadPat, LLVM::LoadOp, amd64::MOV, llvmIntLoadMatchReplace);
using LLVMFloatLoadPat32 = Match<LLVM::LoadOp, 32, llvmFloatLoadMatchReplace<amd64::MOVSSrm>, floatBitwidthMatchLambda>;
using LLVMFloatLoadPat64 = Match<LLVM::LoadOp, 64, llvmFloatLoadMatchReplace<amd64::MOVSDrm>, floatBitwidthMatchLambda>;
template<typename INSTmr>
auto llvmStoreMatchReplace = [](LLVM::StoreOp op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
// TODO this is an ugly hack, because this op gets unrealized conversion casts as args (ptr in this case), because the ptr type gets converted to an i64, instead of a memloc, so the alloca returning a memloc doesn't work
// even if we perform the conversion casts ourselves and insert a 1-2 type conversion from ptr to memloc/i64, this still doesn't work
auto ptr = adaptor.getAddr();
auto val = adaptor.getValue();
if(auto cast = mlir::dyn_cast_if_present<mlir::UnrealizedConversionCastOp>(ptr.getDefiningOp())){
assert(cast->getNumOperands() == 1 && mlir::isa<amd64::memLocType>(cast->getOperand(0).getType()));
rewriter.replaceOpWithNewOp<INSTmr>(op, cast->getOperand(0), val);
return mlir::success();
}
auto mem = rewriter.create<amd64::MemB>(op.getLoc(), ptr);
rewriter.replaceOpWithNewOp<INSTmr>(op, mem, val);
return mlir::success();
};
auto llvmIntStoreMatchReplace = []<unsigned actualBitwidth,
typename, typename, typename, typename, typename INSTmr
>(LLVM::StoreOp op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
return llvmStoreMatchReplace<INSTmr>(op, adaptor, rewriter);
};
template<typename INSTmr>
auto llvmFloatStoreMatchReplace = []<unsigned actualBitwidth,
typename, typename, typename, typename, typename
>(LLVM::StoreOp op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
return llvmStoreMatchReplace<INSTmr>(op, adaptor, rewriter);
};
template <typename RegisterTy>
requires MLIRInterfaceDerivedFrom<RegisterTy, amd64::RegisterTypeInterface>
auto llvmStoreBitwidthMatcher = []<unsigned bitwidth>(auto thiis, LLVM::StoreOp op, auto, mlir::ConversionPatternRewriter& rewriter){
if(auto gprType = mlir::dyn_cast<RegisterTy>(thiis->getTypeConverter()->convertType(op.getValue().getType()))){
if(gprType.getBitwidth() == bitwidth)
return mlir::success();
return rewriter.notifyMatchFailure(op, "bitwidth mismatch");
}
return rewriter.notifyMatchFailure(op, "expected other register type");
};
PATTERN_INT(LLVMIntStorePat, LLVM::StoreOp, amd64::MOV, llvmIntStoreMatchReplace, llvmStoreBitwidthMatcher<amd64::GPRegisterTypeInterface>);
using LLVMFloatStorePat32 = Match<LLVM::StoreOp, 32, llvmFloatStoreMatchReplace<amd64::MOVSSmr>, llvmStoreBitwidthMatcher<amd64::FPRegisterTypeInterface>>;
using LLVMFloatStorePat64 = Match<LLVM::StoreOp, 64, llvmFloatStoreMatchReplace<amd64::MOVSDmr>, llvmStoreBitwidthMatcher<amd64::FPRegisterTypeInterface>>;
struct LLVMAllocaPat : public mlir::OpConversionPattern<mlir::LLVM::AllocaOp>{
LLVMAllocaPat(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<LLVM::AllocaOp>(tc, ctx, 1){}
mlir::LogicalResult matchAndRewrite(LLVM::AllocaOp op, OpAdaptor, mlir::ConversionPatternRewriter& rewriter) const override {
// TODO maybe this can be improved when considering that the alloca is only ever used as a ptr by GEP, load, store, and ptrtoint. In this case the lea is technically only needed for ptrtoint
auto numElemsVal = op.getArraySize();
auto constantElemNumOp = mlir::dyn_cast<LLVM::ConstantOp>(numElemsVal.getDefiningOp());
if(!constantElemNumOp)
return rewriter.notifyMatchFailure(op, "only constant allocas supported for now");
auto numElems = constantElemNumOp.getValue().cast<mlir::IntegerAttr>().getValue().getSExtValue();
// TODO AllocaOp::print does this a bit differently -> use that?
auto dl = mlir::DataLayout::closest(op);
assert(op.getElemType().has_value());
auto elemSize = dl.getTypeSize(*op.getElemType());
// gets converted to i64 with target materialization from type converter
rewriter.replaceOpWithNewOp<amd64::AllocaOp>(op, elemSize*numElems);
return mlir::success();
}
};
auto llvmMovMatchReplace = []<unsigned actualBitwidth,
typename INSTrr, typename INSTri, typename INSTrm, typename INSTmi, typename INSTmr
>(LLVM::ConstantOp op, auto adaptor, mlir::ConversionPatternRewriter& rewriter) {
auto intAttr = adaptor.getValue().template cast<mlir::IntegerAttr>();
if(!intAttr)
return rewriter.notifyMatchFailure(op, "expected integer constant");
rewriter.replaceOpWithNewOp<INSTri>(op, intAttr.getValue().getSExtValue());
return mlir::success();
};
PATTERN_INT(LLVMConstantIntPat, LLVM::ConstantOp, amd64::MOV, llvmMovMatchReplace, intBitwidthMatchLambda, 2);
// TODO this is obviously not finished
struct LLVMConstantStringPat : public mlir::OpConversionPattern<LLVM::ConstantOp>{
LLVMConstantStringPat(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<LLVM::ConstantOp>(tc, ctx, 1){}
mlir::LogicalResult matchAndRewrite(LLVM::ConstantOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override {
if(op.use_empty()){
rewriter.eraseOp(op);
return mlir::success();
}
return mlir::failure();
}
};
struct LLVMAddrofPat : public mlir::OpConversionPattern<LLVM::AddressOfOp>{
LLVMAddrofPat(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<LLVM::AddressOfOp>(tc, ctx, 1){}
mlir::LogicalResult matchAndRewrite(LLVM::AddressOfOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override {
auto globalOrFunc = mlir::SymbolTable::lookupNearestSymbolFrom(op, adaptor.getGlobalNameAttr());
assert(globalOrFunc && "global from addrof not found");
if(auto globalOp = mlir::dyn_cast<LLVM::GlobalOp>(globalOrFunc)){
// we don't necessarily know the offset of the global yet, might need to be resolved later
// TODO maybe pass globals and check if it's in there already to avoid making the extra op. Might also be slower, not sure
rewriter.replaceOpWithNewOp<amd64::AddrOfGlobal>(op, adaptor.getGlobalNameAttr());
} else {
// calling dlsym here is a mere optimization. It could also be left to AddrOfFunc, but that would require another symbol table look-up, and those are glacially slow already...
auto funcOp = mlir::cast<mlir::FunctionOpInterface>(globalOrFunc);
if(!funcOp.isExternal())
// unknown until later
rewriter.replaceOpWithNewOp<amd64::AddrOfFunc>(op, adaptor.getGlobalNameAttr());
else
rewriter.replaceOpWithNewOp<amd64::MOV64ri>(op, (intptr_t) checked_dlsym(adaptor.getGlobalName())); // TODO only do this if args.jit
}
return mlir::success();
}
};
using LLVMReturnPat = Match<LLVM::ReturnOp, 64, returnMatchReplace, matchAllLambda, 1, amd64::RET>;
struct LLVMFuncPat : public mlir::OpConversionPattern<LLVM::LLVMFuncOp>{
LLVMFuncPat(mlir::TypeConverter& tc, mlir::MLIRContext* ctx) : mlir::OpConversionPattern<LLVM::LLVMFuncOp>(tc, ctx, 1){}
mlir::LogicalResult matchAndRewrite(LLVM::LLVMFuncOp func, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override {
if(func.isVarArg() && !func.isExternal())
return rewriter.notifyMatchFailure(func, "vararg functions are not supported yet");
auto llvmFnType = func.getFunctionType();
llvm::SmallVector<mlir::Type> convertedArgTypes;
auto res = getTypeConverter()->convertTypes(llvmFnType.getParams(), convertedArgTypes);
if(mlir::failed(res))
return rewriter.notifyMatchFailure(func, "failed to convert arg types");
mlir::Type convertedReturnType{};
if(auto retType = llvmFnType.getReturnType(); retType != LLVM::LLVMVoidType::get(func.getContext())){
convertedReturnType = getTypeConverter()->convertType(retType);
if(!convertedReturnType)
return rewriter.notifyMatchFailure(func, "failed to convert return type");
}
auto newFunc = rewriter.create<mlir::func::FuncOp>(
func.getLoc(), func.getName(),
rewriter.getFunctionType(convertedArgTypes, convertedReturnType == mlir::Type{} ? TypeRange() : TypeRange(convertedReturnType)),
rewriter.getStringAttr(/* this is apparently a different kind of visibility: LLVM::stringifyVisibility(adaptor.getVisibility_()) */ "private"),
adaptor.getArgAttrsAttr(),
adaptor.getResAttrsAttr());
rewriter.inlineRegionBefore(func.getRegion(), newFunc.getRegion(), newFunc.getRegion().end());
if (failed(rewriter.convertRegionTypes(&newFunc.getRegion(), *getTypeConverter())))
return rewriter.notifyMatchFailure(func, "failed to convert region types");
rewriter.replaceOp(func, newFunc->getResults());
return mlir::success();
}
};
struct LLVMGlobalPat : public mlir::OpConversionPattern<LLVM::GlobalOp>{
GlobalsInfo& globals;
LLVMGlobalPat(mlir::TypeConverter& tc, mlir::MLIRContext* ctx, GlobalsInfo& globals) : mlir::OpConversionPattern<LLVM::GlobalOp>(tc, ctx, 1), globals(globals){ }
mlir::LogicalResult matchAndRewrite(LLVM::GlobalOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override {
// if its a declaration: handle specially, we can already look up the address of the symbol and write it there, or fail immediately
//llvm::errs() << "handling global: "; op.dump();
GlobalSymbolInfo unfinishedGlobal;
intptr_t& addr = unfinishedGlobal.addrInDataSection = (intptr_t) nullptr;
auto& bytes = unfinishedGlobal.bytes = {};
auto symbol = op.getSymName();
if(op.isDeclaration()){
// this address is allowed to be 0, checked_dlsym handles an actual error through dlerror()
DEBUGLOG("external symbol: " << symbol << ", getting address from environment");
addr = (intptr_t) checked_dlsym(symbol);
}
auto insertGlobalAndEraseOp = [&](){
globals.insert({symbol, std::move(unfinishedGlobal)});
rewriter.eraseOp(op);
return mlir::success();
};
auto mlirGetTypeSize = [&](auto type){
return mlir::DataLayout::closest(op).getTypeSize(adaptor.getGlobalType());
};
auto fail = [&](StringRef msg = "globals with complex initializers are not supported yet"){
return rewriter.notifyMatchFailure(op, msg);
};
auto letLLVMDoIt = [&](){
// do the same thing that LLVM does itself (ModuleTranslation::convertGlobals), this is quite ugly, but can handle the most amount of cases
// TODO performance test if creating the module/context globally (or with `static` here)once is faster than locally every time
// TODO because globals (that might need to be translated by llvm) can cross reference each other, we can't create isolated modules only containing one global, we need to create one big llvm module containing all globals, to allow for that. But only do it if it actually happens
// TODO calling functions in global initializers would be even more difficult.
llvm::LLVMContext llvmCtx;
MLIRContext& mlirCtx = *op.getContext();
auto miniModule = mlir::OwningOpRef<ModuleOp>(ModuleOp::create(UnknownLoc::get(&mlirCtx)));
miniModule->getBody()->push_back(op.clone());
auto llvmModule = translateModuleToLLVMIR(miniModule.get(), llvmCtx);
if(!llvmModule)
return fail("failed to translate global to llvm ir");
assert(!llvmModule->empty() && llvmModule->global_size() == 1 && "expected exactly one global in the module");
auto llvmGetTypeSize = [&](auto type){
return llvmModule->getDataLayout().getTypeSizeInBits(type) / 8;
};
auto llvmConstant = llvmModule->globals().begin()->getInitializer();
if(!llvmConstant)
return fail("failed to get initializer of global");
unfinishedGlobal.alignment = op.getAlignment().value_or(0);
assert(llvm::isa<llvm::ConstantData>(llvmConstant));
assert(mlirGetTypeSize(adaptor.getGlobalType()) == llvmGetTypeSize(llvmConstant->getType()));
// get bytes from the constant
if(auto sequential = llvm::dyn_cast<llvm::ConstantDataSequential>(llvmConstant)){
DEBUGLOG("constant data sequential");
auto bytes = sequential->getRawDataValues();
assert(bytes.size() == mlirGetTypeSize(adaptor.getGlobalType()) && "global size mismatch");
unfinishedGlobal.bytes.resize(bytes.size());
memcpy(unfinishedGlobal.bytes.data(), bytes.data(), bytes.size());
}else if(auto zero = llvm::dyn_cast<llvm::ConstantAggregateZero>(llvmConstant)){
DEBUGLOG("constant aggregate zero");
unfinishedGlobal.bytes.resize(mlirGetTypeSize(adaptor.getGlobalType()));
// TODO use `zero` to validate the size
#ifndef NDEBUG