Диагностика заболеваний по ЭКГ с помощью сверточных нейронных сетей

Характеристика аритмии как типичного типа сердечно-сосудистых заболеваний, который относится к любым изменениям нормальных ритмов сердца. Анализ особенностей применения сверточных нейронных сетей к задаче классификации данных электрокардиографии.

Рубрика Медицина
Вид дипломная работа
Язык русский
Дата добавления 18.07.2020
Размер файла 1,9 M

Отправить свою хорошую работу в базу знаний просто. Используйте форму, расположенную ниже

Студенты, аспиранты, молодые ученые, использующие базу знаний в своей учебе и работе, будут вам очень благодарны.

[15] Lee, J., Oh, K., Kim, B., Yoo, S.K., 2019a. Synthesis of electrocardiogram v lead signals from limb lead measurement using r peak aligned generative adversarial network. IEEE journal of biomedical and health informatics

[16] Karpov N., Lyashuk A., Vizgunov A. Sentiment Analysis Using Deep Learning //International Conference on Network Analysis. - Springer, Cham, 2016. - С. 281-288.

[17] Saadatnejad, S., Oveisi, M., Hashemi, M., 2019. Lstm-based ecg classification for continuous monitoring on personal wearable devices. IEEE journal of biomedical and health informatics.

[18] Li, R., Zhang, X., Dai, H., Zhou, B., Wang, Z., 2019e. Interpretability analysis of heartbeat classification based on heartbeat global sequence features and bilstm-attention neural network. IEEE Access 7, 109870-109883.

[19] Krizhevsky A., Sutskever I., Hinton G. E. Imagenet classification with deep convolutional neural networks //Advances in neural information processing systems. - 2012. - С. 1097-1105.

[20] Zihlmann M., Perekrestenko D., Tschannen M. Convolutional recurrent neural networks for electrocardiogram classification //2017 Computing in Cardiology (CinC). - IEEE, 2017. - С. 1-4.

[21] Jun T. J. et al. ECG arrhythmia classification using a 2-D convolutional neural network //arXiv preprint arXiv:1804.06812. - 2018.

[22] Deng, J., Dong, W., Socher, R., Li, L.J., Li, K., Fei-Fei, L., 2009. Imagenet: A large-scale hierarchical image database, in: 2009 IEEE conference on computer vision and pattern recognition, Ieee. pp. 248-255.

[23] He K. et al. Deep residual learning for image recognition //Proceedings of the IEEE conference on computer vision and pattern recognition. - 2016. - С. 770-778.

[24] Simonyan K., Zisserman A. Very deep convolutional networks for large-scale image recognition //arXiv preprint arXiv:1409.1556. - 2014.

[25] Rajpurkar P. et al. Cardiologist-level arrhythmia detection with convolutional neural networks //arXiv preprint arXiv:1707.01836. - 2017.

[26] Xu X., Liu H. ECG Heartbeat Classification Using Convolutional Neural Networks //IEEE Access. - 2020. - Т. 8. - С. 8614-8619.

Приложения

Приложение 1

1. import argparse

2. import os

3. import os.path as osp

4.  

5. import cv2

6. import matplotlib.pyplot as plt

7. import numpy as np

8. import wfdb

9. from sklearn.preprocessing import scale

10. from wfdb import rdrecord

11.  

12. # Choose from peak to peak or centered

13. # mode = [20, 20]

14. mode = 128

15.  

16. image_size = 128

17. output_dir = '../data'

18.  

19. # dpi fix

20. fig = plt.figure(frameon=False)

21. dpi = fig.dpi

22.  

23. # fig size / image size

24. figsize = (image_size / dpi, image_size / dpi)

25. image_size = (image_size, image_size)

26.  

27.  

28. def plot(signal, filename):

29.     plt.figure(figsize=figsize, frameon=False)

30.     plt.axis('off')

31.     plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)

32.     # plt.margins(0, 0) # use for generation images with no margin

33.     plt.plot(signal)

34.     plt.savefig(filename)

35.  

36.     plt.close()

37.  

38.     im_gray = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)

39.     im_gray = cv2.resize(im_gray, image_size, interpolation=cv2.INTER_LANCZOS4)

40.     cv2.imwrite(filename, im_gray)

41.  

42.  

43. if __name__ == '__main__':

44.  

45.     parser = argparse.ArgumentParser()

46.     parser.add_argument('--file', required=True)

47.     args = parser.parse_args()

48.  

49.  

50.     ecg = args.file

51.     name = osp.basename(ecg)

52.     record = rdrecord(ecg)

53.     ann = wfdb.rdann(ecg, extension='atr')

54.     for sig_name, signal in zip(record.sig_name, record.p_signal.T):

55.         if not np.all(np.isfinite(signal)):

56.             continue

57.         signal = scale(signal)

58.         for i, (label, peak) in enumerate(zip(ann.symbol, ann.sample)):

59.             if label == '/': label = "\\"

60.             print('\r{} [{}/{}]'.format(sig_name, i + 1, len(ann.symbol)), end="")

61.             if isinstance(mode, list):

62.                 if np.all([i > 0, i + 1 < len(ann.sample)]):

63.                     left = ann.sample[i - 1] + mode[0]

64.                     right = ann.sample[i + 1] - mode[1]

65.                 else:

66.                     continue

67.             elif isinstance(mode, int):

68.                 left, right = peak - mode // 2, peak + mode // 2

69.             else:

70.                 raise Exception("Wrong mode in script beginning")

71.  

72.             if np.all([left > 0, right < len(signal)]):

73.                 one_dim_data_dir = osp.join(output_dir, '1D', name, sig_name, label)

74.                 two_dim_data_dir = osp.join(output_dir, '2D', name, sig_name, label)

75.                 os.makedirs(one_dim_data_dir, exist_ok=True)

76.                 os.makedirs(two_dim_data_dir, exist_ok=True)

77.  

78.                 filename = osp.join(one_dim_data_dir, '{}.npy'.format(peak))

79.                 np.save(filename, signal[left:right])

80.                 filename = osp.join(two_dim_data_dir, '{}.png'.format(peak))

81.  

82.                 plot(signal[left:right], filename)

Приложение 2

1. import subprocess

2. import os.path as osp

3. import multiprocessing as mp

4. from glob import glob

5. from tqdm import tqdm

6.  

7. input_dir = '../mit-bih/*.atr'

8. ecg_data = sorted([osp.splitext(i)[0] for i in glob(input_dir)])

9. pbar = tqdm(total=len(ecg_data))

10.  

11.  

12. def run(file):

13.     params = ['python3', 'dataset-generation.py', '--file', file]

14.     subprocess.check_call(params)

15.     pbar.update(1)

16.  

17.  

18. if __name__ == '__main__':

19.     p = mp.Pool(processes=mp.cpu_count())

20.     p.map(run, ecg_data)

Приложение 3

1. import json

2. import os.path as osp

3. from glob import glob

4.  

5. import pandas as pd

6.  

7. # 1. N - Normal

8. # 2. V - PVC (Premature ventricular contraction)

9. # 3. \ - PAB (Paced beat)

10. # 4. R - RBB (Right bundle branch)

11. # 5. L - LBB (Left bundle branch)

12. # 6. A - APB (Atrial premature beat)

13. # 7. ! - AFW (Ventricular flutter wave)

14. # 8. E - VEB (Ventricular escape beat)

15.  

16. classes = ['N', 'V', '\\', 'R', 'L', 'A', '!', 'E']

17. lead = 'MLII'

18. extension = 'png'  # or `npy` for 1D

19. data_path = osp.abspath('../data/*/*/*/*/*.{}'.format(extension))

20. val_size = 0.1  # [0, 1]

21.  

22. output_path = '/'.join(data_path.split('/')[:-5])

23. random_state = 7

24.  

25. if __name__ == '__main__':

26.     dataset = []

27.     files = glob(data_path)

28.  

29.     for file in glob(data_path):

30.         *_, name, lead, label, filename = file.split('/')

31.         dataset.append({

32.             "name": name,

33.             "lead": lead,

34.             "label": label,

35.             "filename": osp.splitext(filename)[0],

36.             "path": file

37.         })

38.  

39.     data = pd.DataFrame(dataset)

40.     data = data[data['lead'] == lead]

41.     data = data[data['label'].isin(classes)]

42.     data = data.sample(frac=1, random_state=random_state)

43.  

44.     val_ids = []

45.     for cl in classes:

46.         val_ids.extend(data[data['label'] == cl].sample(frac=val_size, random_state=random_state).index)

47.  

48.     val = data.loc[val_ids, :]

49.     train = data[~data.index.isin(val.index)]

50.  

51.     train.to_json(osp.join(output_path, 'train.json'), orient='records')

52.     val.to_json(osp.join(output_path, 'val.json'), orient='records')

53.  

54.     d = {}

55.     for label in train.label.unique():

56.         d[label] = len(d)

57.  

58.     with open(osp.join(output_path, 'class-mapper.json'), 'w') as file:

59.         file.write(json.dumps(d, indent=1))

Приложение 4

1. import torch.nn as nn

2. import torch.nn.functional as F

3.  

4.  

5. def conv_block(in_planes, out_planes, stride=1, groups=1, dilation=1):

6.     return nn.Conv1d(in_planes, out_planes, kernel_size=17, stride=stride,

7.                      padding=8, groups=groups, bias=False, dilation=dilation)

8.  

9.  

10. def conv_subsumpling(in_planes, out_planes, stride=1):

11.     return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

12.  

13.  

14. class BasicBlockHeartNet(nn.Module):

15.     expansion = 1

16.  

17.     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,

18.                  base_width=64, dilation=1, norm_layer=None):

19.         super(BasicBlockHeartNet, self).__init__()

20.         if norm_layer is None:

21.             norm_layer = nn.BatchNorm1d

22.         if groups != 1 or base_width != 64:

23.             raise ValueError('BasicBlock only supports groups=1 and base_width=64')

24.         if dilation > 1:

25.             raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

26.         # Both self.conv1 and self.downsample layers downsample the input when stride != 1

27.         self.conv1 = conv_block(inplanes, planes, stride)

28.         self.bn1 = norm_layer(inplanes)

29.         self.relu = nn.ReLU(inplace=True)

30.         self.conv2 = conv_block(planes, planes)

31.         self.bn2 = norm_layer(planes)

32.         self.downsample = downsample

33.         self.stride = stride

34.  

35.     def forward(self, x):

36.         identity = x

37.  

38.         out = self.bn1(x)

39.         out = self.relu(out)

40.         out = self.conv1(out)

41.  

42.         out = self.bn2(out)

43.         out = self.relu(out)

44.         out = self.conv2(out)

45.  

46.         if self.downsample is not None:

47.             identity = self.downsample(x)

48.         if self.stride != 1:

49.             identity = F.max_pool1d(identity, self.stride)

50.  

51.         out += identity

52.  

53.         return out

54.  

55.  

56. class BasicBlock(nn.Module):

57.     expansion = 1

58.  

59.     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,

60.                  base_width=64, dilation=1, norm_layer=None):

61.         super(BasicBlock, self).__init__()

62.         if norm_layer is None:

63.             norm_layer = nn.BatchNorm1d

64.         if groups != 1 or base_width != 64:

65.             raise ValueError('BasicBlock only supports groups=1 and base_width=64')

66.         if dilation > 1:

67.             raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

68.         # Both self.conv1 and self.downsample layers downsample the input when stride != 1

69.         self.conv1 = conv_block(inplanes, planes, stride)

70.         self.bn1 = norm_layer(inplanes)

71.         self.relu = nn.ReLU(inplace=True)

72.         self.conv2 = conv_block(planes, planes)

73.         self.bn2 = norm_layer(planes)

74.         self.dropout = nn.Dropout()

75.         self.downsample = downsample

76.         self.stride = stride

77.  

78.     def forward(self, x):

79.         identity = x

80.  

81.         out = self.bn1(x)

82.         out = self.relu(out)

83.         out = self.dropout(out)

84.         out = self.conv1(out)

85.  

86.         out = self.bn2(out)

87.         out = self.relu(out)

88.         out = self.dropout(out)

89.         out = self.conv2(out)

90.  

91.         if self.downsample is not None:

92.             identity = self.downsample(x)

93.  

94.         out += identity

95.  

96.         return out

97.  

98.  

99. class HeartNet(nn.Module):

100.  

101.     def __init__(self, layers=(1, 2, 2, 2, 2, 2, 2, 2, 1), num_classes=1000, zero_init_residual=False,

102.                  groups=1, width_per_group=64, replace_stride_with_dilation=None,

103.                  norm_layer=None, block=BasicBlockHeartNet):

104.  

105.         super(HeartNet, self).__init__()

106.         if norm_layer is None:

107.             norm_layer = nn.BatchNorm1d

108.         self._norm_layer = norm_layer

109.  

110.         self.inplanes = 32

111.         self.dilation = 1

112.         if replace_stride_with_dilation is None:

113.             # each element in the tuple indicates if we should replace

114.             # the 2x2 stride with a dilated convolution instead

115.             replace_stride_with_dilation = [False, False, False]

116.         if len(replace_stride_with_dilation) != 3:

117.             raise ValueError("replace_stride_with_dilation should be None "

118.                              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))

119.         self.groups = groups

120.         self.base_width = width_per_group

121.         self.conv1 = conv_block(1, self.inplanes, stride=1,)

122.         self.bn1 = norm_layer(self.inplanes)

123.         self.relu = nn.ReLU(inplace=True)

124.         self.layer0 = self._make_layer(block, 64, layers[0])

125.         self.layer1 = self._make_layer(block, 64, layers[1], stride=2,

126.                                        dilate=replace_stride_with_dilation[0])

127.         self.layer2 = self._make_layer(block, 128, layers[2], stride=2,

128.                                        dilate=replace_stride_with_dilation[0])

129.         self.layer2_ = self._make_layer(block, 128, layers[3], stride=2,

130.                                        dilate=replace_stride_with_dilation[0])

131.         self.layer3 = self._make_layer(block, 256, layers[4], stride=2,

132.                                        dilate=replace_stride_with_dilation[1])

133.         self.layer3_ = self._make_layer(block, 256, layers[5], stride=2,

134.                                        dilate=replace_stride_with_dilation[1])

135.         self.layer4 = self._make_layer(block, 512, layers[6], stride=2,

136.                                        dilate=replace_stride_with_dilation[2])

137.         self.layer4_ = self._make_layer(block, 512, layers[7], stride=2,

138.                                        dilate=replace_stride_with_dilation[2])

139.         self.layer5 = self._make_layer(block, 1024, layers[8], stride=2,

140.                                        dilate=replace_stride_with_dilation[2])

141.         self.avgpool = nn.AdaptiveAvgPool1d(1)

142.         self.fc = nn.Linear(1024 * block.expansion, num_classes)

143.  

144.         for m in self.modules():

145.             if isinstance(m, nn.Conv1d):

146.                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

147.             elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):

148.                 nn.init.constant_(m.weight, 1)

149.                 nn.init.constant_(m.bias, 0)

150.  

151.         # Zero-initialize the last BN in each residual branch,

152.         # so that the residual branch starts with zeros, and each residual block behaves like an identity.

153.         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677

154.         if zero_init_residual:

155.             for m in self.modules():

156.                 if isinstance(m, BasicBlockHeartNet):

157.                     nn.init.constant_(m.bn2.weight, 0)

158.  

159.     def _make_layer(self, block, planes, blocks, stride=1, dilate=False):

160.         norm_layer = self._norm_layer

161.         downsample = None

162.         previous_dilation = self.dilation

163.         self.stride = stride

164.         if dilate:

165.             self.dilation *= stride

166.             stride = 1

167.         if stride != 1 or self.inplanes != planes * block.expansion:

168.             downsample = nn.Sequential(

169.                 conv_subsumpling(self.inplanes, planes * block.expansion)

170.             )

171.  

172.         layers = []

173.         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,

174.                             self.base_width, previous_dilation, norm_layer))

175.         self.inplanes = planes * block.expansion

176.         for _ in range(1, blocks):

177.             layers.append(block(self.inplanes, planes, groups=self.groups,

178.                                 base_width=self.base_width, dilation=self.dilation,

179.                                 norm_layer=norm_layer))

180.  

181.         return nn.Sequential(*layers)

182.  

183.     def forward(self, x):

184.         x = self.conv1(x)

185.  

186.         x = self.layer0(x)

187.         x = self.layer1(x)

188.         x = self.layer2(x)

189.         x = self.layer2_(x)

190.         x = self.layer3(x)

191.         x = self.layer3_(x)

192.         x = self.layer4(x)

193.         x = self.layer4_(x)

194.         x = self.layer5(x)

195.  

196.         x = self.avgpool(x)

197.         x = x.reshape(x.size(0), -1)

198.         x = self.fc(x)

199.  

200.         return x

201.  

202.  

203. class EcgResNet34(nn.Module):

204.  

205.     def __init__(self, layers=(1, 5, 5, 5), num_classes=1000, zero_init_residual=False,

206.                  groups=1, width_per_group=64, replace_stride_with_dilation=None,

207.                  norm_layer=None, block=BasicBlock):

208.  

209.         super(EcgResNet34, self).__init__()

210.         if norm_layer is None:

211.             norm_layer = nn.BatchNorm1d

212.         self._norm_layer = norm_layer

213.  

214.         self.inplanes = 32

215.         self.dilation = 1

216.         if replace_stride_with_dilation is None:

217.             # each element in the tuple indicates if we should replace

218.             # the 2x2 stride with a dilated convolution instead

219.             replace_stride_with_dilation = [False, False, False]

220.         if len(replace_stride_with_dilation) != 3:

221.             raise ValueError("replace_stride_with_dilation should be None "

222.                              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))

223.         self.groups = groups

224.         self.base_width = width_per_group

225.         self.conv1 = conv_block(1, self.inplanes, stride=1,)

226.         self.bn1 = norm_layer(self.inplanes)

227.         self.relu = nn.ReLU(inplace=True)

228.         self.layer1 = self._make_layer(block, 64, layers[0])

229.         self.layer2 = self._make_layer(block, 128, layers[1], stride=2,

230.                                        dilate=replace_stride_with_dilation[0])

231.         self.layer3 = self._make_layer(block, 256, layers[2], stride=2,

232.                                        dilate=replace_stride_with_dilation[1])

233.         self.layer4 = self._make_layer(block, 512, layers[3], stride=2,

234.                                        dilate=replace_stride_with_dilation[2])

235.         self.avgpool = nn.AdaptiveAvgPool1d(1)

236.         self.fc = nn.Linear(512 * block.expansion, num_classes)

237.  

238.         for m in self.modules():

239.             if isinstance(m, nn.Conv1d):

240.                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

241.             elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):

242.                 nn.init.constant_(m.weight, 1)

243.                 nn.init.constant_(m.bias, 0)

244.  

245.         # Zero-initialize the last BN in each residual branch,

246.         # so that the residual branch starts with zeros, and each residual block behaves like an identity.

247.         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677

248.         if zero_init_residual:

249.             for m in self.modules():

250.                 if isinstance(m, BasicBlock):

251.                     nn.init.constant_(m.bn2.weight, 0)

252.  

253.     def _make_layer(self, block, planes, blocks, stride=1, dilate=False):

254.         norm_layer = self._norm_layer

255.         downsample = None

256.         previous_dilation = self.dilation

257.         if dilate:

258.             self.dilation *= stride

259.             stride = 1

260.         if stride != 1 or self.inplanes != planes * block.expansion:

261.             downsample = nn.Sequential(

262.                 conv_subsumpling(self.inplanes, planes * block.expansion, stride),

263.                 norm_layer(planes * block.expansion),

264.             )

265.  

266.         layers = []

267.         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,

268.                             self.base_width, previous_dilation, norm_layer))

269.         self.inplanes = planes * block.expansion

270.         for _ in range(1, blocks):

271.             layers.append(block(self.inplanes, planes, groups=self.groups,

272.                                 base_width=self.base_width, dilation=self.dilation,

273.                                 norm_layer=norm_layer))

274.  

275.         return nn.Sequential(*layers)

276.  

277.     def forward(self, x):

278.         x = self.conv1(x)

279.  

280.         x = self.layer1(x)

281.         x = self.layer2(x)

282.         x = self.layer3(x)

283.         x = self.layer4(x)

284.  

285.         x = self.avgpool(x)

286.         x = x.reshape(x.size(0), -1)

287.         x = self.fc(x)

288.  

289.         return x

290.  

291.  

292. class HeartNetIEEE(nn.Module):

293.     def __init__(self, num_classes=8):

294.         super().__init__()

295.  

296.         self.features = nn.Sequential(

297.             nn.Conv1d(1, 64, kernel_size=5),

298.             nn.ReLU(inplace=True),

299.             nn.Conv1d(64, 64, kernel_size=5),

300.             nn.ReLU(inplace=True),

301.             nn.MaxPool1d(2),

302.             nn.Conv1d(64, 128, kernel_size=3),

303.             nn.ReLU(inplace=True),

304.             nn.Conv1d(128, 128, kernel_size=3),

305.             nn.ReLU(inplace=True),

306.             nn.MaxPool1d(2)

307.         )

308.  

309.         self.classifier = nn.Sequential(

310.             nn.Linear(128 * 28, 256),

311.             nn.Linear(256, 128),

312.             nn.Linear(128, num_classes)

313.         )

314.  

315.     def forward(self, x):

316.         x = self.features(x)

317.         x = x.view(x.size(0), 128 * 28)

318.         x = self.classifier(x)

319.         return x

320.  

321.  

322.  

323. class Flatten(nn.Module):

324.     def forward(self, input):

325.         return input.view(input.size(0), -1)

326.  

327.  

328. class ZolotyhNet(nn.Module):

329.     def __init__(self, num_classes=8):

330.         super().__init__()

331.  

332.         self.features_up = nn.Sequential(

333.             nn.Conv1d(1, 8, kernel_size=3, padding=1),

334.             nn.BatchNorm1d(8),

335.             nn.ReLU(inplace=True),

336.             nn.MaxPool1d(2),

337.  

338.             nn.Conv1d(8, 16, kernel_size=3, padding=1),

339.             nn.BatchNorm1d(16),

340.             nn.ReLU(inplace=True),

341.             nn.MaxPool1d(2),

342.  

343.             nn.Conv1d(16, 32, kernel_size=3, padding=1),

344.             nn.BatchNorm1d(32),

345.             nn.ReLU(inplace=True),

346.             nn.MaxPool1d(2),

347.  

348.             nn.Conv1d(32, 32, kernel_size=3, padding=1),

349.             nn.BatchNorm1d(32),

350.             nn.ReLU(inplace=True),

351.             nn.MaxPool1d(2),

352.  

353.             nn.Conv1d(32, 1, kernel_size=3, padding=1),

354.             Flatten(),

355.         )

356.  

357.         self.features_down = nn.Sequential(

358.             Flatten(),

359.             nn.Linear(128,64),

360.             nn.BatchNorm1d(64),

361.             nn.ReLU(inplace=True),

362.  

363.             nn.Linear(64, 16),

364.             nn.BatchNorm1d(16),

365.             nn.ReLU(inplace=True),

366.  

367.             nn.Linear(16, 8)

368.         )

369.  

370.         self.classifier = nn.Linear(8, num_classes)

371.  

372.     def forward(self, x):

373.         out_up = self.features_up(x)

374.         out_down = self.features_down(x)

375.         out_middle = out_up + out_down

376.  

377.         out = self.classifier(out_middle)

378.  

379.         return out

Приложение 5

1. import json

2. import os

3. import os.path as osp

4. import numpy as np

5. from datetime import datetime

6.  

7. import torch

8. import wfdb

9. from tqdm import tqdm

10. import plotly.graph_objects as go

11.  

12. from utils.network_utils import load_checkpoint

13.  

14.  

15. class BasePipeline:

16.     def __init__(self, config):

17.         self.config = config

18.         self.exp_name = self.config.get('exp_name', None)

19.         if self.exp_name is None:

20.             self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

21.  

22.         self.res_dir = osp.join(self.config['exp_dir'], self.exp_name, 'results')

23.         os.makedirs(self.res_dir, exist_ok=True)

24.  

25.         self.model = self._init_net()

26.  

27.         self.pipeline_loader = self._init_dataloader()

28.  

29.         self.mapper = json.load(open(config['mapping_json']))

30.         self.mapper = {j: i for i, j in self.mapper.items()}

31.  

32.         pretrained_path = self.config.get('model_path', False)

33.         if pretrained_path:

34.             load_checkpoint(pretrained_path, self.model)

35.         else:

36.             raise Exception("model_path doesnt't exist in config. Please specify checkpoint path")

37.  

38.     def _init_net(self):

39.         raise NotImplemented

40.  

41.     def _init_dataloader(self):

42.         raise NotImplemented

43.  

44.     def run_pipeline(self):

45.         self.model.eval()

46.         pd_class = np.empty(0)

47.         pd_peaks = np.empty(0)

48.  

49.         with torch.no_grad():

50.             for i, batch in tqdm(enumerate(self.pipeline_loader)):

51.                 inputs = batch['image'].to(self.config['device'])

52.  

53.                 predictions = self.model(inputs)

54.  

55.                 classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()

56.  

57.                 pd_class = np.concatenate((pd_class, classes))

58.                 pd_peaks = np.concatenate((pd_peaks, batch['peak']))

59.  

60.         pd_class = pd_class.astype(int)

61.         pd_peaks = pd_peaks.astype(int)

62.  

63.         annotations = []

64.         for label, peak in zip(pd_class, pd_peaks):

65.             if peak < len(self.pipeline_loader.dataset.signal) and self.mapper[label] != 'N':

66.                 annotations.append({

67.                     "x": peak,

68.                     "y": self.pipeline_loader.dataset.signal[peak],

69.                     "text": self.mapper[label],

70.                     "xref": "x",

71.                     "yref": "y",

72.                     "showarrow": True,

73.                     "arrowcolor": "black",

74.                     "arrowhead": 1,

75.                     "arrowsize": 2

76.                 })

77.  

78.         if osp.exists(self.config['ecg_data'] + '.atr'):

79.             ann = wfdb.rdann(self.config['ecg_data'], extension='atr')

80.             for label, peak in zip(ann.symbol, ann.sample):

81.                 if peak < len(self.pipeline_loader.dataset.signal) and label != 'N':

82.                     annotations.append({

83.                         "x": peak,

84.                         "y": self.pipeline_loader.dataset.signal[peak] - 0.1,

85.                         "text": label,

86.                         "xref": "x",

87.                         "yref": "y",

88.                         "showarrow": False,

89.                         "bordercolor": "#c7c7c7",

90.                         "borderwidth": 1,

91.                         "borderpad": 4,

92.                         "bgcolor": "#ffffff",

93.                         "opacity": 1

94.                     })

95.  

96.         fig = go.Figure(data=go.Scatter(x=list(range(len(self.pipeline_loader.dataset.signal))), y=self.pipeline_loader.dataset.signal))

97.         fig.update_layout(title='ECG',

98.                           xaxis_title='Time',

99.                           yaxis_title='ECG Output Value',

100.                           title_x=0.5,

101.                           annotations=annotations,

102.                           autosize=True)

103.  

104.         fig.write_html(osp.join(self.res_dir, osp.basename(self.config['ecg_data'] + '.html')))

Приложение 6

1. import json

2.  

3. import wfdb

4. from scipy.signal import find_peaks

5. from sklearn.preprocessing import scale

6. from torch.utils.data import Dataset, DataLoader

7. import numpy as np

8.  

9.  

10. class EcgDataset1D(Dataset):

11.     def __init__(self, ann_path, mapping_path):

12.         super().__init__()

13.         self.data = json.load(open(ann_path))

14.         self.mapper = json.load(open(mapping_path))

15.  

16.     def __getitem__(self, index):

17.         img = np.load(self.data[index]['path']).astype('float32')

18.         img = img.reshape(1, img.shape[0])

19.  

20.         return {

21.             "image": img,

22.             "class": self.mapper[self.data[index]['label']]

23.         }

24.  

25.     def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True):


Подобные документы

Работы в архивах красиво оформлены согласно требованиям ВУЗов и содержат рисунки, диаграммы, формулы и т.д.
PPT, PPTX и PDF-файлы представлены только в архивах.
Рекомендуем скачать работу.