UnwrapVisitor.cs 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. namespace ExpressionKit.Unwrap
  2. {
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Linq.Expressions;
  7. using System.Reflection;
  8. /// <summary>
  9. /// Unwraps calls to <see cref="Wax.Expand{TParameter, TResult}(Expression{Func{TParameter, TResult}}, TParameter)"/>
  10. /// into the definition of the expression they call.
  11. /// </summary>
  12. internal class UnwrapVisitor : ExpressionVisitor
  13. {
  14. private static Type DeclaringType = typeof(Wax);
  15. private static Type UnwrappableType = typeof(UnwrappableMethodAttribute);
  16. private static Type ConstantValueType = typeof(ConstantValueMethodAttribute);
  17. private static bool MethodIsThunk(MethodInfo method)
  18. {
  19. return method.DeclaringType == DeclaringType
  20. && method.IsPublic
  21. && method.CustomAttributes
  22. .Any(attr => attr.AttributeType == ConstantValueType);
  23. }
  24. private static bool MethodIsUnwrappable(MethodInfo method)
  25. {
  26. return method.DeclaringType == DeclaringType
  27. && method.IsPublic
  28. && method.CustomAttributes
  29. .Any(attr => attr.AttributeType == UnwrappableType);
  30. }
  31. // A dictionary of parameters to replace.
  32. private Dictionary<ParameterExpression, Expression> Replacements;
  33. /// <summary>
  34. /// Unwraps calls to unwrappable methods
  35. /// into the definition of the expression they call.
  36. /// </summary>
  37. internal UnwrapVisitor()
  38. {
  39. this.Replacements = new Dictionary<ParameterExpression, Expression>();
  40. }
  41. private UnwrapVisitor(Dictionary<ParameterExpression, Expression> replacements)
  42. {
  43. this.Replacements = replacements;
  44. }
  45. // Replace a parameter if it's in our dictionary of replacements
  46. protected override Expression VisitParameter(ParameterExpression node)
  47. {
  48. if (this.Replacements.ContainsKey(node))
  49. return this.Replacements[node];
  50. else
  51. return base.VisitParameter(node);
  52. }
  53. private Expression Unwrap(MethodCallExpression node)
  54. {
  55. // The first argument of an unwrappable call
  56. // is the expression to unwrap into.
  57. var expression = (node.Arguments[0] as MemberExpression);
  58. if (expression == null)
  59. return base.VisitMethodCall(node);
  60. // The owning object that holds our method.
  61. object constant;
  62. var e = expression.Expression;
  63. var member = expression.Member;
  64. if (e == null)
  65. {
  66. // This is a static field or property
  67. constant = member.ReflectedType;
  68. }
  69. else
  70. {
  71. while (true)
  72. {
  73. // Dig down to the underlying
  74. // constant value of the expression
  75. if (e is ConstantExpression)
  76. {
  77. constant = (e as ConstantExpression).Value;
  78. break;
  79. }
  80. if (e is MemberExpression)
  81. {
  82. var m = e as MemberExpression;
  83. e = m.Expression;
  84. member = m.Member;
  85. continue;
  86. }
  87. throw new InvalidOperationException(
  88. string.Format(
  89. "Can't work with expression {0} of type {1}.",
  90. e,
  91. e.GetType()));
  92. }
  93. }
  94. // The field or property of `constant` that we want.
  95. var field = member as FieldInfo;
  96. var property = member as PropertyInfo;
  97. // The value of the field of `constant` - our method body.
  98. LambdaExpression lambda;
  99. if (property != null)
  100. {
  101. lambda = property.GetValue(constant) as LambdaExpression;
  102. }
  103. else
  104. {
  105. lambda = field.GetValue(constant) as LambdaExpression;
  106. }
  107. if (lambda == null)
  108. {
  109. return base.VisitMethodCall(node);
  110. }
  111. var replacements = new Dictionary<ParameterExpression, Expression>();
  112. // Unwrap each parameter of the lambda by replacing
  113. // it with the correspoding argument to the outer
  114. // expression (the method call)
  115. for (var i = 0; i < lambda.Parameters.Count; i++)
  116. {
  117. // Recursively unwrap the entire tree.
  118. var replacement = this.Visit(node.Arguments[i + 1]);
  119. replacements.Add(lambda.Parameters[i], replacement);
  120. }
  121. // Allow another visit to replace parameters defined here and
  122. // to recursively unwrap method calls.
  123. return new UnwrapVisitor(replacements).Visit(lambda.Body);
  124. }
  125. private Expression Unthunk(MethodCallExpression node)
  126. {
  127. var lambda = Expression.Lambda(node.Arguments[0]);
  128. return Expression.Constant(lambda.Compile().DynamicInvoke());
  129. }
  130. protected override Expression VisitMethodCall(MethodCallExpression node)
  131. {
  132. if (MethodIsUnwrappable(node.Method))
  133. {
  134. return this.Unwrap(node);
  135. }
  136. if (MethodIsThunk(node.Method))
  137. {
  138. return this.Unthunk(node);
  139. }
  140. return base.VisitMethodCall(node);
  141. }
  142. }
  143. }