keras原始碼分析之Dense

2021-09-26 03:42:59 字數 3706 閱讀 8554

本文主要講解一下dense層的原始碼,dense層即最常用的全連線層,**很簡單,主要是重寫了buildcall方法,在我們自定義layer時,也可以參考該層的實現。

class

dense

(layer)

:def

__init__

(self, units,

activation=

none

, use_bias=

true

, kernel_initializer=

'glorot_uniform'

, bias_initializer=

'zeros'

, kernel_regularizer=

none

, bias_regularizer=

none

, activity_regularizer=

none

, kernel_constraint=

none

, bias_constraint=

none

,**kwargs):if

'input_shape'

notin kwargs and

'input_dim'

in kwargs:

kwargs[

'input_shape']=

(kwargs.pop(

'input_dim'),

)super

(dense, self)

.__init__(

**kwargs)

self.units = units

self.activation = activations.get(activation)

self.use_bias = use_bias

self.kernel_initializer = initializers.get(kernel_initializer)

self.bias_initializer = initializers.get(bias_initializer)

self.kernel_regularizer = regularizers.get(kernel_regularizer)

self.bias_regularizer = regularizers.get(bias_regularizer)

self.activity_regularizer = regularizers.get(activity_regularizer)

self.kernel_constraint = constraints.get(kernel_constraint)

self.bias_constraint = constraints.get(bias_constraint)

self.input_spec = inputspec(min_ndim=2)

self.supports_masking =

true

構造方法沒什麼好說的,就是一些簡單的賦值

def

build

(self, input_shape)

:assert

len(input_shape)

>=

2 input_dim = input_shape[-1

] self.kernel = self.add_weight(shape=

(input_dim, self.units)

, initializer=self.kernel_initializer,

name=

'kernel'

, regularizer=self.kernel_regularizer,

constraint=self.kernel_constraint)

if self.use_bias:

self.bias = self.add_weight(shape=

(self.units,),

initializer=self.bias_initializer,

name=

'bias'

, regularizer=self.bias_regularizer,

constraint=self.bias_constraint)

else

: self.bias =

none

self.input_spec = inputspec(min_ndim=

2, axes=

) self.built =

true

build方法中定義了兩個variable即權重,最後把built引數置為true

def

call

(self, inputs)

: output = k.dot(inputs, self.kernel)

if self.use_bias:

output = k.bias_add(output, self.bias, data_format=

'channels_last'

)if self.activation is

notnone

: output = self.activation(output)

return output

call方法把輸入值與build方法中定義的權重進行了點積的操作,然後與build中的偏移量進行相加,最後經過啟用函式返回最終的輸出結果。

def

compute_output_shape

(self, input_shape)

:assert input_shape and

len(input_shape)

>=

2assert input_shape[-1

] output_shape =

list

(input_shape)

output_shape[-1

]= self.units

return

tuple

(output_shape)

計算出輸出tensor的維度並返回

def

get_config

(self)

: config =

base_config =

super

(dense, self)

.get_config(

)return

dict

(list

(base_config.items())

+list

(config.items())

)

保留一些中間值並以字典的形式返回

keras原始碼之application目錄

關於 pycache 目錄請參考keras原始碼之 pycache 目錄 裡面包含了各種網路的實現模組 包括權重路徑以及設定 網路結構等 接下來是一些已經棄用的函式模組 backend backend,layers layers,models models,utils utils def set k...

原始碼分析之LayoutInflater

簡介 inflate填充的過程 viewstub,merge,include的載入過程 layoutinflater系統服務的註冊過程 systemserviceregistry類有個靜態 塊,完成了常用服務的註冊,如下 static 註冊am registerservice context.act...

原始碼分析之HashMap

首先hashmap繼承了abstractmap,並且實現了map cloneable和serializable三個介面。cloneable和serializable是比較常規的兩個介面,在這裡並不作為重點。重點將會放在abstractmap和map兩個規範上。其中abstractmap是乙個抽象類,...