初识 Fashion MNIST

前言

从本章开始,将记录对官方文档中基本分类篇的剖析和理解。

基本分类文档主要对计算机视觉的基础进行演示和简要讲解,其实质就是第一章的新手代码。就是对 70000 张服装图像数据集进行训练和测试的过程

导入库


# 引入未来版本新特性,作用是在使用 Python2 时能够兼容 Python3 的语法
from __future__ import absolute_import, division, print_function, unicode_literals

# 导入 TensorFlow 和 tf.keras
import tensorflow as tf
from tensorflow import keras

# 导入辅助库
import numpy as np
import matplotlib.pyplot as plt

# 输出 TensorFlow 的版本号
print(tf.__version__)

注意:刚复制这段代码到 IDE 中时,可能提示 matplotlib 包未安装。所以需要安装一下 matplotlib 包:

pip install matplotlib

关于 __future__ 官方地址 --> 传送门

train_images、train_labels、test_images、test_labels 简介

  • train_images:用来训练的 60000 张图片

    图片

  • train_labels:用来训练的 60000 个图片分类(0-9),共 10 种

    图片

  • test_images:用来测试的 10000 张图片

  • test_labels:用来测试的 10000 个图片分类

训练数据和测试数据的获取

# 从 tf.keras 中获取 fashion_mnist 对象
fashion_mnist = keras.datasets.fashion_mnist

# 调用 fashion_mnist 对象的 load_data() 方法,获取训练数据和测试数据
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

下面我们来详细查看 load_data() 方法的源码

# Python2 兼容 python3 的语法
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# 导入 gzip 和 os 标准库
import gzip
import os

# 导入 numpy
import numpy as np

# 导入工具类
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import keras_export

@keras_export('keras.datasets.fashion_mnist.load_data')
def load_data():
  # 根据系统生成合理相对路径 datasets\\fashion-mnist
  dirname = os.path.join('datasets', 'fashion-mnist')

  # 数据包网络位置
  base = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
  # 数据包名称
  files = [
      'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
      't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
  ]

  # 下载网络数据到本地,并记录数据在本地的绝对地址
  paths = []
  for fname in files:

    # 重点就是这个 get_file 方法,实现了网络路径到本地路径的转换(俗称下载)
    paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))

  with gzip.open(paths[0], 'rb') as lbpath:
    y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

  with gzip.open(paths[1], 'rb') as imgpath:
    x_train = np.frombuffer(
        imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)

  with gzip.open(paths[2], 'rb') as lbpath:
    y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)

  with gzip.open(paths[3], 'rb') as imgpath:
    x_test = np.frombuffer(
        imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)

  return (x_train, y_train), (x_test, y_test)

通过上面,我们知道了从网络数据到本地数据的核心方法:get_file

## 传进来的参数如下:
# fname: 文件名称
# origin: 远程文件地址
# cache_subdir:本地相对地址

@keras_export('keras.utils.get_file')
def get_file(fname,
             origin,
             untar=False,
             md5_hash=None,
             file_hash=None,
             cache_subdir='datasets',
             hash_algorithm='auto',
             extract=False,
             archive_format='auto',
             cache_dir=None):
  # cache_dir 如何是 None 则,生成 家目录 + .keras。如 window 就是 C:\Users\Administrator\.keras
  if cache_dir is None:
    cache_dir = os.path.join(os.path.expanduser('~'), '.keras')

  # 如果传入 md5_hash 则优先使用
  if md5_hash is not None and file_hash is None:
    file_hash = md5_hash
    hash_algorithm = 'md5'
  datadir_base = os.path.expanduser(cache_dir)

  # 如果家目录没有写权限,那就到在根目录下建立 tmp 目录
  if not os.access(datadir_base, os.W_OK):
    datadir_base = os.path.join('/tmp', '.keras')

  # 拼接 datadir_base 和 cache_subdir 生成完整本地绝对路径
  datadir = os.path.join(datadir_base, cache_subdir)

  # 检测本地绝对路径目录存在吗,不存在就创建
  if not os.path.exists(datadir):
    os.makedirs(datadir)

  # 文件名是否启用扩展优化
  if untar:
    untar_fpath = os.path.join(datadir, fname)
    fpath = untar_fpath + '.tar.gz'
  else:
    fpath = os.path.join(datadir, fname)

  download = False

  # 检测本地现存数据的 hash 与远程 hash 是否一致,不一致则重新下载
  if os.path.exists(fpath):
    # File found; verify integrity if a hash was provided.
    if file_hash is not None:
      if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
        print('A local file was found, but it seems to be '
              'incomplete or outdated because the ' + hash_algorithm +
              ' file hash does not match the original value of ' + file_hash +
              ' so we will re-download the data.')
        download = True
  else:
    download = True

  # 启用下载,并声明进度条
  if download:
    print('Downloading data from', origin)

    class ProgressTracker(object):
      # Maintain progbar for the lifetime of download.
      # This design was chosen for Python 2.7 compatibility.
      progbar = None

    def dl_progress(count, block_size, total_size):
      if ProgressTracker.progbar is None:
        if total_size == -1:
          total_size = None
        ProgressTracker.progbar = Progbar(total_size)
      else:
        ProgressTracker.progbar.update(count * block_size)

    error_msg = 'URL fetch failure on {}: {} -- {}'
    try:
      try:
        # 进行文件下载,主要运用 Python 的 urllib.request 库,进行下载和本地写入
        urlretrieve(origin, fpath, dl_progress)
      except HTTPError as e:
        raise Exception(error_msg.format(origin, e.code, e.msg))
      except URLError as e:
        raise Exception(error_msg.format(origin, e.errno, e.reason))
    except (Exception, KeyboardInterrupt) as e:
      if os.path.exists(fpath):
        os.remove(fpath)
      raise
    ProgressTracker.progbar = None

  if untar:
    if not os.path.exists(untar_fpath):
      _extract_archive(fpath, datadir, archive_format='tar')
    return untar_fpath

  if extract:
    _extract_archive(fpath, datadir, archive_format)

  # 最后返回本地数据的绝对路径
  return fpath

关于 urlretrieve 方法,我们看一下

if sys.version_info[0] == 2:

  def urlretrieve(url, filename, reporthook=None, data=None):
    def chunk_read(response, chunk_size=8192, reporthook=None):
      content_type = response.info().get('Content-Length')
      total_size = -1
      if content_type is not None:
        total_size = int(content_type.strip())
      count = 0
      while True:
        chunk = response.read(chunk_size)
        count += 1
        if reporthook is not None:
          reporthook(count, chunk_size, total_size)
        if chunk:
          yield chunk
        else:
          break

    response = urlopen(url, data)
    with open(filename, 'wb') as fd:
      for chunk in chunk_read(response, reporthook=reporthook):
        fd.write(chunk)
else:
  from six.moves.urllib.request import urlretrieve

注意,大家都看到了第一行 if sys.version_info[0] == 2:。这是什么意思呢,简单讲:就是判断 python 版本是不是 2.X ,如果是,直接定义 urlretrieve 函数。否则从 request 库中导入 urlretrieve。

关于 sys.version_info --> 传送门

由于我的 python 是 3.7 ,所以 urlretrieve 方法实际在 request 库中:

def urlretrieve(url, filename=None, reporthook=None, data=None):
    # 从 URL 中分离协议和路由:'https://www.test.com/test' --> ('https', '//www.test.com/test') 
    url_type, path = splittype(url)

    # urlopen 从 URL 获取远程资源
    with contextlib.closing(urlopen(url, data)) as fp:
        headers = fp.info()

        if url_type == "file" and not filename:
            return os.path.normpath(path), headers

        if filename:
            tfp = open(filename, 'wb')
        else:
            tfp = tempfile.NamedTemporaryFile(delete=False)
            filename = tfp.name
            _url_tempfiles.append(filename)

        with tfp:
            result = filename, headers
            bs = 1024*8
            size = -1
            read = 0
            blocknum = 0
            if "content-length" in headers:
                size = int(headers["Content-Length"])

            if reporthook:
                reporthook(blocknum, bs, size)

            while True:
                block = fp.read(bs)
                if not block:
                    break
                read += len(block)
                tfp.write(block)
                blocknum += 1
                if reporthook:
                    reporthook(blocknum, bs, size)

    if size >= 0 and read < size:
        raise ContentTooShortError(
            "retrieval incomplete: got only %i out of %i bytes"
            % (read, size), result)

    return result

关于 urlopen 方法 --> 传送门

结果

  • 命令行运行

    Python

  • 下载到本地文件的位置

    Python

本文章首发在 LearnKu.com 网站上。
上一篇 下一篇
讨论数量: 0
发起讨论 只看当前版本


暂无话题~