Machine Learning을 쉽게 사용하도록 구성되어 있는 클래스에 대한 튜토리얼
iris 데이터셋을 이용하여 학습 3층짜리 DNN 모델을 학습한다.
r0.11 튜토리얼에서는 load_csv 함수를 사용하는 코드가 기록 되어 있는데
2016년 09월 15일 자로 해당 함수가 deprecated 되어서 다른 함수를 사용하여야 한다.
아래 정의에서 볼 수 있는 load_csv_with_header 함수가 그것이다.
-def load_csv(filename, target_dtype, target_column=-1, has_header=True):
- """Load dataset from CSV file."""
- if has_header:
- return load_csv_with_header(filename=filename,
- target_dtype=target_dtype,
- features_dtype=np.float64,
- target_column=target_column)
- else:
- return load_csv_without_header(filename=filename,
- target_dtype=target_dtype,
- features_dtype=np.float64,
- target_column=target_column)
수정된 코드를 반영한 예제 코드
#-*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
# 데이터셋
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"
# 데이터셋을 불러옵니다.
# load_csv 함수는 사용되지 않음
# load_csv_with_header 함수를 써야하며 세번째 파라미터 필요
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TRAINING, \
target_dtype=np.int, \
features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TEST, \
target_dtype=np.int, \
features_dtype=np.float32)
# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
# Build 3 layer DNN with 10, 20, 10 units respectively.
# feature_column 과 hidden_unit은 필수 요소.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="/tmp/iris_model")
# Fit model.
classifier.fit(x=training_set.data,
y=training_set.target,
steps=2000)
# Evaluate accuracy.
accuracy_score = classifier.evaluate(x=test_set.data,
y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))
# Classify two new flower samples.
new_samples = np.array(
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = classifier.predict(new_samples)
print('Predictions: {}'.format(str(y)))
'전공관련 > Deep Learning' 카테고리의 다른 글
[TensorFlow] Tutorial 5. Large-scale Linear Models with TensorFlow (0) | 2016.11.09 |
---|---|
[Caffe] windows 환경에서 caffe를 설치하자 (161102 기준) (11) | 2016.11.03 |
[TensorFlow] Tutorial 3. TensorFlow Mechanics 101 (0) | 2016.11.02 |
[TensorFlow] Tutorial 2. Deep MNIST for Experts (0) | 2016.11.02 |
[TensorFlow] Tutorial 1. MNIST For ML Beginners (0) | 2016.10.20 |