From 97a36298fa1e8b782d83eb947031ef58879c42b7 Mon Sep 17 00:00:00 2001 From: Isaac Marovitz Date: Fri, 21 Jun 2024 10:31:21 +0100 Subject: [PATCH] Instruction.Barrier Whoops Fix inline functions in compute stage Fix regression Declare SharedMemories + Only Declare Memories on Main Func Lowecase struct Avoid magic strings Make function signatures readable Change how unsized arrays are indexed Use string builder Fix shuffle instructions Cleanup NumberFormater Bunch of Subgroup I/O Vars Will probably need further refinement Fix point_coord type Fix support buffer declaration Fix point_coord --- .../CodeGen/Msl/CodeGenContext.cs | 2 +- .../CodeGen/Msl/Declarations.cs | 41 +++++++++++----- .../CodeGen/Msl/DefaultNames.cs | 2 + .../CodeGen/Msl/Instructions/InstGen.cs | 25 ++++++---- .../CodeGen/Msl/Instructions/InstGenCall.cs | 17 +++++-- .../CodeGen/Msl/Instructions/InstGenHelper.cs | 8 ++-- .../CodeGen/Msl/Instructions/InstGenMemory.cs | 10 +--- .../CodeGen/Msl/Instructions/IoMap.cs | 10 +++- .../CodeGen/Msl/MslGenerator.cs | 24 +++++++--- .../CodeGen/Msl/NumberFormatter.cs | 48 ++++++++----------- 10 files changed, 110 insertions(+), 77 deletions(-) diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/CodeGenContext.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/CodeGenContext.cs index 79c13964c..0ae6313eb 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/CodeGenContext.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/CodeGenContext.cs @@ -9,7 +9,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public const string Tab = " "; // The number of additional arguments that every function (except for the main one) must have (for instance support_buffer) - public const int AdditionalArgCount = 2; + public const int AdditionalArgCount = 1; public StructuredFunction CurrentFunction { get; set; } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs index 8d4a9c877..60729ac60 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs @@ -64,9 +64,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined; } - public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage) + public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false) { - DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false); + if (isMainFunc) + { + DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false); + DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true); + } + switch (stage) { case ShaderStage.Vertex: @@ -112,6 +117,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl private static void DeclareMemories(CodeGenContext context, IEnumerable memories, bool isShared) { + string prefix = isShared ? "threadgroup " : string.Empty; + foreach (var memory in memories) { string arraySize = ""; @@ -120,7 +127,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl arraySize = $"[{memory.ArrayLength}]"; } var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array); - context.AppendLine($"{typeName} {memory.Name}{arraySize};"); + context.AppendLine($"{prefix}{typeName} {memory.Name}{arraySize};"); } } @@ -128,23 +135,28 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl { foreach (BufferDefinition buffer in buffers) { - context.AppendLine($"struct Struct_{buffer.Name}"); + context.AppendLine($"struct {DefaultNames.StructPrefix}_{buffer.Name}"); context.EnterScope(); foreach (StructureField field in buffer.Type.Fields) { - if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0) - { - string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array); + string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array); + string arraySuffix = ""; - context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];"); - } - else + if (field.Type.HasFlag(AggregateType.Array)) { - string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array); - - context.AppendLine($"{typeName} {field.Name};"); + if (field.ArrayLength > 0) + { + arraySuffix = $"[{field.ArrayLength}]"; + } + else + { + // Probably UB, but this is the approach that MVK takes + arraySuffix = "[1]"; + } } + + context.AppendLine($"{typeName} {field.Name}{arraySuffix};"); } context.LeaveScope(";"); @@ -191,6 +203,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable.GlobalId => "uint3", IoVariable.VertexId => "uint", IoVariable.VertexIndex => "uint", + IoVariable.PointCoord => "float2", _ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false)) }; string name = ioDefinition.IoVariable switch @@ -199,6 +212,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable.GlobalId => "global_id", IoVariable.VertexId => "vertex_id", IoVariable.VertexIndex => "vertex_index", + IoVariable.PointCoord => "point_coord", _ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}" }; string suffix = ioDefinition.IoVariable switch @@ -208,6 +222,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable.VertexId => "[[vertex_id]]", // TODO: Avoid potential redeclaration IoVariable.VertexIndex => "[[vertex_id]]", + IoVariable.PointCoord => "[[point_coord]]", IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]", _ => "" }; diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/DefaultNames.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/DefaultNames.cs index 8a468395e..0b946c3aa 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/DefaultNames.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/DefaultNames.cs @@ -8,6 +8,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public const string IAttributePrefix = "inAttr"; public const string OAttributePrefix = "outAttr"; + public const string StructPrefix = "struct"; + public const string ArgumentNamePrefix = "a"; public const string UndefinedName = "0"; diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs index 8c101ad75..696564992 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs @@ -2,7 +2,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.Translation; using System; - +using System.Text; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory; @@ -39,11 +39,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions int arity = (int)(info.Type & InstType.ArityMask); - string args = string.Empty; + StringBuilder builder = new(); - if (atomic) + if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory)) { - // Hell + builder.Append(GenerateLoadOrStore(context, operation, isStore: false)); + + AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32 + ? AggregateType.S32 + : AggregateType.U32; + + for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++) + { + builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}"); + } } else { @@ -51,16 +60,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions { if (argIndex != 0) { - args += ", "; + builder.Append(", "); } AggregateType dstType = GetSrcVarType(inst, argIndex); - args += GetSourceExpr(context, operation.GetSource(argIndex), dstType); + builder.Append(GetSourceExpr(context, operation.GetSource(argIndex), dstType)); } } - return info.OpName + '(' + args + ')'; + return $"{info.OpName}({builder})"; } else if ((info.Type & InstType.Op) != 0) { @@ -110,7 +119,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions switch (inst & Instruction.Mask) { case Instruction.Barrier: - return "|| BARRIER ||"; + return "threadgroup_barrier(mem_flags::mem_threadgroup)"; case Instruction.Call: return Call(context, operation); case Instruction.FSIBegin: diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs index 5df3aa282..c063ff458 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs @@ -13,13 +13,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions var functon = context.GetFunction(funcId.Value); int argCount = operation.SourcesCount - 1; - string[] args = new string[argCount + CodeGenContext.AdditionalArgCount]; + int additionalArgCount = CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0); + + string[] args = new string[argCount + additionalArgCount]; // Additional arguments - args[0] = "in"; - args[1] = "support_buffer"; + if (context.Definitions.Stage != ShaderStage.Compute) + { + args[0] = "in"; + args[1] = "support_buffer"; + } + else + { + args[0] = "support_buffer"; + } - int argIndex = CodeGenContext.AdditionalArgCount; + int argIndex = additionalArgCount; for (int i = 0; i < argCount; i++) { args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i)); diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs index 406fda11a..014d070ef 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs @@ -109,10 +109,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3); Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3); Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3); - Add(Instruction.Shuffle, InstType.CallQuaternary, "simd_shuffle"); - Add(Instruction.ShuffleDown, InstType.CallQuaternary, "simd_shuffle_down"); - Add(Instruction.ShuffleUp, InstType.CallQuaternary, "simd_shuffle_up"); - Add(Instruction.ShuffleXor, InstType.CallQuaternary, "simd_shuffle_xor"); + Add(Instruction.Shuffle, InstType.CallBinary, "simd_shuffle"); + Add(Instruction.ShuffleDown, InstType.CallBinary, "simd_shuffle_down"); + Add(Instruction.ShuffleUp, InstType.CallBinary, "simd_shuffle_up"); + Add(Instruction.ShuffleXor, InstType.CallBinary, "simd_shuffle_xor"); Add(Instruction.Sine, InstType.CallUnary, "sin"); Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt"); Add(Instruction.Store, InstType.Special); diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs index 135cd80e0..bb1a69939 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs @@ -47,15 +47,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions StructureField field = buffer.Type.Fields[fieldIndex.Value]; varName = buffer.Name; - if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0) - { - // Unsized array, the buffer is indexed instead of the field - fieldName = "." + field.Name; - } - else - { - varName += "->" + field.Name; - } + varName += "->" + field.Name; varType = field.Type; break; diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs index 2e93310aa..1561271d0 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs @@ -27,13 +27,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32), IoVariable.InstanceId => ("instance_id", AggregateType.S32), IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32), - IoVariable.PointCoord => ("point_coord", AggregateType.Vector2), + IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32), IoVariable.PointSize => ("out.point_size", AggregateType.FP32), IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32), IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32), + IoVariable.SubgroupEqMask => ("thread_index_in_simdgroup >= 32 ? uint4(0, (1 << (thread_index_in_simdgroup - 32)), uint2(0)) : uint4(1 << thread_index_in_simdgroup, uint3(0))", AggregateType.Vector4 | AggregateType.U32), + IoVariable.SubgroupGeMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup, 32 - thread_index_in_simdgroup), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32), + IoVariable.SubgroupGtMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup + 1, 32 - thread_index_in_simdgroup - 1), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32), + IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32), + IoVariable.SubgroupLeMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup + 1 - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32), + IoVariable.SubgroupLtMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32), + IoVariable.ThreadKill => ("simd_is_helper_thread()", AggregateType.Bool), IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch), IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32), - IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32), IoVariable.VertexId => ("vertex_id", AggregateType.S32), // gl_VertexIndex does not have a direct equivalent in MSL IoVariable.VertexIndex => ("vertex_id", AggregateType.U32), diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs index 87512a961..bb5ea5f6f 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs @@ -44,7 +44,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc)); context.EnterScope(); - Declarations.DeclareLocals(context, function, stage); + Declarations.DeclareLocals(context, function, stage, isMainFunc); PrintBlock(context, function.MainBlock, isMainFunc); @@ -63,15 +63,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl ShaderStage stage, bool isMainFunc = false) { - int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount; + int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0); string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length]; // All non-main functions need to be able to access the support_buffer as well if (!isMainFunc) { - args[0] = "FragmentIn in"; - args[1] = "constant Struct_support_buffer* support_buffer"; + if (stage != ShaderStage.Compute) + { + args[0] = stage == ShaderStage.Vertex ? "VertexIn in" : "FragmentIn in"; + args[1] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer"; + } + else + { + args[0] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer"; + } } int argIndex = additionalArgCount; @@ -141,13 +148,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl foreach (var constantBuffer in context.Properties.ConstantBuffers.Values) { - args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray(); + args = args.Append($"constant {DefaultNames.StructPrefix}_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray(); } foreach (var storageBuffers in context.Properties.StorageBuffers.Values) { // Offset the binding by 15 to avoid clashing with the constant buffers - args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray(); + args = args.Append($"device {DefaultNames.StructPrefix}_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray(); } foreach (var texture in context.Properties.Textures.Values) @@ -162,7 +169,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl } } - return $"{funcKeyword} {returnType} {funcName ?? function.Name}({string.Join(", ", args)})"; + var funcPrefix = $"{funcKeyword} {returnType} {funcName ?? function.Name}("; + var indent = new string(' ', funcPrefix.Length); + + return $"{funcPrefix}{string.Join($", \n{indent}", args)})"; } private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction) diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs index f086e7436..63ecbc0aa 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs @@ -10,25 +10,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public static bool TryFormat(int value, AggregateType dstType, out string formatted) { - if (dstType == AggregateType.FP32) + switch (dstType) { - return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted); - } - else if (dstType == AggregateType.S32) - { - formatted = FormatInt(value); - } - else if (dstType == AggregateType.U32) - { - formatted = FormatUint((uint)value); - } - else if (dstType == AggregateType.Bool) - { - formatted = value != 0 ? "true" : "false"; - } - else - { - throw new ArgumentException($"Invalid variable type \"{dstType}\"."); + case AggregateType.FP32: + return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted); + case AggregateType.S32: + formatted = FormatInt(value); + break; + case AggregateType.U32: + formatted = FormatUint((uint)value); + break; + case AggregateType.Bool: + formatted = value != 0 ? "true" : "false"; + break; + default: + throw new ArgumentException($"Invalid variable type \"{dstType}\"."); } return true; @@ -65,18 +61,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public static string FormatInt(int value, AggregateType dstType) { - if (dstType == AggregateType.S32) + return dstType switch { - return FormatInt(value); - } - else if (dstType == AggregateType.U32) - { - return FormatUint((uint)value); - } - else - { - throw new ArgumentException($"Invalid variable type \"{dstType}\"."); - } + AggregateType.S32 => FormatInt(value), + AggregateType.U32 => FormatUint((uint)value), + _ => throw new ArgumentException($"Invalid variable type \"{dstType}\".") + }; } public static string FormatInt(int value)