Как динамически добавить тело в метод, помеченный как «static extern» в C #? - PullRequest
1 голос
/ 11 июня 2019

Резюме

Я пытаюсь написать небольшой API для моего проекта , который обеспечил бы почти полную замену System.DllImportAttribute.Как мне динамически заменить или добавить тело в метод, помеченный static extern?

Справочная информация

Я видел несколько ответов здесь ( здесь и здесь ), которые показывают, как перехватывать и динамически заменять методы, но они не extern, и я не могу заставить ни одного из них работать с методами, которые этого не делают.

Мой текущий API выполняет следующие действия:

  1. Находит все методы, отмеченные NativeCallAttribute, и возвращает MethodInfo[].
  2. Для каждого MethodInfo ввозвращаемое MethodInfo[]:
    1. Загружает указанную библиотеку (если она еще не загружена), используя LoadLibrary или dlopen.
    2. Получите IntPtr, представляющий указатель на функцию для указанногометод из загруженной библиотеки с использованием GetProcAddress или dlsym.
    3. Создает делегат для указателя на встроенную функцию на основе текущего MethodInfo.
    4. Получает MethodInfo изсгенерированный делегат для замены существующего.
    5. Заменяет старый метод на новый.

Код

Моя текущая реализация требуетподход, аналогичный этому проекту с точки зрения получения атрибута и сбора информации о присоединенном методе, и этого ответа stackoverflow с точки зрения замены тела функции собственным делегатом.

Текущий APIчто мне нужно загрузить библиотеку и заставить указатель функции работать как следует (как они работают в других ситуациях, кроме этого), и они были исключены из кода ниже.

Текущий API

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.CompilerServices;
using System.Security;
using TCDFx.ComponentModel;

namespace TCDFx.InteropServices
{
    // Indicates that the attributed method is exposed by an native assembly as a static entry point.
    [CLSCompliant(false)]
    [SuppressUnmanagedCodeSecurity]
    [AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = false)]
    public sealed class NativeCallAttribute : Attribute
    {
        // The name of the native method.
        public string EntryPoint;

        // Initializes a new instance of the NativeCallAttribute.
        public NativeCallAttribute(params string[] assemblyNames)
        {
            if (assemblyNames == null || assemblyNames.Length == 0)
                throw new NativeCallException("No assembly specified.");

            string[] names = new string[] { };
            int i = 0;
            foreach (string name in assemblyNames)
            {
                if (!string.IsNullOrWhiteSpace(name))
                {
                    names[i] = name;
                    i++;
                }
            }

            AssemblyNames = names;
        }

        // An ordered list of assembly names.
        public string[] AssemblyNames { get; }
    }

    [SuppressUnmanagedCodeSecurity]
    public static class NativeCalls
    {
        private static readonly object sync = new object();

        // Replaces all defined functions with the 'NativeCallAttribute' that are 'static' and 'extern' with their native call.
        public static void Load()
        {
            lock (sync)
            {
                MethodInfo[] funcInfo = GetNativeCalls();

                for (int i = 0; i < funcInfo.Length; i++)
                {
                    NativeCallAttribute attribute = funcInfo[i].GetCustomAttribute<NativeCallAttribute>(false);
                    NativeAssemblyBase nativeAssembly;

                    if (IsAssemblyCached(attribute.AssemblyNames, out NativeAssemblyBase cachedAssembly))
                        nativeAssembly = cachedAssembly;
                    else
                    {
                        if (TryLoadAssembly(attribute.AssemblyNames, out NativeAssemblyBase loadedAssembly, out Exception loadingEx))
                            nativeAssembly = loadedAssembly;
                        else
                            throw loadingEx;
                    }

                    string funcName = attribute.EntryPoint ?? funcInfo[i].Name;
                    IntPtr funcPtr = nativeAssembly.LoadFunctionPointer(funcName);

                    Delegate funcDelegate = GenerateNativeDelegate(funcName, nativeAssembly.Name, funcInfo[i], funcPtr);
                    MethodInfo funcInfoNew = funcDelegate.GetMethodInfo();

                    MethodReplacementState state = ReplaceMethod(funcInfo[i], funcInfoNew);
                    replacements.Add(state);
                }
            }
        }

        // Gets all methods marked with a 'NativeCallAttribute'.
        private static MethodInfo[] GetNativeCalls()
        {
            List<MethodInfo> result = new List<MethodInfo>();
            Assembly[] assemblies = AppDomain.CurrentDomain.GetAssemblies();
            for (int i = 0; i < assemblies.Length; i++)
            {
                Type[] types = assemblies[i].GetTypes();
                for (int ii = 0; ii < types.Length; ii++)
                {
                    MethodInfo[] methods = types[ii].GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
                    for (int iii = 0; iii < methods.Length; iii++)
                    {
                        Attribute attr = methods[iii].GetCustomAttribute<NativeCallAttribute>(false);
                        if (attr != null)
                            result.Add(methods[iii]);
                    }
                }
            }
            return result.ToArray();
        }

        // Gets a 'Delegate' for a native function pointer with information provided from the method to replace.
        private static Delegate GenerateNativeDelegate(string funcName, string assemblyName, MethodInfo funcInfo, IntPtr funcPtr)
        {
            Type returnType = funcInfo.ReturnType;
            ParameterInfo[] @params = funcInfo.GetParameters();
            Type[] paramTypes = new Type[] { };

            for (int i = 0; i < @params.Length; i++)
                paramTypes[i] = @params[i].ParameterType;

            DynamicMethod nativeMethod = new DynamicMethod($"{assemblyName}_{funcName}", returnType, paramTypes, funcInfo.Module);
            ILGenerator ilGenerator = nativeMethod.GetILGenerator();

            // Generate the arguments
            for (int i = 0; i < @params.Length; i++)
            {
                //TODO: See if I need separate out code for this...
                if (@params[i].ParameterType.IsByRef || @params[i].IsOut)
                {
                    ilGenerator.Emit(OpCodes.Ldarg, i);
                    ilGenerator.Emit(OpCodes.Ldnull);
                    ilGenerator.Emit(OpCodes.Stind_Ref);
                }
                else
                {
                    ilGenerator.Emit(OpCodes.Ldarg, i);
                }
            }

            // Push the funcPtr to the stack
            if (IntPtr.Size == 4)
                ilGenerator.Emit(OpCodes.Ldc_I4, funcPtr.ToInt32());
            else if (IntPtr.Size == 8)
                ilGenerator.Emit(OpCodes.Ldc_I8, funcPtr.ToInt64());
            else throw new PlatformNotSupportedException();

            // Call it and return;
            ilGenerator.EmitCall(OpCodes.Call, funcInfo, null);
            ilGenerator.Emit(OpCodes.Ret);

            Type delegateType = Expression.GetDelegateType((from param in @params select param.ParameterType).Concat(new[] { returnType }).ToArray());
            return nativeMethod.CreateDelegate(delegateType);
        }

        private static bool IsAssemblyCached(string[] assemblyNames, out NativeAssemblyBase cachedAssembly)
        {
            bool result = false;
            cachedAssembly = null;
            foreach (string name in assemblyNames)
            {
                if (!Component.Cache.ContainsKey(Path.GetFileNameWithoutExtension(name)))
                {
                    Type asmType = Component.Cache[Path.GetFileNameWithoutExtension(name)].Value1;
                    if (asmType == typeof(NativeAssembly))
                        cachedAssembly = (NativeAssembly)Component.Cache[Path.GetFileNameWithoutExtension(name)].Value2;
                    else if (asmType == typeof(NativeAssembly))
                        cachedAssembly = (DependencyNativeAssembly)Component.Cache[Path.GetFileNameWithoutExtension(name)].Value2;
                    else if (asmType == typeof(NativeAssembly))
                        cachedAssembly = (EmbeddedNativeAssembly)Component.Cache[Path.GetFileNameWithoutExtension(name)].Value2;
                    result = true;
                    break;
                }
            }
            return result;
        }

        private static bool TryLoadAssembly(string[] assemblyNames, out NativeAssemblyBase loadedAssembly, out Exception exception)
        {
            bool result = false;
            exception = null;
            try
            {
                loadedAssembly = new NativeAssembly(assemblyNames);
            }
            catch (Exception ex)
            {
                exception = ex;
                loadedAssembly = null;
            }
            try
            {
                if (exception != null)
                    loadedAssembly = new DependencyNativeAssembly(assemblyNames);
            }
            catch (Exception ex)
            {
                exception = ex;
                loadedAssembly = null;
            }
            try
            {
                if (exception == null)
                    loadedAssembly = new EmbeddedNativeAssembly(assemblyNames);
            }
            catch (Exception ex)
            {
                exception = ex;
                loadedAssembly = null;
            }
            return result;
        }

        private static unsafe MethodReplacementState ReplaceMethod(MethodInfo targetMethod, MethodInfo replacementMethod)
        {
            if (!(targetMethod.GetMethodBody() == null && targetMethod.IsStatic))
                throw new NativeCallException($"Only the replacement of methods marked 'static extern' is supported.");

#if DEBUG
            RuntimeHelpers.PrepareMethod(targetMethod.MethodHandle);
            RuntimeHelpers.PrepareMethod(replacementMethod.MethodHandle);
#endif
            IntPtr target = targetMethod.MethodHandle.Value;
            IntPtr replacement = replacementMethod.MethodHandle.Value + 8;
            if (!targetMethod.IsVirtual)
                target += 8;
            else
            {
                int i = (int)(((*(long*)target) >> 32) & 0xFF);
                IntPtr classStart = *(IntPtr*)(targetMethod.DeclaringType.TypeHandle.Value + (IntPtr.Size == 4 ? 40 : 64));
                target = classStart + (IntPtr.Size * i);
            }

#if DEBUG
            target = *(IntPtr*)target + 1;
            replacement = *(IntPtr*)replacement + 1;

            MethodReplacementState state = new MethodReplacementState(target, new IntPtr(*(int*)target));
            *(int*)target = *(int*)replacement + (int)(long)replacement - (int)(long)target;
            return state;
#else
            MethodReplacementState state = new MethodReplacementState(target, *(IntPtr*)target);
            * (IntPtr*)target = *(IntPtr*)replacement;
            return state;
#endif
        }

        private readonly struct MethodReplacementState : IDisposable
        {
            private readonly IntPtr Location;
            private readonly IntPtr OriginalValue;

            public MethodReplacementState(IntPtr location, IntPtr origValue)
            {
                Location = location;
                OriginalValue = origValue;
            }
            public void Dispose() => Restore();

            private unsafe void Restore() =>
#if DEBUG
        *(int*)Location = (int)OriginalValue;
#else
        *(IntPtr*)Location = OriginalValue;
#endif
        }
    }
}

Тестовый код

using TCDFx.InteropServices;

namespace NativeCallExample
{
    internal class Program
    {
        internal static void Main()
        {
            NativeCalls.Load();
            Beep(2000, 400);
        }

        [NativeCall("kernel32.dll")]
        private static extern bool Beep(uint frequency, uint duration);
    }
}

Ожидаемые / фактические результаты

Я ожидал, что он будет работать так, как должен (как указано выше в описании), но он падает dotnet.exe скод ошибки -532462766.Никакие точки останова не встречаются нигде в коде (в тестовом приложении или API-библиотека), и исключение не выдается.Я полагаю, что проблема возникает между шагами 2.3 и 2.5 выше, но в данный момент я застрял.Буду признателен за любую помощь!

Дополнительная информация

Если вы хотите увидеть ссылочный код, который не включен, и полную копию того, что у меня есть для этого, вы можете найти его в это ветка моего проекта .

...