Instruction.Barrier

Whoops

Fix inline functions in compute stage

Fix regression

Declare SharedMemories + Only Declare Memories on Main Func

Lowecase struct

Avoid magic strings

Make function signatures readable

Change how unsized arrays are indexed

Use string builder

Fix shuffle instructions

Cleanup NumberFormater

Bunch of Subgroup I/O Vars

Will probably need further refinement

Fix point_coord type

Fix support buffer declaration

Fix point_coord
This commit is contained in:
Isaac Marovitz 2024-06-21 10:31:21 +01:00 committed by Isaac Marovitz
parent 03161d8048
commit 97a36298fa
10 changed files with 110 additions and 77 deletions

View file

@ -9,7 +9,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public const string Tab = " "; public const string Tab = " ";
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer) // The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
public const int AdditionalArgCount = 2; public const int AdditionalArgCount = 1;
public StructuredFunction CurrentFunction { get; set; } public StructuredFunction CurrentFunction { get; set; }

View file

@ -64,9 +64,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined; return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined;
} }
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage) public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false)
{ {
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false); if (isMainFunc)
{
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true);
}
switch (stage) switch (stage)
{ {
case ShaderStage.Vertex: case ShaderStage.Vertex:
@ -112,6 +117,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
private static void DeclareMemories(CodeGenContext context, IEnumerable<MemoryDefinition> memories, bool isShared) private static void DeclareMemories(CodeGenContext context, IEnumerable<MemoryDefinition> memories, bool isShared)
{ {
string prefix = isShared ? "threadgroup " : string.Empty;
foreach (var memory in memories) foreach (var memory in memories)
{ {
string arraySize = ""; string arraySize = "";
@ -120,7 +127,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
arraySize = $"[{memory.ArrayLength}]"; arraySize = $"[{memory.ArrayLength}]";
} }
var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array); var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {memory.Name}{arraySize};"); context.AppendLine($"{prefix}{typeName} {memory.Name}{arraySize};");
} }
} }
@ -128,23 +135,28 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
foreach (BufferDefinition buffer in buffers) foreach (BufferDefinition buffer in buffers)
{ {
context.AppendLine($"struct Struct_{buffer.Name}"); context.AppendLine($"struct {DefaultNames.StructPrefix}_{buffer.Name}");
context.EnterScope(); context.EnterScope();
foreach (StructureField field in buffer.Type.Fields) foreach (StructureField field in buffer.Type.Fields)
{ {
if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0) string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
{ string arraySuffix = "";
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];"); if (field.Type.HasFlag(AggregateType.Array))
}
else
{ {
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array); if (field.ArrayLength > 0)
{
context.AppendLine($"{typeName} {field.Name};"); arraySuffix = $"[{field.ArrayLength}]";
}
else
{
// Probably UB, but this is the approach that MVK takes
arraySuffix = "[1]";
}
} }
context.AppendLine($"{typeName} {field.Name}{arraySuffix};");
} }
context.LeaveScope(";"); context.LeaveScope(";");
@ -191,6 +203,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.GlobalId => "uint3", IoVariable.GlobalId => "uint3",
IoVariable.VertexId => "uint", IoVariable.VertexId => "uint",
IoVariable.VertexIndex => "uint", IoVariable.VertexIndex => "uint",
IoVariable.PointCoord => "float2",
_ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false)) _ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false))
}; };
string name = ioDefinition.IoVariable switch string name = ioDefinition.IoVariable switch
@ -199,6 +212,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.GlobalId => "global_id", IoVariable.GlobalId => "global_id",
IoVariable.VertexId => "vertex_id", IoVariable.VertexId => "vertex_id",
IoVariable.VertexIndex => "vertex_index", IoVariable.VertexIndex => "vertex_index",
IoVariable.PointCoord => "point_coord",
_ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}" _ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}"
}; };
string suffix = ioDefinition.IoVariable switch string suffix = ioDefinition.IoVariable switch
@ -208,6 +222,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.VertexId => "[[vertex_id]]", IoVariable.VertexId => "[[vertex_id]]",
// TODO: Avoid potential redeclaration // TODO: Avoid potential redeclaration
IoVariable.VertexIndex => "[[vertex_id]]", IoVariable.VertexIndex => "[[vertex_id]]",
IoVariable.PointCoord => "[[point_coord]]",
IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]", IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]",
_ => "" _ => ""
}; };

View file

@ -8,6 +8,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public const string IAttributePrefix = "inAttr"; public const string IAttributePrefix = "inAttr";
public const string OAttributePrefix = "outAttr"; public const string OAttributePrefix = "outAttr";
public const string StructPrefix = "struct";
public const string ArgumentNamePrefix = "a"; public const string ArgumentNamePrefix = "a";
public const string UndefinedName = "0"; public const string UndefinedName = "0";

View file

@ -2,7 +2,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation; using Ryujinx.Graphics.Shader.Translation;
using System; using System;
using System.Text;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory;
@ -39,11 +39,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
int arity = (int)(info.Type & InstType.ArityMask); int arity = (int)(info.Type & InstType.ArityMask);
string args = string.Empty; StringBuilder builder = new();
if (atomic) if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory))
{ {
// Hell builder.Append(GenerateLoadOrStore(context, operation, isStore: false));
AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32
? AggregateType.S32
: AggregateType.U32;
for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++)
{
builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}");
}
} }
else else
{ {
@ -51,16 +60,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{ {
if (argIndex != 0) if (argIndex != 0)
{ {
args += ", "; builder.Append(", ");
} }
AggregateType dstType = GetSrcVarType(inst, argIndex); AggregateType dstType = GetSrcVarType(inst, argIndex);
args += GetSourceExpr(context, operation.GetSource(argIndex), dstType); builder.Append(GetSourceExpr(context, operation.GetSource(argIndex), dstType));
} }
} }
return info.OpName + '(' + args + ')'; return $"{info.OpName}({builder})";
} }
else if ((info.Type & InstType.Op) != 0) else if ((info.Type & InstType.Op) != 0)
{ {
@ -110,7 +119,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
switch (inst & Instruction.Mask) switch (inst & Instruction.Mask)
{ {
case Instruction.Barrier: case Instruction.Barrier:
return "|| BARRIER ||"; return "threadgroup_barrier(mem_flags::mem_threadgroup)";
case Instruction.Call: case Instruction.Call:
return Call(context, operation); return Call(context, operation);
case Instruction.FSIBegin: case Instruction.FSIBegin:

View file

@ -13,13 +13,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
var functon = context.GetFunction(funcId.Value); var functon = context.GetFunction(funcId.Value);
int argCount = operation.SourcesCount - 1; int argCount = operation.SourcesCount - 1;
string[] args = new string[argCount + CodeGenContext.AdditionalArgCount]; int additionalArgCount = CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
string[] args = new string[argCount + additionalArgCount];
// Additional arguments // Additional arguments
args[0] = "in"; if (context.Definitions.Stage != ShaderStage.Compute)
args[1] = "support_buffer"; {
args[0] = "in";
args[1] = "support_buffer";
}
else
{
args[0] = "support_buffer";
}
int argIndex = CodeGenContext.AdditionalArgCount; int argIndex = additionalArgCount;
for (int i = 0; i < argCount; i++) for (int i = 0; i < argCount; i++)
{ {
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i)); args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));

View file

@ -109,10 +109,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3); Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3);
Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3); Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3);
Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3); Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3);
Add(Instruction.Shuffle, InstType.CallQuaternary, "simd_shuffle"); Add(Instruction.Shuffle, InstType.CallBinary, "simd_shuffle");
Add(Instruction.ShuffleDown, InstType.CallQuaternary, "simd_shuffle_down"); Add(Instruction.ShuffleDown, InstType.CallBinary, "simd_shuffle_down");
Add(Instruction.ShuffleUp, InstType.CallQuaternary, "simd_shuffle_up"); Add(Instruction.ShuffleUp, InstType.CallBinary, "simd_shuffle_up");
Add(Instruction.ShuffleXor, InstType.CallQuaternary, "simd_shuffle_xor"); Add(Instruction.ShuffleXor, InstType.CallBinary, "simd_shuffle_xor");
Add(Instruction.Sine, InstType.CallUnary, "sin"); Add(Instruction.Sine, InstType.CallUnary, "sin");
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt"); Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
Add(Instruction.Store, InstType.Special); Add(Instruction.Store, InstType.Special);

View file

@ -47,15 +47,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
StructureField field = buffer.Type.Fields[fieldIndex.Value]; StructureField field = buffer.Type.Fields[fieldIndex.Value];
varName = buffer.Name; varName = buffer.Name;
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0) varName += "->" + field.Name;
{
// Unsized array, the buffer is indexed instead of the field
fieldName = "." + field.Name;
}
else
{
varName += "->" + field.Name;
}
varType = field.Type; varType = field.Type;
break; break;

View file

@ -27,13 +27,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32), IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
IoVariable.InstanceId => ("instance_id", AggregateType.S32), IoVariable.InstanceId => ("instance_id", AggregateType.S32),
IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32), IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32),
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2), IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32),
IoVariable.PointSize => ("out.point_size", AggregateType.FP32), IoVariable.PointSize => ("out.point_size", AggregateType.FP32),
IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32), IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32),
IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32), IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32),
IoVariable.SubgroupEqMask => ("thread_index_in_simdgroup >= 32 ? uint4(0, (1 << (thread_index_in_simdgroup - 32)), uint2(0)) : uint4(1 << thread_index_in_simdgroup, uint3(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupGeMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup, 32 - thread_index_in_simdgroup), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupGtMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup + 1, 32 - thread_index_in_simdgroup - 1), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
IoVariable.SubgroupLeMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup + 1 - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupLtMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.ThreadKill => ("simd_is_helper_thread()", AggregateType.Bool),
IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch), IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch),
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32), IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
IoVariable.VertexId => ("vertex_id", AggregateType.S32), IoVariable.VertexId => ("vertex_id", AggregateType.S32),
// gl_VertexIndex does not have a direct equivalent in MSL // gl_VertexIndex does not have a direct equivalent in MSL
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32), IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),

View file

@ -44,7 +44,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc)); context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc));
context.EnterScope(); context.EnterScope();
Declarations.DeclareLocals(context, function, stage); Declarations.DeclareLocals(context, function, stage, isMainFunc);
PrintBlock(context, function.MainBlock, isMainFunc); PrintBlock(context, function.MainBlock, isMainFunc);
@ -63,15 +63,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
ShaderStage stage, ShaderStage stage,
bool isMainFunc = false) bool isMainFunc = false)
{ {
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount; int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length]; string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
// All non-main functions need to be able to access the support_buffer as well // All non-main functions need to be able to access the support_buffer as well
if (!isMainFunc) if (!isMainFunc)
{ {
args[0] = "FragmentIn in"; if (stage != ShaderStage.Compute)
args[1] = "constant Struct_support_buffer* support_buffer"; {
args[0] = stage == ShaderStage.Vertex ? "VertexIn in" : "FragmentIn in";
args[1] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
}
else
{
args[0] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
}
} }
int argIndex = additionalArgCount; int argIndex = additionalArgCount;
@ -141,13 +148,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values) foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
{ {
args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray(); args = args.Append($"constant {DefaultNames.StructPrefix}_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
} }
foreach (var storageBuffers in context.Properties.StorageBuffers.Values) foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
{ {
// Offset the binding by 15 to avoid clashing with the constant buffers // Offset the binding by 15 to avoid clashing with the constant buffers
args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray(); args = args.Append($"device {DefaultNames.StructPrefix}_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
} }
foreach (var texture in context.Properties.Textures.Values) foreach (var texture in context.Properties.Textures.Values)
@ -162,7 +169,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
} }
} }
return $"{funcKeyword} {returnType} {funcName ?? function.Name}({string.Join(", ", args)})"; var funcPrefix = $"{funcKeyword} {returnType} {funcName ?? function.Name}(";
var indent = new string(' ', funcPrefix.Length);
return $"{funcPrefix}{string.Join($", \n{indent}", args)})";
} }
private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction) private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction)

View file

@ -10,25 +10,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public static bool TryFormat(int value, AggregateType dstType, out string formatted) public static bool TryFormat(int value, AggregateType dstType, out string formatted)
{ {
if (dstType == AggregateType.FP32) switch (dstType)
{ {
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted); case AggregateType.FP32:
} return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
else if (dstType == AggregateType.S32) case AggregateType.S32:
{ formatted = FormatInt(value);
formatted = FormatInt(value); break;
} case AggregateType.U32:
else if (dstType == AggregateType.U32) formatted = FormatUint((uint)value);
{ break;
formatted = FormatUint((uint)value); case AggregateType.Bool:
} formatted = value != 0 ? "true" : "false";
else if (dstType == AggregateType.Bool) break;
{ default:
formatted = value != 0 ? "true" : "false"; throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
} }
return true; return true;
@ -65,18 +61,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public static string FormatInt(int value, AggregateType dstType) public static string FormatInt(int value, AggregateType dstType)
{ {
if (dstType == AggregateType.S32) return dstType switch
{ {
return FormatInt(value); AggregateType.S32 => FormatInt(value),
} AggregateType.U32 => FormatUint((uint)value),
else if (dstType == AggregateType.U32) _ => throw new ArgumentException($"Invalid variable type \"{dstType}\".")
{ };
return FormatUint((uint)value);
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
} }
public static string FormatInt(int value) public static string FormatInt(int value)