|
| 1 | +use std::io::ByRefReader; |
| 2 | +use std::io::util::LimitReader; |
| 3 | + |
| 4 | +use time::Timespec; |
| 5 | +use postgres::Type; |
| 6 | +use postgres::types::{RawFromSql, RawToSql}; |
| 7 | + |
| 8 | +use {postgres, Range, RangeBound, BoundType, BoundSided, Normalizable}; |
| 9 | + |
| 10 | +macro_rules! check_types { |
| 11 | + ($($expected:pat)|+, $actual:ident) => ( |
| 12 | + match $actual { |
| 13 | + $(&$expected)|+ => {} |
| 14 | + actual => return Err(::postgres::Error::WrongType(actual.clone())) |
| 15 | + } |
| 16 | + ) |
| 17 | +} |
| 18 | + |
| 19 | +macro_rules! from_sql_impl { |
| 20 | + ($($oid:pat)|+, $t:ty) => { |
| 21 | + impl ::postgres::FromSql for Option<::Range<$t>> { |
| 22 | + fn from_sql(ty: &::postgres::Type, raw: Option<&[u8]>) -> ::postgres::Result<Self> { |
| 23 | + check_types!($($oid)|+, ty); |
| 24 | + |
| 25 | + match raw { |
| 26 | + Some(mut raw) => ::postgres::types::RawFromSql::raw_from_sql(&mut raw).map(Some), |
| 27 | + None => Ok(None), |
| 28 | + } |
| 29 | + } |
| 30 | + } |
| 31 | + |
| 32 | + impl ::postgres::FromSql for ::Range<$t> { |
| 33 | + fn from_sql(ty: &::postgres::Type, raw: Option<&[u8]>) -> ::postgres::Result<Self> { |
| 34 | + let v: ::postgres::Result<Option<Self>> = ::postgres::FromSql::from_sql(ty, raw); |
| 35 | + match v { |
| 36 | + Ok(None) => Err(::postgres::Error::WasNull), |
| 37 | + Ok(Some(v)) => Ok(v), |
| 38 | + Err(err) => Err(err), |
| 39 | + } |
| 40 | + } |
| 41 | + } |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +macro_rules! to_sql_impl { |
| 46 | + ($($oid:pat)|+, $t:ty) => { |
| 47 | + impl ::postgres::ToSql for ::Range<$t> { |
| 48 | + fn to_sql(&self, ty: &::postgres::Type) -> ::postgres::Result<Option<Vec<u8>>> { |
| 49 | + check_types!($($oid)|+, ty); |
| 50 | + |
| 51 | + let mut writer = vec![]; |
| 52 | + try!(self.raw_to_sql(&mut writer)); |
| 53 | + Ok(Some(writer)) |
| 54 | + } |
| 55 | + } |
| 56 | + |
| 57 | + impl ::postgres::ToSql for Option<::Range<$t>> { |
| 58 | + fn to_sql(&self, ty: &::postgres::Type) -> ::postgres::Result<Option<Vec<u8>>> { |
| 59 | + check_types!($($oid)|+, ty); |
| 60 | + match *self { |
| 61 | + Some(ref arr) => arr.to_sql(ty), |
| 62 | + None => Ok(None) |
| 63 | + } |
| 64 | + } |
| 65 | + } |
| 66 | + } |
| 67 | +} |
| 68 | + |
| 69 | +const RANGE_UPPER_UNBOUNDED: i8 = 0b0001_0000; |
| 70 | +const RANGE_LOWER_UNBOUNDED: i8 = 0b0000_1000; |
| 71 | +const RANGE_UPPER_INCLUSIVE: i8 = 0b0000_0100; |
| 72 | +const RANGE_LOWER_INCLUSIVE: i8 = 0b0000_0010; |
| 73 | +const RANGE_EMPTY: i8 = 0b0000_0001; |
| 74 | + |
| 75 | +impl<T> RawFromSql for Range<T> where T: PartialOrd+Normalizable+RawFromSql { |
| 76 | + fn raw_from_sql<R: Reader>(rdr: &mut R) -> postgres::Result<Range<T>> { |
| 77 | + let t = try!(rdr.read_i8()); |
| 78 | + |
| 79 | + if t & RANGE_EMPTY != 0 { |
| 80 | + return Ok(Range::empty()); |
| 81 | + } |
| 82 | + |
| 83 | + fn make_bound<S, T, R>(rdr: &mut R, tag: i8, bound_flag: i8, inclusive_flag: i8) |
| 84 | + -> postgres::Result<Option<RangeBound<S, T>>> |
| 85 | + where S: BoundSided, T: PartialOrd+Normalizable+RawFromSql, R: Reader { |
| 86 | + match tag & bound_flag { |
| 87 | + 0 => { |
| 88 | + let type_ = match tag & inclusive_flag { |
| 89 | + 0 => BoundType::Exclusive, |
| 90 | + _ => BoundType::Inclusive, |
| 91 | + }; |
| 92 | + let len = try!(rdr.read_be_i32()) as uint; |
| 93 | + let mut limit = LimitReader::new(rdr.by_ref(), len); |
| 94 | + let bound = try!(RawFromSql::raw_from_sql(&mut limit)); |
| 95 | + if limit.limit() != 0 { |
| 96 | + return Err(postgres::Error::BadData); |
| 97 | + } |
| 98 | + Ok(Some(RangeBound::new(bound, type_))) |
| 99 | + } |
| 100 | + _ => Ok(None) |
| 101 | + } |
| 102 | + } |
| 103 | + |
| 104 | + let lower = try!(make_bound(rdr, t, RANGE_LOWER_UNBOUNDED, RANGE_LOWER_INCLUSIVE)); |
| 105 | + let upper = try!(make_bound(rdr, t, RANGE_UPPER_UNBOUNDED, RANGE_UPPER_INCLUSIVE)); |
| 106 | + Ok(Range::new(lower, upper)) |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +from_sql_impl!(Type::Int4Range, i32); |
| 111 | +from_sql_impl!(Type::Int8Range, i64); |
| 112 | +from_sql_impl!(Type::TsRange | Type::TstzRange, Timespec); |
| 113 | + |
| 114 | +impl<T> RawToSql for Range<T> where T: PartialOrd+Normalizable+RawToSql { |
| 115 | + fn raw_to_sql<W: Writer>(&self, buf: &mut W) -> postgres::Result<()> { |
| 116 | + let mut tag = 0; |
| 117 | + if self.is_empty() { |
| 118 | + tag |= RANGE_EMPTY; |
| 119 | + } else { |
| 120 | + fn make_tag<S, T>(bound: Option<&RangeBound<S, T>>, unbounded_tag: i8, |
| 121 | + inclusive_tag: i8) -> i8 where S: BoundSided { |
| 122 | + match bound { |
| 123 | + None => unbounded_tag, |
| 124 | + Some(&RangeBound { type_: BoundType::Inclusive, .. }) => inclusive_tag, |
| 125 | + _ => 0 |
| 126 | + } |
| 127 | + } |
| 128 | + tag |= make_tag(self.lower(), RANGE_LOWER_UNBOUNDED, RANGE_LOWER_INCLUSIVE); |
| 129 | + tag |= make_tag(self.upper(), RANGE_UPPER_UNBOUNDED, RANGE_UPPER_INCLUSIVE); |
| 130 | + } |
| 131 | + |
| 132 | + try!(buf.write_i8(tag)); |
| 133 | + |
| 134 | + fn write_value<S, T, W>(buf: &mut W, v: Option<&RangeBound<S, T>>) -> postgres::Result<()> |
| 135 | + where S: BoundSided, T: RawToSql, W: Writer { |
| 136 | + if let Some(bound) = v { |
| 137 | + let mut inner_buf = vec![]; |
| 138 | + try!(bound.value.raw_to_sql(&mut inner_buf)); |
| 139 | + try!(buf.write_be_u32(inner_buf.len() as u32)); |
| 140 | + try!(buf.write(&*inner_buf)); |
| 141 | + } |
| 142 | + Ok(()) |
| 143 | + } |
| 144 | + |
| 145 | + try!(write_value(buf, self.lower())); |
| 146 | + try!(write_value(buf, self.upper())); |
| 147 | + |
| 148 | + Ok(()) |
| 149 | + } |
| 150 | +} |
| 151 | + |
| 152 | +to_sql_impl!(Type::Int4Range, i32); |
| 153 | +to_sql_impl!(Type::Int8Range, i64); |
| 154 | +to_sql_impl!(Type::TsRange | Type::TstzRange, Timespec); |
| 155 | + |
| 156 | +#[cfg(test)] |
| 157 | +mod test { |
| 158 | + use std::fmt; |
| 159 | + |
| 160 | + use postgres::{Connection, FromSql, ToSql, SslMode}; |
| 161 | + use time::{mod, Timespec}; |
| 162 | + |
| 163 | + macro_rules! test_range { |
| 164 | + ($name:expr, $t:ty, $low:expr, $low_str:expr, $high:expr, $high_str:expr) => ({ |
| 165 | + let tests = &[(Some(range!('(', ')')), "'(,)'".to_string()), |
| 166 | + (Some(range!('[' $low, ')')), format!("'[{},)'", $low_str)), |
| 167 | + (Some(range!('(' $low, ')')), format!("'({},)'", $low_str)), |
| 168 | + (Some(range!('(', $high ']')), format!("'(,{}]'", $high_str)), |
| 169 | + (Some(range!('(', $high ')')), format!("'(,{})'", $high_str)), |
| 170 | + (Some(range!('[' $low, $high ']')), |
| 171 | + format!("'[{},{}]'", $low_str, $high_str)), |
| 172 | + (Some(range!('[' $low, $high ')')), |
| 173 | + format!("'[{},{})'", $low_str, $high_str)), |
| 174 | + (Some(range!('(' $low, $high ']')), |
| 175 | + format!("'({},{}]'", $low_str, $high_str)), |
| 176 | + (Some(range!('(' $low, $high ')')), |
| 177 | + format!("'({},{})'", $low_str, $high_str)), |
| 178 | + (Some(range!(empty)), "'empty'".to_string()), |
| 179 | + (None, "NULL".to_string())]; |
| 180 | + test_type($name, tests); |
| 181 | + }) |
| 182 | + } |
| 183 | + |
| 184 | + fn test_type<T: PartialEq+FromSql+ToSql, S: fmt::Show>(sql_type: &str, checks: &[(T, S)]) { |
| 185 | + let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap(); |
| 186 | + for &(ref val, ref repr) in checks.iter() { |
| 187 | + let stmt = conn.prepare(&*format!("SELECT {}::{}", *repr, sql_type)).unwrap(); |
| 188 | + let result = stmt.query(&[]).unwrap().next().unwrap().get(0u); |
| 189 | + assert!(val == &result); |
| 190 | + |
| 191 | + let stmt = conn.prepare(&*format!("SELECT $1::{}", sql_type)).unwrap(); |
| 192 | + let result = stmt.query(&[val]).unwrap().next().unwrap().get(0u); |
| 193 | + assert!(val == &result); |
| 194 | + } |
| 195 | + } |
| 196 | + |
| 197 | + #[test] |
| 198 | + fn test_int4range_params() { |
| 199 | + test_range!("INT4RANGE", i32, 100i32, "100", 200i32, "200") |
| 200 | + } |
| 201 | + |
| 202 | + #[test] |
| 203 | + fn test_int8range_params() { |
| 204 | + test_range!("INT8RANGE", i64, 100i64, "100", 200i64, "200") |
| 205 | + } |
| 206 | + |
| 207 | + fn test_timespec_range_params(sql_type: &str) { |
| 208 | + fn t(time: &str) -> Timespec { |
| 209 | + time::strptime(time, "%Y-%m-%d").unwrap().to_timespec() |
| 210 | + } |
| 211 | + let low = "1970-01-01"; |
| 212 | + let high = "1980-01-01"; |
| 213 | + test_range!(sql_type, Timespec, t(low), low, t(high), high); |
| 214 | + } |
| 215 | + |
| 216 | + #[test] |
| 217 | + fn test_tsrange_params() { |
| 218 | + test_timespec_range_params("TSRANGE"); |
| 219 | + } |
| 220 | + |
| 221 | + #[test] |
| 222 | + fn test_tstzrange_params() { |
| 223 | + test_timespec_range_params("TSTZRANGE"); |
| 224 | + } |
| 225 | +} |
0 commit comments