#unet class 정의 사용 unet테스트를 사용하기위한 unet.py import torch import torch.nn as nn class UNet(nn.Module): def __init__(self, in_channels, out_channels): super(UNet, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(in_channels, 64, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2) ) self.middle = nn.Sequential( nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2) ) self.bottleneck = nn.Sequential( nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2) ) self.upconv1 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.decoder1 = nn.Sequential( nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True) ) self.upconv2 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.decoder2 = nn.Sequential( nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True) ) self.output_conv = nn.Conv2d(64, out_channels, 1) def forward(self, x): enc1 = self.encoder(x) enc2 = self.middle(enc1) bottleneck = self.bottleneck(enc2) up1 = self.upconv1(bottleneck) merge1 = torch.cat([enc2, up1], dim=1) dec1 = self.decoder1(merge1) up2 = self.upconv2(dec1) merge2 = torch.cat([enc1, up2], dim=1) dec2 = self.decoder2(merge2) output = self.output_conv(dec2) return output