Only need to run this section if you run the notebook in Google Colab.
colab = False # set to True if running on Google Colab
if colab:
!git clone https://github.com/diffinfinite/diffinfinite.git
if colab:
%cd diffinfinite/
!pip install -q -r requirements.txt
!python download.py --n_classes=5
import sys
sys.path.append('/content/diffinfinite')
import os
import sys
import yaml
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as T
from dataset import set_global_seed
from dm import Unet, GaussianDiffusion, Trainer
from random_diffusion import RandomDiffusion
from random_diffusion_masks import RandomDiffusionMasks
import importlib
import dm_masks
importlib.reload(dm_masks)
from dm_masks import Unet as MaskUnet
from dm_masks import GaussianDiffusion as MaskGD
from dm_masks import Trainer as MaskTrainer
from utils.mask_modules import CleanMask
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
os.environ["CUDA_VISIBLE_DEVICES"]="0"
@torch.no_grad()
def sample(masks, cond_scale=3.0):
z = torch.ones((masks.shape[0],
4,512//8,512//8), device='cuda:0')
z = trainer.ema.ema_model.sample(z,masks, cond_scale=cond_scale+1)*50
return torch.clip(trainer.vae.decode(z).sample,0,1)
milestone=5
config_file='./config/mask_gen_sample5.yaml'
with open(config_file, 'r') as config_file:
config = yaml.safe_load(config_file)
for key in config.keys():
globals().update(config[key])
maskunet = MaskUnet(
dim=dim,
num_classes=num_classes,
dim_mults=dim_mults,
channels=channels,
resnet_block_groups = resnet_block_groups,
block_per_layer=block_per_layer,
)
maskmodel = MaskGD(
maskunet,
image_size=mask_size//8,
timesteps=timesteps,
sampling_timesteps=sampling_timesteps,
loss_type='l2')
masktrainer = MaskTrainer(
maskmodel,
train_batch_size=batch_size,
train_lr=lr,
train_num_steps=train_num_steps,
save_and_sample_every=save_sample_every,
gradient_accumulate_every=gradient_accumulate_every,
save_loss_every=save_loss_every,
num_samples=num_samples,
num_workers=num_workers,
results_folder=results_folder)
masktrainer.load(milestone)
masktrainer.ema.cuda()
masktrainer.ema=masktrainer.ema.eval()
Fetching 13 files: 0%| | 0/13 [00:00<?, ?it/s]
milestone=10
config_file='./config/image_gen_sample5.yaml'
with open(config_file, 'r') as config_file:
config = yaml.safe_load(config_file)
for key in config.keys():
globals().update(config[key])
unet = Unet(
dim=dim,
num_classes=num_classes,
dim_mults=dim_mults,
channels=channels,
resnet_block_groups = resnet_block_groups,
block_per_layer=block_per_layer,
)
model = GaussianDiffusion(
unet,
image_size=image_size//8,
timesteps=timesteps,
sampling_timesteps=sampling_timesteps,
loss_type='l2')
trainer = Trainer(
model,
train_batch_size=batch_size,
train_lr=lr,
train_num_steps=train_num_steps,
save_and_sample_every=save_sample_every,
gradient_accumulate_every=gradient_accumulate_every,
save_loss_every=save_loss_every,
num_samples=num_samples,
num_workers=num_workers,
results_folder=results_folder)
trainer.load(milestone)
trainer.ema.cuda()
trainer.ema=trainer.ema.eval()
Fetching 13 files: 0%| | 0/13 [00:00<?, ?it/s]
Visualising the class generation in sequence. This toy example shows DiffInfinite Mask-Image control. This output doesn't have biological meaning and it isn't accurate because the generated class depends on the mask shape and its neighbours. The mask generation is necessary to give this spatial information.
'''Strides Generation'''
classes_names = ['Carcinoma', 'Necrosis', 'Tumor Stroma', 'Others']
test_masks=torch.ones((1,1,512,512), device='cuda:0').int()
for i in range(len(classes_names)):
test_masks[:,:,:,i*(512//len(classes_names)):(i+1)*(512//len(classes_names))]=i+1
outputs=sample(test_masks, cond_scale=3.0)
xtick_positions = [(512//(2*len(classes_names)))+(512//len(classes_names))*i for i in range(4)]
fig,ax=plt.subplots(ncols=2)
ax[0].imshow(test_masks[0,0].cpu())
ax[1].imshow(outputs[0].permute(1,2,0).cpu())
ax[1].set_xticks(xtick_positions, classes_names, rotation=45)
plt.show()
'''Squares Generation'''
test_masks=torch.ones((1,1,512,512), device='cuda:0').int()
test_masks[:,:,:512//2,:512//2]=1
test_masks[:,:,:512//2,512//2:]=2
test_masks[:,:,512//2:,:512//2]=3
test_masks[:,:,512//2:,512//2:]=4
outputs=sample(test_masks, cond_scale=3.0)
fig,ax=plt.subplots(ncols=2)
ax[0].imshow(test_masks[0,0].cpu())
ax[1].imshow(outputs[0].permute(1,2,0).cpu())
plt.show()
We initially create a synthetic mask based on a contextual prompt, and then we use that mask as a guide to generate the corresponding image.
'''Mask Generation'''
n_batches = 2
batch_size = 2
sample_label = 1 # 0-> 'adenocarcinoma'; 1 -> 'squamouscarcinoma'; None -> use both labels
masks = masktrainer.sample_loop(n_batches, batch_size=batch_size, sample_label=sample_label, save_sample=False)
*_, h,w = masks.shape
masks = masks.reshape(-1, 1, h, w)
generate 4 masks using device=cuda Sampling complete.
'''Image Generation'''
outputs=sample(masks.cuda(), cond_scale=3.0)
'''Visualise Synthetic Mask-Image pairs'''
colors = {0:'#000000', 1:'#3b4cc0', 2:'#518abf', 3:'#62b7bf', 4:'#71e3bd'}
for i in range(len(outputs)):
fig,ax=plt.subplots(ncols=2)
cmap = ListedColormap([colors[unique.cpu().item()] for unique in torch.unique(masks[i])])
ax[0].imshow(outputs[i].permute(1,2,0).cpu().detach())
im=ax[1].imshow(masks[i,0].cpu().detach(),cmap=cmap)
'''Image Generation for different omegas'''
cond_scales=[0.0,3.0,6.0]
outputs=[]
for cond_scale in cond_scales:
set_global_seed(10)
outputs.append(sample(masks[0][None].cuda(), cond_scale=cond_scale))
outputs=torch.cat(outputs,0)
'''Visualise Synthetic Mask-Image pairs with different omegas'''
colors = {0:'#000000', 1:'#3b4cc0', 2:'#518abf', 3:'#62b7bf', 4:'#71e3bd'}
fig,ax=plt.subplots(ncols=len(cond_scales)+1, figsize=(10,10))
cmap = ListedColormap([colors[unique.cpu().item()] for unique in torch.unique(masks[0])])
ax[0].imshow(masks[0,0].cpu().detach(),cmap=cmap)
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].set_title('Mask')
for i in range(len(cond_scales)):
ax[i+1].imshow(outputs[i].permute(1,2,0).cpu().detach())
ax[i+1].set_xticks([])
ax[i+1].set_yticks([])
ax[i+1].set_title(f'omega={cond_scales[i]}')