diff --git a/src/sf_serialize.c b/src/sf_serialize.c index 95539dc..cabc917 100644 --- a/src/sf_serialize.c +++ b/src/sf_serialize.c @@ -27,6 +27,24 @@ #include "fastcommon/logger.h" #include "sf_serialize.h" +typedef struct { + int min_size; + int elt_size; +} SFSerializeTypeConfig; + +static SFSerializeTypeConfig value_type_configs[SF_SERIALIZE_VALUE_TYPE_COUNT] = +{ + {sizeof(SFSerializePackFieldInt8), 0}, + {sizeof(SFSerializePackFieldInt16), 0}, + {sizeof(SFSerializePackFieldInt32), 0}, + {sizeof(SFSerializePackFieldInt64), 0}, + {sizeof(SFSerializePackStringValue), 0}, + {sizeof(SFSerializePackFieldArray), 4}, + {sizeof(SFSerializePackFieldArray), 8}, + {sizeof(SFSerializePackFieldArray), 2 * + sizeof(SFSerializePackStringValue)} +}; + int sf_serialize_unpack(SFSerializeIterator *it, const string_t *content) { SFSerializePackHeader *header; @@ -67,39 +85,16 @@ int sf_serialize_unpack(SFSerializeIterator *it, const string_t *content) static int check_field_type(SFSerializeIterator *it, const int remain_len, const SFSerializeValueType type) { - int min_size; - - switch (type) { - case sf_serialize_value_type_int8: - min_size = sizeof(SFSerializePackFieldInt8); - break; - case sf_serialize_value_type_int16: - min_size = sizeof(SFSerializePackFieldInt16); - break; - case sf_serialize_value_type_int32: - min_size = sizeof(SFSerializePackFieldInt32); - break; - case sf_serialize_value_type_int64: - min_size = sizeof(SFSerializePackFieldInt64); - break; - case sf_serialize_value_type_string: - min_size = sizeof(SFSerializePackFieldString); - break; - case sf_serialize_value_type_int32_array: - case sf_serialize_value_type_int64_array: - case sf_serialize_value_type_map: - min_size = sizeof(SFSerializePackFieldArray); - break; - default: - snprintf(it->error_info, sizeof(it->error_info), - "unknown type: %d", type); - return EINVAL; + if (!(type >= 0 && type < SF_SERIALIZE_VALUE_TYPE_COUNT)) { + snprintf(it->error_info, sizeof(it->error_info), + "unknown type: %d", type); + return EINVAL; } - if (remain_len < min_size) { + if (remain_len < value_type_configs[type].min_size) { snprintf(it->error_info, sizeof(it->error_info), "remain length: %d is too small which < %d", - remain_len, min_size); + remain_len, value_type_configs[type].min_size); return EINVAL; } return 0; @@ -124,6 +119,149 @@ static inline int check_string_value(SFSerializeIterator *it, return 0; } +static inline int unpack_array_count(SFSerializeIterator *it, + const int remain_len, int *count) +{ + int min_size; + + *count = buff2int(((SFSerializePackFieldArray *)it->p)->value.count); + if (*count < 0) { + snprintf(it->error_info, sizeof(it->error_info), + "invalid array count: %d < 0", *count); + return EINVAL; + } + + min_size = value_type_configs[it->field.type].elt_size * (*count); + if (min_size > remain_len) { + snprintf(it->error_info, sizeof(it->error_info), + "array min bytes: %d is too large > remain: %d", + min_size, remain_len); + return EINVAL; + } + + return 0; +} + +static int array_expand(void_array_t *array, const int elt_size, + const int target_count, int *alloc_size) +{ + int new_alloc; + void *new_elts; + + if (*alloc_size == 0) { + new_alloc = 256; + } else { + new_alloc = (*alloc_size) * 2; + } + while (new_alloc < target_count) { + new_alloc *= 2; + } + + new_elts = fc_malloc(elt_size * new_alloc); + if (new_elts == NULL) { + return ENOMEM; + } + + if (array->elts != NULL) { + free(array->elts); + } + array->elts = new_elts; + *alloc_size = new_alloc; + return 0; +} + +static inline int unpack_string(SFSerializeIterator *it, const int remain_len, + SFSerializePackStringValue *input, string_t *output) +{ + if (remain_len < sizeof(SFSerializePackStringValue)) { + snprintf(it->error_info, sizeof(it->error_info), + "remain length: %d is too small < %d", + remain_len, (int)sizeof(SFSerializePackStringValue)); + return EINVAL; + } + + output->len = buff2int(input->len); + output->str = input->str; + it->p += sizeof(SFSerializePackStringValue) + output->len; + return check_string_value(it, remain_len - + sizeof(SFSerializePackStringValue), output); +} + +static int unpack_array(SFSerializeIterator *it, const int remain_len) +{ + int result; + int count; + int64_t *pn; + int64_t *end; + + if ((result=unpack_array_count(it, remain_len, &count)) != 0) { + return result; + } + + if (count > it->int_array_alloc) { + if ((result=array_expand((void_array_t *)&it->int_array, + sizeof(int64_t), count, &it->int_array_alloc)) != 0) + { + return result; + } + } + + it->p += sizeof(SFSerializePackFieldArray); + end = it->int_array.elts + count; + for (pn=it->int_array.elts; pnfield.type == sf_serialize_value_type_int32_array) { + *pn = buff2int(it->p); + } else { + *pn = buff2long(it->p); + } + it->p += value_type_configs[it->field.type].elt_size; + } + it->int_array.count = count; + + return 0; +} + +static int unpack_map(SFSerializeIterator *it, const int remain_len) +{ + int result; + int count; + key_value_pair_t *pair; + key_value_pair_t *end; + + if ((result=unpack_array_count(it, remain_len, &count)) != 0) { + return result; + } + + if (count > it->kv_array_alloc) { + if ((result=array_expand((void_array_t *)&it->kv_array, + sizeof(key_value_pair_t), count, + &it->kv_array_alloc)) != 0) + { + return result; + } + } + + it->p += sizeof(SFSerializePackFieldArray); + end = it->kv_array.kv_pairs + count; + for (pair=it->kv_array.kv_pairs; pairend - it->p, + (SFSerializePackStringValue *)it->p, + &pair->key)) != 0) + { + return result; + } + if ((result=unpack_string(it, it->end - it->p, + (SFSerializePackStringValue *)it->p, + &pair->value)) != 0) + { + return result; + } + } + it->kv_array.count = count; + + return 0; +} + const SFSerializeFieldValue *sf_serialize_next(SFSerializeIterator *it) { int remain_len; @@ -176,25 +314,29 @@ const SFSerializeFieldValue *sf_serialize_next(SFSerializeIterator *it) break; case sf_serialize_value_type_string: fs = (SFSerializePackFieldString *)it->p; - it->field.value.s.len = buff2int(fs->value.len); - it->field.value.s.str = fs->value.str; - if ((it->error_no=check_string_value(it, remain_len - - sizeof(SFSerializePackFieldString), - &it->field.value.s)) != 0) + it->p += sizeof(SFSerializePackFieldInfo); + if ((it->error_no=unpack_string(it, remain_len - + sizeof(SFSerializePackFieldInfo), + &fs->value, &it->field.value.s)) != 0) { return NULL; } - it->p += sizeof(SFSerializePackFieldString) + - it->field.value.s.len; break; case sf_serialize_value_type_int32_array: case sf_serialize_value_type_int64_array: + if ((it->error_no=unpack_array(it, remain_len - sizeof( + SFSerializePackFieldArray))) != 0) + { + return NULL; + } + break; case sf_serialize_value_type_map: - default: - snprintf(it->error_info, sizeof(it->error_info), - "unknown type: %d", field->type); - it->error_no = EINVAL; - return NULL; + if ((it->error_no=unpack_map(it, remain_len - sizeof( + SFSerializePackFieldArray))) != 0) + { + return NULL; + } + break; } return &it->field; diff --git a/src/sf_serialize.h b/src/sf_serialize.h index c37679c..5bf957b 100644 --- a/src/sf_serialize.h +++ b/src/sf_serialize.h @@ -22,8 +22,10 @@ #include "fastcommon/fast_buffer.h" #include "fastcommon/hash.h" +#define SF_SERIALIZE_VALUE_TYPE_COUNT 8 + typedef enum { - sf_serialize_value_type_int8 = 1, + sf_serialize_value_type_int8 = 0, sf_serialize_value_type_int16, sf_serialize_value_type_int32, sf_serialize_value_type_int64, @@ -334,8 +336,8 @@ static inline void sf_serialize_iterator_init(SFSerializeIterator *it) static inline void sf_serialize_iterator_destroy(SFSerializeIterator *it) { - if (it->int_array.values != NULL) { - free(it->int_array.values); + if (it->int_array.elts != NULL) { + free(it->int_array.elts); it->int_array_alloc = 0; }