我是靠谱客的博主 文静小蝴蝶,这篇文章主要介绍keras 实现多任务学习,现在分享给大家,希望可以做个参考。

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def deep_multi_model(feature_dim, cvr_label_dim, profit_label_dim): inputs = Input(shape=(feature_dim,)) dense_1 = Dense(512, activation='relu')(inputs) dense_2 = Dense(384, activation='relu')(dense_1) dense_3 = Dense(256, activation='relu')(dense_2) drop_1 = Dropout(0.2)(dense_3) dense_4 = Dense(128, activation='relu')(drop_1) dense_5 = Dense(64, activation='relu')(dense_4) output_1 = Dense(32, activation='relu')(dense_5) output_cvr = Dense(cvr_label_dim, activation='softmax', name='output_cvr')(output_1) output_2 = Dense(16, activation='relu')(dense_5) output_profit = Dense(profit_label_dim, activation='softmax', name='output_profit')(output_2) # 模型有两个输出 output_cvr, output_profit model = Model(inputs=inputs, outputs=[output_cvr, output_profit]) model.summary() # 模型有两个 loss, 都是 categorical_crossentropy # loss 的 key 需要和模型的 output 层的 name 保持一致 model.compile(optimizer='adam', loss={'output_cvr': 'categorical_crossentropy', 'output_profit': 'categorical_crossentropy'}, loss_weights={'output_cvr':1, 'output_profit': 0.3}, metrics=[categorical_accuracy]) return model # 产生训练数据的生成器 # 模型只有一个 input 有两个 output,所以 yield 格式为如下 def generate_arrays(X_train, y_train_cvr_label, y_train_profit_label): while True: for x, y_cvr, y_profit in zip(X_train, y_train_cvr_label, y_train_profit_label): yield (x[np.newaxis, :], {'output_cvr': y_cvr[np.newaxis, :], 'output_profit': y_profit[np.newaxis, :]}) # fit_generator 进行 fit 训练 def train_multi(X_train, y_train_cvr_label, y_train_profit_label, X_test, y_test_cvr_label, y_test_profit_label): feature_dim = X_train.shape[1] cvr_label_dim = y_train_cvr_label.shape[1] profit_label_dim = y_train_profit_label.shape[1] model = deep_multi_model(feature_dim, cvr_label_dim, profit_label_dim) model.summary() early_stopping = EarlyStopping(monitor='val_loss', patience=15, verbose=0) model.fit_generator(generate_arrays(X_train, y_train_cvr_label, y_train_profit_label), steps_per_epoch=1024, epochs=100, validation_data=generate_arrays(X_test, y_test_cvr_label, y_test_profit_label), validation_steps=1024, callbacks=[early_stopping]) return model

最后

以上就是文静小蝴蝶最近收集整理的关于keras 实现多任务学习的全部内容,更多相关keras内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(87)

评论列表共有 0 条评论

立即
投稿
返回
顶部