Instruction.Barrier
Whoops Fix inline functions in compute stage Fix regression Declare SharedMemories + Only Declare Memories on Main Func Lowecase struct Avoid magic strings Make function signatures readable Change how unsized arrays are indexed Use string builder Fix shuffle instructions Cleanup NumberFormater Bunch of Subgroup I/O Vars Will probably need further refinement Fix point_coord type Fix support buffer declaration Fix point_coord
This commit is contained in:
parent
03161d8048
commit
97a36298fa
10 changed files with 110 additions and 77 deletions
|
@ -9,7 +9,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
public const string Tab = " ";
|
||||
|
||||
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
|
||||
public const int AdditionalArgCount = 2;
|
||||
public const int AdditionalArgCount = 1;
|
||||
|
||||
public StructuredFunction CurrentFunction { get; set; }
|
||||
|
||||
|
|
|
@ -64,9 +64,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined;
|
||||
}
|
||||
|
||||
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage)
|
||||
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false)
|
||||
{
|
||||
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
|
||||
if (isMainFunc)
|
||||
{
|
||||
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
|
||||
DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true);
|
||||
}
|
||||
|
||||
switch (stage)
|
||||
{
|
||||
case ShaderStage.Vertex:
|
||||
|
@ -112,6 +117,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
|
||||
private static void DeclareMemories(CodeGenContext context, IEnumerable<MemoryDefinition> memories, bool isShared)
|
||||
{
|
||||
string prefix = isShared ? "threadgroup " : string.Empty;
|
||||
|
||||
foreach (var memory in memories)
|
||||
{
|
||||
string arraySize = "";
|
||||
|
@ -120,7 +127,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
arraySize = $"[{memory.ArrayLength}]";
|
||||
}
|
||||
var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array);
|
||||
context.AppendLine($"{typeName} {memory.Name}{arraySize};");
|
||||
context.AppendLine($"{prefix}{typeName} {memory.Name}{arraySize};");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,23 +135,28 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
{
|
||||
foreach (BufferDefinition buffer in buffers)
|
||||
{
|
||||
context.AppendLine($"struct Struct_{buffer.Name}");
|
||||
context.AppendLine($"struct {DefaultNames.StructPrefix}_{buffer.Name}");
|
||||
context.EnterScope();
|
||||
|
||||
foreach (StructureField field in buffer.Type.Fields)
|
||||
{
|
||||
if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0)
|
||||
{
|
||||
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
|
||||
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
|
||||
string arraySuffix = "";
|
||||
|
||||
context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];");
|
||||
}
|
||||
else
|
||||
if (field.Type.HasFlag(AggregateType.Array))
|
||||
{
|
||||
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
|
||||
|
||||
context.AppendLine($"{typeName} {field.Name};");
|
||||
if (field.ArrayLength > 0)
|
||||
{
|
||||
arraySuffix = $"[{field.ArrayLength}]";
|
||||
}
|
||||
else
|
||||
{
|
||||
// Probably UB, but this is the approach that MVK takes
|
||||
arraySuffix = "[1]";
|
||||
}
|
||||
}
|
||||
|
||||
context.AppendLine($"{typeName} {field.Name}{arraySuffix};");
|
||||
}
|
||||
|
||||
context.LeaveScope(";");
|
||||
|
@ -191,6 +203,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
IoVariable.GlobalId => "uint3",
|
||||
IoVariable.VertexId => "uint",
|
||||
IoVariable.VertexIndex => "uint",
|
||||
IoVariable.PointCoord => "float2",
|
||||
_ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false))
|
||||
};
|
||||
string name = ioDefinition.IoVariable switch
|
||||
|
@ -199,6 +212,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
IoVariable.GlobalId => "global_id",
|
||||
IoVariable.VertexId => "vertex_id",
|
||||
IoVariable.VertexIndex => "vertex_index",
|
||||
IoVariable.PointCoord => "point_coord",
|
||||
_ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}"
|
||||
};
|
||||
string suffix = ioDefinition.IoVariable switch
|
||||
|
@ -208,6 +222,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
IoVariable.VertexId => "[[vertex_id]]",
|
||||
// TODO: Avoid potential redeclaration
|
||||
IoVariable.VertexIndex => "[[vertex_id]]",
|
||||
IoVariable.PointCoord => "[[point_coord]]",
|
||||
IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]",
|
||||
_ => ""
|
||||
};
|
||||
|
|
|
@ -8,6 +8,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
public const string IAttributePrefix = "inAttr";
|
||||
public const string OAttributePrefix = "outAttr";
|
||||
|
||||
public const string StructPrefix = "struct";
|
||||
|
||||
public const string ArgumentNamePrefix = "a";
|
||||
|
||||
public const string UndefinedName = "0";
|
||||
|
|
|
@ -2,7 +2,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation;
|
|||
using Ryujinx.Graphics.Shader.StructuredIr;
|
||||
using Ryujinx.Graphics.Shader.Translation;
|
||||
using System;
|
||||
|
||||
using System.Text;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory;
|
||||
|
@ -39,11 +39,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
|
||||
int arity = (int)(info.Type & InstType.ArityMask);
|
||||
|
||||
string args = string.Empty;
|
||||
StringBuilder builder = new();
|
||||
|
||||
if (atomic)
|
||||
if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory))
|
||||
{
|
||||
// Hell
|
||||
builder.Append(GenerateLoadOrStore(context, operation, isStore: false));
|
||||
|
||||
AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32
|
||||
? AggregateType.S32
|
||||
: AggregateType.U32;
|
||||
|
||||
for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++)
|
||||
{
|
||||
builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -51,16 +60,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
{
|
||||
if (argIndex != 0)
|
||||
{
|
||||
args += ", ";
|
||||
builder.Append(", ");
|
||||
}
|
||||
|
||||
AggregateType dstType = GetSrcVarType(inst, argIndex);
|
||||
|
||||
args += GetSourceExpr(context, operation.GetSource(argIndex), dstType);
|
||||
builder.Append(GetSourceExpr(context, operation.GetSource(argIndex), dstType));
|
||||
}
|
||||
}
|
||||
|
||||
return info.OpName + '(' + args + ')';
|
||||
return $"{info.OpName}({builder})";
|
||||
}
|
||||
else if ((info.Type & InstType.Op) != 0)
|
||||
{
|
||||
|
@ -110,7 +119,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
switch (inst & Instruction.Mask)
|
||||
{
|
||||
case Instruction.Barrier:
|
||||
return "|| BARRIER ||";
|
||||
return "threadgroup_barrier(mem_flags::mem_threadgroup)";
|
||||
case Instruction.Call:
|
||||
return Call(context, operation);
|
||||
case Instruction.FSIBegin:
|
||||
|
|
|
@ -13,13 +13,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
var functon = context.GetFunction(funcId.Value);
|
||||
|
||||
int argCount = operation.SourcesCount - 1;
|
||||
string[] args = new string[argCount + CodeGenContext.AdditionalArgCount];
|
||||
int additionalArgCount = CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
|
||||
|
||||
string[] args = new string[argCount + additionalArgCount];
|
||||
|
||||
// Additional arguments
|
||||
args[0] = "in";
|
||||
args[1] = "support_buffer";
|
||||
if (context.Definitions.Stage != ShaderStage.Compute)
|
||||
{
|
||||
args[0] = "in";
|
||||
args[1] = "support_buffer";
|
||||
}
|
||||
else
|
||||
{
|
||||
args[0] = "support_buffer";
|
||||
}
|
||||
|
||||
int argIndex = CodeGenContext.AdditionalArgCount;
|
||||
int argIndex = additionalArgCount;
|
||||
for (int i = 0; i < argCount; i++)
|
||||
{
|
||||
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
|
||||
|
|
|
@ -109,10 +109,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3);
|
||||
Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3);
|
||||
Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3);
|
||||
Add(Instruction.Shuffle, InstType.CallQuaternary, "simd_shuffle");
|
||||
Add(Instruction.ShuffleDown, InstType.CallQuaternary, "simd_shuffle_down");
|
||||
Add(Instruction.ShuffleUp, InstType.CallQuaternary, "simd_shuffle_up");
|
||||
Add(Instruction.ShuffleXor, InstType.CallQuaternary, "simd_shuffle_xor");
|
||||
Add(Instruction.Shuffle, InstType.CallBinary, "simd_shuffle");
|
||||
Add(Instruction.ShuffleDown, InstType.CallBinary, "simd_shuffle_down");
|
||||
Add(Instruction.ShuffleUp, InstType.CallBinary, "simd_shuffle_up");
|
||||
Add(Instruction.ShuffleXor, InstType.CallBinary, "simd_shuffle_xor");
|
||||
Add(Instruction.Sine, InstType.CallUnary, "sin");
|
||||
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
|
||||
Add(Instruction.Store, InstType.Special);
|
||||
|
|
|
@ -47,15 +47,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
|
||||
StructureField field = buffer.Type.Fields[fieldIndex.Value];
|
||||
varName = buffer.Name;
|
||||
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0)
|
||||
{
|
||||
// Unsized array, the buffer is indexed instead of the field
|
||||
fieldName = "." + field.Name;
|
||||
}
|
||||
else
|
||||
{
|
||||
varName += "->" + field.Name;
|
||||
}
|
||||
varName += "->" + field.Name;
|
||||
varType = field.Type;
|
||||
break;
|
||||
|
||||
|
|
|
@ -27,13 +27,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
|
||||
IoVariable.InstanceId => ("instance_id", AggregateType.S32),
|
||||
IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32),
|
||||
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2),
|
||||
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32),
|
||||
IoVariable.PointSize => ("out.point_size", AggregateType.FP32),
|
||||
IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32),
|
||||
IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32),
|
||||
IoVariable.SubgroupEqMask => ("thread_index_in_simdgroup >= 32 ? uint4(0, (1 << (thread_index_in_simdgroup - 32)), uint2(0)) : uint4(1 << thread_index_in_simdgroup, uint3(0))", AggregateType.Vector4 | AggregateType.U32),
|
||||
IoVariable.SubgroupGeMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup, 32 - thread_index_in_simdgroup), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
|
||||
IoVariable.SubgroupGtMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup + 1, 32 - thread_index_in_simdgroup - 1), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
|
||||
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
|
||||
IoVariable.SubgroupLeMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup + 1 - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
|
||||
IoVariable.SubgroupLtMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
|
||||
IoVariable.ThreadKill => ("simd_is_helper_thread()", AggregateType.Bool),
|
||||
IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch),
|
||||
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
|
||||
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
|
||||
IoVariable.VertexId => ("vertex_id", AggregateType.S32),
|
||||
// gl_VertexIndex does not have a direct equivalent in MSL
|
||||
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),
|
||||
|
|
|
@ -44,7 +44,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc));
|
||||
context.EnterScope();
|
||||
|
||||
Declarations.DeclareLocals(context, function, stage);
|
||||
Declarations.DeclareLocals(context, function, stage, isMainFunc);
|
||||
|
||||
PrintBlock(context, function.MainBlock, isMainFunc);
|
||||
|
||||
|
@ -63,15 +63,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
ShaderStage stage,
|
||||
bool isMainFunc = false)
|
||||
{
|
||||
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount;
|
||||
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
|
||||
|
||||
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
|
||||
|
||||
// All non-main functions need to be able to access the support_buffer as well
|
||||
if (!isMainFunc)
|
||||
{
|
||||
args[0] = "FragmentIn in";
|
||||
args[1] = "constant Struct_support_buffer* support_buffer";
|
||||
if (stage != ShaderStage.Compute)
|
||||
{
|
||||
args[0] = stage == ShaderStage.Vertex ? "VertexIn in" : "FragmentIn in";
|
||||
args[1] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
|
||||
}
|
||||
else
|
||||
{
|
||||
args[0] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
|
||||
}
|
||||
}
|
||||
|
||||
int argIndex = additionalArgCount;
|
||||
|
@ -141,13 +148,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
|
||||
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
|
||||
{
|
||||
args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
|
||||
args = args.Append($"constant {DefaultNames.StructPrefix}_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
|
||||
}
|
||||
|
||||
foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
|
||||
{
|
||||
// Offset the binding by 15 to avoid clashing with the constant buffers
|
||||
args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
|
||||
args = args.Append($"device {DefaultNames.StructPrefix}_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
|
||||
}
|
||||
|
||||
foreach (var texture in context.Properties.Textures.Values)
|
||||
|
@ -162,7 +169,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
}
|
||||
}
|
||||
|
||||
return $"{funcKeyword} {returnType} {funcName ?? function.Name}({string.Join(", ", args)})";
|
||||
var funcPrefix = $"{funcKeyword} {returnType} {funcName ?? function.Name}(";
|
||||
var indent = new string(' ', funcPrefix.Length);
|
||||
|
||||
return $"{funcPrefix}{string.Join($", \n{indent}", args)})";
|
||||
}
|
||||
|
||||
private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction)
|
||||
|
|
|
@ -10,25 +10,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
|
||||
public static bool TryFormat(int value, AggregateType dstType, out string formatted)
|
||||
{
|
||||
if (dstType == AggregateType.FP32)
|
||||
switch (dstType)
|
||||
{
|
||||
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
|
||||
}
|
||||
else if (dstType == AggregateType.S32)
|
||||
{
|
||||
formatted = FormatInt(value);
|
||||
}
|
||||
else if (dstType == AggregateType.U32)
|
||||
{
|
||||
formatted = FormatUint((uint)value);
|
||||
}
|
||||
else if (dstType == AggregateType.Bool)
|
||||
{
|
||||
formatted = value != 0 ? "true" : "false";
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
|
||||
case AggregateType.FP32:
|
||||
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
|
||||
case AggregateType.S32:
|
||||
formatted = FormatInt(value);
|
||||
break;
|
||||
case AggregateType.U32:
|
||||
formatted = FormatUint((uint)value);
|
||||
break;
|
||||
case AggregateType.Bool:
|
||||
formatted = value != 0 ? "true" : "false";
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -65,18 +61,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
|
||||
public static string FormatInt(int value, AggregateType dstType)
|
||||
{
|
||||
if (dstType == AggregateType.S32)
|
||||
return dstType switch
|
||||
{
|
||||
return FormatInt(value);
|
||||
}
|
||||
else if (dstType == AggregateType.U32)
|
||||
{
|
||||
return FormatUint((uint)value);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
|
||||
}
|
||||
AggregateType.S32 => FormatInt(value),
|
||||
AggregateType.U32 => FormatUint((uint)value),
|
||||
_ => throw new ArgumentException($"Invalid variable type \"{dstType}\".")
|
||||
};
|
||||
}
|
||||
|
||||
public static string FormatInt(int value)
|
||||
|
|
Loading…
Reference in a new issue