Datawhale第13期組隊學習筆記Task5

2021-10-06 18:39:44 字數 1525 閱讀 8600

tta及baseline改進思考

測試時增強(test time augmentation, tta),可將準確率提高若干個百分點,。這裡會為原始影象造出多個不同版本,包括不同區域裁剪和更改縮放程度等,並將它們輸入到模型中;然後對多個版本進行計算得到平均輸出,作為影象的最終輸出分數。

有作弊的嫌疑。

這種技術很有效,因為原始影象顯示的區域可能會缺少一些重要特徵,在模型中輸入影象的多個版本並取平均值,能解決上述問題。

具體**如下:

def

predict

(test_loader, model, tta=10)

: model.

eval()

test_pred_tta =

none

# tta 次數

for _ in

range

(tta)

: test_pred =

with torch.no_grad():

for i,

(input

, target)

inenumerate

(test_loader)

: c0, c1, c2, c3, c4, c5 = model(data[0]

) output = np.concatenate(

[c0.data.numpy(

), c1.data.numpy(),

c2.data.numpy(

), c3.data.numpy(),

c4.data.numpy(

), c5.data.numpy()]

, axis=1)

test_pred = np.vstack(test_pred)

if test_pred_tta is

none

: test_pred_tta = test_pred

else

: test_pred_tta += test_pred

return test_pred_tta

即在**時,多次**,取乙個最大值進行綜合,最後輸出結果。

而baseline最後提出的乙個模型的改進方法,則是10折交叉驗證

下面假設構建了10折交叉驗證,訓練得到10個cnn模型。

那麼在10個cnn模型可以使用如下方式進行整合:

十折交叉驗證

(1)英文名叫做10-fold cross-validation,用來測試演算法準確性,是常用的測試方法。

(2)將資料集分成十份,輪流將其中9份作為訓練資料,1份作為測試資料,進行試驗。每次試驗都會得出相應的正確率(或差錯率)。

(3)10次的結果的正確率(或差錯率)的平均值作為對演算法精度的估計,一般還需要進行多次10折交叉驗證(例如10次10折交叉驗證),再求其均值,作為對演算法準確性的估計。

這部分**稍後實現了再新增上來。

Datawhale第13期組隊學習筆記Task2

task2主要內容為資料讀取 資料擴增方法和pytorch讀取賽題資料三個部分組成。用pil裡面的image方法讀取比較簡單,直接 im image.open 名.jpg 即可。而其中包含了一些別的讀取方式,如應用模糊濾鏡 im image.open test1.png im.filter imag...

Datawhale組隊學習Pandas

下面直接展示內聯 片。備註內容為學習後的感想與總結 author xuxt time 2020 12 14l def my func x return 2 x for i in range 5 l.my func i print l 定義 我的函式 輸入x,返回,2x,即輸入1,2,3,4,5可以得...

DataWhale 21期資料分析組隊學習

總結day1 今天是參加datawhale 21期資料分析組隊學習的第一天,在參加這個組隊學習之前我心裡是很猶豫的,因為我的python基礎不是很好,雖然我自學了python基礎,但是我還沒有真正用到案例上的經歷,並且很多東西因為用的少,學習的時間長了有點淡忘了,如果我以後要從事資料分析類的工作的話...