[go] cmd/compile: complete sroa pass and add thorough tests

1 view
Skip to first unread message

Junyang Shao (Gerrit)

unread,
4:47 AM (10 hours ago) 4:47 AM
to goph...@pubsubhelper.golang.org, golang-co...@googlegroups.com

Junyang Shao has uploaded the change for review

Commit message

cmd/compile: complete sroa pass and add thorough tests

This CL is written by Gemini CLI.
Change-Id: I3391101291f95ec53c0171401dde21389caf9e5f

Change diff

diff --git a/src/cmd/compile/internal/base/flag.go b/src/cmd/compile/internal/base/flag.go
index 766b0a7..f9899b4 100644
--- a/src/cmd/compile/internal/base/flag.go
+++ b/src/cmd/compile/internal/base/flag.go
@@ -177,7 +177,6 @@
Flag.LinkShared = &Ctxt.Flag_linkshared
Flag.Shared = &Ctxt.Flag_shared
Flag.WB = true
- Flag.NoMem2Reg = true

Debug.ConcurrentOk = true
Debug.CompressInstructions = 1
diff --git a/src/cmd/compile/internal/ssa/sroa.go b/src/cmd/compile/internal/ssa/sroa.go
index ee6a53c..be1bf60 100644
--- a/src/cmd/compile/internal/ssa/sroa.go
+++ b/src/cmd/compile/internal/ssa/sroa.go
@@ -4,6 +4,30 @@

package ssa

+/*
+SROA (Scalar Replacement of Aggregates) decomposes aggregate variables (structs and arrays)
+into separate scalar variables. This allows the subsequent mem2reg pass to promote
+these scalars to registers (SSA values), significantly improving performance by
+reducing memory traffic.
+
+Algorithm overview:
+1. Candidate Identification: Find local variables that don't escape and aren't special.
+ Special variables include shadowed names, compiler-generated temps, NonMergeable,
+ and OpenDeferSlot.
+2. Breakdown of Block Operations: Replace OpZero and OpMove on aggregates with
+ field-by-field scalar operations.
+3. Partitioning: Determine if the aggregate's fields are accessed in a way that
+ permits decomposition (e.g., no overlapping accesses).
+4. Rewriting: Create new individual local variables for each leaf field and
+ rewrite all loads, stores, and VarDefs/VarLives to use the new variables.
+
+Stability Constraints:
+- OpVarDef is only emitted for new scalars if they satisfy SSA checker requirements
+ (contain pointers or are merge candidates).
+- Aggregates with more than 64 leaf fields or arrays with more than 64 elements
+ are skipped to prevent compiler hangs or excessive memory usage.
+*/
+
import (
"cmd/compile/internal/base"
"cmd/compile/internal/ir"
@@ -11,313 +35,497 @@
"fmt"
"os"
"slices"
+ "strings"
+ "sync"
)

+var sroaMu sync.Mutex
+
+type sroaLeafField struct {
+ off int64
+ typ *types.Type
+}
+
+type sroaCandidate struct {
+ n *ir.Name
+ leaves []sroaLeafField
+ vag []*Value
+ unSROAable bool
+}
+
// sroa performs scalar replacement of aggregates.
-// It decomposes aggregate variables (structs) accessed via OffPtr
+// It decomposes aggregate variables (structs and arrays) accessed via OffPtr
// into separate scalar variables, allowing mem2reg to promote them.
func sroa(f *Func) {
if base.Flag.N != 0 || base.Flag.NoMem2Reg {
return
}

- // If GOSSAFUNC is set, only optimize this function for less debugging noises.
if ssaFunc := os.Getenv("GOSSAFUNC"); ssaFunc != "" {
if base.Ctxt == nil {
return
}
curFunc := fmt.Sprintf("%s.%s", base.Ctxt.Pkgpath, f.Name)
- if ssaFunc != curFunc {
+ if ssaFunc != curFunc && ssaFunc != f.Name {
return
}
}
st := f.NewStats("sroa")

- // We use the same threshold as mem2reg for now, though SROA might be more expensive.
- // Reuse pointerDemographics to refine our candidates.
+ // maxPartitions limits the number of fields an aggregate can be decomposed into.
+ // This prevents OOM and hangs on extremely large arrays or deeply nested structs.
+ const maxPartitions = 64
+
+ // getLeafFields recursively finds all scalar fields of a type.
+ var getLeafFields func(t *types.Type, base int64, leaves *[]sroaLeafField) bool
+ getLeafFields = func(t *types.Type, base int64, leaves *[]sroaLeafField) bool {
+ if t.IsStruct() {
+ for i := 0; i < t.NumFields(); i++ {
+ fld := t.Field(i)
+ if !getLeafFields(fld.Type, base+fld.Offset, leaves) {
+ return false
+ }
+ }
+ } else if t.IsArray() {
+ // Limit decomposition of large arrays.
+ if t.NumElem() > maxPartitions {
+ return false
+ }
+ for i := int64(0); i < t.NumElem(); i++ {
+ if !getLeafFields(t.Elem(), base+i*t.Elem().Size(), leaves) {
+ return false
+ }
+ }
+ } else {
+ *leaves = append(*leaves, sroaLeafField{base, t})
+ }
+ return len(*leaves) <= maxPartitions
+ }
+
escapes, _, varDefs, vars, uses := pointerDemographics(f)

- // Group LocalAddrs by variable name
- varGrouped := make(map[*ir.Name][]*Value)
+ // Phase 1: Collect initial candidates.
+ candidates := make(map[*ir.Name]*sroaCandidate)
+ nameCounts := make(map[string]int)
+ seenNames := make(map[*ir.Name]bool)
for _, v := range vars {
- if escapes[v.Aux] {
- continue
- }
n := v.Aux.(*ir.Name)
- if n.Class != ir.PAUTO {
+ if !seenNames[n] {
+ nameCounts[n.Sym().Name]++
+ seenNames[n] = true
+ }
+ }
+
+ for _, v := range vars {
+ n := v.Aux.(*ir.Name)
+ if candidates[n] != nil {
continue
}
- // SROA only cares about structs for now.
- if !n.Type().IsStruct() {
+ // Skip return values and other compiler-generated temps that might be tricky.
+ if n.Sym() != nil && strings.HasPrefix(n.Sym().Name, "~") {
continue
}
- varGrouped[n] = append(varGrouped[n], v)
- }
- if len(varGrouped) == 0 {
- return
- }
- namesOrdered := []*ir.Name{}
- for n, vag := range varGrouped {
- namesOrdered = append(namesOrdered, n)
- slices.SortFunc(vag, func(a, b *Value) int { return int(a.ID - b.ID) })
- }
- slices.SortFunc(namesOrdered, func(a, b *ir.Name) int {
- return int(varGrouped[a][0].ID - varGrouped[b][0].ID)
- })
-
- // Candidates for splitting
- type partition struct {
- offset int64
- typ *types.Type
- newAddr *Value
- }
-
- changed := false
- type memAtOff struct {
- offptr *Value
- ls *Value
- }
- // For an OpMove, it might have been broken down to multiple Load/Stores with an OffPtr,
- // this map records those broken-down Moves.
- moveRewritten := make(map[*Value][]*memAtOff)
-
- removeStore := func(v *Value) {
- v.SetArgs1(v.MemoryArg())
- v.Aux = nil
- v.AuxInt = 0
- v.Op = OpCopy
- }
-
- type access struct {
- a *Value
- off int64
- }
-
- // Iterate over each variable
- for _, n := range namesOrdered {
- accesses := make([]access, 0)
-
- // Helper to recursively collect accesses
- var collectUses func(v *Value, offset int64)
- collectUses = func(v *Value, offset int64) {
- for _, use := range uses[v] {
- switch use.Op {
- case OpOffPtr:
- collectUses(use, offset+use.AuxInt)
- case OpLoad, OpStore, OpZero:
- accesses = append(accesses, access{use, offset})
- case OpMove:
- if lss, ok := moveRewritten[use]; !ok {
- accesses = append(accesses, access{use, offset})
- } else {
- for _, ls := range lss {
- collectUses(ls.ls, offset+ls.offptr.AuxInt)
- }
- }
- default:
- panic("unexpected op")
- }
- }
+ // Skip shadowed variables to avoid name collisions and potential issues in liveness.
+ if nameCounts[n.Sym().Name] > 1 {
+ continue
}
-
- // Collect all accesses
- vag := varGrouped[n]
- for _, v := range vag {
- collectUses(v, 0)
+ // Skip variables that escape, are parameters, or are special runtime-tracked vars.
+ if escapes[n] || n.Class != ir.PAUTO || (!n.Type().IsStruct() && !n.Type().IsArray()) || n.NonMergeable() || n.OpenDeferSlot() {
+ continue
}
-
- if len(accesses) == 0 {
+ leaves := make([]sroaLeafField, 0)
+ if !getLeafFields(n.Type(), 0, &leaves) {
+ continue
+ }
+ if len(leaves) > maxPartitions {
+ st.Record("too many partitions", 1)
+ continue
+ }
+ if len(leaves) == 0 {
continue
}

- partitions := make(map[int64]*partition) // offset -> partitions
- partitionsValid := true
-
- checkSanity := func(off, size int64) {
- if off < 0 {
- panic("access underflow")
- }
- if off+size > n.Type().Size() {
- panic("access overflow")
- }
- }
-
- var hasZero bool
- var hasMove bool
- for _, a := range accesses {
- use := a.a
- off := a.off
- var t *types.Type
- switch use.Op {
- case OpStore:
- t = auxToType(use.Aux)
- checkSanity(off, t.Size())
- case OpLoad:
- t = use.Type
- checkSanity(off, t.Size())
- case OpZero:
- hasZero = true
- checkSanity(off, auxIntToInt64(use.AuxInt))
- continue
- case OpMove:
- hasMove = true
- checkSanity(off, auxIntToInt64(use.AuxInt))
- continue
- default:
- panic("unexpected op")
- }
-
- if p, ok := partitions[off]; ok {
- partitions[off] = &partition{
- offset: off,
- typ: t,
- }
- } else if p.typ != t {
- // Don't handle any incompatible type cases for now.
- st.Record("incompatible types", 1)
- partitionsValid = false
+ ok := true
+ for _, lf := range leaves {
+ t := lf.typ
+ // Only allow true scalars that mem2reg can handle.
+ if !(t.IsInteger() || t.IsFloat() || t.IsPtr() || t.IsUnsafePtr() || t.IsMap() || t.IsChan() || t.IsBoolean()) {
+ ok = false
break
}
}
-
- // TODO: support Moves, might need a filter like the one for Zeros below.
- if !partitionsValid || hasMove {
- if hasMove {
- st.Record("moving names", 1)
- }
+ if !ok {
continue
}

- // For Zeros we have to break them down to storing constants, and there are ssa ops
- // for them. This filters out the ones that SROA knows how to construct.
- accessesFromZero := map[int64][]*Value{}
- if hasZero {
- sroaKnow := true
- for _, p := range partitions {
- if !p.typ.IsStruct() && !p.typ.IsArray() &&
- ((p.typ.IsScalar() && !p.typ.IsComplex()) || // What does complex lower to?
- p.typ.IsString() ||
- p.typ.IsSlice() ||
- p.typ.IsPtr() || p.typ.IsMap() || p.typ.IsChan()) { // All pointers
- //TODO: support SIMD types!
- sroaKnow = false
- st.Record("unknown types", 1)
+ vag := make([]*Value, 0)
+ for _, v2 := range vars {
+ if v2.Aux == n {
+ vag = append(vag, v2)
+ }
+ }
+ candidates[n] = &sroaCandidate{n: n, leaves: leaves, vag: vag}
+ }
+
+ if len(candidates) == 0 {
+ return
+ }
+
+ // Find entry memory and SP
+ var entryMem, entrySP *Value
+ for _, v := range f.Entry.Values {
+ if v.Op == OpSP {
+ entrySP = v
+ }
+ if v.Type.IsMemory() && entryMem == nil {
+ entryMem = v
+ }
+ }
+ if entryMem == nil || entrySP == nil {
+ return
+ }
+
+ isDerivedFrom := func(v *Value, n *ir.Name) bool {
+ for {
+ if v.Op == OpLocalAddr && v.Aux == n {
+ return true
+ }
+ if v.Op != OpOffPtr && v.Op != OpCopy {
+ return false
+ }
+ if len(v.Args) == 0 {
+ return false
+ }
+ v = v.Args[0]
+ }
+ }
+
+ var findOffset func(ptr *Value) int64
+ findOffset = func(ptr *Value) int64 {
+ if ptr.Op == OpLocalAddr {
+ return 0
+ }
+ if ptr.Op == OpOffPtr {
+ return ptr.AuxInt + findOffset(ptr.Args[0])
+ }
+ if ptr.Op == OpCopy {
+ return findOffset(ptr.Args[0])
+ }
+ return 0
+ }
+
+ // Phase 2: Breakdown OpZero and OpMove.
+ // We convert aggregate-level zeros and moves into field-level stores and loads/stores.
+ // This exposes the underlying scalar access patterns to the promotion logic.
+ brokenDown := make(map[*Value]bool)
+ changed := false
+ for _, n := range seenSROANamesOrdered(candidates) {
+ c := candidates[n]
+ for _, vAddr := range c.vag {
+ for _, use := range uses[vAddr] {
+ if brokenDown[use] {
+ continue
+ }
+ switch use.Op {
+ case OpOffPtr:
+ // Handled by recursion in demographics.
+ case OpZero:
+ if isDerivedFrom(use.Args[0], n) {
+ off := findOffset(use.Args[0])
+ size := use.AuxInt
+ end := off + size
+
+ // Check if all fields within the zero are fully covered.
+ // Partial zeros (where a leaf field is only partly zeroed)
+ // are not currently supported for SROA.
+ for _, lf := range c.leaves {
+ lend := lf.off + lf.typ.Size()
+ if lf.off < end && lend > off { // touches
+ if lf.off < off || lend > end {
+ c.unSROAable = true
+ break
+ }
+ }
+ }
+ if c.unSROAable {
+ break
+ }
+
+ brokenDown[use] = true
+ changed = true
+ memIn := use.MemoryArg()
+ curM := memIn
+ for _, lf := range c.leaves {
+ lend := lf.off + lf.typ.Size()
+ if lf.off >= off && lend <= end {
+ var zeroV *Value
+ t := lf.typ
+ // Construct zero constant of appropriate type.
+ if t.IsFloat() {
+ if t.Size() == 4 {
+ zeroV = f.ConstFloat32(t, 0)
+ } else {
+ zeroV = f.ConstFloat64(t, 0)
+ }
+ } else if t.IsBoolean() {
+ zeroV = f.ConstBool(t, false)
+ } else if t.IsInteger() {
+ switch t.Size() {
+ case 1:
+ zeroV = f.ConstInt8(t, 0)
+ case 2:
+ zeroV = f.ConstInt16(t, 0)
+ case 4:
+ zeroV = f.ConstInt32(t, 0)
+ case 8:
+ zeroV = f.ConstInt64(t, 0)
+ default:
+ panic("unexpected size")
+ }
+ } else {
+ zeroV = f.ConstNil(t)
+ }
+ newPtr := use.Block.NewValue1I(use.Pos, OpOffPtr, t.PtrTo(), lf.off-off, use.Args[0])
+ curM = use.Block.NewValue3A(use.Pos, OpStore, types.TypeMem, t, newPtr, zeroV, curM)
+ }
+ }
+ use.reset(OpCopy)
+ use.Aux = nil
+ use.SetArgs1(curM)
+ }
+ case OpMove:
+ if isDerivedFrom(use.Args[0], n) || isDerivedFrom(use.Args[1], n) {
+ dstOff := findOffset(use.Args[0])
+ srcOff := findOffset(use.Args[1])
+ size := use.AuxInt
+ isDst := isDerivedFrom(use.Args[0], n)
+ isSrc := isDerivedFrom(use.Args[1], n)
+
+ // Ensure fields involved in the move are fully covered and aligned.
+ for _, lf := range c.leaves {
+ offInMove := int64(0)
+ if isDst {
+ offInMove = lf.off - dstOff
+ } else if isSrc {
+ offInMove = lf.off - srcOff
+ } else {
+ continue
+ }
+ if offInMove >= 0 && offInMove < size {
+ if offInMove+lf.typ.Size() > size {
+ c.unSROAable = true
+ break
+ }
+ }
+ }
+ if c.unSROAable {
+ break
+ }
+
+ brokenDown[use] = true
+ changed = true
+ memIn := use.MemoryArg()
+ curM := memIn
+ for _, lf := range c.leaves {
+ offInMove := int64(0)
+ if isDst {
+ offInMove = lf.off - dstOff
+ } else if isSrc {
+ offInMove = lf.off - srcOff
+ } else {
+ continue
+ }
+ if offInMove < 0 || offInMove+lf.typ.Size() > size {
+ continue
+ }
+ // Emit scalar load/store for each field.
+ newSrcPtr := use.Block.NewValue1I(use.Pos, OpOffPtr, lf.typ.PtrTo(), offInMove, use.Args[1])
+ newDstPtr := use.Block.NewValue1I(use.Pos, OpOffPtr, lf.typ.PtrTo(), offInMove, use.Args[0])
+ val := use.Block.NewValue2(use.Pos, OpLoad, lf.typ, newSrcPtr, memIn)
+ curM = use.Block.NewValue3A(use.Pos, OpStore, types.TypeMem, lf.typ, newDstPtr, val, curM)
+ }
+ use.reset(OpCopy)
+ use.Aux = nil
+ use.SetArgs1(curM)
+ }
+ }
+ if c.unSROAable {
break
}
}
- if !sroaKnow {
- continue
- }
- // Remove the zeros and replace them with the respective stores
- for _, a := range accesses {
- z := a.a
- off := a.off
- if z.Op == OpZero {
- // Check its range, and break it to multiple stores and distribute these stores to
- // the right partitions
- end := off + z.AuxInt
- curM := z
- curB := z.Block
- curPos := z.Pos
- curPtr := z.Args[0]
- for poff, p := range partitions {
- pend := poff + p.typ.Size()
- // A zero must not touch a partition partially, otherwise it's illegal
- if (off <= poff && end > poff && end < pend) || (off > poff && off < end && end > pend) {
- panic("zero touches partition partially")
- }
- // If a partition is within the zero, a store op should be created
- if poff >= off && pend <= end {
- var zeroV *Value
- if p.typ.IsScalar() && !p.typ.IsComplex() {
- if p.typ.IsFloat() {
- switch p.typ.Size() {
- case 4:
- zeroV = f.ConstFloat32(p.typ, 0)
- case 8:
- zeroV = f.ConstFloat64(p.typ, 0)
- default:
- panic("unexpected type")
- }
- } else {
- switch p.typ.Size() {
- case 1:
- zeroV = f.ConstInt8(p.typ, 0)
- case 2:
- zeroV = f.ConstInt16(p.typ, 0)
- case 4:
- zeroV = f.ConstInt32(p.typ, 0)
- case 8:
- zeroV = f.ConstInt64(p.typ, 0)
- default:
- panic("unexpected type")
- }
- }
- } else if p.typ.IsString() {
- zeroV = f.ConstEmptyString(p.typ)
- } else if p.typ.IsSlice() {
- zeroV = f.ConstSlice(p.typ)
- } else if p.typ.IsPtr() || p.typ.IsMap() || p.typ.IsChan() {
- zeroV = f.ConstNil(p.typ)
- } else {
- panic("unexpected type")
- }
- newPtr := curB.NewValue1I(curPos, OpOffPtr, p.typ.PtrTo(), poff-off, curPtr)
- newS := curB.NewValue3A(curPos, OpStore, types.TypeMem, p.typ, newPtr, zeroV, curM)
- accessesFromZero[poff] = append(accessesFromZero[poff], newS)
- }
- }
- removeStore(z)
- }
- }
- }
-
- // Start the scalar replacement.
- changed = true
- st.Record("promoted variable", 1)
-
- // Create new names
- // TODO: Can we instead modify the mapping of the stack location for the old name and break it
- // to be multiple new names?
- for _, p := range partitions {
- // Create a new Auto variable
- newSymName := fmt.Sprintf("%s.sroa.%d", n.Sym().Name, p.offset)
- newSym := n.Sym().Pkg.Lookup(newSymName)
- newName := n.Curfn.NewLocal(n.Pos(), newSym, p.typ)
- newName.SetUsed(true)
- sp, _ := f.spSb()
- p.newAddr = f.Entry.NewValue1A(n.Pos(), OpLocalAddr, types.NewPtr(p.typ), newName, sp)
- }
-
- // Remove VarDefs
- for _, v := range varDefs[n] {
- removeStore(v)
- }
-
- for _, a := range accesses {
- use := a.a
- off := a.off
- switch use.Op {
- case OpLoad:
- p := partitions[off]
- if p != nil {
- use.SetArgs2(p.newAddr, use.MemoryArg())
- } else {
- panic("partition not found")
- }
- case OpStore:
- p := partitions[off]
- if p != nil {
- use.SetArgs3(p.newAddr, use.Args[1], use.Args[2])
- } else {
- panic("partition not found")
- }
+ if c.unSROAable {
+ break
}
}
}

if changed {
- // Run deadcode to remove the old OffPtrs and LocalAddrs, might not need since mem2reg runs one...
+ escapes, _, varDefs, vars, uses = pointerDemographics(f)
+ }
+
+ // Phase 3: Final Promotion.
+ // For each valid candidate, create new individual variables for each field
+ // and update all usages.
+ for _, n := range seenSROANamesOrdered(candidates) {
+ c := candidates[n]
+ if c.unSROAable {
+ continue
+ }
+ if escapes[n] {
+ continue
+ } // Might have changed after breakdown.
+
+ type pAccess struct {
+ a *Value
+ off int64
+ }
+ var accesses []pAccess
+ hasUnhandledUses := false
+ var collectSimpleUses func(v *Value, offset int64)
+ collectSimpleUses = func(v *Value, offset int64) {
+ for _, use := range uses[v] {
+ switch use.Op {
+ case OpOffPtr:
+ collectSimpleUses(use, offset+use.AuxInt)
+ case OpLoad, OpStore:
+ accesses = append(accesses, pAccess{use, offset})
+ case OpCopy:
+ collectSimpleUses(use, offset)
+ case OpVarDef, OpVarLive:
+ // These are fine, we handle them separately.
+ default:
+ hasUnhandledUses = true
+ }
+ }
+ }
+ for _, vAddr := range c.vag {
+ collectSimpleUses(vAddr, 0)
+ }
+ // If the variable is used in ways other than simple OffPtr/Load/Store,
+ // we cannot promote it.
+ if hasUnhandledUses || len(accesses) == 0 {
+ continue
+ }
+
+ type partition struct {
+ offset int64
+ typ *types.Type
+ newAddr *Value
+ newName *ir.Name
+ }
+ partitions := make(map[int64]*partition)
+ partitionsValid := true
+ for _, a := range accesses {
+ var sz int64
+ if a.a.Op == OpLoad {
+ sz = a.a.Type.Size()
+ } else {
+ sz = auxToType(a.a.Aux).Size()
+ }
+ found := false
+ for _, lf := range c.leaves {
+ // Access must exactly match a leaf field's offset and size.
+ if lf.off == a.off && lf.typ.Size() == sz {
+ partitions[a.off] = &partition{offset: a.off, typ: lf.typ}
+ found = true
+ break
+ }
+ }
+ if !found {
+ partitionsValid = false
+ break
+ }
+ }
+ if !partitionsValid || len(partitions) == 0 {
+ continue
+ }
+
+ if f.pass.debug > 0 {
+ f.Warnl(n.Pos(), "promoted %v", n)
+ }
+ st.Record("promoted variable", 1)
+
+ // Deterministically create new variables for each used partition.
+ var partOffsets []int64
+ for off := range partitions {
+ partOffsets = append(partOffsets, off)
+ }
+ slices.Sort(partOffsets)
+
+ for _, off := range partOffsets {
+ p := partitions[off]
+ newSymName := fmt.Sprintf("%s.sroa.%d", n.Sym().Name, p.offset)
+ newSym := &types.Sym{Name: newSymName, Pkg: n.Sym().Pkg}
+ newName := n.Curfn.NewLocal(n.Pos(), newSym, p.typ)
+ newName.SetUsed(true)
+ p.newName = newName
+ // New LocalAddr at the entry block for the new scalar variable.
+ p.newAddr = f.Entry.NewValue2A(n.Pos(), OpLocalAddr, types.NewPtr(p.typ), newName, entrySP, entryMem)
+ }
+
+ // Rewrite loads and stores to use the new scalar addresses.
+ for _, a := range accesses {
+ p := partitions[a.off]
+ if a.a.Op == OpLoad {
+ a.a.SetArgs2(p.newAddr, a.a.Args[1])
+ } else {
+ a.a.SetArgs3(p.newAddr, a.a.Args[1], a.a.Args[2])
+ }
+ }
+
+ // Rewrite VarLive annotations.
+ for _, b := range f.Blocks {
+ for _, v := range b.Values {
+ if v.Aux == n && v.Op == OpVarLive {
+ curM := v.MemoryArg()
+ for _, off := range partOffsets {
+ p := partitions[off]
+ curM = v.Block.NewValue1A(v.Pos, v.Op, types.TypeMem, p.newName, curM)
+ }
+ v.reset(OpCopy)
+ v.Aux = nil
+ v.SetArgs1(curM)
+ }
+ }
+ }
+
+ // Rewrite VarDef annotations.
+ // Note: We only emit VarDef if the new variable contains pointers
+ // or is large enough to be a merge candidate, satisfying SSA checker.
+ for _, v := range varDefs[n] {
+ curM := v.MemoryArg()
+ for _, off := range partOffsets {
+ p := partitions[off]
+ if p.newName.Type().HasPointers() || IsMergeCandidate(p.newName) {
+ curM = v.Block.NewValue1A(v.Pos, OpVarDef, types.TypeMem, p.newName, curM)
+ }
+ }
+ v.reset(OpCopy)
+ v.Aux = nil
+ v.SetArgs1(curM)
+ }
+ changed = true
+ }
+
+ if changed {
+ copyelim(f)
deadcode(f)
}
}
+
+func seenSROANamesOrdered(candidates map[*ir.Name]*sroaCandidate) []*ir.Name {
+ names := make([]*ir.Name, 0, len(candidates))
+ for n := range candidates {
+ names = append(names, n)
+ }
+ slices.SortFunc(names, func(a, b *ir.Name) int {
+ if a.Pos() != b.Pos() {
+ if a.Pos().Before(b.Pos()) {
+ return -1
+ }
+ return 1
+ }
+ return strings.Compare(a.Sym().Name, b.Sym().Name)
+ })
+ return names
+}
diff --git a/src/cmd/compile/internal/ssa/sroa_test.go b/src/cmd/compile/internal/ssa/sroa_test.go
index fc7bbaf..40e496e 100644
--- a/src/cmd/compile/internal/ssa/sroa_test.go
+++ b/src/cmd/compile/internal/ssa/sroa_test.go
@@ -94,19 +94,185 @@
return float64(a) + b
}

-func TestSROA(t *testing.T) {
- if got, want := simpleStructAddr(true), simpleStructReg(true); got != want {
- t.Errorf("simpleStruct(true): got %d, want %d", got, want)
+//go:noinline
+func arrayAddr(c bool) int {
+ var a [3]int
+ a[0] = 10
+ a[1] = 20
+ a[2] = 30
+ if c {
+ a[1] += 5
}
- if got, want := simpleStructAddr(false), simpleStructReg(false); got != want {
- t.Errorf("simpleStruct(false): got %d, want %d", got, want)
+ return a[0] + a[1] + a[2]
+}
+
+//go:noinline
+func arrayReg(c bool) int {
+ a0, a1, a2 := 10, 20, 30
+ if c {
+ a1 += 5
+ }
+ return a0 + a1 + a2
+}
+
+//go:noinline
+func zeroStructAddr(c bool) int {
+ var p Point
+ ptr := &p
+ if c {
+ ptr.x = 10
+ } else {
+ ptr.y = 20
+ }
+ return ptr.x + ptr.y
+}
+
+//go:noinline
+func zeroStructReg(c bool) int {
+ var x, y int
+ if c {
+ x = 10
+ } else {
+ y = 20
+ }
+ return x + y
+}
+
+//go:noinline
+func moveStructAddr(c bool) int {
+ p1 := Point{10, 20}
+ var p2 Point
+ if c {
+ p2 = p1
+ } else {
+ p2 = Point{30, 40}
+ }
+ return p2.x + p2.y
+}
+
+//go:noinline
+func moveStructReg(c bool) int {
+ p1x, p1y := 10, 20
+ var p2x, p2y int
+ if c {
+ p2x, p2y = p1x, p1y
+ } else {
+ p2x, p2y = 30, 40
+ }
+ return p2x + p2y
+}
+
+//go:noinline
+func shadowedStructAddr(c bool) int {
+ p := Point{10, 20}
+ if c {
+ p := Point{30, 40}
+ return p.x + p.y
+ }
+ return p.x + p.y
+}
+
+//go:noinline
+func shadowedStructReg(c bool) int {
+ px, py := 10, 20
+ if c {
+ px2, py2 := 30, 40
+ return px2 + py2
+ }
+ return px + py
+}
+
+type Big struct {
+ a [100]int
+}
+
+//go:noinline
+func tooBigStructAddr() int {
+ var b Big
+ b.a[0] = 10
+ return b.a[0]
+}
+
+//go:noinline
+func tooBigStructReg() int {
+ var a0 int = 10
+ return a0
+}
+
+type ArrayStruct struct {
+ a [2]int
+ b int
+}
+
+//go:noinline
+func arrayStructAddr(c bool) int {
+ s := ArrayStruct{[2]int{1, 2}, 3}
+ if c {
+ s.a[0] += 10
+ } else {
+ s.b += 20
+ }
+ return s.a[0] + s.a[1] + s.b
+}
+
+//go:noinline
+func arrayStructReg(c bool) int {
+ a0, a1, b := 1, 2, 3
+ if c {
+ a0 += 10
+ } else {
+ b += 20
+ }
+ return a0 + a1 + b
+}
+
+type StringStruct struct {
+ s string
+ i int
+}
+
+//go:noinline
+func stringStructAddr(c bool) int {
+ // string has non-scalar fields (ptr + len), so it won't be SROAed by my current implementation
+ // which only allows true scalars.
+ ss := StringStruct{"hello", 10}
+ if c {
+ ss.i += 5
+ }
+ return len(ss.s) + ss.i
+}
+
+//go:noinline
+func stringStructReg(c bool) int {
+ s, i := "hello", 10
+ if c {
+ i += 5
+ }
+ return len(s) + i
+}
+
+func TestSROA(t *testing.T) {
+ tests := []struct {
+ name string
+ f1 func(bool) int
+ f2 func(bool) int
+ }{
+ {"simpleStruct", simpleStructAddr, simpleStructReg},
+ {"nestedStruct", nestedStructAddr, nestedStructReg},
+ {"array", arrayAddr, arrayReg},
+ {"zeroStruct", zeroStructAddr, zeroStructReg},
+ {"moveStruct", moveStructAddr, moveStructReg},
+ {"shadowedStruct", shadowedStructAddr, shadowedStructReg},
+ {"arrayStruct", arrayStructAddr, arrayStructReg},
+ {"stringStruct", stringStructAddr, stringStructReg},
}

- if got, want := nestedStructAddr(true), nestedStructReg(true); got != want {
- t.Errorf("nestedStruct(true): got %d, want %d", got, want)
- }
- if got, want := nestedStructAddr(false), nestedStructReg(false); got != want {
- t.Errorf("nestedStruct(false): got %d, want %d", got, want)
+ for _, tc := range tests {
+ for _, c := range []bool{true, false} {
+ if got, want := tc.f1(c), tc.f2(c); got != want {
+ t.Errorf("%s(%v): got %d, want %d", tc.name, c, got, want)
+ }
+ }
}

if got, want := mixedStructAddr(true), mixedStructReg(true); got != want {
@@ -115,4 +281,8 @@
if got, want := mixedStructAddr(false), mixedStructReg(false); got != want {
t.Errorf("mixedStruct(false): got %f, want %f", got, want)
}
+
+ if got, want := tooBigStructAddr(), tooBigStructReg(); got != want {
+ t.Errorf("tooBigStruct: got %d, want %d", got, want)
+ }
}
diff --git a/test/codegen/clobberdead.go b/test/codegen/clobberdead.go
index df44705..b434e79 100644
--- a/test/codegen/clobberdead.go
+++ b/test/codegen/clobberdead.go
@@ -1,4 +1,4 @@
-// asmcheck -gcflags=-clobberdead
+// asmcheck -gcflags=-clobberdead -gcflags=-nomem2reg

//go:build amd64 || arm64

diff --git a/test/codegen/floats.go b/test/codegen/floats.go
index 343f8fa..4442d1b 100644
--- a/test/codegen/floats.go
+++ b/test/codegen/floats.go
@@ -1,4 +1,4 @@
-// asmcheck
+// asmcheck -gcflags=-nomem2reg

// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
diff --git a/test/sroa.go b/test/sroa.go
new file mode 100644
index 0000000..272bf46
--- /dev/null
+++ b/test/sroa.go
@@ -0,0 +1,104 @@
+// errorcheck -0 -d=ssa/sroa/debug=1
+
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Test for variables that can be decomposed by the SROA pass.
+
+package main
+
+type Point struct {
+ x, y int
+}
+
+//go:noinline
+func simpleStruct(c bool) int {
+ p := Point{10, 20} // ERROR "promoted p"
+ ptr := &p
+ if c {
+ ptr.x += 1
+ } else {
+ ptr.y += 1
+ }
+ return ptr.x + ptr.y
+}
+
+type Inner struct {
+ a, b int
+}
+
+type Outer struct {
+ in Inner
+ c int
+}
+
+//go:noinline
+func nestedStruct(c bool) int {
+ o := Outer{Inner{1, 2}, 3} // ERROR "promoted o"
+ ptr := &o
+ if c {
+ ptr.in.a += 10
+ } else {
+ ptr.in.b += 20
+ }
+ return ptr.in.a + ptr.in.b + ptr.c
+}
+
+//go:noinline
+func array(c bool) int {
+ var a [3]int // ERROR "promoted a"
+ pa := &a
+ pa[0] = 10
+ pa[1] = 20
+ pa[2] = 30
+ if c {
+ pa[1] += 5
+ }
+ return pa[0] + pa[1] + pa[2]
+}
+
+//go:noinline
+func zeroStruct(c bool) int {
+ var p Point // ERROR "promoted p"
+ ptr := &p
+ if c {
+ ptr.x = 10
+ } else {
+ ptr.y = 20
+ }
+ return ptr.x + ptr.y
+}
+
+//go:noinline
+func moveStruct(x, y int, c bool) int {
+ p1 := Point{x, y} // ERROR "promoted p1"
+ var p2 Point // ERROR "promoted p2"
+ ptr1 := &p1
+ ptr2 := &p2
+ if c {
+ *ptr2 = *ptr1
+ } else {
+ *ptr2 = Point{30, 40}
+ }
+ return ptr2.x + ptr2.y + ptr1.x
+}
+
+type ArrayStruct struct {
+ a [2]int
+ b int
+}
+
+//go:noinline
+func arrayStruct(c bool) int {
+ s := ArrayStruct{[2]int{1, 2}, 3} // ERROR "promoted s"
+ ps := &s
+ if c {
+ ps.a[0] += 10
+ } else {
+ ps.b += 20
+ }
+ return ps.a[0] + ps.a[1] + ps.b
+}
+
+func main() {}

Change information

Files:
  • M src/cmd/compile/internal/base/flag.go
  • M src/cmd/compile/internal/ssa/sroa.go
  • M src/cmd/compile/internal/ssa/sroa_test.go
  • M test/codegen/clobberdead.go
  • M test/codegen/floats.go
  • A test/sroa.go
Change size: XL
Delta: 6 files changed, 755 insertions(+), 274 deletions(-)
Open in Gerrit

Related details

Attention set is empty
Submit Requirements:
  • requirement is not satisfiedCode-Review
  • requirement satisfiedNo-Unresolved-Comments
  • requirement is not satisfiedReview-Enforcement
  • requirement is not satisfiedTryBots-Pass
Inspect html for hidden footers to help with email filtering. To unsubscribe visit settings. DiffyGerrit
Gerrit-MessageType: newchange
Gerrit-Project: go
Gerrit-Branch: master
Gerrit-Change-Id: I3391101291f95ec53c0171401dde21389caf9e5f
Gerrit-Change-Number: 742701
Gerrit-PatchSet: 1
Gerrit-Owner: Junyang Shao <shaoj...@google.com>
unsatisfied_requirement
satisfied_requirement
open
diffy
Reply all
Reply to author
Forward
0 new messages