diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs index 60729ac60..0b6aadd03 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs @@ -70,16 +70,26 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl { DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false); DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true); - } - switch (stage) - { - case ShaderStage.Vertex: - context.AppendLine("VertexOut out;"); - break; - case ShaderStage.Fragment: - context.AppendLine("FragmentOut out;"); - break; + switch (stage) + { + case ShaderStage.Vertex: + context.AppendLine("VertexOut out;"); + // TODO: Only add if necessary + context.AppendLine("uint instance_index = instance_id + base_instance;"); + break; + case ShaderStage.Fragment: + context.AppendLine("FragmentOut out;"); + break; + } + + // TODO: Only add if necessary + if (stage != ShaderStage.Compute) + { + // MSL does not give us access to [[thread_index_in_simdgroup]] + // outside compute. But we may still need to provide this value in frag/vert. + context.AppendLine("uint thread_index_in_simdgroup = simd_prefix_exclusive_sum(1);"); + } } foreach (AstOperand decl in function.Locals) @@ -90,15 +100,18 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl } } - public static string GetVarTypeName(CodeGenContext context, AggregateType type) + public static string GetVarTypeName(CodeGenContext context, AggregateType type, bool atomic = false) { + var s32 = atomic ? "atomic_int" : "int"; + var u32 = atomic ? "atomic_uint" : "uint"; + return type switch { AggregateType.Void => "void", AggregateType.Bool => "bool", AggregateType.FP32 => "float", - AggregateType.S32 => "int", - AggregateType.U32 => "uint", + AggregateType.S32 => s32, + AggregateType.U32 => u32, AggregateType.Vector2 | AggregateType.Bool => "bool2", AggregateType.Vector2 | AggregateType.FP32 => "float2", AggregateType.Vector2 | AggregateType.S32 => "int2", diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs index 696564992..0bea4d1aa 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs @@ -3,6 +3,7 @@ using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.Translation; using System; using System.Text; +using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenBallot; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory; @@ -43,15 +44,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory)) { - builder.Append(GenerateLoadOrStore(context, operation, isStore: false)); - AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32 ? AggregateType.S32 : AggregateType.U32; + builder.Append($"(device {Declarations.GetVarTypeName(context, dstType, true)}*)&{GenerateLoadOrStore(context, operation, isStore: false)}"); + + for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++) { - builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}"); + builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}, memory_order_relaxed"); } } else @@ -118,6 +120,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions { switch (inst & Instruction.Mask) { + case Instruction.Ballot: + return Ballot(context, operation); case Instruction.Barrier: return "threadgroup_barrier(mem_flags::mem_threadgroup)"; case Instruction.Call: diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBallot.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBallot.cs new file mode 100644 index 000000000..1f53c74ed --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBallot.cs @@ -0,0 +1,21 @@ +using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; + +using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; +using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; + +namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions +{ + static class InstGenBallot + { + public static string Ballot(CodeGenContext context, AstOperation operation) + { + AggregateType dstType = GetSrcVarType(operation.Inst, 0); + + string arg = GetSourceExpr(context, operation.GetSource(0), dstType); + char component = "xyzw"[operation.Index]; + + return $"uint4(as_type((simd_vote::vote_t)simd_ballot({arg})), 0, 0).{component}"; + } + } +} diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs index 014d070ef..d230e2ed4 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs @@ -15,17 +15,17 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions _infoTable = new InstInfo[(int)Instruction.Count]; #pragma warning disable IDE0055 // Disable formatting - Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomic_add_explicit"); - Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomic_and_explicit"); + Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomic_fetch_add_explicit"); + Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomic_fetch_and_explicit"); Add(Instruction.AtomicCompareAndSwap, InstType.AtomicBinary, "atomic_compare_exchange_weak_explicit"); - Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "atomic_max_explicit"); - Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "atomic_min_explicit"); - Add(Instruction.AtomicOr, InstType.AtomicBinary, "atomic_or_explicit"); + Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "atomic_fetch_max_explicit"); + Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "atomic_fetch_min_explicit"); + Add(Instruction.AtomicOr, InstType.AtomicBinary, "atomic_fetch_or_explicit"); Add(Instruction.AtomicSwap, InstType.AtomicBinary, "atomic_exchange_explicit"); - Add(Instruction.AtomicXor, InstType.AtomicBinary, "atomic_xor_explicit"); + Add(Instruction.AtomicXor, InstType.AtomicBinary, "atomic_fetch_xor_explicit"); Add(Instruction.Absolute, InstType.CallUnary, "abs"); Add(Instruction.Add, InstType.OpBinaryCom, "+", 2); - Add(Instruction.Ballot, InstType.CallUnary, "simd_ballot"); + Add(Instruction.Ballot, InstType.Special); Add(Instruction.Barrier, InstType.Special); Add(Instruction.BitCount, InstType.CallUnary, "popcount"); Add(Instruction.BitfieldExtractS32, InstType.CallTernary, "extract_bits"); diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs index 1561271d0..f9d0a96d9 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs @@ -17,15 +17,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions { var returnValue = ioVariable switch { - IoVariable.BaseInstance => ("base_instance", AggregateType.S32), - IoVariable.BaseVertex => ("base_vertex", AggregateType.S32), + IoVariable.BaseInstance => ("base_instance", AggregateType.U32), + IoVariable.BaseVertex => ("base_vertex", AggregateType.U32), IoVariable.CtaId => ("threadgroup_position_in_grid", AggregateType.Vector3 | AggregateType.U32), IoVariable.ClipDistance => ("clip_distance", AggregateType.Array | AggregateType.FP32), IoVariable.FragmentOutputColor => ($"out.color{location}", AggregateType.Vector4 | AggregateType.FP32), IoVariable.FragmentOutputDepth => ("out.depth", AggregateType.FP32), IoVariable.FrontFacing => ("in.front_facing", AggregateType.Bool), IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32), - IoVariable.InstanceId => ("instance_id", AggregateType.S32), + IoVariable.InstanceId => ("instance_id", AggregateType.U32), + IoVariable.InstanceIndex => ("instance_index", AggregateType.U32), IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32), IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32), IoVariable.PointSize => ("out.point_size", AggregateType.FP32), diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs index bb5ea5f6f..a3e09d3cb 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs @@ -137,6 +137,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl { args = args.Append("uint vertex_id [[vertex_id]]").ToArray(); args = args.Append("uint instance_id [[instance_id]]").ToArray(); + args = args.Append("uint base_instance [[base_instance]]").ToArray(); + args = args.Append("uint base_vertex [[base_vertex]]").ToArray(); } else if (stage == ShaderStage.Compute) {