Упрощение модификации строк из динамического списка? - PullRequest
0 голосов
/ 26 апреля 2018

Я хочу изменить список строк из одного формата в другой.

Пример полного списка можно найти здесь: https://gist.github.com/ProGamerGov/1d728e7ca4cc52abf398277642e4ee78

Некоторые примеры того, что я пытаюсьделать:

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

Кому:

(2): Conv2d(3 -> 64, 3x3, 1,1, 1,1)

И это:

Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

К этому:

(8): Conv2d(64 -> 128, 3x3, 1,1, 1,1)

Мой текущий кодНастройка выглядит следующим образом:

def modify_text(text, new): 
    return text.replace(", ","*", 1).replace(", ", new, 1)

for i, layer in enumerate(net): 
        if "Conv2d" in str(layer):
           layer = str(layer).replace(","," ->", 1)
           layer = modify_text(layer, "x").replace("kernel_size=(", "").replace("stride=(", "").replace("padding=(", "").replace(")","", 3)
           layer = modify_text(modify_text(layer, ","), ",").replace("*",", ")
           print("  (" + str(i+1) + "): " + layer)

Но я чувствую, что я мог бы найти лучший / более простой способ сделать это?

Редактировать, я упростил свою настройку до этого:

regx_map = r'(2d).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*'
regx_pool = r'(2d).*?(\d+).*?(\d+).*?(\d+).*'
for i, layer in enumerate(net): 
     if "Conv2d" in str(layer):
          print("  (" + str(i+1) + "): " + re.sub(regx_map, r'\1(\2 -> \3, \4x\5, \6,\7, \8,\9)', str(layer)))
     elif "MaxPool2d" in str(layer) or "AvgPool2d" in str(layer):
          print("  (" + str(i+1) + "): " + re.sub(regx_pool, r'\1(\2x\2, \3,\3)', str(layer)))
     else:
          print("  (" + str(i+1) + "): " + "nn." + str(layer).split("(", 1)[0]) 

Ответы [ 2 ]

0 голосов
/ 26 апреля 2018

Вы можете использовать регулярное выражение для анализа сигнатур функций в строках, преобразования параметров в требуемый синтаксис, а затем и затем используйте re.sub для добавления новых сигнатур:

import re
import itertools
def signature(header):
  s = re.findall('(?<=\()[\w\W]+(?=\)$)', header)
  return re.split(',\s(?=\w\w)', s[0]) if s else ''

def combine_args(d):
   if '=' in d:
      return '{}x{}'.format(*re.findall('\d+', d)) if 'kernel_size' in d and len(re.findall('\d+', d)) == 2 else '{},{}'.format(*re.findall('\d+', d)) if len(re.findall('\d+', d)) == 2 else d
   return d

def combine_header(d):
  vals = [[a, list(b)] for a, b in itertools.groupby(d, key=lambda x:x.isdigit())]
  return list(itertools.chain(*[[' -> '.join(b)] if a else [combine_args(i) for i in b] for a, b in vals]))

lines = ['TVLoss()', 'Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'ContentLoss((crit): MSELoss())', 'Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())']
final_lines = [re.sub('(?<=\()[\w\W]+(?=\)$)', ', '.join(combine_header(c)), a) for a, c in zip(lines, map(signature, lines))]

Вывод:

['TVLoss()', 'Conv2d(3 -> 64, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(64 -> 64, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(64 -> 128, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(128 -> 128, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(128 -> 256, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(256 -> 256, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'Conv2d(256 -> 256, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'Conv2d(256 -> 256, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(256 -> 512, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(512 -> 512, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'ContentLoss((crit): MSELoss())', 'Conv2d(512 -> 512, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'Conv2d(512 -> 512, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(512 -> 512, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())']
0 голосов
/ 26 апреля 2018

Вы можете использовать регулярное выражение для сопоставления:

import re
s = 'Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))'
regx = r'(Conv2d).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*'

print(re.sub(regx, r'\1(\2 -> \3, \4x\5, \6,\7,\8,\9)', s))

Вывод:

Conv2d(3 -> 64, 3x3, 1,1,1,1)

Если вы хотите сделать регулярное выражение немного более устойчивым, за счет того, что оно довольно длинное:

(Conv2d)\((\d+),\s(\d+),\skernel_size=\((\d+),\s(\d+)\),\sstride=\((\d+),\s(\d+)\),\spadding=\((\d+),\s(\d+)\)\)

Попробуйте здесь

...