I wrote code in python and used pytorch library to implement GAN which is an approach to generate pictures. Here is my code:
class Reshape(torch.nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.reshape(x.size(0), *self.shape)
Here is Generator:
class Generator(torch.nn.Module):
def __init__(self, z_dim=64, num_channels=1):
super().__init__()
self.z_dim = z_dim
self.net = nn.Sequential(
nn.Linear(z_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 64 * 7 * 7),
nn.BatchNorm1d(64 * 7 * 7),
nn.ReLU(),
Reshape(64, 7, 7),
nn.PixelShuffle(2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.PixelShuffle(2),
nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding=1)
)
def forward(self, z):
return self.net(z)
Here is Discriminator:
class Discriminator(torch.nn.Module):
def __init__(self, num_channels=1):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, padding=1, stride=2),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, padding=1, stride=2),
nn.ReLU(),
Reshape(64*7*7),
nn.Linear(64*7*7, 512),
nn.ReLU(),
nn.Linear(512, 1),
Reshape()
)
def forward(self, x):
return self.net(x)
Here is code to calculate loss:
def loss_nonsaturating(d, g, x_real, *, device):
'''
Input Arguments:
- x_real (torch.Tensor): training data samples (64, 1, 28, 28)
- device (torch.device): 'cpu' by default
Returns:
- d_loss (torch.Tensor): nonsaturating discriminator loss
- g_loss (torch.Tensor): nonsaturating generator loss
'''
z = torch.randn(x_real.shape[0], g.z_dim, device=device)
gz = g(z)
dgz = F.sigmoid(d(gz))
dx = d(x_real)
real_label = torch.ones(x_real.shape[0], device=device)
fake_label = torch.zeros(x_real.shape[0], device=device)
bce_loss = F.binary_cross_entropy_with_logits
g_loss = bce_loss(dgz, real_label).mean()
d_loss = bce_loss(dx, real_label).mean() + bce_loss(dgz, fake_label).mean()
return d_loss, g_loss
Here is code to train model:
def build_input(x, y, device):
x_real = x.to(device)
y_real = y.to(device)
return x_real, y_real
num_latents = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = Generator(z_dim=64)
d = Discriminator()
g_optimizer = torch.optim.Adam(g.parameters(), lr=1e-3)
d_optimizer = torch.optim.Adam(d.parameters(), lr=1e-3)
iter_max = 1000
torch.autograd.set_detect_anomaly(True)
with tqdm(total=int(iter_max)) as pbar:
for idx, (x, y) in enumerate(train_loader):
x_real, y_real = build_input(x, y, device)
g_loss, d_loss = loss_nonsaturating(d, g, x_real, device=device)
d_optimizer.zero_grad()
d_loss.backward(retain_graph=True)
d_optimizer.step()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
I got this error:
0%| | 0/1000 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:251: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>
ColabKernelApp.launch_instance()
File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance
app.start()
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start
self.io_loop.start()
File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
self.asyncio_loop.run_forever()
File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>
lambda f: self._run_callback(functools.partial(callback, future))
File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback
ret = callback()
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner
self.ctx_run(self.run)
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
yielded = self.gen.send(value)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 377, in dispatch_queue
yield self.process_one()
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 250, in wrapper
runner = Runner(ctx_run, result, future, yielded)
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 748, in __init__
self.ctx_run(self.run)
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
yielded = self.gen.send(value)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
yield gen.maybe_future(dispatch(*args))
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
yielded = ctx_run(next, result)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
yield gen.maybe_future(handler(stream, idents, msg))
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
yielded = ctx_run(next, result)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
self.do_execute(
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
yielded = ctx_run(next, result)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
result = self._run_cell(
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
return runner(coro)
File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
coro.send(None)
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
if (await self.run_code(code, result, async_=asy)):
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-20-10c36497e22a>", line 17, in <cell line: 13>
g_loss, d_loss = loss_nonsaturating(d, g, x_real, device=device)
File "<ipython-input-18-c563e2132852>", line 16, in loss_nonsaturating
dx = d(x_real)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<ipython-input-7-e00f0cc91a93>", line 24, in forward
return self.net(x)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 215, in forward
input = module(input)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
0%| | 0/1000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-20-10c36497e22a> in <cell line: 13>()
23
24 g_optimizer.zero_grad()
---> 25 g_loss.backward()
26 g_optimizer.step()
27
1 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
252 tensors,
253 grad_tensors_,
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [512, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
I did anything to solve the problem but did not help. I replace generator and discriminator backpropagation code but I got error again in the second one. Do you know what should I do?
When training a GAN, you need to separate the loss computations for the generator and the discriminator. In particular, you don't want your discriminator loss to backprop into the generator. You want to do something like this:
Using
x_fake.detach()
is the important part. That prevents the discriminator loss from backproping into the generator. It should also clear up the need to retain the graph.