diff --git a/Packages/com.vrchat.UdonSharp/Editor/Compiler/Binder/BoundNodes/BoundSwitchStatement.cs b/Packages/com.vrchat.UdonSharp/Editor/Compiler/Binder/BoundNodes/BoundSwitchStatement.cs index ffd1d463..c0f20869 100644 --- a/Packages/com.vrchat.UdonSharp/Editor/Compiler/Binder/BoundNodes/BoundSwitchStatement.cs +++ b/Packages/com.vrchat.UdonSharp/Editor/Compiler/Binder/BoundNodes/BoundSwitchStatement.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using Microsoft.CodeAnalysis; using UdonSharp.Compiler.Assembly; using UdonSharp.Compiler.Emit; @@ -128,51 +129,40 @@ private void EmitJumpTableSwitchStatement(EmitContext context) JumpLabel exitLabel = context.PushBreakLabel(); JumpLabel defaultJump = context.Module.CreateLabel(); - int maxValue = 0; - foreach (var switchSection in SwitchSections) - { - foreach (var expression in switchSection.Item1) - maxValue = Math.Max(maxValue, Convert.ToInt32(expression.ConstantValue.Value)); - } + Value labelTable = context.CreateGlobalInternalValue(expressionValue.UdonType.MakeArrayType(context)); - Value greaterThanZeroCondition = context.EmitValue(BoundInvocationExpression.CreateBoundInvocation( + Value jumpAddressIndex = context.EmitValue(BoundInvocationExpression.CreateBoundInvocation( context, SyntaxNode, - new ExternSynthesizedOperatorSymbol(BuiltinOperatorType.GreaterThanOrEqual, - expressionValue.UdonType, context), null, - new[] - { - BoundAccessExpression.BindAccess(expressionValue), - BoundAccessExpression.BindAccess(context.GetConstantValue(expressionValue.UdonType, - Convert.ChangeType(0, expressionValue.UdonType.SystemType))) - })); - - context.Module.AddJumpIfFalse(defaultJump, greaterThanZeroCondition); - - Value lessThanMaxCondition = context.EmitValue(BoundInvocationExpression.CreateBoundInvocation( + context.GetTypeSymbol(typeof(Array)).GetMember(nameof(Array.IndexOf), context), null, + new[] { BoundAccessExpression.BindAccess(labelTable), BoundAccessExpression.BindAccess(expressionValue) })); + + Value condition = context.EmitValue(BoundInvocationExpression.CreateBoundInvocation( context, SyntaxNode, - new ExternSynthesizedOperatorSymbol(BuiltinOperatorType.LessThanOrEqual, - expressionValue.UdonType, context), null, + new ExternSynthesizedOperatorSymbol(BuiltinOperatorType.Inequality, + context.GetTypeSymbol(SpecialType.System_Int32), context), null, new[] { - BoundAccessExpression.BindAccess(expressionValue), - BoundAccessExpression.BindAccess(context.GetConstantValue(expressionValue.UdonType, - Convert.ChangeType(maxValue, expressionValue.UdonType.SystemType))) + BoundAccessExpression.BindAccess(jumpAddressIndex), + BoundAccessExpression.BindAccess(context.GetConstantValue(context.GetTypeSymbol(SpecialType.System_Int32), -1)) })); - - context.Module.AddJumpIfFalse(defaultJump, lessThanMaxCondition); - Value convertedValue = context.CastValue(expressionValue, context.GetTypeSymbol(SpecialType.System_Int32), true); + context.Module.AddJumpIfFalse(defaultJump, condition); + Value jumpTable = context.CreateGlobalInternalValue(context.GetTypeSymbol(SpecialType.System_UInt32).MakeArrayType(context)); Value jumpAddress = context.EmitValue(BoundAccessExpression.BindElementAccess(context, SyntaxNode, BoundAccessExpression.BindAccess(jumpTable), - new BoundExpression[] { BoundAccessExpression.BindAccess(convertedValue) })); + new BoundExpression[] { BoundAccessExpression.BindAccess(jumpAddressIndex) })); context.Module.AddJumpIndrect(jumpAddress); - uint[] jumpTableArr = new uint[maxValue + 1]; + int labelCount = SwitchSections.Select(x => x.Item1.Count).Sum(); + Array labelTableArr = Array.CreateInstance(expressionValue.UdonType.SystemType, labelCount); + uint[] jumpTableArr = new uint[labelCount]; using (context.OpenBlockScope()) { + int labelIdx = 0; + for (int i = 0; i < SwitchSections.Count; ++i) { var switchSection = SwitchSections[i]; @@ -184,8 +174,9 @@ private void EmitJumpTableSwitchStatement(EmitContext context) foreach (BoundExpression labelExpression in switchSection.Item1) { - int labelIdx = Convert.ToInt32(labelExpression.ConstantValue.Value); + labelTableArr.SetValue(labelExpression.ConstantValue.Value, labelIdx); jumpTableArr[labelIdx] = currentPos.Address; + labelIdx++; } foreach (BoundStatement statement in switchSection.Item2) @@ -199,42 +190,16 @@ private void EmitJumpTableSwitchStatement(EmitContext context) context.Module.LabelJump(defaultJump); context.Module.LabelJump(exitLabel); - - context.PopBreakLabel(); - for (int i = 0; i < jumpTableArr.Length; ++i) - { - if (jumpTableArr[i] == 0) - jumpTableArr[i] = defaultJump.Address; - } + context.PopBreakLabel(); + labelTable.DefaultValue = labelTableArr; jumpTable.DefaultValue = jumpTableArr; } - private const int JUMP_TABLE_MAX = 256; - private bool IsJumpTableCandidate() { - if (!UdonSharpUtils.IsIntegerType(SwitchExpression.ValueType.UdonType.SystemType)) - return false; - - int labelCount = 0; - - foreach (var switchSection in SwitchSections) - { - foreach (var expression in switchSection.Item1) - { - labelCount++; - - if (expression.ConstantValue.Value is ulong ulongVal && (ulongVal > JUMP_TABLE_MAX)) - return false; - - long intVal = Convert.ToInt64(expression.ConstantValue.Value); - - if (intVal > JUMP_TABLE_MAX || intVal < 0) - return false; - } - } + int labelCount = SwitchSections.Select(x => x.Item1.Count).Sum(); if (labelCount < 4) return false;