diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules
index ed7e142..4504890 100644
--- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules
+++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules
@@ -883,3 +883,20 @@
(OR (CZEROEQZ x:(ANDI <t> [c] z) cond) y:(CZERONEZ z cond)) => (OR x y)
(OR (CZEROEQZ <t> x cond) (CZERONEZ <t> ((ADDI|ORI|XORI) [c] x) cond)) => ((ADD|OR|XOR) x (CZERONEZ <t> (MOVDconst [c]) cond))
(OR (CZEROEQZ <t> ((ADDI|ORI|XORI) [c] x) cond) (CZERONEZ <t> x cond)) => ((ADD|OR|XOR) x (CZEROEQZ <t> (MOVDconst [c]) cond))
+
+// Strength reduction for multiplication by a constant.
+(MUL x (MOVDconst [c])) && c == 3 && buildcfg.GORISCV64 >= 22 => (SH1ADD <x.Type> x x)
+(MUL x (MOVDconst [c])) && c == 5 && buildcfg.GORISCV64 >= 22 => (SH2ADD <x.Type> x x)
+(MUL x (MOVDconst [c])) && c == 9 && buildcfg.GORISCV64 >= 22 => (SH3ADD <x.Type> x x)
+(MUL x (MOVDconst [c])) && c%3 == 0 && isPowerOfTwo(c/3) && c != 3 && buildcfg.GORISCV64 >= 22 => (SLLI [log64(c/3)] (SH1ADD <x.Type> x x))
+(MUL x (MOVDconst [c])) && c%5 == 0 && isPowerOfTwo(c/5) && c != 5 && buildcfg.GORISCV64 >= 22 => (SLLI [log64(c/5)] (SH2ADD <x.Type> x x))
+(MUL x (MOVDconst [c])) && c%9 == 0 && isPowerOfTwo(c/9) && c != 9 && buildcfg.GORISCV64 >= 22 => (SLLI [log64(c/9)] (SH3ADD <x.Type> x x))
+
+// 32-bit multiply by constant (int32/uint32).
+(MULW x (MOVDconst [c])) && c == 3 && buildcfg.GORISCV64 >= 22 => (SH1ADD <x.Type> x x)
+(MULW x (MOVDconst [c])) && c == 5 && buildcfg.GORISCV64 >= 22 => (SH2ADD <x.Type> x x)
+(MULW x (MOVDconst [c])) && c == 9 && buildcfg.GORISCV64 >= 22 => (SH3ADD <x.Type> x x)
+(MULW x (MOVDconst [c])) && c%3 == 0 && isPowerOfTwo(c/3) && c != 3 && buildcfg.GORISCV64 >= 22 => (SLLIW [log64(c/3)] (SH1ADD <x.Type> x x))
+(MULW x (MOVDconst [c])) && c%5 == 0 && isPowerOfTwo(c/5) && c != 5 && buildcfg.GORISCV64 >= 22 => (SLLIW [log64(c/5)] (SH2ADD <x.Type> x x))
+(MULW x (MOVDconst [c])) && c%9 == 0 && isPowerOfTwo(c/9) && c != 9 && buildcfg.GORISCV64 >= 22 => (SLLIW [log64(c/9)] (SH3ADD <x.Type> x x))
+
diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go
index 43df9db..f5ccace 100644
--- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go
+++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go
@@ -614,6 +614,10 @@
return rewriteValueRISCV64_OpRISCV64MOVWstore(v)
case OpRISCV64MOVWstorezero:
return rewriteValueRISCV64_OpRISCV64MOVWstorezero(v)
+ case OpRISCV64MUL:
+ return rewriteValueRISCV64_OpRISCV64MUL(v)
+ case OpRISCV64MULW:
+ return rewriteValueRISCV64_OpRISCV64MULW(v)
case OpRISCV64NEG:
return rewriteValueRISCV64_OpRISCV64NEG(v)
case OpRISCV64NEGW:
@@ -7013,6 +7017,270 @@
}
return false
}
+func rewriteValueRISCV64_OpRISCV64MUL(v *Value) bool {
+ v_1 := v.Args[1]
+ v_0 := v.Args[0]
+ b := v.Block
+ // match: (MUL x (MOVDconst [c]))
+ // cond: c == 3 && buildcfg.GORISCV64 >= 22
+ // result: (SH1ADD <x.Type> x x)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c == 3 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SH1ADD)
+ v.Type = x.Type
+ v.AddArg2(x, x)
+ return true
+ }
+ break
+ }
+ // match: (MUL x (MOVDconst [c]))
+ // cond: c == 5 && buildcfg.GORISCV64 >= 22
+ // result: (SH2ADD <x.Type> x x)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c == 5 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SH2ADD)
+ v.Type = x.Type
+ v.AddArg2(x, x)
+ return true
+ }
+ break
+ }
+ // match: (MUL x (MOVDconst [c]))
+ // cond: c == 9 && buildcfg.GORISCV64 >= 22
+ // result: (SH3ADD <x.Type> x x)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c == 9 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SH3ADD)
+ v.Type = x.Type
+ v.AddArg2(x, x)
+ return true
+ }
+ break
+ }
+ // match: (MUL x (MOVDconst [c]))
+ // cond: c%3 == 0 && isPowerOfTwo(c/3) && c != 3 && buildcfg.GORISCV64 >= 22
+ // result: (SLLI [log64(c/3)] (SH1ADD <x.Type> x x))
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c%3 == 0 && isPowerOfTwo(c/3) && c != 3 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SLLI)
+ v.AuxInt = int64ToAuxInt(log64(c / 3))
+ v0 := b.NewValue0(v.Pos, OpRISCV64SH1ADD, x.Type)
+ v0.AddArg2(x, x)
+ v.AddArg(v0)
+ return true
+ }
+ break
+ }
+ // match: (MUL x (MOVDconst [c]))
+ // cond: c%5 == 0 && isPowerOfTwo(c/5) && c != 5 && buildcfg.GORISCV64 >= 22
+ // result: (SLLI [log64(c/5)] (SH2ADD <x.Type> x x))
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c%5 == 0 && isPowerOfTwo(c/5) && c != 5 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SLLI)
+ v.AuxInt = int64ToAuxInt(log64(c / 5))
+ v0 := b.NewValue0(v.Pos, OpRISCV64SH2ADD, x.Type)
+ v0.AddArg2(x, x)
+ v.AddArg(v0)
+ return true
+ }
+ break
+ }
+ // match: (MUL x (MOVDconst [c]))
+ // cond: c%9 == 0 && isPowerOfTwo(c/9) && c != 9 && buildcfg.GORISCV64 >= 22
+ // result: (SLLI [log64(c/9)] (SH3ADD <x.Type> x x))
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c%9 == 0 && isPowerOfTwo(c/9) && c != 9 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SLLI)
+ v.AuxInt = int64ToAuxInt(log64(c / 9))
+ v0 := b.NewValue0(v.Pos, OpRISCV64SH3ADD, x.Type)
+ v0.AddArg2(x, x)
+ v.AddArg(v0)
+ return true
+ }
+ break
+ }
+ return false
+}
+func rewriteValueRISCV64_OpRISCV64MULW(v *Value) bool {
+ v_1 := v.Args[1]
+ v_0 := v.Args[0]
+ b := v.Block
+ // match: (MULW x (MOVDconst [c]))
+ // cond: c == 3 && buildcfg.GORISCV64 >= 22
+ // result: (SH1ADD <x.Type> x x)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c == 3 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SH1ADD)
+ v.Type = x.Type
+ v.AddArg2(x, x)
+ return true
+ }
+ break
+ }
+ // match: (MULW x (MOVDconst [c]))
+ // cond: c == 5 && buildcfg.GORISCV64 >= 22
+ // result: (SH2ADD <x.Type> x x)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c == 5 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SH2ADD)
+ v.Type = x.Type
+ v.AddArg2(x, x)
+ return true
+ }
+ break
+ }
+ // match: (MULW x (MOVDconst [c]))
+ // cond: c == 9 && buildcfg.GORISCV64 >= 22
+ // result: (SH3ADD <x.Type> x x)
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c == 9 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SH3ADD)
+ v.Type = x.Type
+ v.AddArg2(x, x)
+ return true
+ }
+ break
+ }
+ // match: (MULW x (MOVDconst [c]))
+ // cond: c%3 == 0 && isPowerOfTwo(c/3) && c != 3 && buildcfg.GORISCV64 >= 22
+ // result: (SLLIW [log64(c/3)] (SH1ADD <x.Type> x x))
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c%3 == 0 && isPowerOfTwo(c/3) && c != 3 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SLLIW)
+ v.AuxInt = int64ToAuxInt(log64(c / 3))
+ v0 := b.NewValue0(v.Pos, OpRISCV64SH1ADD, x.Type)
+ v0.AddArg2(x, x)
+ v.AddArg(v0)
+ return true
+ }
+ break
+ }
+ // match: (MULW x (MOVDconst [c]))
+ // cond: c%5 == 0 && isPowerOfTwo(c/5) && c != 5 && buildcfg.GORISCV64 >= 22
+ // result: (SLLIW [log64(c/5)] (SH2ADD <x.Type> x x))
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c%5 == 0 && isPowerOfTwo(c/5) && c != 5 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SLLIW)
+ v.AuxInt = int64ToAuxInt(log64(c / 5))
+ v0 := b.NewValue0(v.Pos, OpRISCV64SH2ADD, x.Type)
+ v0.AddArg2(x, x)
+ v.AddArg(v0)
+ return true
+ }
+ break
+ }
+ // match: (MULW x (MOVDconst [c]))
+ // cond: c%9 == 0 && isPowerOfTwo(c/9) && c != 9 && buildcfg.GORISCV64 >= 22
+ // result: (SLLIW [log64(c/9)] (SH3ADD <x.Type> x x))
+ for {
+ for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+ x := v_0
+ if v_1.Op != OpRISCV64MOVDconst {
+ continue
+ }
+ c := auxIntToInt64(v_1.AuxInt)
+ if !(c%9 == 0 && isPowerOfTwo(c/9) && c != 9 && buildcfg.GORISCV64 >= 22) {
+ continue
+ }
+ v.reset(OpRISCV64SLLIW)
+ v.AuxInt = int64ToAuxInt(log64(c / 9))
+ v0 := b.NewValue0(v.Pos, OpRISCV64SH3ADD, x.Type)
+ v0.AddArg2(x, x)
+ v.AddArg(v0)
+ return true
+ }
+ break
+ }
+ return false
+}
func rewriteValueRISCV64_OpRISCV64NEG(v *Value) bool {
v_0 := v.Args[0]
b := v.Block
diff --git a/test/codegen/multiply.go b/test/codegen/multiply.go
index 8c76fd9..1077b74 100644
--- a/test/codegen/multiply.go
+++ b/test/codegen/multiply.go
@@ -13,36 +13,42 @@
// amd64: "XORL"
// arm64: "MOVD ZR"
// loong64: "MOVV R0"
+ // riscv64: "MOV [$]0"
return x * 0
}
func m2(x int64) int64 {
// amd64: "ADDQ"
// arm64: "ADD"
// loong64: "ADDVU"
+ // riscv64: "SLLI"
return x * 2
}
func m3(x int64) int64 {
// amd64: "LEAQ .*[*]2"
// arm64: "ADD R[0-9]+<<1,"
// loong64: "ALSLV [$]1,"
+ // riscv64/rva22u64: "SH1ADD"
return x * 3
}
func m4(x int64) int64 {
// amd64: "SHLQ [$]2,"
// arm64: "LSL [$]2,"
// loong64: "SLLV [$]2,"
+ // riscv64: "SLLI [$]2"
return x * 4
}
func m5(x int64) int64 {
// amd64: "LEAQ .*[*]4"
// arm64: "ADD R[0-9]+<<2,"
// loong64: "ALSLV [$]2,"
+ // riscv64/rva22u64: "SH2ADD"
return x * 5
}
func m6(x int64) int64 {
// amd64: "LEAQ .*[*]1" "LEAQ .*[*]2"
// arm64: "ADD R[0-9]+," "ADD R[0-9]+<<1,"
// loong64: "ADDVU" "ADDVU" "ADDVU"
+ // riscv64/rva22u64: "SH1ADD"
return x * 6
}
func m7(x int64) int64 {
@@ -55,18 +61,21 @@
// amd64: "SHLQ [$]3,"
// arm64: "LSL [$]3,"
// loong64: "SLLV [$]3,"
+ // riscv64: "SLLI [$]3,"
return x * 8
}
func m9(x int64) int64 {
// amd64: "LEAQ .*[*]8"
// arm64: "ADD R[0-9]+<<3,"
// loong64: "ALSLV [$]3,"
+ // riscv64/rva22u64: "SH3ADD"
return x * 9
}
func m10(x int64) int64 {
// amd64: "LEAQ .*[*]1" "LEAQ .*[*]4"
// arm64: "ADD R[0-9]+," "ADD R[0-9]+<<2,"
// loong64: "ADDVU" "ALSLV [$]2,"
+ // riscv64/rva22u64: "SH2ADD"
return x * 10
}
func m11(x int64) int64 {
@@ -79,6 +88,7 @@
// amd64: "LEAQ .*[*]2" "SHLQ [$]2,"
// arm64: "LSL [$]2," "ADD R[0-9]+<<1,"
// loong64: "SLLV" "ALSLV [$]1,"
+ // riscv64/rva22u64: "SH1ADD"
return x * 12
}
func m13(x int64) int64 {
@@ -103,6 +113,7 @@
// amd64: "SHLQ [$]4,"
// arm64: "LSL [$]4,"
// loong64: "SLLV [$]4,"
+ // riscv64: "SLLI [$]4,"
return x * 16
}
func m17(x int64) int64 {
@@ -115,6 +126,7 @@
// amd64: "LEAQ .*[*]1" "LEAQ .*[*]8"
// arm64: "ADD R[0-9]+," "ADD R[0-9]+<<3,"
// loong64: "ADDVU" "ALSLV [$]3,"
+ // riscv64/rva22u64: "SH3ADD"
return x * 18
}
func m19(x int64) int64 {
@@ -127,6 +139,7 @@
// amd64: "LEAQ .*[*]4" "SHLQ [$]2,"
// arm64: "LSL [$]2," "ADD R[0-9]+<<2,"
// loong64: "SLLV [$]2," "ALSLV [$]2,"
+ // riscv64/rva22u64: "SH2ADD"
return x * 20
}
func m21(x int64) int64 {
@@ -151,6 +164,7 @@
// amd64: "LEAQ .*[*]2" "SHLQ [$]3,"
// arm64: "LSL [$]3," "ADD R[0-9]+<<1,"
// loong64: "SLLV [$]3" "ALSLV [$]1,"
+ // riscv64/rva22u64: "SH1ADD"
return x * 24
}
func m25(x int64) int64 {
@@ -199,6 +213,7 @@
// amd64: "SHLQ [$]5,"
// arm64: "LSL [$]5,"
// loong64: "SLLV [$]5,"
+ // riscv64: "SLLI [$]5,"
return x * 32
}
func m33(x int64) int64 {
@@ -223,6 +238,7 @@
// amd64: "LEAQ .*[*]8" "SHLQ [$]2,"
// arm64: "LSL [$]2," "ADD R[0-9]+<<3,"
// loong64: "SLLV [$]2," "ALSLV [$]3,"
+ // riscv64/rva22u64: "SH3ADD"
return x * 36
}
func m37(x int64) int64 {
@@ -247,6 +263,7 @@
// amd64: "LEAQ .*[*]4" "SHLQ [$]3,"
// arm64: "LSL [$]3," "ADD R[0-9]+<<2,"
// loong64: "SLLV [$]3," "ALSLV [$]2,"
+ // riscv64/rva22u64: "SH2ADD"
return x * 40
}
@@ -254,12 +271,14 @@
// amd64: "NEGQ "
// arm64: "NEG R[0-9]+,"
// loong64: "SUBVU R[0-9], R0,"
+ // riscv64: "NEG"
return x * -1
}
func mn2(x int64) int64 {
// amd64: "NEGQ" "ADDQ"
// arm64: "NEG R[0-9]+<<1,"
// loong64: "ADDVU" "SUBVU R[0-9], R0,"
+ // riscv64: "NEG"
return x * -2
}
func mn3(x int64) int64 {
@@ -272,6 +291,7 @@
// amd64: "NEGQ" "SHLQ [$]2,"
// arm64: "NEG R[0-9]+<<2,"
// loong64: "SLLV [$]2," "SUBVU R[0-9], R0,"
+ // riscv64: "SLLI" "NEG"
return x * -4
}
func mn5(x int64) int64 {
@@ -296,6 +316,7 @@
// amd64: "NEGQ" "SHLQ [$]3,"
// arm64: "NEG R[0-9]+<<3,"
// loong64: "SLLV [$]3" "SUBVU R[0-9], R0,"
+ // riscv64: "SLLI" "NEG"
return x * -8
}
func mn9(x int64) int64 {
@@ -344,6 +365,7 @@
// amd64: "NEGQ" "SHLQ [$]4,"
// arm64: "NEG R[0-9]+<<4,"
// loong64: "SLLV [$]4," "SUBVU R[0-9], R0,"
+ // riscv64: "NEG"
return x * -16
}
func mn17(x int64) int64 {