1
\$\begingroup\$

I am a PhD student working on a machine learning project with binary classification and RESNET architecture in TensorFlow. I believe I have done everything correctly but I am looking for some validation that the code is correct as I have no one to check my work. I am working in Google Colab with two programs: 1 to split the dataset and 1 to run the model. The first code is splitdataset.ipynb and the second code is classifyimages.ipynb. I know their are easier and other ways to do this but this is how I implemented it. Some of the more important things to check over would be how I implemented the brightness augmentations and how I split the dataset but ideally a whole code validation would be nice.

splitdataset.ipynb

import cv2
import os
import shutil
import random
%cd /content/drive/MyDrive/static_CTC_classification

!rm -r data_set
!mkdir data_set
!mkdir data_set/training
!mkdir data_set/validation
!mkdir data_set/testing
!mkdir data_set/training/DU145
!mkdir data_set/training/PC3
!mkdir data_set/validation/DU145
!mkdir data_set/validation/PC3
!mkdir data_set/testing/DU145
!mkdir data_set/testing/PC3

du145 = []
pc3 = []

for image in os.listdir("full_ds/DU145"):
  du145.append("full_ds/DU145/" + image)

for image in os.listdir("full_ds/PC3"):
  pc3.append("full_ds/PC3/" + image)

images = du145 + pc3
random.shuffle(images)

num_images = len(images)
train_num = int(0.8 * num_images)
val_num = int(0.1 * num_images)
print("train_num: ",train_num)
print("val_num: ",val_num)

train = images[0:train_num]
val = images[train_num:train_num+val_num]
test = images[train_num+val_num:]


for image in train:
  if(image[8] == "D"):
    im = image[14:]
    shutil.copyfile(image, "split_ds/training/DU145/"+im)
  if(image[8] == "P"):
    im = image[12:]
    shutil.copyfile(image, "split_ds/training/PC3/"+im)

for image in val:
  if(image[8] == "D"):
    im = image[14:]
    shutil.copyfile(image, "split_ds/validation/DU145/"+im)
  if(image[8] == "P"):
    im = image[12:]
    shutil.copyfile(image, "split_ds/validation/PC3/"+im)

for image in test:
  if(image[8] == "D"):
    im = image[14:]
    shutil.copyfile(image, "split_ds/testing/DU145/"+im)
  if(image[8] == "P"):
    im = image[12:]
    shutil.copyfile(image, "split_ds/testing/PC3/"+im)

classifyimages.ipynb

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt

image_size = (224,224)
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "/content/drive/MyDrive/static_CTC_classification/split_ds/training",
    seed=1337,
    color_mode='rgb',
    image_size=image_size,
    batch_size=batch_size
)


val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "/content/drive/MyDrive/static_CTC_classification/split_ds/validation",
    seed=1337,
    color_mode='rgb',
    image_size=image_size,
    batch_size=batch_size
)

test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "/content/drive/MyDrive/static_CTC_classification/split_ds/testing",
    seed=1337,
    color_mode='rgb',
    image_size=image_size,
    batch_size=batch_size
)

# from website: https://towardsdatascience.com/writing-a-custom-data-augmentation-layer-in-keras-2b53e048a98

class RandomColorDistortion(tf.keras.layers.Layer):
  def __init__(self, contrast_range=[0.5, 1.5], 
                brightness_delta=[-0.2, 0.2], **kwargs):
      super(RandomColorDistortion, self).__init__(**kwargs)
      self.contrast_range = contrast_range
      self.brightness_delta = brightness_delta

  def call(self, images, training=None):
          if not training:
              return images
          
          contrast = np.random.uniform(
              self.contrast_range[0], self.contrast_range[1])
          brightness = np.random.uniform(
              self.brightness_delta[0], self.brightness_delta[1])
          
          images = tf.image.adjust_contrast(images, contrast)
          images = tf.image.adjust_brightness(images, brightness)
          images = tf.clip_by_value(images, 0, 1)
          return images

augment_and_normalize = tf.keras.Sequential([
  RandomColorDistortion(contrast_range=[0.5,1.5], brightness_delta=[-0.15, 0.15]),                       
  tf.keras.layers.RandomFlip("horizontal"),
  tf.keras.layers.RandomRotation(0.1),
  tf.keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)
], name="augment_and_normalize")


def make_model(input_shape, num_classes):
    input = tf.keras.Input(shape=input_shape)

    # Entry block
    x = augment_and_normalize(input)
    feature_extractor = tf.keras.applications.resnet.ResNet50(input_shape=(224, 224, 3),
                                               include_top=False,
                                               weights='imagenet')(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(feature_extractor)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1024, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(512, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    output = tf.keras.layers.Dense(1, activation="sigmoid", name="classification")(x)
    return tf.keras.Model(input, output)

model = make_model(input_shape=image_size + (3,), num_classes=2)
tf.keras.utils.plot_model(model)


epochs = 250

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss="binary_crossentropy",
    metrics=["accuracy"],
)

history = model.fit(x=train_ds, epochs=epochs, validation_data=val_ds,
    callbacks = [ModelCheckpoint(filepath="weights.{epoch:02d}.ckpt", monitor='val_accuracy',
    verbose=0, save_best_only=True,save_weights_only=True, mode='auto', save_freq='epoch',options=None)])

!cp /content/weights.05.ckpt.data-00000-of-00001 /content/drive/MyDrive/static_CTC_classification/
!cp /content/weights.05.ckpt.index /content/drive/MyDrive/static_CTC_classification/


# Then after saving these weights to my google drive I can load them later with the following command

model.load_weights("/content/drive/MyDrive/static_CTC_classification/weights.05.ckpt")
\$\endgroup\$

0

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.