fbpx

Written by 1:01 pm Tutoriale

Wykrywanie COVID-19 na obrazach tomografii komputerowej

Standardową metodą wykrywania wirusa SARS-CoV-2 jest reakcja łańcuchowa polimerazy z odwrotną transkrypcją (RT-PCR). Jednak w zależności od fazy choroby może dawać wyniki fałszywie negatywne. Uzupełniającą metodą w diagnostyce koronawirusa jest tomografia komputerowa, która według niektórych badań ma wyższą czułość niż RT-PCR.

COVID-19 na obrazach tomografii komputerowej może ujawniać się w podobny sposób co zapalenie płuc wywołane innymi wirusami. Posiada jednak również cechy charakterystyczne, co pozwala na odróżnienie od niektórych innych chorób układu oddechowego, chociaż nie zawsze jest to łatwe, na przykład w przypadku grypy.

W tym wpisie zajmiemy się analizą obrazów tomografii komputerowej klatki piersiowej oraz ich klasyfikacją do jednej z trzech klas. Porównanych zostanie kilka architektur. Mimo że COVID-19 może powodować również inne zmiany widoczne na obrazach CT, np. zakrzepicę, skupimy się głównie na zmianach patologicznych w płucach, gdyż to one są podstawą do diagnozy w oparciu o obraz z tomografu komputerowego (wiele prac bazuje właśnie na wysegmentowanych płucach do klasyfikacji/segmentacji zmian patologicznych, o czym można przeczytać m.in. w tym artykule przeglądowym). Spróbujemy również odpowiedzieć na pytanie, czy transfer learning z użyciem modeli pretrenowanych na ImageNet ma sens w przypadku obrazów medycznych.

Projekt można znaleźć na GitHubie.

Analiza zbioru

Zbiór danych znajduje się na GitHubie i jest już podzielony na zbiór treningowy, walidacyjny i testowy. Do zbioru testowego na razie nie zaglądamy i go nie analizujemy, aby uniknąć obciążenia związanego z podglądaniem danych (data snooping bias) i nie próbować dostosowywać skuteczności modelu specjalnie pod ten zbiór (choćby nieświadomie), co utrudniłoby generalizację na obrazy spoza zbioru.

Niektóre obrazy są w formacie jpg, niektóre png. Obrazów w zbiorze treningowym jest 4404, natomiast walidacyjnym 1341. Większość obrazów ma rozdzielczość 512 x 512 px, jednak zdarzają się również obrazy większe.

Liczba obrazów o danej rozdzielczości.
512 x 512 Train: 3832
512 x 512 Val: 1257
768 x 768 Train: 391
768 x 768 Val: 24
1024 x 1024 Train: 181
1024 x 1024 Val: 60

Są trzy klasy:

  • COV – pacjent chory na COVID-19,
  • Normal – zdrowy pacjent,
  • OtherPneumonia – zapalenie płuc innego pochodzenia.

Zbiór nie jest zbalansowany – obrazów COV jest znacznie więcej niż pozostałych.

Liczebność zbiorów treningowego i walidacyjnego z podziałem na klasy.
COV Train: 2448
COV Val: 908
Normal Train: 1413
Normal Val: 388
OtherPneumonia Train: 543
OtherPneumonia Val: 45

Niestety nie mamy więcej informacji na temat zbioru, na przykład czy obrazy pochodzą z jednego skanera, ilu pacjentów jest w zbiorze, jacy są to pacjenci, jakie dobrano okno do prezentacji skali szarości (ang. windowing). Brak tych oraz innych szczegółów na temat zbioru wpływa na jakość analizy, a także na interpretowalność uzyskanych modeli. Należy to mieć na uwadze przy projektowaniu i tworzeniu rozwiązań z użyciem deep learningu dla medycyny (pomocna w tym jest ta checklista).

Wizualizacje

Zwizualizujmy po kilka przykładów z każdej klasy:

Przykłady klasy COV
COV
Przykłady klasy Normal
Normal
Przykłady klasy OtherPneumonia
OtherPneumonia

Na podstawie przykładowych obrazów widać, że w zbiorze może znajdować się więcej niż jeden przekrój od danego pacjenta. Widzimy również, że na zdjęciach oprócz pacjenta może znajdować się stół, a także że klatka piersiowa może zajmować mniej lub więcej powierzchni (tj. może być więcej lub mniej tła, które nie jest istotne przy klasyfikacji) i nie zawsze znajdować się dokładnie na środku obrazu.

Średni obraz

Obrazy w folderach Train oraz Val znajdują się w podkatalogach. Zobaczmy, czy poszczególne obrazy różnią się bardzo między sobą, wizualizując średni obraz z każdego podkatalogu. Poniżej znajdują się wizualizacje średnich obrazów zbioru treningowego i walidacyjnego. Poszczególne tytuły oprócz nazwy klasy zawierają nazwę katalogu oraz ile obrazów w danym katalogu się znajduje.

Średnie obrazy poszczególnych katalogów klasy COV - zbiór treningowy
Zbiór treningowy – obraz średni podkatalogów COV
Średnie obrazy poszczególnych katalogów klasy COV - zbiór walidacyjny
Zbiór walidaycjny – obraz średni podkatalogów COV
Średnie obrazy poszczególnych katalogów klasy Normal - zbiór treningowy
Zbiór treningowy – obraz średni podkatalogów Normal
Średnie obrazy poszczególnych katalogów klasy Normal - zbiór walidacyjny
Zbiór walidacyjny – obraz średni podkatalogów Normal
Średnie obrazy poszczególnych katalogów klasy OtherPneumonia - zbiór treningowy
Zbiór treningowy – obraz średni podkatalogów OtherPneumonia
Średnie obrazy poszczególnych katalogów klasy OtherPneumonia - zbiór walidacyjny
Zbiór walidacyjny – obraz średni katalogu OtherPneumonia

W obrębie danego podkatalogu obrazy nie zawsze są do siebie podobne pod względem wielkości klatki piersiowej oraz położenia na obrazie. Warto to mieć na uwadze podczas wstępnego przetwarzania obrazów. Przykładowo: jeśli chcielibyśmy usunąć z obrazu stół lub część tła, poszczególne obrazy wymagałyby różnych parametrów przycięcia (crop).

Zobaczmy również obraz średni ze wszystkich obrazów dla poszczególnych klas:

Średni obraz każdej klasy - zbiór treningowy
Zbiór treningowy
Średni obraz każdej klasy - zbiór walidacyjny
Zbiór walidacyjny

Przetwarzanie obrazów

Obrazy ze skanera CT zawierają szum, który nie jest istotny przy klasyfikacji, a wręcz może pogarszać wynik. Jak zobaczyliśmy na przykładowych obrazach takimi niepożądanymi elementami są tło oraz stół, na którym leży badany pacjent. W przypadku klasyfikacji na podstawie zmian w płucach właściwie wszystko, co nie jest płucami, należałoby usunąć z obrazu.

Wstępne przetwarzanie obrazów w tym projekcie obejmuje znalezienie konturów ciała pacjenta, usunięcie zaszumionego tła oraz stołu z obrazu, a następnie znalezienie na obrazie płuc. Przed podaniem do modelu rozmiar zdjęć jest ujednolicany, a wartości pikseli są normalizowane do zakresu 0-1.

W celu sprawdzenia, jak wstępne przetwarzanie obrazów wpływa na wyniki, modele trenowano na: obrazach oryginalnych (512 x 512 px), obrazach z wyciętym tłem i stołem (512 x 512 px), obrazach przyciętych do pacjenta (512 x 328 px), obrazach z wysegmentowanymi płucami (512 x 512 px) oraz obrazach przyciętych do płuc (330 x 256 px).

Usunięcie szumu

Do usunięcia szumów z obrazu została wykorzystana biblioteka OpenCV.

Pierwszym krokiem jest zbinaryzowanie obrazu:

retval, img_thresh = cv2.threshold(img, thresh=150, maxval=255, type=cv2.THRESH_BINARY)
Obraz po progowaniu
img_thresh

Do zbinaryzowanego obrazu dodawana jest ramka (ułatwi znalezienie największego obiektu – pacjenta – w przypadku, gdy pacjent będzie znajdował się przy krawędzi obrazu), kolory są odwracane, a następnie są znajdywane krawędzie:

img_border = cv2.copyMakeBorder(img_thresh, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0)
img_border_a = cv2.bitwise_not(img_border)
img_edges = cv2.Canny(img_border_a, 240, 255)
Obraz ze znalezionymi krawędziami
img_edges

Aby domknąć krawędzie, robimy dylatację, po czym znajdujemy najdłuższy kontur:

kernel = np.ones((3, 3), np.uint8)
img_dilated = cv2.dilate(img_edges, kernel, 1)
contours, hierarchy = cv2.findContours(img_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
max_index, max_area = max(enumerate([cv2.arcLength(contour, True) for contour in contours]), key=lambda x: x[1])
max_contour = contours[max_index]
max_contour = max_contour - top
img_body_contour = cv2.drawContours(img.copy(), [max_contour], 0, (0, 255, 0), 2)
Obraz z zaznaczonym konturem ciała pacjenta
img_body_contour

Na podstawie znalezionego konturu, tworzymy maskę i zostawiamy na obrazie tylko pacjenta:

mask = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
mask = cv2.fillPoly(mask, [max_contour], (1.0, 1.0, 1.0))
img_masked = cv2.bitwise_and(img, img, mask=mask)
Obraz z usuniętym tłem
img_masked

Usunięcie szumu – napotkane problemy

Niestety przedstawione kroki nie są wystarczające w każdym przypadku. Zdarzają się sytuacje, w których najdłuższy kontur obejmuje również stół lub najdłuższy kontur ma płuco albo stół:

Przykładowe obrazy, przy których był problem z prawidłowym zaznaczeniem konturu

1. W przypadku, gdy najdłuższy kontur ma płuco, rozwiązaniem jest sprawdzenie, czy kontur wypukły (convex hull) jest dłuższy niż zwykły kontur:

if_contour = 1
max_contour = contours[0]
max_hull = cv2.convexHull(contours[0])
max_hull_perimeter = cv2.arcLength(max_hull, True)
    
for contour in contours:
    hull = cv2.convexHull(contour)
      
    if cv2.arcLength(contour, True) > cv2.arcLength(max_contour, True):
        max_contour = contour
        if_contour = 1
        
    elif cv2.arcLength(hull, True) > max_hull_perimeter:
        max_hull_perimeter = cv2.arcLength(hull, True)
        max_hull = hull
        max_contour = contour
        if_contour = 0

2. Z kolei pozbycie się stołu można uzyskać przez zastosowanie erozji na zbinaryzowanym obrazie z ramką:

kernel = np.ones((3, 3), np.uint8)
img_border_a = cv2.erode(img_border, kernel, 10)

Poprawione kontury wyglądają teraz odpowiednio:

Poprawione kontury obrazów, z którymi wcześniej był problem

3. Innym napotkanym problemem było to, że czasem pacjent znajdował się zbyt blisko górnej krawędzi, co powodowało, że płuca nie znajdowały się wewnątrz znalezionego konturu:

Przykład obrazu, w którym pacjent jest zbyt blisko krawędzi i w efekcie razem z tłem są usuwane również płuca

Udało się to naprawić, poprawiając maskę poprzez jej „zamknięcie” przy górnej krawędzi i użycie floodFill:

non_zero_indices = np.argwhere(mask[0,:])

if non_zero_indices.size:
    mask[0:5,int(non_zero_indices[0]):int(non_zero_indices[-1])] = 1
    mask_border = cv2.copyMakeBorder(mask, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0)

    mask_fill = np.zeros((mask_border.shape[0]+2, mask_border.shape[1]+2), dtype=np.uint8)
    cv2.floodFill(mask_border, mask_fill, (0,0), 255)
    mask_fill = cv2.bitwise_not(mask_fill)

    mask = mask_fill[top:mask_fill.shape[0]-bottom-2, left:mask_fill.shape[1]-right-2]
    mask = mask - np.min(mask)
Obraz z usuniętym tłem po poprawieniu maski

Segmentacja płuc

Kolejnym etapem przetwarzania wstępnego obrazów jest znalezienie płuc na obrazie przy użyciu modułu segmentation biblioteki scikit-image. Ponieważ mamy już przygotowane obrazy bez tła, wykorzystamy je w tym kroku.

Płuca znajdujemy, wykorzystując funkcję active_contour, która dopasowywuje początkowy kształt (snake) – tutaj okrąg – do cech obrazu. Na podstawie znalezionego konturu tworzymy maskę i zostawiamy na obrazie jedynie płuca.

def snake_contour(path, num_points=500, radius=250, alpha=0.1, beta=10, gamma=0.001, max_iterations=2500, w_line=1, w_edge=1, center=False, title=None, show_steps=False):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
 
    non_zero_h = np.argwhere(img[:,int(img.shape[0]//2)])
    center_h = (non_zero_h[0] + non_zero_h[-1]) // 2
    non_zero_w = np.argwhere(img[int(img.shape[1]//2),:])
    center_w = (non_zero_w[0] + non_zero_w[-1]) // 2
    
    img = img_as_float(img)
    
    s = np.linspace(0, 2*np.pi, num_points)
    r = center_h + radius*np.sin(s)
    c = center_w + radius*np.cos(s)
    init = np.array([r, c]).T

    snake = active_contour(gaussian(img, 3), init, alpha=alpha, beta=beta, gamma=gamma, coordinates='rc', max_iterations=max_iterations, w_line=w_line, w_edge=w_edge)
    
    snake = np.asarray(snake, dtype=np.int32)
    snake_contour = np.dstack((snake[:,1], snake[:,0]))
    
    mask = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
    mask = cv2.fillPoly(mask, [snake_contour], (1.0, 1.0, 1.0))
    img_masked = cv2.bitwise_and(img, img, mask=mask)

    return img_masked * 255
Przedstawienie działania funkcji active contour i wynikowy obraz z wysegmentowanymi płucami

Segmentacja płuc – napotkane problemy

Zdarzyły się przypadki, w których snake nie od razu poradził sobie ze znalezieniem właściwego obszaru.

1. Gdy pacjent znajdował się zbyt blisko krawędzi, należało przesunąć początkowy okrąg tak, aby jego środek znajdował się na środku pacjenta:

Przykład obrazu, dla którego active contour nie obejmuje poprawnie płuc i poprawiony obraz po przesunięciu początkowego okręgu

2. Innym problemem były ręce znajdujące się na obrazie. W takim przypadku pomogło zwiększenie parametru alpha funkcji active_contour z 0.1 na 0.3:

Przykład obrazu, dla którego active contour nie obejmuje poprawnie płuc i poprawiony obraz po zwiększeniu parametru alpha

3. W niektórych przypadkach okazało się, że początkowa liczba punktów w okręgu jest za mała, aby objąć oba płuca, dlatego została zwiększona z 400 na 500.

Przykład obrazu, dla którego active contour nie obejmuje płuc i poprawiony obraz po zwiększeniu liczby punktów w początkowym okręgu

Niestety nadal może się zdarzyć, że kontur, poza płucami, otoczy również inne struktury tak jak w powyższym przykładzie lub inne, na przykład pętlę jelita. Jest to związane z wybraną metodą przetwarzania obrazów. Aby tego uniknąć, można by było zastosować model do segmentacji płuc, który powinien precyzyjniej poradzić sobie z tym zadaniem.

Data augmentation

Dla każdej grupy obrazów wymienionych na początku sekcji Przetwarzanie obrazów zastosowano te same techniki data augmentation: losowe odbicie poziome, losowy obrót, losowe przesunięcie oraz losowe przybliżenie, jednak w różnym zakresie. Przykładowo dla obrazów oryginalnych, na których płuca znajdują się na środku i stosunkowo daleko od krawędzi można zastosować większe przesunięcie niż w przypadku obrazów przyciętych do płuc.

ObrazyHorizontal FlipRotationTranslationZoom
Oryginalne0.10.1-0.2
Z usuniętym tłem0.10.1-0.2
Z usuniętym tłem – przycięte0.070.07-0.1
Płuca z usuniętym tłem0.10.15-0.2
Płuca z usuniętym tłem – przycięte 0.050.05-0.1
Data augmentation dla poszczególnych obrazów. Liczby odpowiadają wartościom podawanym do warstw tf.keras.experimenatl.preprocessing

W zależności od modelu data augmentation uzyskano poprzez dodanie do modelu warstwy data augmentatio składającej się z warstw tf.keras.layers.experimental.preprocessing:

data_augmentation = tf.keras.Sequential([layers.experimental.preprocessing.RandomFlip('horizontal', seed=seed, input_shape=(img_height, img_width, num_channels)),
                                         layers.experimental.preprocessing.RandomRotation(rotation, seed=seed, fill_mode='constant'),
                                         layers.experimental.preprocessing.RandomZoom(zoom, seed=seed, fill_mode='constant'),
                                         layers.experimental.preprocessing.RandomTranslation(translation, translation, seed=seed, fill_mode='constant')])

model = Sequential([data_augmentation,
                    model_base])

lub używając ImageDataGenerator z Kerasa:

datagen_train = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,
                                                                horizontal_flip=True,
                                                                rotation_range=rotation,
                                                                width_shift_range=translation,
                                                                height_shift_range=translation,
                                                                fill_mode='constant',
                                                                zoom_range=zoom)

train_generator = datagen_train.flow_from_directory(PATH_TRAIN,
                                                    classes=class_names,
                                                    color_mode=color_mode,
                                                    target_size=(img_height, img_width),
                                                    seed=seed,
                                                    batch_size=batch_size,
                                                    class_mode='categorical')
6 przykładowych obrazów po data augmentation
Przykłady obrazów przyciętych do samych płuc po data augmentation

Modele

W ramach projektu sprawdzono kilka architektur. Część modeli była trenowana od zera, a w niektórych przypadkach zastosowano fine-tuning jedynie warstwy klasyfikującej. Jak wspomniano powyżej skuteczność modeli sprawdzono na pięciu rodzajach obrazów oraz z i bez data augmentation. Sprawdzono również, jak na wyniki wpływa ważenie klas (ang. class weight) – jest to jeden ze sposobów radzenia sobie z niezbalansowanym zbiorem danych.

Sprawdzone konfiguracje obejmują (w nawiasach są podane oznaczenia używane w dalszej części przy prezentacji wyników):

Modele:

  • Simple – prosty model
  • Tiny
  • Small
  • LargeW
  • LargeT
  • ResNet-50
  • EfficientNet B3 pretrenowany na ImageNet-1k
  • EfficientNet B3 trenowany od zera

Obrazy:

  • oryginalne (original)
  • z usuniętym tłem (nobackg)
  • z usuniętym tłem, przycięte do pacjenta (crop)
  • z wysegmentowanymi płucami (lungs-nocrop)
  • przycięte do wysegmentowanych płuc (lungs)
Przykłady obrazów na poszczególnych etapach preprocessingu

Inne parametry i techniki:

  • data augmentation (dataaug)
  • class weight (classw)
  • brak powyższych (baseline)

Prosty model

Dla sprawdzenia, jak bardzo poprawiają (lub pogarszają) wyniki zaawansowane architektury wytrenowany został również bardzo prosty model:

model = Sequential([normalization_layer,
                    Conv2D(16, 3, activation='relu', input_shape=(img_height, img_width, num_channels)),
                    MaxPooling2D(),

                    Conv2D(32, 3, activation='relu'),
                    MaxPooling2D(),

                    Conv2D(64, 3, activation='relu'),
                    MaxPooling2D(),

                    GlobalMaxPooling2D(),
                    Dense(32, activation='relu'),
                    Dense(num_classes)])

Poniżej znajdują się wyniki dla każdej z konfiguracji pogrupowane według obrazów, na których dany model był trenowany. Dla ułatwienia porównań przerywaną linią zaznaczony jest f1_score = 0.65.

Wykres f1 score dla wszystkich konfiguracji modelu Simple

Okazuje się, że zastosowane techniki data augmentation niekoniecznie poprawiają wyniki, a mogą je wręcz pogorszyć. Podobnie jest z ważeniem klas. Zerknijmy na krzywe ROC i Precision-Recall modelu, który osiągnął najwyższy f1 score:

Krzywe ROC i PR dla modelu Simple (konfiguracja original-classw-dataaug)

Prosty model ma dość wysokie AUC i Average Precision. Jedynie dla klasy OtherPneumonia Average Precision jest niskie; jest to klasa, dla której mamy najmniej przykładów w zbiorze, a jednocześnie może być podobna do COV. Zerknijmy jeszcze na znormalizowaną macierz pomyłek (normalizacja pomaga w interpretacji, szczególnie gdy mamy zbiór niezbalansowany):

Macierz pomyłek dla modelu Simple (konfiguracja original-classw-dataaug)

Model myli jeszcze trochę klasy COV i Normal, 31% klasy Normal zostało zaklasyfikowane jako COV. Jeśli chodzi o OtherPneumonia to widzimy, że większość przypadków zostało zaklasyfikowanych poprawnie, jednak musimy mieć na uwadze, że w zbiorze walidacyjnym jest ich jedynie 45, jak już zostało wspomniane powyżej.

Transfusion

Kolejnymi wykorzystanymi modelami są architektury Tiny, Small, LargeW, LargeT zaproponowane w artykule Transfusion: Understanding Transfer Learning for Medical Imaging. Modele były trenowane od zera.

Wykres f1 score dla wszystkich konfiguracji modelu Tiny
Wykres f1 score dla wszystkich konfiguracji modelu Small
Wykres f1 score dla wszystkich konfiguracji modelu LargeW
Wykres f1 score dla wszystkich konfiguracji modelu LargeT

Podobnie jak w przypadku modelu Simple okazuje się, że zastosowanie data augmentation może pogarszać wyniki. W takim wypadku możliwe jest, że wybrane metody były nieodpowiednie dla danych obrazów lub użyte parametry były zbyt duże. Nieco lepszy efekt w tej grupie modeli miało ważenie klas – czasem poprawiało wynik, chociaż czasem różnica między baseline a classw była nieznaczna.

Krzywe ROC i PR dla modelu Tiny (konfiguracja original-baseline)
Krzywe ROC i PR dla modelu Small (konfiguracja original-baseline)
Krzywe ROC i PR dla modelu LargeW (konfiguracja original-classw)
Krzywe ROC i PR dla modelu LargeT (konfiguracja original-baseline)

Najlepsze metryki z powyższych ma model LargeT, widać to również na poniższych macierzach pomyłek:

Macierz pomyłek dla modelu Tiny (konfiguracja original)
Macierz pomyłek dla modelu Small (konfiguracja original)
Macierz pomyłek dla modelu LargeW (konfiguracja original-classw)
Macierz pomyłek dla modelu LargeT (konfiguracja original)

Google Big Transfer

Architektura zaczerpnięta z Google Big Transfer (A. Kolesnikov, L. Beyer, X. Zhai, J. Puigcerver, J. Yung, S. Gelly and N. Houlsby: Big Transfer (BiT): General Visual Representation Learning) to ResNet-50 (BiT-S R50x1). Model jest pretrenowany na ImageNet-1k i w tym projekcie dotrenowana została jedynie warstwa klasyfikująca.

Wykres f1 score dla wszystkich konfiguracji modelu ResNet-50

ResNet-50 dla obrazów przyciętych, na których praktycznie nie było tła (crop oraz lungs) osiąga równomierne wyniki. Z kolei w przypadku pozostałych obrazów ponownie okazuje się, że data augmentation pogarsza wynik.

Krzywe ROC i PR dla modelu ResNet-50 (konfiguracja original-baseline)
Macierz pomyłek dla modelu ResNet-50 (konfiguracja original)

Podobnie jak w niektórych powyższych przypadkach część przypadków klasy Normal jest klasyfikowana jako COV, a COV jako Normal. OtherPneumonia jest klasyfikowane poprawnie w 96%.

EfficientNet B3

W przypadku EfficientNet B3 zastosowano dwa podejścia: trening całego modelu od początku oraz transfer learning. W drugim podejściu wykorzystano model pretrenowany na ImageNet dostępny w Keras Applications i dostrojono jedynie warstwę klasyfikującą.

Wykres f1 score dla wszystkich konfiguracji modelu EfficientNet B3 pretrenowanego na ImageNet
Wykres f1 score dla wszystkich konfiguracji modelu EfficientNet B3 trenowanego od zera

Widać pewne różnice pomiędzy wariantami EfficientNet B3. Są także pewne podobieństwa – w wielu przypadkach data augmentation wpłynęło niekorzystnie na wyniki. Również ważenie klas nie poprawiało znacząco wyników.

Uwzględniając konfiguracje bez data augmentation (które okazuje się pogarszać wyniki) można stwierdzić, że lepszym podejściem w przypadku tych danych jest transfer learning. Wpływa na to między innymi liczba obrazów – jest ich mało, co utrudnia skuteczne trenowanie dużych modeli.

Krzywe ROC i PR dla modelu EfficientNet B3 pretrained on ImageNet (konfiguracja lungs-baseline)
Krzywe ROC i PR dla modelu EfficientNet B3 trenowanego od zera (konfiguracja lungs-baseline)
Macierz pomyłek dla modelu EfficientNet B3 pretrained on ImageNet (konfiguracja lungs)
Macierz pomyłek dla modelu EfficientNet B3 trenowanego od zera (konfiguracja lungs-nocrop)

Porównanie modeli EfficientNet

Oprócz EfficientNet B3 na zdjęciach oryginalnych wytrenowane zostały jeszcze warianty B0 i B7 jako dodatkowy test sprawdzający, czy bardziej rozbudowany model poradzi sobie lepiej. Poniżej na wykresie znajdują się wyniki dla poszczególnych architektur.

Wykres f1 score dla modeli EfficientNet B0, B3 i B7 pretrenowanych oraz trenowanych od zera na zdjęciach oryginalnych

Widzimy, że – podobnie jak B3 – B0 oraz B7 trenowane od zera radzą sobie nieco gorzej niż wersje pretrenowane. W przypadku modeli pretrenowanych B7 osiągnął trochę lepszy wynik od pozostałych, z kolei w przypadku modeli trenowanych od zera wynik B7 różni się nieznacznie od B0. Wstępnie można stwierdzić, że faktycznie B7 daje dokładniejsze wyniki, przynajmniej dla transfer learningu, jednak do wysnucia pewniejszych wniosków należałoby przeprowadzić więcej eksperymentów.

Podsumowanie walidacji wszystkich modeli

Wszystkie modele trenowały się stabilnie, natomiast walidacja podczas treningu była dość niestabilna. Dla większych modeli (ResNet-50, EfficientNet B3) wahania metryk i lossu były znacznie mniejsze niż u innych modeli, natomiast największe wahania zaobserwowano u modeli Transfusion.

W przypadku wszystkich modeli data augmentation negatywnie wpływało na wyniki (oprócz ResNet-50 trenowanego na obrazach crop oraz lungs). Z kolei ważenie klas w niektórych przypadkach poprawiło wynik, w niektórych pogorszyło, a w jeszcze innych nie zmieniło znacząco.

Poniżej w tabeli znajduje się zestawienie najlepszych modeli z każdej grupy wraz z konfiguracją i metrykami. Najwyższe metryki osiągnął model LargeT na obrazach oryginalnych bez ważenia klas i data augmentation.

ModelKonfiguracjaF1 ScoreAUC
Simpleoriginal-classw-dataaug0.8100.925
Transfusion Tinyoriginal-baseline0.8610.954
Transfusion Smalloriginal-baseline0.8600.944
Transfusion LargeWoriginal-classw0.8530.944
Transfusion LargeToriginal-baseline0.8880.958
Google BiT ResNet-50original-baseline0.7660.893
EfficientNet B3 pretrained on ImageNetlungs-baseline0.7680.896
EfficientNet B3lungs-nocrop-baseline0.7610.866

Gradient-weighted Class Activation Mapping

W medycynie istotne jest, aby dane rozwiązanie było interpretowalne, na przykład dodatkową odpowiedzią modelu może być heatmapa pokazująca, na jakie obszary obrazu model zwraca uwagę przy podejmowaniu decyzji. Może się okazać, że są to inne fragmenty niż te, na podstawie których decyzje podejmuje lekarz.

Jedną z metod interpretacji wyników jest Gradient-weighted Class Activation Mapping (Grad-CAM). Jest to generalizacja metody Class Activation Maps (CAM) i w porównaniu do niej nie wymaga konkretnej architektury modelu. Grad-CAM, podobnie jak CAM, wykorzystuje informacje z ostatniej warstwy konwolucyjnej (feature maps). Jednak wagi do obliczenia wynikowej heatmapy nie są brane z ostatniej warstwy fully connected, ale są liczone na podstawie gradientów.

Zobaczmy wyniki GradCAM dla najlepszego modelu według metryk – LargeT – dla poszczególnych grup obrazów (przykłady klasy COV):

GradCAM dla modelu LargeT - obraz oryginalny, heatmapa oraz heatmapa nałożona na obraz
GradCAM dla modelu LargeT - obraz oryginalny bez tła, heatmapa oraz heatmapa nałożona na obraz
GradCAM dla modelu LargeT - obraz bez tła przycięty do pacjenta, heatmapa oraz heatmapa nałożona na obraz
GradCAM dla modelu LargeT - obraz z wysegmentowanymi płucami, heatmapa oraz heatmapa nałożona na obraz
GradCAM dla modelu LargeT - obraz z wysegmentowanymi płucami i przycięty, heatmapa oraz heatmapa nałożona na obraz

W pierwszym przykładzie widać, że model skupia się na płucach. Jednak poza tym zaznacza wyraźniej również tchawicę. Widać także, że istotne są dla modelu obszary wokół stołu, co potwierdza założenie, że stół należy usunąć z obrazu, gdyż nie wpływa on na diagnozę, a może wprowadzać model w błąd.

Dla jednego obrazu możemy wygenerować heatmapy odpowiadające każdej z klas. Poniżej znajdują się przykłady obrazu klasy COV oraz heatmapy uzyskane względem COV, Normal, OtherPneumonia (kolejno od góry):

GradCAM dla modelu LargeT - obraz oryginalny COV, heatmapa dla COV oraz heatmapa nałożona na obraz
GradCAM dla modelu LargeT - obraz oryginalny COV, heatmapa dla Normal oraz heatmapa nałożona na obraz
GradCAM dla modelu LargeT - obraz oryginalny COV, heatmapa dla OtherPneumonia oraz heatmapa nałożona na obraz

Widać, że reprezentacje z ostatniej warstwy konwolucyjnej są różne dla poszczególnych klas, przy czym w każdym przypadku jest zaznaczany niepożądany obszar wokół stołu.

Zerknijmy jeszcze na heatmapy dla poszczególnych klas wygenerowane na podstawie samych płuc, tym razem dla modelu EfficientNet B3 – ImageNet (lungs):

GradCAM dla modelu EfficientNet B3 - wagi ImageNet - obraz COV z wysegmentowanymi płucami i przycięty, heatmapa oraz heatmapa nałożona na obraz
COV
GradCAM dla modelu EfficientNet B3 - wagi ImageNet - obraz Normal z wysegmentowanymi płucami i przycięty, heatmapa oraz heatmapa nałożona na obraz
Normal
GradCAM dla modelu EfficientNet B3 - wagi ImageNet - obraz OtherPneumonia z wysegmentowanymi płucami i przycięty, heatmapa oraz heatmapa nałożona na obraz
OtherPneumonia

Widać, że dla modelu istotne są płuca, które na heatmapie można by nawet rozróżnić jako dwa. Jednak są one traktowanego raczej całościowo i model nie wyszukuje konkretnych zmian patologicznych.

Dla porównania spójrzmy jeszcze na heatmapy dla EfficientNet B3 trenowanego od zera:

GradCAM dla modelu EfficientNet B3 - wagi None - obraz COV z wysegmentowanymi płucami i przycięty, heatmapa oraz heatmapa nałożona na obraz
COV
GradCAM dla modelu EfficientNet B3 - wagi None - obraz Normal z wysegmentowanymi płucami i przycięty, heatmapa oraz heatmapa nałożona na obraz
Normal
GradCAM dla modelu EfficientNet B3 - wagi None - obraz OtherPneumonia z wysegmentowanymi płucami i przycięty, heatmapa oraz heatmapa nałożona na obraz
OtherPneumonia

Widać, że model trenowany od zera wyłapuje inne cechy obrazu niż dostrojony model z wagami ImageNet, przy czym część istotnych dla modelu fragmentów obrazu jest nieistotna z medycznego punktu widzenia (na przykład obszary poza płucami). Może to wynikać ze zbyt małego zbioru danych – model nie jest w stanie nauczyć się ważnych cech na takim małym zbiorze treningowym.

Zbiór testowy

Ostatnim krokiem w projekcie jest przetestowanie modeli na zbiorze testowym. Okazuje się, że zbiór ten ma podobne proporcje jak zbiór treningowy i walidacyjny, tj. klasy COV jest najwięcej, a OtherPneumonia najmniej.

Liczebność zbioru testowego z podziałem na klasy:
COV: 1130
Normal: 579
OtherPneumonia: 168

Sprawdzone zostały modele, które osiągnęły najwyższe wyniki na zbiorze walidacyjnym.

ModelKonfiguracjaF1 ScoreAUC
Simpleoriginal-classw-dataaug0.5820.810
Transfusion Tinyoriginal-baseline0.5800.811
Transfusion Smalloriginal-baseline0.6000.798
Transfusion LargeWoriginal-classw0.6340.792
Transfusion LargeToriginal-baseline0.5770.796
Google BiT ResNet-50original-baseline0.7110.886
EfficientNet B3 pretrained on ImageNetlungs-baseline0.6290.845
EfficientNet B3lungs-nocrop-baseline0.5700.760

Wszystkie modele osiągnęły na zbiorze testowym niższe niż na zbiorze walidacyjnym wartości f1 score oraz AUC. Najwyższe wyniki na zbiorze testowym osiągnął ResNet-50. Wyniki te nie różnią się znacznie od metryk tego modelu dla zbioru walidacyjnego, a więc ResNet-50 okazał się modelem o najlepszej generalizacji – najlepiej rozpoznaje nowe obrazy, spoza zbioru treningowego, a także walidacyjnego (do którego niejako dostrajaliśmy skuteczność modeli w trakcie ich rozwijania i poprawy wyników).

W tym miejscu warto by było przeprowadzić analizę gradCAM dla ResNet-50 i sprawdzić, dlaczego osiągnął lepsze wyniki niż pozostałe modele oraz czy podejmuje decyzje na podstawie właściwych cech. Niestety otrzymanie wyjścia z ostatniej warstwy konwolucyjnej okazało się problematyczne ze względu na format, w jakim są zapisane modele Google Big Transfer – TensorFlow SavedModel – a także ze względu na opakowanie zapisanego modelu w KerasLayer podobnie jak w przykładowym notebooku. Z tego powodu pomijamy tymczasem analizę gradCAM dla ResNet-50. Jeśli ktoś z czytelników ma pomysł na rozwiązanie tego problemu – sugestie mile widziane.

Podsumowanie

Analiza zbioru CT-COV19 okazała się niełatwym zadaniem, między innymi ze względu na niepodanie przez autora zbioru dokładnej charakterystyki obrazów. Brak informacji o zbiorze oraz/lub brak wiedzy domenowej może powodować, że podejmiemy błędne założenia w trakcie projektu, co doprowadzi do wadliwego rozwiązania. Dobrze jest konsultować założenia oraz wyniki z radiologami i innymi ekspertami. W tym kontekście warto zapoznać się z już wspomnianą wyżej checklistą, a także tym artykułem przeglądowym (autorzy artykułu konkludują, że żadne z rozwiązań machine learning powstałych w minionym roku do wykrywania i prognozowania COVID-19 na podstawie obrazów CT oraz RTG niestety nie nadaje się do użytku klinicznego). Oba opracowania listują błędy i nieścisłości rozwiązań uczenia maszynowego dla wykrywania i przewidywania COVID-19 oraz zawierają wytyczne, jak tych błędów unikać.

Preprocessing obrazów i ostatecznie segmentacja płuc nie przyniosły oczekiwanej poprawy wyników. Oprócz modelu EfficientNet B3 pozostałe modele podczas walidacji osiągnęły najwyższe wyniki na obrazach oryginalnych. Ten rezultat może być przypadkowy w tym kontekście, że walidacja podczas treningu była niestabilna, a checkpoint danego modelu był zapisywany na podstawie najwyższego f1 score walidacji po każdej epoce. W tej sytuacji na pewno mogłaby pomóc większa liczba obrazów w zbiorze zarówno treningowym, jak i walidacyjnym. Zwłaszcza, że analiza GradCAM pokazuje, że modele mogą podejmować decyzje na podstawie cech nieistotnych z medycznego punktu widzenia. Prawdopodobnie zbiór jest po prostu za mały, aby nauczyły się istotnych cech.

Ponieważ zbiór jest dość mały jak na wymagania sieci neuronowych, całkiem dobrze sprawdził się transfer learning (co widać m.in. na heatmapach gradCAM – model pretrenowany daje sensowniejsze wyniki w porównaniu do niepretrenowanego). Ogólnodostępne gotowe do użycia modele są najczęściej wytrenowane na zbiorze ImageNet. Pewnym problemem jest to, że zbiór ten zawiera jedynie obrazy naturalne (czyli na przykład zwierzęta, przedmioty codziennego użytku), które znacząco różnią się od obrazów medycznych. Najkorzystniejszym rozwiązaniem byłoby wykorzystanie modelu wytrenowanego na dużym zbiorze obrazów medycznych pochodzących z konkretnej modalności i dostrojenie go na posiadanym małym zbiorze. Niestety takie modele nie są powszechne dostępne (podobnie jak duże zbiory obrazów medycznych) i jeśli posiadamy mały zbiór nadal zasadne może być wykorzystanie modelu pretrenowanego na ImageNet (ciekawym podejściem do zoptymalizowania transferu wiedzy między obrazami naturalnymi a medycznymi jest pokolorowanie obrazów medycznych).

W jakim kierunku można dalej rozwijać projekt oraz jak poprawić wyniki? Przede wszystkim dobrze by było zdobyć większą ilość danych. Jak już było wspomniane obrazów jest dość mało, aby modele mogły nauczyć się poprawnie istotnych cech. Można by również wytrenować model na wszystkich obrazach naraz (oryginalnych, przyciętych, z usuniętym tłem). Kolejnym krokiem mogłoby być wypróbowanie innych technik data augmentation, np. zmiana kontrastu, CLAHE, czy zastosowanie filtru Frangiego.

Close