Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

common.py 19KB

4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
4 år sedan
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Common modules
  4. """
  5. import logging
  6. import math
  7. import warnings
  8. from copy import copy
  9. from pathlib import Path
  10. import numpy as np
  11. import pandas as pd
  12. import requests
  13. import torch
  14. import torch.nn as nn
  15. from PIL import Image
  16. from torch.cuda import amp
  17. from utils.datasets import exif_transpose, letterbox
  18. from utils.general import colorstr, increment_path, is_ascii, make_divisible, non_max_suppression, save_one_box, \
  19. scale_coords, xyxy2xywh
  20. from utils.plots import Annotator, colors
  21. from utils.torch_utils import time_sync
  22. LOGGER = logging.getLogger(__name__)
  23. def autopad(k, p=None): # kernel, padding
  24. # Pad to 'same'
  25. if p is None:
  26. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  27. return p
  28. class Conv(nn.Module):
  29. # Standard convolution
  30. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  31. super().__init__()
  32. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  33. self.bn = nn.BatchNorm2d(c2)
  34. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  35. def forward(self, x):
  36. return self.act(self.bn(self.conv(x)))
  37. def forward_fuse(self, x):
  38. return self.act(self.conv(x))
  39. class DWConv(Conv):
  40. # Depth-wise convolution class
  41. def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  42. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  43. class TransformerLayer(nn.Module):
  44. # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
  45. def __init__(self, c, num_heads):
  46. super().__init__()
  47. self.q = nn.Linear(c, c, bias=False)
  48. self.k = nn.Linear(c, c, bias=False)
  49. self.v = nn.Linear(c, c, bias=False)
  50. self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
  51. self.fc1 = nn.Linear(c, c, bias=False)
  52. self.fc2 = nn.Linear(c, c, bias=False)
  53. def forward(self, x):
  54. x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
  55. x = self.fc2(self.fc1(x)) + x
  56. return x
  57. class TransformerBlock(nn.Module):
  58. # Vision Transformer https://arxiv.org/abs/2010.11929
  59. def __init__(self, c1, c2, num_heads, num_layers):
  60. super().__init__()
  61. self.conv = None
  62. if c1 != c2:
  63. self.conv = Conv(c1, c2)
  64. self.linear = nn.Linear(c2, c2) # learnable position embedding
  65. self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
  66. self.c2 = c2
  67. def forward(self, x):
  68. if self.conv is not None:
  69. x = self.conv(x)
  70. b, _, w, h = x.shape
  71. p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
  72. return self.tr(p + self.linear(p)).unsqueeze(3).transpose(0, 3).reshape(b, self.c2, w, h)
  73. class Bottleneck(nn.Module):
  74. # Standard bottleneck
  75. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  76. super().__init__()
  77. c_ = int(c2 * e) # hidden channels
  78. self.cv1 = Conv(c1, c_, 1, 1)
  79. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  80. self.add = shortcut and c1 == c2
  81. def forward(self, x):
  82. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  83. class BottleneckCSP(nn.Module):
  84. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  85. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  86. super().__init__()
  87. c_ = int(c2 * e) # hidden channels
  88. self.cv1 = Conv(c1, c_, 1, 1)
  89. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  90. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  91. self.cv4 = Conv(2 * c_, c2, 1, 1)
  92. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  93. self.act = nn.LeakyReLU(0.1, inplace=True)
  94. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  95. def forward(self, x):
  96. y1 = self.cv3(self.m(self.cv1(x)))
  97. y2 = self.cv2(x)
  98. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  99. class C3(nn.Module):
  100. # CSP Bottleneck with 3 convolutions
  101. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  102. super().__init__()
  103. c_ = int(c2 * e) # hidden channels
  104. self.cv1 = Conv(c1, c_, 1, 1)
  105. self.cv2 = Conv(c1, c_, 1, 1)
  106. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
  107. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  108. # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
  109. def forward(self, x):
  110. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
  111. class C3TR(C3):
  112. # C3 module with TransformerBlock()
  113. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  114. super().__init__(c1, c2, n, shortcut, g, e)
  115. c_ = int(c2 * e)
  116. self.m = TransformerBlock(c_, c_, 4, n)
  117. class C3SPP(C3):
  118. # C3 module with SPP()
  119. def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
  120. super().__init__(c1, c2, n, shortcut, g, e)
  121. c_ = int(c2 * e)
  122. self.m = SPP(c_, c_, k)
  123. class C3Ghost(C3):
  124. # C3 module with GhostBottleneck()
  125. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  126. super().__init__(c1, c2, n, shortcut, g, e)
  127. c_ = int(c2 * e) # hidden channels
  128. self.m = nn.Sequential(*[GhostBottleneck(c_, c_) for _ in range(n)])
  129. class SPP(nn.Module):
  130. # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
  131. def __init__(self, c1, c2, k=(5, 9, 13)):
  132. super().__init__()
  133. c_ = c1 // 2 # hidden channels
  134. self.cv1 = Conv(c1, c_, 1, 1)
  135. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  136. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  137. def forward(self, x):
  138. x = self.cv1(x)
  139. with warnings.catch_warnings():
  140. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  141. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  142. class SPPF(nn.Module):
  143. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
  144. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  145. super().__init__()
  146. c_ = c1 // 2 # hidden channels
  147. self.cv1 = Conv(c1, c_, 1, 1)
  148. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  149. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  150. def forward(self, x):
  151. x = self.cv1(x)
  152. with warnings.catch_warnings():
  153. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  154. y1 = self.m(x)
  155. y2 = self.m(y1)
  156. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
  157. class Focus(nn.Module):
  158. # Focus wh information into c-space
  159. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  160. super().__init__()
  161. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  162. # self.contract = Contract(gain=2)
  163. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  164. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  165. # return self.conv(self.contract(x))
  166. class GhostConv(nn.Module):
  167. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  168. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  169. super().__init__()
  170. c_ = c2 // 2 # hidden channels
  171. self.cv1 = Conv(c1, c_, k, s, None, g, act)
  172. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
  173. def forward(self, x):
  174. y = self.cv1(x)
  175. return torch.cat([y, self.cv2(y)], 1)
  176. class GhostBottleneck(nn.Module):
  177. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  178. def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
  179. super().__init__()
  180. c_ = c2 // 2
  181. self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
  182. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  183. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  184. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
  185. Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  186. def forward(self, x):
  187. return self.conv(x) + self.shortcut(x)
  188. class Contract(nn.Module):
  189. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  190. def __init__(self, gain=2):
  191. super().__init__()
  192. self.gain = gain
  193. def forward(self, x):
  194. b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
  195. s = self.gain
  196. x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
  197. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  198. return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
  199. class Expand(nn.Module):
  200. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  201. def __init__(self, gain=2):
  202. super().__init__()
  203. self.gain = gain
  204. def forward(self, x):
  205. b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  206. s = self.gain
  207. x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
  208. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  209. return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
  210. class Concat(nn.Module):
  211. # Concatenate a list of tensors along dimension
  212. def __init__(self, dimension=1):
  213. super().__init__()
  214. self.d = dimension
  215. def forward(self, x):
  216. return torch.cat(x, self.d)
  217. class AutoShape(nn.Module):
  218. # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  219. conf = 0.25 # NMS confidence threshold
  220. iou = 0.45 # NMS IoU threshold
  221. classes = None # (optional list) filter by class
  222. multi_label = False # NMS multiple labels per box
  223. max_det = 1000 # maximum number of detections per image
  224. def __init__(self, model):
  225. super().__init__()
  226. self.model = model.eval()
  227. def autoshape(self):
  228. LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
  229. return self
  230. @torch.no_grad()
  231. def forward(self, imgs, size=640, augment=False, profile=False):
  232. # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
  233. # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
  234. # URI: = 'https://ultralytics.com/images/zidane.jpg'
  235. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
  236. # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
  237. # numpy: = np.zeros((640,1280,3)) # HWC
  238. # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
  239. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  240. t = [time_sync()]
  241. p = next(self.model.parameters()) # for device and type
  242. if isinstance(imgs, torch.Tensor): # torch
  243. with amp.autocast(enabled=p.device.type != 'cpu'):
  244. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  245. # Pre-process
  246. n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
  247. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  248. for i, im in enumerate(imgs):
  249. f = f'image{i}' # filename
  250. if isinstance(im, (str, Path)): # filename or uri
  251. im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
  252. im = np.asarray(exif_transpose(im))
  253. elif isinstance(im, Image.Image): # PIL Image
  254. im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
  255. files.append(Path(f).with_suffix('.jpg').name)
  256. if im.shape[0] < 5: # image in CHW
  257. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  258. im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
  259. s = im.shape[:2] # HWC
  260. shape0.append(s) # image shape
  261. g = (size / max(s)) # gain
  262. shape1.append([y * g for y in s])
  263. imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
  264. shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
  265. x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
  266. x = np.stack(x, 0) if n > 1 else x[0][None] # stack
  267. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  268. x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
  269. t.append(time_sync())
  270. with amp.autocast(enabled=p.device.type != 'cpu'):
  271. # Inference
  272. y = self.model(x, augment, profile)[0] # forward
  273. t.append(time_sync())
  274. # Post-process
  275. y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes,
  276. multi_label=self.multi_label, max_det=self.max_det) # NMS
  277. for i in range(n):
  278. scale_coords(shape1, y[i][:, :4], shape0[i])
  279. t.append(time_sync())
  280. return Detections(imgs, y, files, t, self.names, x.shape)
  281. class Detections:
  282. # YOLOv5 detections class for inference results
  283. def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
  284. super().__init__()
  285. d = pred[0].device # device
  286. gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
  287. self.imgs = imgs # list of images as numpy arrays
  288. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  289. self.names = names # class names
  290. self.ascii = is_ascii(names) # names are ascii (use PIL for UTF-8)
  291. self.files = files # image filenames
  292. self.xyxy = pred # xyxy pixels
  293. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  294. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  295. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  296. self.n = len(self.pred) # number of images (batch size)
  297. self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
  298. self.s = shape # inference BCHW shape
  299. def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
  300. crops = []
  301. for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
  302. str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
  303. if pred.shape[0]:
  304. for c in pred[:, -1].unique():
  305. n = (pred[:, -1] == c).sum() # detections per class
  306. str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  307. if show or save or render or crop:
  308. annotator = Annotator(im, pil=not self.ascii)
  309. for *box, conf, cls in reversed(pred): # xyxy, confidence, class
  310. label = f'{self.names[int(cls)]} {conf:.2f}'
  311. if crop:
  312. file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
  313. crops.append({'box': box, 'conf': conf, 'cls': cls, 'label': label,
  314. 'im': save_one_box(box, im, file=file, save=save)})
  315. else: # all others
  316. annotator.box_label(box, label, color=colors(cls))
  317. im = annotator.im
  318. else:
  319. str += '(no detections)'
  320. im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
  321. if pprint:
  322. LOGGER.info(str.rstrip(', '))
  323. if show:
  324. im.show(self.files[i]) # show
  325. if save:
  326. f = self.files[i]
  327. im.save(save_dir / f) # save
  328. if i == self.n - 1:
  329. LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
  330. if render:
  331. self.imgs[i] = np.asarray(im)
  332. if crop:
  333. if save:
  334. LOGGER.info(f'Saved results to {save_dir}\n')
  335. return crops
  336. def print(self):
  337. self.display(pprint=True) # print results
  338. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
  339. self.t)
  340. def show(self):
  341. self.display(show=True) # show results
  342. def save(self, save_dir='runs/detect/exp'):
  343. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
  344. self.display(save=True, save_dir=save_dir) # save results
  345. def crop(self, save=True, save_dir='runs/detect/exp'):
  346. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
  347. return self.display(crop=True, save=save, save_dir=save_dir) # crop results
  348. def render(self):
  349. self.display(render=True) # render results
  350. return self.imgs
  351. def pandas(self):
  352. # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
  353. new = copy(self) # return copy
  354. ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
  355. cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
  356. for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
  357. a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
  358. setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
  359. return new
  360. def tolist(self):
  361. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  362. x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
  363. for d in x:
  364. for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  365. setattr(d, k, getattr(d, k)[0]) # pop out of list
  366. return x
  367. def __len__(self):
  368. return self.n
  369. class Classify(nn.Module):
  370. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  371. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  372. super().__init__()
  373. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  374. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  375. self.flat = nn.Flatten()
  376. def forward(self, x):
  377. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  378. return self.flat(self.conv(z)) # flatten to x(b,c2)