I wrote code in python and used pytorch library to implement GAN which is an approach to generate pictures. Here is my code:

class Reshape(torch.nn.Module):
    def __init__(self, *shape):
        super().__init__()
        self.shape = shape
    def forward(self, x):
        return x.reshape(x.size(0), *self.shape)

Here is Generator:

class Generator(torch.nn.Module):
    def __init__(self, z_dim=64, num_channels=1):
        super().__init__()
        self.z_dim = z_dim

        self.net = nn.Sequential(
            nn.Linear(z_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 64 * 7 * 7),
            nn.BatchNorm1d(64 * 7 * 7),
            nn.ReLU(),
            Reshape(64, 7, 7),
            nn.PixelShuffle(2),

            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding=1)
        )
        
    def forward(self, z):
        return self.net(z)

Here is Discriminator:

class Discriminator(torch.nn.Module):
    def __init__(self, num_channels=1):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, padding=1, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, padding=1, stride=2),
            nn.ReLU(),
            Reshape(64*7*7),
            nn.Linear(64*7*7, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            Reshape()
        )

    def forward(self, x):
        return self.net(x)

Here is code to calculate loss:

def loss_nonsaturating(d, g, x_real, *, device):
      '''
      Input Arguments:

      - x_real (torch.Tensor): training data samples (64, 1, 28, 28)
      - device (torch.device): 'cpu' by default

      Returns:
      - d_loss (torch.Tensor): nonsaturating discriminator loss
      - g_loss (torch.Tensor): nonsaturating generator loss
      '''

      z = torch.randn(x_real.shape[0], g.z_dim, device=device)
      gz = g(z)
      dgz = F.sigmoid(d(gz))
      dx = d(x_real)

      real_label = torch.ones(x_real.shape[0], device=device)
      fake_label = torch.zeros(x_real.shape[0], device=device)
      
      bce_loss = F.binary_cross_entropy_with_logits
      g_loss = bce_loss(dgz, real_label).mean()
      d_loss = bce_loss(dx, real_label).mean() + bce_loss(dgz, fake_label).mean()


      return d_loss, g_loss

Here is code to train model:

def build_input(x, y, device):
    x_real = x.to(device)
    y_real = y.to(device)
    return x_real, y_real

num_latents = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = Generator(z_dim=64)
d = Discriminator()

g_optimizer = torch.optim.Adam(g.parameters(), lr=1e-3)
d_optimizer = torch.optim.Adam(d.parameters(), lr=1e-3)

iter_max = 1000

torch.autograd.set_detect_anomaly(True)

with tqdm(total=int(iter_max)) as pbar:
      for idx, (x, y) in enumerate(train_loader):
        x_real, y_real = build_input(x, y, device)

        g_loss, d_loss = loss_nonsaturating(d, g, x_real, device=device)


        d_optimizer.zero_grad()
        d_loss.backward(retain_graph=True)
        d_optimizer.step()

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

I got this error:

0%|          | 0/1000 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:251: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>
    ColabKernelApp.launch_instance()
  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>
    lambda f: self._run_callback(functools.partial(callback, future))
  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback
    ret = callback()
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner
    self.ctx_run(self.run)
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
    yielded = self.gen.send(value)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 377, in dispatch_queue
    yield self.process_one()
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 250, in wrapper
    runner = Runner(ctx_run, result, future, yielded)
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 748, in __init__
    self.ctx_run(self.run)
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
    yielded = self.gen.send(value)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
    yielded = ctx_run(next, result)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
    yielded = ctx_run(next, result)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
    self.do_execute(
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
    yielded = ctx_run(next, result)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
    result = self._run_cell(
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
    return runner(coro)
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
    coro.send(None)
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-20-10c36497e22a>", line 17, in <cell line: 13>
    g_loss, d_loss = loss_nonsaturating(d, g, x_real, device=device)
  File "<ipython-input-18-c563e2132852>", line 16, in loss_nonsaturating
    dx = d(x_real)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<ipython-input-7-e00f0cc91a93>", line 24, in forward
    return self.net(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  0%|          | 0/1000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-20-10c36497e22a> in <cell line: 13>()
     23 
     24         g_optimizer.zero_grad()
---> 25         g_loss.backward()
     26         g_optimizer.step()
     27 

1 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    249     # some Python versions print out the first line of a multi-line function
    250     # calls in the traceback and some print out the last line
--> 251     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252         tensors,
    253         grad_tensors_,

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [512, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I did anything to solve the problem but did not help. I replace generator and discriminator backpropagation code but I got error again in the second one. Do you know what should I do?

1

There are 1 answers

0
Karl On BEST ANSWER

When training a GAN, you need to separate the loss computations for the generator and the discriminator. In particular, you don't want your discriminator loss to backprop into the generator. You want to do something like this:

with tqdm(total=int(iter_max)) as pbar:
      for idx, (x, y) in enumerate(train_loader):
        x_real, y_real = build_input(x, y, device)

        x_fake, y_fake = generator(...)

        g_loss = g_loss_function(descriminator(x_fake), y_real)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        d_loss = d_loss_function(descriminator(x_fake.detach(), y_fake)

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

Using x_fake.detach() is the important part. That prevents the discriminator loss from backproping into the generator. It should also clear up the need to retain the graph.