12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- /// Parser for the MNIST handwriting recognition data set.
- ///
- /// http://yann.lecun.com/exdb/mnist/
- extern crate byteorder;
- use std::io::Result;
- use std::io::Read;
- use byteorder::ReadBytesExt;
- #[derive(Debug, PartialEq)]
- enum DataType {
- UnsignedByte,
- SignedByte,
- Short,
- Int,
- Float,
- Double
- }
- #[derive(Debug)]
- pub struct Idx {
- data_type: DataType,
- pub dimensions: Vec<u32>,
- }
- fn magic<T: Read>(src: &mut T) -> Result<[u8; 4]> {
- let mut limit = src.take(4);
- let mut header = [0; 4];
- let bytes_read = try!(limit.read(&mut header));
- assert!(bytes_read == 4);
- assert!(header[0] == 0);
- assert!(header[1] == 0);
- Ok(header)
- }
- pub fn header<T: Read>(src: &mut T) -> Result<Idx> {
- let header = try!(magic(src));
- let data_type = match header[2] {
- 0x08 => DataType::UnsignedByte,
- 0x09 => DataType::SignedByte,
- 0x0b => DataType::Short,
- 0x0c => DataType::Int,
- 0x0d => DataType::Float,
- 0x0f => DataType::Double,
- v => panic!(format!("Unknown data type {} in header!", v)),
- };
- let mut dim = Vec::new();
- for _ in 0..header[3] {
- let size = try!(src.read_u32::<byteorder::BigEndian>());
- dim.push(size);
- }
- Ok(Idx {
- data_type: data_type,
- dimensions: dim,
- })
- }
- #[test]
- fn valid_magic() {
- let mut test = std::io::Cursor::new(vec![0,0,8,3]);
- let header = magic(&mut test).unwrap();
- assert!(header.len() == 4);
- assert!(header[0] == 0);
- assert!(header[1] == 0);
- assert!(header[2] == 8);
- assert!(header[3] == 3);
- }
- #[test]
- fn valid_header() {
- let mut test = std::io::Cursor::new(vec![0,0,0x0c,2,0,0,0,11,0,1,0,1]);
- let header = header(&mut test).unwrap();
- assert!(header.data_type == DataType::Int);
- assert!(header.dimensions.len() == 2);
- assert!(header.dimensions[0] == 11);
- assert!(header.dimensions[1] == 65537);
- }
|