diff --git a/src/cmd/compile/internal/ssa/rewritetern.go b/src/cmd/compile/internal/ssa/rewritetern.go
index 766b3f8..98f04f3 100644
--- a/src/cmd/compile/internal/ssa/rewritetern.go
+++ b/src/cmd/compile/internal/ssa/rewritetern.go
@@ -5,10 +5,20 @@
package ssa
import (
+ "fmt"
"internal/goarch"
"slices"
+ "sort"
)
+func (slo SIMDLogicalOP) isLeaf() bool {
+ return slo == sloInterior
+}
+
+func (slo SIMDLogicalOP) Strip() SIMDLogicalOP {
+ return slo &^ sloInterior
+}
+
var truthTableValues [3]uint8 = [3]uint8{0b1111_0000, 0b1100_1100, 0b1010_1010}
func (slop SIMDLogicalOP) String() string {
@@ -30,6 +40,8 @@
return "andNot" + interior
case sloNot:
return "not" + interior
+ case sloTernlog:
+ return "ternlog" + interior
}
return "wrong"
}
@@ -51,19 +63,20 @@
// interior nodes will be marked sloInterior,
// root nodes will not be marked sloInterior,
// leaf nodes are only marked sloInterior.
+
for _, b := range f.Blocks {
for _, v := range b.Values {
slo := classifyBooleanSIMD(v)
- switch slo {
- case sloOr,
- sloAndNot,
- sloXor,
- sloAnd:
- boolExprTrees[v.Args[1]] |= sloInterior
- fallthrough
- case sloNot:
- boolExprTrees[v.Args[0]] |= sloInterior
- boolExprTrees[v] |= slo
+ if slo != sloNone { // check if v is a boolean expression node
+ boolExprTrees[v] = slo // set v to its own boolean operation only if v is a boolean expression node
+
+ // setting the children of v to be interior nodes makes all nodes that have parents interior nodes
+ for _, arg := range v.Args {
+ if arg == nil {
+ continue
+ }
+ boolExprTrees[arg] |= sloInterior
+ }
}
}
}
@@ -79,30 +92,13 @@
roots = append(roots, v)
}
}
+
+ if len(roots) == 0 {
+ return
+ }
+
slices.SortFunc(roots, func(u, v *Value) int { return int(u.ID - v.ID) }) // IDs are small enough to not care about overflow.
- // This rewrite works by iterating over the root set.
- // For each boolean expression, it walks the expression
- // bottom up accumulating sets of variables mentioned in
- // subexpressions, lazy-greedily finding the largest subexpressions
- // of 3 inputs that can be rewritten to use ternary-truth-table instructions.
-
- // rewrite recursively attempts to replace v and v's subexpressions with
- // ternary-logic truth-table operations, returning a set of not more than 3
- // subexpressions within v that may be combined into a parent's replacement.
- // V need not have the CPU features that allow a ternary-logic operation;
- // in that case, v will not be rewritten. Replacements also require
- // exactly 3 different variable inputs to a boolean expression.
- //
- // Given the CPU feature and 3 inputs, v is replaced in the following
- // cases:
- //
- // 1) v is a root
- // 2) u = NOT(v) and u lacks the CPU feature
- // 3) u = OP(v, w) and u lacks the CPU feature
- // 4) u = OP(v, w) and u has more than 3 variable inputs. var rewrite func(v *Value) [3]*Value
- var rewrite func(v *Value) [3]*Value
-
// computeTT returns the truth table for a boolean expression
// over the variables in vars, where vars[0] varies slowest in
// the truth table and vars[2] varies fastest.
@@ -113,37 +109,18 @@
// z: 0 1 0 1 0 1 0 1
var computeTT func(v *Value, vars [3]*Value) uint8
- // combine two sets of variables into one, returning ok/not
- // if the two sets contained 3 or fewer elements. Combine
- // ensures that the sets of Values never contain duplicates.
- // (Duplicates would create less-efficient code, not incorrect code.)
- combine := func(a, b [3]*Value) ([3]*Value, bool) {
- var c [3]*Value
- i := 0
- for _, v := range a {
- if v == nil {
- break
- }
- c[i] = v
- i++
+ // simulateTernlog simulates the ternlog instruction by using the imm8 as a lookup table for the result of every combination of the 3 variables, this is used to compute the truth table for a given boolean expression tree
+ simulateTernlog := func(a, b, c uint8, imm uint8) uint8 {
+ var res uint8 = 0
+ for i := 0; i < 8; i++ {
+ bitA := (a >> i) & 1
+ bitB := (b >> i) & 1
+ bitC := (c >> i) & 1
+ idx := (bitA << 2) | (bitB << 1) | bitC
+ resBit := (imm >> idx) & 1
+ res |= (resBit << i)
}
- bloop:
- for _, v := range b {
- if v == nil {
- break
- }
- for _, u := range a {
- if v == u {
- continue bloop
- }
- }
- if i == 3 {
- return [3]*Value{}, false
- }
- c[i] = v
- i++
- }
- return c, true
+ return res
}
computeTT = func(v *Value, vars [3]*Value) uint8 {
@@ -153,7 +130,10 @@
return truthTableValues[i]
}
}
- slo := boolExprTrees[v] &^ sloInterior
+ if len(v.Args) == 0 {
+ panic(fmt.Errorf("leaf node not found in vars, v is %s, vars are %v", v.LongString(), vars))
+ }
+ slo := boolExprTrees[v].Strip()
a := computeTT(v.Args[0], vars)
switch slo {
case sloNot:
@@ -165,130 +145,240 @@
case sloOr:
return a | computeTT(v.Args[1], vars)
case sloAndNot:
- return a & ^computeTT(v.Args[1], vars)
+ if v.Args[0].ID < v.Args[1].ID {
+ return a &^ computeTT(v.Args[1], vars)
+ }
+ return computeTT(v.Args[1], vars) &^ a
+ case sloTernlog:
+ b := a
+ c := a
+
+ if v.Args[1] != nil {
+ b = computeTT(v.Args[1], vars)
+ }
+ if v.Args[2] != nil {
+ c = computeTT(v.Args[2], vars)
+ }
+
+ return simulateTernlog(a, b, c, uint8(v.AuxInt))
}
panic("switch should have covered all cases, or unknown var in logical expression")
}
- replace := func(a0 *Value, vars0 [3]*Value) {
- imm := computeTT(a0, vars0)
+ // findParameters returns a map of all the unique leaves a node has
+ var findParameters func(v *Value) map[ID]*Value
+
+ var collectionParams func(v *Value, params map[ID]*Value)
+
+ collectionParams = func(v *Value, params map[ID]*Value) {
+ if boolExprTrees[v].isLeaf() || boolExprTrees[v].Strip() == sloTernlog {
+ params[v.ID] = v
+ return
+ }
+
+ for _, arg := range v.Args {
+ collectionParams(arg, params)
+ }
+
+ for _, arg := range v.Args {
+ for _, param := range findParameters(arg) {
+ params[param.ID] = param
+ }
+ }
+ }
+
+ findParameters = func(v *Value) map[ID]*Value {
+ params := make(map[ID]*Value)
+
+ collectionParams(v, params)
+
+ return params
+ }
+
+ // phase 1: expandTerlong turns every instruction into a ternlog instruction, while trying to merge nodes that have 3 leaves(or 2) into 1 ternlog instruction
+ var expandTernlog func(a0 *Value)
+
+ expandTernlog = func(a0 *Value) {
+ parameters := findParameters(a0)
+
+ var vars [3]*Value
+
+ if len(parameters) > 3 {
+ for _, arg := range a0.Args {
+ if boolExprTrees[arg].isLeaf() || boolExprTrees[arg].Strip() == sloTernlog {
+ continue
+ }
+ expandTernlog(arg)
+ }
+ }
+
+ parameters = findParameters(a0)
+
+ if len(parameters) > 3 {
+ panic(fmt.Errorf("too many parameters to rewrite, a0 is %s, parameters are %v", a0.LongString(), parameters))
+ }
+
+ var sortedParams []*Value
+ for _, param := range parameters {
+ sortedParams = append(sortedParams, param)
+ }
+
+ sort.Slice(sortedParams, func(i, j int) bool {
+ return sortedParams[i].ID < sortedParams[j].ID
+ })
+
+ for i := range 3 {
+ if i < len(sortedParams) {
+ vars[i] = sortedParams[i]
+ } else {
+ vars[i] = sortedParams[0]
+ }
+ }
+
+ if boolExprTrees[a0].Strip() == sloTernlog {
+ return
+ }
+ imm := computeTT(a0, vars)
op := ternOpForLogical(a0.Op)
if op == a0.Op {
- if f.pass.debug > 0 {
- f.Warnl(a0.Pos, "Skipping rewrite for %s, op=%v", a0.LongString(), op)
- }
- return
+ panic(fmt.Errorf("should have mapped away from input op, a0 is %s", a0.LongString()))
}
if f.pass.debug > 0 {
f.Warnl(a0.Pos, "Rewriting %s into %v of 0b%b %v %v %v", a0.LongString(), op, imm,
- vars0[0], vars0[1], vars0[2])
+ vars[0], vars[1], vars[2])
}
+ boolExprTrees[a0] = sloTernlog | sloInterior // sloInterior can be removed but for simplicity it is left in for future changes to this optimization
+
a0.reset(op)
- a0.SetArgs3(vars0[0], vars0[1], vars0[2])
+ if vars[0] == nil || vars[1] == nil || vars[2] == nil {
+ panic(fmt.Errorf("vars should never be nil since computeTT should return different imm8 for different combinations of nil, a0 is %s, vars are %v", a0.LongString(), vars))
+ }
+ a0.SetArgs3(vars[0], vars[1], vars[2])
a0.AuxInt = int64(int8(imm))
}
- // addOne ensures the no-duplicates addition of a single value
- // to a set that is not full. It seems possible that a shared
- // subexpression in tricky combination with blocks lacking the
- // AVX512 feature might permit this.
- addOne := func(vars [3]*Value, v *Value) [3]*Value {
- if vars[2] != nil {
- panic("rewriteTern.addOne, vars[2] should be nil")
+ // mergeBranchLeaves tries to merge a ternlog node t that is a leaf of v into v,
+ // if the number of free slots in v is bigger or equal to the number of taken slots in t,
+ // free slots are the number of times the root of t appears in v, and taken slots are the number of arguments of t that are not its root
+ mergeBranchLeaves := func(v *Value, t *Value) {
+ rootFreeSlots := -1
+
+ for _, arg := range v.Args {
+ if arg == v.Args[0] || arg == t {
+ rootFreeSlots++
+ }
}
- if v == vars[0] || v == vars[1] {
- return vars
+
+ leafTakenSlots := 0
+
+ for i, arg := range t.Args {
+ if arg != t.Args[0] || i == 0 {
+ leafTakenSlots++
+ }
}
- if vars[1] == nil {
- vars[1] = v
- } else {
- vars[2] = v
+
+ if leafTakenSlots > rootFreeSlots {
+ return
}
- return vars
+
+ vars := [3]*Value{v.Args[0], v.Args[1], v.Args[2]}
+
+ if vars[0] == t { // option A: the first argument is the leaf
+ i := 0
+ for j, arg := range vars {
+ if arg == vars[0] {
+ vars[j] = t.Args[i]
+ i++
+ if i > leafTakenSlots {
+ break
+ }
+ }
+ }
+ } else { // option B: the second argument is the leaf, this is the only other option since if the leaf is the third argument then it can't have more than 1 taken slot and it would always be merged into the root
+ i := 0
+ for j, arg := range vars {
+ if arg == t || (arg == vars[0] && j != 0) {
+ vars[j] = t.Args[i]
+ i++
+ if i > leafTakenSlots {
+ break
+ }
+ }
+ }
+ }
+
+ imm := computeTT(v, [3]*Value{vars[0], vars[1], vars[2]})
+ v.reset(v.Op)
+ v.SetArgs3(vars[0], vars[1], vars[2])
+ v.AuxInt = int64(int8(imm))
}
- rewrite = func(v *Value) [3]*Value {
- slo := boolExprTrees[v]
- if slo == sloInterior { // leaf node, i.e., a "variable"
- return [3]*Value{v, nil, nil}
- }
- var vars [3]*Value
- hasFeature := v.Block.CPUfeatures.hasFeature(CPUavx512)
- if slo&sloNot == sloNot {
- vars = rewrite(v.Args[0])
- if !hasFeature {
- if vars[2] != nil {
- replace(v.Args[0], vars)
- return [3]*Value{v, nil, nil}
- }
- return vars
- }
- } else {
- var ok bool
- a0, a1 := v.Args[0], v.Args[1]
- vars0 := rewrite(a0)
- vars1 := rewrite(a1)
- vars, ok = combine(vars0, vars1)
+ // phase 2: mergeTernlong travereses the tree created earlier by the expandTernlog function and combines ternlong instructions
+ var mergeTernlog func(v *Value)
- if f.pass.debug > 1 {
- f.Warnl(a0.Pos, "combine(%v, %v) -> %v, %v", vars0, vars1, vars, ok)
- }
-
- if !(ok && v.Block.CPUfeatures.hasFeature(CPUavx512)) {
- // too many variables, or cannot rewrite current values.
- // rewrite one or both subtrees if possible
- if vars0[2] != nil && a0.Block.CPUfeatures.hasFeature(CPUavx512) {
- replace(a0, vars0)
- }
- if vars1[2] != nil && a1.Block.CPUfeatures.hasFeature(CPUavx512) {
- replace(a1, vars1)
- }
-
- // 3-element var arrays are either rewritten, or unable to be rewritten
- // because of the features in effect in their block. Either way, they
- // are treated as a "new var" if 3 elements are present.
-
- if vars0[2] == nil {
- if vars1[2] == nil {
- // both subtrees are 2-element and were not rewritten.
- //
- // TODO a clever person would look at subtrees of inputs,
- // e.g. rewrite
- // ((a AND b) XOR b) XOR (d XOR (c AND d))
- // to (((a AND b) XOR b) XOR d) XOR (c AND d)
- // to v = TERNLOG(truthtable, a, b, d) XOR (c AND d)
- // and return the variable set {v, c, d}
- //
- // But for now, just restart with a0 and a1.
- return [3]*Value{a0, a1, nil}
- } else {
- // a1 (maybe) rewrote, a0 has room for another var
- vars = addOne(vars0, a1)
- }
- } else if vars1[2] == nil {
- // a0 (maybe) rewrote, a1 has room for another var
- vars = addOne(vars1, a0)
- } else if !ok {
- // both (maybe) rewrote
- // a0 and a1 are different because otherwise their variable
- // sets would have combined "ok".
- return [3]*Value{a0, a1, nil}
- }
- // continue with either the vars from "ok" or the updated set of vars.
+ mergeTernlog = func(v *Value) {
+ for _, arg := range v.Args {
+ if boolExprTrees[arg].Strip() == sloTernlog {
+ mergeTernlog(arg)
}
}
- // if root and 3 vars and hasFeature, rewrite.
- if slo&sloInterior == 0 && vars[2] != nil && hasFeature {
- replace(v, vars)
- return [3]*Value{v, nil, nil}
+
+ if boolExprTrees[v].Strip() != sloTernlog {
+ panic(fmt.Errorf("mergeTernlog should only be called on ternlog nodes, v is %s with SLO %v", v.LongString(), boolExprTrees[v]))
}
- return vars
+
+ if boolExprTrees[v.Args[0]].Strip() == sloTernlog {
+ mergeBranchLeaves(v, v.Args[0])
+ }
+
+ if boolExprTrees[v.Args[1]].Strip() == sloTernlog {
+ mergeBranchLeaves(v, v.Args[1])
+ }
}
for _, v := range roots {
if f.pass.debug > 1 {
- f.Warnl(v.Pos, "SLO root %s", v.LongString())
+ f.Warnl(v.Pos, "tern optimized root %s", v.LongString())
}
- rewrite(v)
+ expandTernlog(v) // phase 1: expand all the nodes into ternlog instructions
+ mergeTernlog(v) // phase 2: merge ternlog instructions together when possible
+ }
+
+ // phase 3: reverse ternlog instructions back into their original logical operations when possible, this is mostly to lower binary size
+ for _, b := range f.Blocks {
+ for _, v := range b.Values {
+ if boolExprTrees[v].Strip() == sloTernlog {
+ var slo SIMDLogicalOP
+ if v.Args[0] == v.Args[2] { // TODO: more variations of imm8 for different instructions(and even constants) could exist, as well as different orders arguments need to be taken to account
+ switch uint8(v.AuxInt) {
+ case 0b1100_0000:
+ slo = sloAnd
+ case 0b1111_1100:
+ slo = sloOr
+ case 0b0011_1100:
+ slo = sloXor
+ case 0b0000_1100:
+ slo = sloAndNot
+ case 0b0011_0000:
+ slo = sloAndNot
+ v.Args[0], v.Args[1] = v.Args[1], v.Args[0] // andnot with imm8 48 is the same as andnot with imm8 12 but with the arguments reversed
+ case 0b0000_1111:
+ slo = sloNot
+ default:
+ continue // Leave as ternlog
+ }
+ }
+ op := classifyInstructionForTernaryLogical(v, slo)
+ if op != v.Op {
+ vars := [2]*Value{v.Args[0], v.Args[1]}
+ if f.pass.debug > 0 {
+ f.Warnl(v.Pos, "Reversed %s into %s based on truth table 0b%b", v.LongString(), op, uint8(v.AuxInt))
+ }
+ v.reset(op)
+ v.SetArgs2(vars[0], vars[1])
+ v.AuxInt = 0
+ }
+ }
+ }
}
}
diff --git a/src/cmd/compile/internal/ssa/tern_helpers.go b/src/cmd/compile/internal/ssa/tern_helpers.go
index 923a9f5..1076864 100644
--- a/src/cmd/compile/internal/ssa/tern_helpers.go
+++ b/src/cmd/compile/internal/ssa/tern_helpers.go
@@ -15,85 +15,86 @@
sloAndNot
sloXor
sloNot
+ sloTernlog
)
func classifyBooleanSIMD(v *Value) SIMDLogicalOP {
switch v.Op {
- case OpAndInt8x16, OpAndInt16x8, OpAndInt32x4, OpAndInt64x2, OpAndInt8x32, OpAndInt16x16, OpAndInt32x8, OpAndInt64x4, OpAndInt8x64, OpAndInt16x32, OpAndInt32x16, OpAndInt64x8:
+ case OpAndInt8x16, OpAndUint8x16, OpAndInt16x8, OpAndUint16x8, OpAndInt32x4, OpAndUint32x4, OpAndInt64x2, OpAndUint64x2, OpAndInt8x32, OpAndUint8x32, OpAndInt16x16, OpAndUint16x16, OpAndInt32x8, OpAndUint32x8, OpAndInt64x4, OpAndUint64x4, OpAndInt8x64, OpAndUint8x64, OpAndInt16x32, OpAndUint16x32, OpAndInt32x16, OpAndUint32x16, OpAndInt64x8, OpAndUint64x8:
return sloAnd
- case OpOrInt8x16, OpOrInt16x8, OpOrInt32x4, OpOrInt64x2, OpOrInt8x32, OpOrInt16x16, OpOrInt32x8, OpOrInt64x4, OpOrInt8x64, OpOrInt16x32, OpOrInt32x16, OpOrInt64x8:
+ case OpOrInt8x16, OpOrUint8x16, OpOrInt16x8, OpOrUint16x8, OpOrInt32x4, OpOrUint32x4, OpOrInt64x2, OpOrUint64x2, OpOrInt8x32, OpOrUint8x32, OpOrInt16x16, OpOrUint16x16, OpOrInt32x8, OpOrUint32x8, OpOrInt64x4, OpOrUint64x4, OpOrInt8x64, OpOrUint8x64, OpOrInt16x32, OpOrUint16x32, OpOrInt32x16, OpOrUint32x16, OpOrInt64x8, OpOrUint64x8:
return sloOr
- case OpAndNotInt8x16, OpAndNotInt16x8, OpAndNotInt32x4, OpAndNotInt64x2, OpAndNotInt8x32, OpAndNotInt16x16, OpAndNotInt32x8, OpAndNotInt64x4, OpAndNotInt8x64, OpAndNotInt16x32, OpAndNotInt32x16, OpAndNotInt64x8:
+ case OpAndNotInt8x16, OpAndNotUint8x16, OpAndNotInt16x8, OpAndNotUint16x8, OpAndNotInt32x4, OpAndNotUint32x4, OpAndNotInt64x2, OpAndNotUint64x2, OpAndNotInt8x32, OpAndNotUint8x32, OpAndNotInt16x16, OpAndNotUint16x16, OpAndNotInt32x8, OpAndNotUint32x8, OpAndNotInt64x4, OpAndNotUint64x4, OpAndNotInt8x64, OpAndNotUint8x64, OpAndNotInt16x32, OpAndNotUint16x32, OpAndNotInt32x16, OpAndNotUint32x16, OpAndNotInt64x8, OpAndNotUint64x8:
return sloAndNot
- case OpXorInt8x16:
+ case OpXorInt8x16, OpXorUint8x16:
if y := v.Args[1]; y.Op == OpEqualInt8x16 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt16x8:
+ case OpXorInt16x8, OpXorUint16x8:
if y := v.Args[1]; y.Op == OpEqualInt16x8 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt32x4:
+ case OpXorInt32x4, OpXorUint32x4:
if y := v.Args[1]; y.Op == OpEqualInt32x4 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt64x2:
+ case OpXorInt64x2, OpXorUint64x2:
if y := v.Args[1]; y.Op == OpEqualInt64x2 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt8x32:
+ case OpXorInt8x32, OpXorUint8x32:
if y := v.Args[1]; y.Op == OpEqualInt8x32 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt16x16:
+ case OpXorInt16x16, OpXorUint16x16:
if y := v.Args[1]; y.Op == OpEqualInt16x16 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt32x8:
+ case OpXorInt32x8, OpXorUint32x8:
if y := v.Args[1]; y.Op == OpEqualInt32x8 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt64x4:
+ case OpXorInt64x4, OpXorUint64x4:
if y := v.Args[1]; y.Op == OpEqualInt64x4 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt8x64:
+ case OpXorInt8x64, OpXorUint8x64:
if y := v.Args[1]; y.Op == OpEqualInt8x64 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt16x32:
+ case OpXorInt16x32, OpXorUint16x32:
if y := v.Args[1]; y.Op == OpEqualInt16x32 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt32x16:
+ case OpXorInt32x16, OpXorUint32x16:
if y := v.Args[1]; y.Op == OpEqualInt32x16 &&
y.Args[0] == y.Args[1] {
return sloNot
}
return sloXor
- case OpXorInt64x8:
+ case OpXorInt64x8, OpXorUint64x8:
if y := v.Args[1]; y.Op == OpEqualInt64x8 &&
y.Args[0] == y.Args[1] {
return sloNot
@@ -158,3 +159,121 @@
}
return op
}
+
+func classifyInstructionForTernaryLogical(v *Value, slo SIMDLogicalOP) Op {
+ switch slo {
+ case sloAnd:
+ switch v.Op {
+ case OpternInt32x4:
+ return OpAndInt32x4
+ case OpternUint32x4:
+ return OpAndUint32x4
+ case OpternInt64x2:
+ return OpAndInt64x2
+ case OpternUint64x2:
+ return OpAndUint64x2
+ case OpternInt32x8:
+ return OpAndInt32x8
+ case OpternUint32x8:
+ return OpAndUint32x8
+ case OpternInt64x4:
+ return OpAndInt64x4
+ case OpternUint64x4:
+ return OpAndUint64x4
+ case OpternInt32x16:
+ return OpAndInt32x16
+ case OpternUint32x16:
+ return OpAndUint32x16
+ case OpternInt64x8:
+ return OpAndInt64x8
+ case OpternUint64x8:
+ return OpAndUint64x8
+
+ }
+ case sloOr:
+ switch v.Op {
+ case OpternInt32x4:
+ return OpOrInt32x4
+ case OpternUint32x4:
+ return OpOrUint32x4
+ case OpternInt64x2:
+ return OpOrInt64x2
+ case OpternUint64x2:
+ return OpOrUint64x2
+ case OpternInt32x8:
+ return OpOrInt32x8
+ case OpternUint32x8:
+ return OpOrUint32x8
+ case OpternInt64x4:
+ return OpOrInt64x4
+ case OpternUint64x4:
+ return OpOrUint64x4
+ case OpternInt32x16:
+ return OpOrInt32x16
+ case OpternUint32x16:
+ return OpOrUint32x16
+ case OpternInt64x8:
+ return OpOrInt64x8
+ case OpternUint64x8:
+ return OpOrUint64x8
+
+ }
+ case sloAndNot:
+ switch v.Op {
+ case OpternInt32x4:
+ return OpAndNotInt32x4
+ case OpternUint32x4:
+ return OpAndNotUint32x4
+ case OpternInt64x2:
+ return OpAndNotInt64x2
+ case OpternUint64x2:
+ return OpAndNotUint64x2
+ case OpternInt32x8:
+ return OpAndNotInt32x8
+ case OpternUint32x8:
+ return OpAndNotUint32x8
+ case OpternInt64x4:
+ return OpAndNotInt64x4
+ case OpternUint64x4:
+ return OpAndNotUint64x4
+ case OpternInt32x16:
+ return OpAndNotInt32x16
+ case OpternUint32x16:
+ return OpAndNotUint32x16
+ case OpternInt64x8:
+ return OpAndNotInt64x8
+ case OpternUint64x8:
+ return OpAndNotUint64x8
+
+ }
+ case sloXor, sloNot:
+ switch v.Op {
+ case OpternInt32x4:
+ return OpXorInt32x4
+ case OpternUint32x4:
+ return OpXorUint32x4
+ case OpternInt64x2:
+ return OpXorInt64x2
+ case OpternUint64x2:
+ return OpXorUint64x2
+ case OpternInt32x8:
+ return OpXorInt32x8
+ case OpternUint32x8:
+ return OpXorUint32x8
+ case OpternInt64x4:
+ return OpXorInt64x4
+ case OpternUint64x4:
+ return OpXorUint64x4
+ case OpternInt32x16:
+ return OpXorInt32x16
+ case OpternUint32x16:
+ return OpXorUint32x16
+ case OpternInt64x8:
+ return OpXorInt64x8
+ case OpternUint64x8:
+ return OpXorUint64x8
+
+ }
+ }
+ return v.Op
+}
diff --git a/src/simd/archsimd/_gen/tmplgen/main.go b/src/simd/archsimd/_gen/tmplgen/main.go
index da7eae9..f001d9e 100644
--- a/src/simd/archsimd/_gen/tmplgen/main.go
+++ b/src/simd/archsimd/_gen/tmplgen/main.go
@@ -976,7 +976,7 @@
one(*cmh, curryTestPrologue("simd methods that compare two operands under a mask"), compareMaskedTemplate)
}
- nonTemplateRewrites(SSA+"tern_helpers.go", ssaPrologue, classifyBooleanSIMD, ternOpForLogical)
+ nonTemplateRewrites(SSA+"tern_helpers.go", ssaPrologue, classifyBooleanSIMD, ternOpForLogical, classifyInstructionForTernaryLogical)
}
@@ -1004,6 +1004,75 @@
}
+func classifyInstructionForTernaryLogical(out io.Writer) {
+ fmt.Fprintf(out, `
+func classifyInstructionForTernaryLogical(v *Value, slo SIMDLogicalOP) Op {
+ switch slo {
+ case sloAnd:
+ switch v.Op {
+`)
+
+ intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
+ if w < 32 {
+ return
+ }
+ fmt.Fprintf(out, "case OpternInt%dx%d: return OpAndInt%dx%d\n", w, c, w, c)
+ fmt.Fprintf(out, "case OpternUint%dx%d: return OpAndUint%dx%d\n", w, c, w, c)
+ }, out)
+
+ fmt.Fprintf(out, `
+ }
+ case sloOr:
+ switch v.Op {
+`)
+
+ intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
+ if w < 32 {
+ return
+ }
+
+ fmt.Fprintf(out, "case OpternInt%dx%d: return OpOrInt%dx%d\n", w, c, w, c)
+ fmt.Fprintf(out, "case OpternUint%dx%d: return OpOrUint%dx%d\n", w, c, w, c)
+ }, out)
+
+ fmt.Fprintf(out, `
+ }
+ case sloAndNot:
+ switch v.Op {
+`)
+
+ intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
+ if w < 32 {
+ return
+ }
+
+ fmt.Fprintf(out, "case OpternInt%dx%d: return OpAndNotInt%dx%d\n", w, c, w, c)
+ fmt.Fprintf(out, "case OpternUint%dx%d: return OpAndNotUint%dx%d\n", w, c, w, c)
+ }, out)
+
+ fmt.Fprintf(out, `
+ }
+ case sloXor, sloNot:
+ switch v.Op {
+`)
+
+ intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
+ if w < 32 {
+ return
+ }
+ fmt.Fprintf(out, "case OpternInt%dx%d: return OpXorInt%dx%d\n", w, c, w, c)
+ fmt.Fprintf(out, "case OpternUint%dx%d: return OpXorUint%dx%d\n", w, c, w, c)
+ }, out)
+
+ fmt.Fprintf(out, `
+ }
+ }
+ return v.Op
+}
+`)
+
+}
+
func classifyBooleanSIMD(out io.Writer) {
fmt.Fprintf(out, `
type SIMDLogicalOP uint8
@@ -1018,17 +1087,19 @@
sloAndNot
sloXor
sloNot
+ sloTernlog
)
func classifyBooleanSIMD(v *Value) SIMDLogicalOP {
switch v.Op {
case `)
intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
op := "And"
+ pref := ""
if seq > 0 {
- fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
- } else {
- fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
+ pref = ","
}
+ fmt.Fprintf(out, "%sOp%s%s%dx%d,", pref, op, "Int", w, c)
+ fmt.Fprintf(out, "Op%s%s%dx%d", op, "Uint", w, c)
seq++
}, out)
@@ -1038,11 +1109,12 @@
case `)
intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
op := "Or"
+ pref := ""
if seq > 0 {
- fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
- } else {
- fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
+ pref = ","
}
+ fmt.Fprintf(out, "%sOp%s%s%dx%d,", pref, op, "Int", w, c)
+ fmt.Fprintf(out, "Op%s%s%dx%d", op, "Uint", w, c)
seq++
}, out)
@@ -1052,11 +1124,12 @@
case `)
intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
op := "AndNot"
+ pref := ""
if seq > 0 {
- fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
- } else {
- fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
+ pref = ","
}
+ fmt.Fprintf(out, "%sOp%s%s%dx%d,", pref, op, "Int", w, c)
+ fmt.Fprintf(out, "Op%s%s%dx%d", op, "Uint", w, c)
seq++
}, out)
@@ -1070,7 +1143,7 @@
intShapes.forAllShapes(
func(seq int, t, upperT string, w, c int, out io.Writer) {
- fmt.Fprintf(out, "case OpXor%s%dx%d: ", upperT, w, c)
+ fmt.Fprintf(out, "case OpXor%s%dx%d, OpXor%s%dx%d: ", "Int", w, c, "Uint", w, c)
fmt.Fprintf(out, `
if y := v.Args[1]; y.Op == OpEqual%s%dx%d &&
y.Args[0] == y.Args[1] {
diff --git a/test/codegen/simd.go b/test/codegen/simd.go
index acec542..741f5dc 100644
--- a/test/codegen/simd.go
+++ b/test/codegen/simd.go
@@ -110,3 +110,20 @@
// amd64:`VPSRLD\s\$1,\s.*$`
return x.ShiftAllRight(1)
}
+
+// checks for src/cmd/compile/internal/ssa/rewritetern.go
+func simdVPTERNLOG_AndOr() archsimd.Uint32x8 {
+ var x, y, z archsimd.Uint32x8
+ // (x & y) | z
+ // This should collapse into VPTERNLOGD with imm8 0xEA
+ // amd64:`VPTERNLOGD\s\$234,`
+ return x.And(y).Or(z)
+}
+
+func simdVPTERNLOG_XorAnd() archsimd.Uint32x8 {
+ var x, y, z archsimd.Uint32x8
+ // (x ^ y) & z
+ // This should collapse into VPTERNLOGD with imm8 0x60
+ // amd64:`VPTERNLOGD\s\$96,`
+ return x.Xor(y).And(z)
+}